Spaces:
Running on Zero
Running on Zero
feat: Add UniCalli Chinese calligraphy generator
Browse files- Add inference.py with CalligraphyGenerator class
- Add Gradio app with support for 1-7 Chinese characters
- Add 90+ historical calligraphers (王羲之, 颜真卿, 赵佶, etc.)
- Support Regular (楷), Running (行), Cursive (草) scripts
- Add dataset files for author styles and fonts
- Add flux model source code
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +15 -0
- README.md +29 -3
- app.py +287 -133
- dataset/author_fonts_summary.csv +95 -0
- dataset/calligraphy_styles_en.json +95 -0
- dataset/chirography.json +7 -0
- inference.py +860 -0
- requirements.txt +9 -4
- src/__init__.py +0 -0
- src/flux/__init__.py +11 -0
- src/flux/__main__.py +4 -0
- src/flux/annotator/canny/__init__.py +6 -0
- src/flux/annotator/ckpts/ckpts.txt +1 -0
- src/flux/annotator/dwpose/__init__.py +68 -0
- src/flux/annotator/dwpose/onnxdet.py +125 -0
- src/flux/annotator/dwpose/onnxpose.py +360 -0
- src/flux/annotator/dwpose/util.py +297 -0
- src/flux/annotator/dwpose/wholebody.py +48 -0
- src/flux/annotator/hed/__init__.py +95 -0
- src/flux/annotator/midas/LICENSE +21 -0
- src/flux/annotator/midas/__init__.py +42 -0
- src/flux/annotator/midas/api.py +168 -0
- src/flux/annotator/midas/midas/__init__.py +0 -0
- src/flux/annotator/midas/midas/base_model.py +16 -0
- src/flux/annotator/midas/midas/blocks.py +342 -0
- src/flux/annotator/midas/midas/dpt_depth.py +109 -0
- src/flux/annotator/midas/midas/midas_net.py +76 -0
- src/flux/annotator/midas/midas/midas_net_custom.py +128 -0
- src/flux/annotator/midas/midas/transforms.py +234 -0
- src/flux/annotator/midas/midas/vit.py +491 -0
- src/flux/annotator/midas/utils.py +189 -0
- src/flux/annotator/mlsd/LICENSE +201 -0
- src/flux/annotator/mlsd/__init__.py +40 -0
- src/flux/annotator/mlsd/models/mbv2_mlsd_large.py +292 -0
- src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py +275 -0
- src/flux/annotator/mlsd/utils.py +580 -0
- src/flux/annotator/tile/__init__.py +26 -0
- src/flux/annotator/tile/guided_filter.py +280 -0
- src/flux/annotator/util.py +38 -0
- src/flux/annotator/zoe/LICENSE +21 -0
- src/flux/annotator/zoe/__init__.py +48 -0
- src/flux/annotator/zoe/zoedepth/data/__init__.py +24 -0
- src/flux/annotator/zoe/zoedepth/data/data_mono.py +573 -0
- src/flux/annotator/zoe/zoedepth/data/ddad.py +117 -0
- src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py +125 -0
- src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py +114 -0
- src/flux/annotator/zoe/zoedepth/data/diode.py +125 -0
- src/flux/annotator/zoe/zoedepth/data/hypersim.py +138 -0
- src/flux/annotator/zoe/zoedepth/data/ibims.py +81 -0
- src/flux/annotator/zoe/zoedepth/data/preprocess.py +154 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.egg-info/
|
| 5 |
+
.eggs/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
*.egg
|
| 9 |
+
.ipynb_checkpoints/
|
| 10 |
+
.DS_Store
|
| 11 |
+
*.bin
|
| 12 |
+
*.pt
|
| 13 |
+
*.pth
|
| 14 |
+
*.ckpt
|
| 15 |
+
*.safetensors
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: UniCalli Dev
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
|
@@ -8,7 +8,33 @@ sdk_version: 5.44.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: cc-by-nc-nd-4.0
|
| 11 |
-
short_description:
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: UniCalli Dev
|
| 3 |
+
emoji: 🖌️
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: cc-by-nc-nd-4.0
|
| 11 |
+
short_description: Chinese Calligraphy Generator with Historical Masters' Styles
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# 🖌️ UniCalli - Chinese Calligraphy Generator
|
| 15 |
+
|
| 16 |
+
Generate beautiful Chinese calligraphy in various styles and by different historical masters.
|
| 17 |
+
|
| 18 |
+
用不同历史书法大师的风格生成精美的中国书法。
|
| 19 |
+
|
| 20 |
+
## Features
|
| 21 |
+
|
| 22 |
+
- **1-7 Chinese Characters**: Supports generating 1 to 7 Chinese characters
|
| 23 |
+
- **Historical Masters**: 90+ calligraphers including 王羲之, 颜真卿, 赵佶/宋徽宗, etc.
|
| 24 |
+
- **Multiple Font Styles**: 楷 (Regular), 行 (Running), 草 (Cursive)
|
| 25 |
+
- **4-bit Quantization**: Optimized for efficient inference
|
| 26 |
+
|
| 27 |
+
## Model
|
| 28 |
+
|
| 29 |
+
This demo uses the UniCalli-pro model from [TSXu/UniCalli-pro](https://huggingface.co/TSXu/UniCalli-pro).
|
| 30 |
+
|
| 31 |
+
## Usage
|
| 32 |
+
|
| 33 |
+
1. Enter 1-7 Chinese characters
|
| 34 |
+
2. Select a calligrapher (or use synthetic style)
|
| 35 |
+
3. Choose a font style
|
| 36 |
+
4. Click "Generate Calligraphy"
|
| 37 |
+
|
| 38 |
+
## Citation
|
| 39 |
+
|
| 40 |
+
If you use this model, please cite the UniCalli paper.
|
app.py
CHANGED
|
@@ -1,154 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
negative_prompt=negative_prompt,
|
| 44 |
-
guidance_scale=guidance_scale,
|
| 45 |
-
num_inference_steps=num_inference_steps,
|
| 46 |
-
width=width,
|
| 47 |
-
height=height,
|
| 48 |
-
generator=generator,
|
| 49 |
-
).images[0]
|
| 50 |
-
|
| 51 |
-
return image, seed
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
examples = [
|
| 55 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 56 |
-
"An astronaut riding a green horse",
|
| 57 |
-
"A delicious ceviche cheesecake slice",
|
| 58 |
-
]
|
| 59 |
-
|
| 60 |
-
css = """
|
| 61 |
-
#col-container {
|
| 62 |
-
margin: 0 auto;
|
| 63 |
-
max-width: 640px;
|
| 64 |
}
|
| 65 |
-
"""
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
label="Prompt",
|
| 74 |
-
show_label=False,
|
| 75 |
-
max_lines=1,
|
| 76 |
-
placeholder="Enter your prompt",
|
| 77 |
-
container=False,
|
| 78 |
-
)
|
| 79 |
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
-
result = gr.Image(label="Result", show_label=False)
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
seed = gr.Slider(
|
| 93 |
-
label="Seed",
|
| 94 |
-
minimum=0,
|
| 95 |
-
maximum=MAX_SEED,
|
| 96 |
-
step=1,
|
| 97 |
-
value=0,
|
| 98 |
-
)
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
with gr.Row():
|
| 103 |
-
width = gr.Slider(
|
| 104 |
-
label="Width",
|
| 105 |
-
minimum=256,
|
| 106 |
-
maximum=MAX_IMAGE_SIZE,
|
| 107 |
-
step=32,
|
| 108 |
-
value=1024, # Replace with defaults that work for your model
|
| 109 |
-
)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
with gr.Row():
|
| 120 |
-
|
| 121 |
-
label="
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
step=0.1,
|
| 125 |
-
value=0.0, # Replace with defaults that work for your model
|
| 126 |
)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
minimum=1,
|
| 131 |
-
maximum=50,
|
| 132 |
-
step=1,
|
| 133 |
-
value=2, # Replace with defaults that work for your model
|
| 134 |
)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
inputs=[
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
seed,
|
| 144 |
-
|
| 145 |
-
width,
|
| 146 |
-
height,
|
| 147 |
-
guidance_scale,
|
| 148 |
-
num_inference_steps,
|
| 149 |
],
|
| 150 |
-
outputs=[result, seed],
|
| 151 |
)
|
| 152 |
|
|
|
|
| 153 |
if __name__ == "__main__":
|
| 154 |
demo.launch()
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import gradio as gr
|
| 7 |
+
from inference import CalligraphyGenerator
|
| 8 |
+
import json
|
| 9 |
+
import csv
|
| 10 |
+
import spaces
|
| 11 |
+
|
| 12 |
+
# Load author and font mappings from CSV
|
| 13 |
+
def load_author_fonts_from_csv(csv_path):
|
| 14 |
+
"""
|
| 15 |
+
Load author and their available fonts from CSV file
|
| 16 |
+
Filters out authors that only support 隶 or 篆 fonts
|
| 17 |
+
Returns: dict mapping author to list of font styles
|
| 18 |
+
"""
|
| 19 |
+
author_fonts = {}
|
| 20 |
+
excluded_fonts = {'隶', '篆'} # Fonts we don't support
|
| 21 |
+
|
| 22 |
+
with open(csv_path, 'r', encoding='utf-8') as f:
|
| 23 |
+
reader = csv.DictReader(f)
|
| 24 |
+
for row in reader:
|
| 25 |
+
author = row['书法家']
|
| 26 |
+
fonts = row['字体类型'].split('|') # Split multiple fonts by |
|
| 27 |
+
|
| 28 |
+
# Filter out unsupported fonts (隶 and 篆)
|
| 29 |
+
supported_fonts = [f for f in fonts if f not in excluded_fonts]
|
| 30 |
+
|
| 31 |
+
# Only include author if they have at least one supported font
|
| 32 |
+
if supported_fonts:
|
| 33 |
+
author_fonts[author] = supported_fonts
|
| 34 |
+
|
| 35 |
+
return author_fonts
|
| 36 |
+
|
| 37 |
+
# Load author-font mappings
|
| 38 |
+
AUTHOR_FONTS = load_author_fonts_from_csv('dataset/author_fonts_summary.csv')
|
| 39 |
+
|
| 40 |
+
# Available authors (sorted)
|
| 41 |
+
AUTHOR_LIST = sorted(AUTHOR_FONTS.keys())
|
| 42 |
+
|
| 43 |
+
# Font style display names (only supported styles)
|
| 44 |
+
FONT_STYLE_NAMES = {
|
| 45 |
+
"楷": "楷 (Regular Script)",
|
| 46 |
+
"行": "行 (Running Script)",
|
| 47 |
+
"草": "草 (Cursive Script)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
}
|
|
|
|
| 49 |
|
| 50 |
+
# Load author descriptions if available
|
| 51 |
+
try:
|
| 52 |
+
with open('dataset/calligraphy_styles_en.json', 'r', encoding='utf-8') as f:
|
| 53 |
+
author_styles = json.load(f)
|
| 54 |
+
except:
|
| 55 |
+
author_styles = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
# Initialize generator (will be done lazily on first generation)
|
| 58 |
+
generator = None
|
| 59 |
|
|
|
|
| 60 |
|
| 61 |
+
def init_generator():
|
| 62 |
+
"""Initialize the generator (lazy loading)"""
|
| 63 |
+
global generator
|
| 64 |
+
|
| 65 |
+
if generator is None:
|
| 66 |
+
generator = CalligraphyGenerator(
|
| 67 |
+
model_name="flux-dev",
|
| 68 |
+
device="cuda",
|
| 69 |
+
offload=False,
|
| 70 |
+
intern_vlm_path="OpenGVLab/InternVL3-1B",
|
| 71 |
+
checkpoint_path="TSXu/UniCalli-pro",
|
| 72 |
+
font_descriptions_path='dataset/chirography.json',
|
| 73 |
+
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 74 |
+
use_deepspeed=False,
|
| 75 |
+
use_4bit_quantization=True,
|
| 76 |
+
)
|
| 77 |
+
return generator
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
def update_font_choices(author: str):
|
| 81 |
+
"""
|
| 82 |
+
Update available font choices based on selected author
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
author: Selected author name
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Updated dropdown with available fonts for the author
|
| 89 |
+
"""
|
| 90 |
+
if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
|
| 91 |
+
# If no author or synthetic, show all font types
|
| 92 |
+
choices = list(FONT_STYLE_NAMES.values())
|
| 93 |
+
else:
|
| 94 |
+
# Show only fonts available for this author
|
| 95 |
+
available_fonts = AUTHOR_FONTS[author]
|
| 96 |
+
choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
|
| 97 |
+
|
| 98 |
+
# Return updated dropdown with first choice as default
|
| 99 |
+
return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
@spaces.GPU
|
| 103 |
+
def generate_calligraphy(
|
| 104 |
+
text: str,
|
| 105 |
+
author_dropdown: str,
|
| 106 |
+
font_style: str,
|
| 107 |
+
num_steps: int,
|
| 108 |
+
seed: int,
|
| 109 |
+
random_seed: bool,
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
Generate calligraphy based on user inputs
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
text: Input text (1-7 characters)
|
| 116 |
+
author_dropdown: Selected author from dropdown
|
| 117 |
+
font_style: Selected font style (display name)
|
| 118 |
+
num_steps: Number of denoising steps
|
| 119 |
+
seed: Random seed
|
| 120 |
+
random_seed: Whether to use random seed
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Generated image and condition image
|
| 124 |
+
"""
|
| 125 |
+
# Validate text - must be 1-7 characters
|
| 126 |
+
if len(text) < 1:
|
| 127 |
+
raise gr.Error("文本不能为空 / Text cannot be empty")
|
| 128 |
+
if len(text) > 7:
|
| 129 |
+
raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
|
| 130 |
+
|
| 131 |
+
# Extract font style value from display name
|
| 132 |
+
font = None
|
| 133 |
+
for font_key, font_display in FONT_STYLE_NAMES.items():
|
| 134 |
+
if font_display == font_style:
|
| 135 |
+
font = font_key
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
if font is None:
|
| 139 |
+
raise gr.Error(f"无法识别的字体风格 / Unknown font style: {font_style}")
|
| 140 |
+
|
| 141 |
+
# Determine author
|
| 142 |
+
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 143 |
+
|
| 144 |
+
# Handle seed
|
| 145 |
+
if random_seed:
|
| 146 |
+
import torch
|
| 147 |
+
seed = torch.randint(0, 2**32, (1,)).item()
|
| 148 |
+
|
| 149 |
+
# Initialize generator if needed
|
| 150 |
+
gen = init_generator()
|
| 151 |
+
|
| 152 |
+
# Generate
|
| 153 |
+
result_img, cond_img = gen.generate(
|
| 154 |
+
text=text,
|
| 155 |
+
font_style=font,
|
| 156 |
+
author=author,
|
| 157 |
+
num_steps=num_steps,
|
| 158 |
+
seed=seed,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return result_img, f"Seed: {seed}"
|
| 162 |
|
| 163 |
+
|
| 164 |
+
# Create Gradio interface
|
| 165 |
+
with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生成器", theme=gr.themes.Soft()) as demo:
|
| 166 |
+
gr.Markdown("""
|
| 167 |
+
# 🖌️ UniCalli - 中国书法生成器 / Chinese Calligraphy Generator
|
| 168 |
+
|
| 169 |
+
Generate beautiful Chinese calligraphy in various styles and by different historical masters.
|
| 170 |
+
|
| 171 |
+
用不同历史书法大师的风格生成精美的中国书法。
|
| 172 |
+
|
| 173 |
+
**注意 / Note**: 支持1-7个汉字输入 / Supports 1-7 Chinese characters.
|
| 174 |
+
""")
|
| 175 |
+
|
| 176 |
+
with gr.Row():
|
| 177 |
+
with gr.Column(scale=1):
|
| 178 |
+
# Input section
|
| 179 |
+
gr.Markdown("### 📝 输入设置 / Input Settings")
|
| 180 |
+
|
| 181 |
+
text_input = gr.Textbox(
|
| 182 |
+
label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
|
| 183 |
+
placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 春风得意马蹄疾",
|
| 184 |
+
value="春风得意马蹄疾",
|
| 185 |
+
max_lines=1
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
gr.Markdown("### 👤 书法家选择 / Calligrapher Selection")
|
| 189 |
+
|
| 190 |
+
author_dropdown = gr.Dropdown(
|
| 191 |
+
label="1. 选择书法家 / Select Calligrapher",
|
| 192 |
+
choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
|
| 193 |
+
value="赵佶\\宋徽宗",
|
| 194 |
+
info="先选择历史书法家 / Choose a historical calligrapher first"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Get initial fonts for default author (赵佶\宋徽宗)
|
| 198 |
+
initial_author = "赵佶\\宋徽宗"
|
| 199 |
+
initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
|
| 200 |
+
initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
|
| 201 |
+
|
| 202 |
+
font_style = gr.Dropdown(
|
| 203 |
+
label="2. 选择字体风格 / Select Font Style",
|
| 204 |
+
choices=initial_font_choices,
|
| 205 |
+
value="楷 (Regular Script)",
|
| 206 |
+
info="根据所选书法家显示可用字体 / Shows available fonts for selected calligrapher"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
gr.Markdown("### ⚙️ 生成设置 / Generation Settings")
|
| 210 |
+
|
| 211 |
+
num_steps = gr.Slider(
|
| 212 |
+
label="生成步数 / Inference Steps",
|
| 213 |
+
minimum=10,
|
| 214 |
+
maximum=50,
|
| 215 |
+
value=39,
|
| 216 |
+
step=1,
|
| 217 |
+
info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
with gr.Row():
|
| 221 |
+
seed = gr.Number(
|
| 222 |
+
label="随机种子 / Seed",
|
| 223 |
+
value=42,
|
| 224 |
+
precision=0
|
|
|
|
|
|
|
| 225 |
)
|
| 226 |
+
random_seed = gr.Checkbox(
|
| 227 |
+
label="随机种子 / Random Seed",
|
| 228 |
+
value=False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
)
|
| 230 |
+
|
| 231 |
+
generate_btn = gr.Button("🎨 生成书法 / Generate Calligraphy", variant="primary", size="lg")
|
| 232 |
+
|
| 233 |
+
with gr.Column(scale=1):
|
| 234 |
+
# Output section
|
| 235 |
+
gr.Markdown("### 🖼️ 生成结果 / Generated Result")
|
| 236 |
+
gr.Markdown("") # Add spacing
|
| 237 |
+
|
| 238 |
+
with gr.Row():
|
| 239 |
+
gr.Column(scale=1) # Left spacer
|
| 240 |
+
with gr.Column(scale=2):
|
| 241 |
+
output_image = gr.Image(
|
| 242 |
+
show_label=False,
|
| 243 |
+
type="pil",
|
| 244 |
+
height=600
|
| 245 |
+
)
|
| 246 |
+
gr.Column(scale=1) # Right spacer
|
| 247 |
+
|
| 248 |
+
seed_info = gr.Textbox(
|
| 249 |
+
label="种子信息 / Seed Info",
|
| 250 |
+
interactive=False
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Author info section
|
| 254 |
+
with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False):
|
| 255 |
+
author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n"
|
| 256 |
+
for author in AUTHOR_LIST[:30]:
|
| 257 |
+
fonts = " | ".join(AUTHOR_FONTS[author])
|
| 258 |
+
desc = author_styles.get(author, "")
|
| 259 |
+
desc_short = desc[:50] + "..." if len(desc) > 50 else desc
|
| 260 |
+
author_info_md += f"| **{author}** | {fonts} |\n"
|
| 261 |
+
if len(AUTHOR_LIST) > 30:
|
| 262 |
+
author_info_md += f"\n*... 还有 {len(AUTHOR_LIST) - 30} 位书法家 / {len(AUTHOR_LIST) - 30} more calligraphers*"
|
| 263 |
+
gr.Markdown(author_info_md)
|
| 264 |
+
|
| 265 |
+
# Event handlers
|
| 266 |
+
# Update font choices when author changes
|
| 267 |
+
author_dropdown.change(
|
| 268 |
+
fn=update_font_choices,
|
| 269 |
+
inputs=[author_dropdown],
|
| 270 |
+
outputs=[font_style]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Generate button click
|
| 274 |
+
generate_btn.click(
|
| 275 |
+
fn=generate_calligraphy,
|
| 276 |
+
inputs=[
|
| 277 |
+
text_input,
|
| 278 |
+
author_dropdown,
|
| 279 |
+
font_style,
|
| 280 |
+
num_steps,
|
| 281 |
+
seed,
|
| 282 |
+
random_seed,
|
| 283 |
+
],
|
| 284 |
+
outputs=[output_image, seed_info]
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Examples
|
| 288 |
+
gr.Markdown("### 📋 示例 / Examples")
|
| 289 |
+
gr.Examples(
|
| 290 |
+
examples=[
|
| 291 |
+
["春风得意马蹄疾", "赵佶\\宋徽宗", "楷 (Regular Script)", 39, 42, False],
|
| 292 |
+
["海内存知己", "黄庭坚", "行 (Running Script)", 39, 42, False],
|
| 293 |
+
["天道酬勤", "王羲之", "草 (Cursive Script)", 39, 42, False],
|
| 294 |
+
["宁静致远", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 39, 42, False],
|
| 295 |
+
],
|
| 296 |
inputs=[
|
| 297 |
+
text_input,
|
| 298 |
+
author_dropdown,
|
| 299 |
+
font_style,
|
| 300 |
+
num_steps,
|
| 301 |
seed,
|
| 302 |
+
random_seed,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
],
|
|
|
|
| 304 |
)
|
| 305 |
|
| 306 |
+
|
| 307 |
if __name__ == "__main__":
|
| 308 |
demo.launch()
|
dataset/author_fonts_summary.csv
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
书法家,字体类型,字体数量
|
| 2 |
+
author,chirography,1
|
| 3 |
+
乾隆,楷,1
|
| 4 |
+
仇靖,隶,1
|
| 5 |
+
何绍基,楷|行|隶,3
|
| 6 |
+
佚名,楷|篆|隶,3
|
| 7 |
+
傅山,行,1
|
| 8 |
+
冯子振,行,1
|
| 9 |
+
刘日升,行,1
|
| 10 |
+
吴昌硕,篆,1
|
| 11 |
+
吴琚,行,1
|
| 12 |
+
吴让之,篆,1
|
| 13 |
+
吴通微,楷,1
|
| 14 |
+
唐寅,行,1
|
| 15 |
+
墓志,楷,1
|
| 16 |
+
姚孟起,行,1
|
| 17 |
+
姜夔,楷,1
|
| 18 |
+
姜立纲,楷,1
|
| 19 |
+
孙过庭,草,1
|
| 20 |
+
宋克,草,1
|
| 21 |
+
宋珏,行,1
|
| 22 |
+
小野道风,行,1
|
| 23 |
+
康熙,楷,1
|
| 24 |
+
康里巎巎,草,1
|
| 25 |
+
张即之,楷|行,2
|
| 26 |
+
张弼,草,1
|
| 27 |
+
张旭,楷|草,2
|
| 28 |
+
张瑞图,草|行,2
|
| 29 |
+
张芝,草,1
|
| 30 |
+
徐渭,行,1
|
| 31 |
+
怀素,草,1
|
| 32 |
+
摩崖刻石,楷|隶,2
|
| 33 |
+
文天祥,草,1
|
| 34 |
+
文彭,行,1
|
| 35 |
+
文征明,楷|草|行,3
|
| 36 |
+
智永,楷|草,2
|
| 37 |
+
曾巩,楷,1
|
| 38 |
+
朱登,隶,1
|
| 39 |
+
朱耷\八大山人,楷|行,2
|
| 40 |
+
李倜,行,1
|
| 41 |
+
李孝光,行,1
|
| 42 |
+
李斯,篆,1
|
| 43 |
+
李邕,楷|行,2
|
| 44 |
+
李阳冰,篆,1
|
| 45 |
+
杨凝式,草|行,2
|
| 46 |
+
杨沂孙,篆,1
|
| 47 |
+
杨秀,楷,1
|
| 48 |
+
杨维桢,行|隶,2
|
| 49 |
+
柯九思,楷,1
|
| 50 |
+
柳公权,楷,1
|
| 51 |
+
欧阳询,楷|行,2
|
| 52 |
+
欧阳通,楷,1
|
| 53 |
+
沈尹默,行,1
|
| 54 |
+
王宠,草,1
|
| 55 |
+
王献之,楷|草|行,3
|
| 56 |
+
王珣,行,1
|
| 57 |
+
王福庵,篆,1
|
| 58 |
+
王羲之,草|行,2
|
| 59 |
+
王诜,行,1
|
| 60 |
+
王铎,草|行,2
|
| 61 |
+
皇象,篆,1
|
| 62 |
+
祝允明,草|行,2
|
| 63 |
+
空海,楷,1
|
| 64 |
+
简牍帛书盟书,隶,1
|
| 65 |
+
米芾,草|行,2
|
| 66 |
+
索靖,草,1
|
| 67 |
+
苏轼,行,1
|
| 68 |
+
董其昌,楷|草|行,3
|
| 69 |
+
蔡京,行,1
|
| 70 |
+
蔡襄,草|行,2
|
| 71 |
+
薛绍彭,草,1
|
| 72 |
+
虞世南,楷|行,2
|
| 73 |
+
褚遂良,楷,1
|
| 74 |
+
赵之谦,楷|篆,2
|
| 75 |
+
赵佶\宋徽宗,楷|草|行,3
|
| 76 |
+
赵孟頫,楷|草|行,3
|
| 77 |
+
赵构,行,1
|
| 78 |
+
造像记,楷,1
|
| 79 |
+
邓文原,草,1
|
| 80 |
+
邓石如,楷|篆|草|隶,4
|
| 81 |
+
金农,隶,1
|
| 82 |
+
金文刻石,篆,1
|
| 83 |
+
钟繇,楷,1
|
| 84 |
+
钟绍京,楷,1
|
| 85 |
+
陆机,草,1
|
| 86 |
+
陆柬之,行,1
|
| 87 |
+
陆游,行,1
|
| 88 |
+
陶弘景,楷,1
|
| 89 |
+
颜真卿,楷|行,2
|
| 90 |
+
鲜于枢,楷|草|行,3
|
| 91 |
+
黄养正,楷,1
|
| 92 |
+
黄庭坚,草|行,2
|
| 93 |
+
黄自元,楷,1
|
| 94 |
+
黄道周,行,1
|
| 95 |
+
龚贤,行,1
|
dataset/calligraphy_styles_en.json
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"乾隆": "Qianlong (1711–1799), the sixth emperor of the Qing dynasty, was a patron of arts and calligraphy. He practiced various scripts, especially regular script (kaishu) and running script (xingshu), imitating earlier masters. His calligraphy was elegant but often criticized as lacking originality, serving more as an imperial cultural statement than a pursuit of innovation.",
|
| 3 |
+
"仇靖": "Qiu Jing, a Ming dynasty calligrapher, was known for his fine regular and semi-cursive script. He emphasized structure and balance, drawing on the styles of Tang and Song masters, and contributed to the literati tradition of refined, scholarly brushwork.",
|
| 4 |
+
"何绍基": "He Shaoji (1799–1873), a Qing dynasty calligrapher and scholar, excelled in seal script (zhuanshu) and clerical script (lishu), blending ancient stele styles with dynamic brushwork. His work combined strength and elegance, influencing later seal and clerical calligraphers.",
|
| 5 |
+
"佚名": "Anonymous (Yi Ming) refers to works by unknown calligraphers, often found in historical manuscripts, inscriptions, or fragments. Such works are valued for their artistic or historical significance despite lacking an identified author.",
|
| 6 |
+
"傅山": "Fu Shan (1607–1684), a late Ming and early Qing calligrapher, painter, and medical scholar, was renowned for his unconventional and vigorous style. He mastered multiple scripts, integrating ancient stele influences and free expression, reflecting his loyalist spirit against the Qing.",
|
| 7 |
+
"冯子振": "Feng Zizhen, a Yuan dynasty calligrapher and poet, was adept in running and cursive scripts. His calligraphy displayed scholarly refinement and rhythmic flow, embodying literati aesthetics of elegance and poetic sentiment.",
|
| 8 |
+
"刘日升": "Liu Risheng, a Qing dynasty calligrapher, was known for his regular and clerical script works. His brush control and structural precision reflected a dedication to classical models while maintaining personal expressiveness.",
|
| 9 |
+
"吴昌硕": "Wu Changshuo (1844–1927), a leading figure of the late Qing and early Republic era, was a master in seal script and an accomplished painter and seal carver. His calligraphy combined monumental strength with artistic charm, often integrating inscriptions into his paintings.",
|
| 10 |
+
"吴琚": "Wu Ju (1102–1180), a Southern Song calligrapher, excelled in running and cursive scripts. Influenced by Wang Xizhi and other Jin-Tang masters, his works were valued for their elegance, fluidity, and literati grace.",
|
| 11 |
+
"吴让之": "Wu Rangzhi (1799–1870), a Qing dynasty calligrapher and seal carver, was prominent in the revival of seal script. His style was based on ancient bronze and stone inscriptions, merging archaeological precision with artistic vitality.",
|
| 12 |
+
"吴通微": "Wu Tongwei, a Ming dynasty calligrapher, was known for his work in clerical and running scripts, blending structure with expressive brush movements, reflecting literati scholarly ideals.",
|
| 13 |
+
"唐寅": "Tang Yin (1470–1524), also known as Tang Bohu, was a Ming dynasty painter, poet, and calligrapher. His calligraphy reflected elegance and charm, combining the styles of earlier masters with his own refined and poetic sensibility.",
|
| 14 |
+
"墓志": "Epitaph inscriptions (muzhi) are stone or brick-carved texts placed in tombs, recording biographical details and virtues of the deceased. They serve as valuable historical sources and display diverse calligraphic styles from various dynasties.",
|
| 15 |
+
"姚孟起": "Yao Mengqi (1807–1860), a Qing dynasty calligrapher, specialized in clerical script influenced by Han dynasty steles. His brushwork was strong and precise, contributing to the Qing stele studies movement.",
|
| 16 |
+
"姜夔": "Jiang Kui (c. 1155–1221), a Southern Song poet, composer, and calligrapher, was famous for his refined running script. His calligraphy displayed purity and grace, matching the elegance of his poetry and musical compositions.",
|
| 17 |
+
"姜立纲": "Jiang Ligang, a Yuan dynasty calligrapher, was celebrated for his monumental regular and clerical script works, emphasizing stability, clarity, and the influence of Tang models.",
|
| 18 |
+
"孙过庭": "Sun Guoting (646–691), a Tang dynasty calligrapher, is best known for his 'Treatise on Calligraphy' (Shu Pu), an influential theoretical and artistic work written in running script, blending philosophical insight with masterful technique.",
|
| 19 |
+
"宋克": "Song Ke (1327–1387), a Ming dynasty calligrapher, excelled in cursive script and seal script, integrating ancient inscriptions into his creative style with vigorous brush energy.",
|
| 20 |
+
"宋珏": "Song Jue, a Ming dynasty calligrapher, was skilled in regular and running scripts, focusing on balanced structures and literati refinement.",
|
| 21 |
+
"小野道风": "Ono no Michikaze (894–966), a Heian period Japanese calligrapher, is considered one of the 'Three Brush Saints' of Japan. He adapted Chinese calligraphy styles, particularly those of Wang Xizhi, into the distinctive Japanese wayo style.",
|
| 22 |
+
"康熙": "Kangxi (1654–1722), the fourth emperor of the Qing dynasty, was a patron of arts and practiced calligraphy in regular and running scripts. His works imitated classical Tang models, symbolizing imperial authority and cultural refinement.",
|
| 23 |
+
"康里巎巎": "Kangli Naonao (dates unknown), a Yuan dynasty calligrapher of Mongol heritage, was celebrated for his running and cursive scripts, blending vigor with refined structure, influenced by Zhao Mengfu.",
|
| 24 |
+
"张即之": "Zhang Jizhi (1186–1266), a Southern Song calligrapher, excelled in regular and running scripts. His works were upright, powerful, and faithful to Tang models, often compared to Yan Zhenqing.",
|
| 25 |
+
"张弼": "Zhang Bi (1425–1487), a Ming dynasty calligrapher, was known for his wild cursive script, demonstrating unrestrained energy and personal expression, breaking from orthodox conventions.",
|
| 26 |
+
"张旭": "Zhang Xu (c. 675–750), a Tang dynasty master of wild cursive (kuangcao), was famed for his unbridled, vigorous brushwork, earning the nickname 'Crazy Zhang.' His style embodied spontaneity and emotion.",
|
| 27 |
+
"张瑞图": "Zhang Ruitu (1570–1641), a Ming dynasty calligrapher, specialized in cursive and semi-cursive scripts. His work was bold, innovative, and structurally unconventional, making him one of the 'Four Masters of the Ming.'",
|
| 28 |
+
"张芝": "Zhang Zhi (d. 192), an Eastern Han calligrapher, was revered as the 'Sage of Cursive Script.' He developed early cursive forms, influencing later masters like Wang Xizhi.",
|
| 29 |
+
"徐渭": "Xu Wei (1521–1593), a Ming dynasty painter, poet, and calligrapher, was renowned for his expressive cursive script. His works reflected personal emotion and bold experimentation.",
|
| 30 |
+
"怀素": "Huai Su (737–799), a Tang dynasty monk, was a master of wild cursive script. His brushwork was swift, forceful, and full of rhythm, leaving a lasting impact on the art of calligraphy.",
|
| 31 |
+
"摩崖刻石": "Cliff inscriptions (moya kestone) refer to large-scale carvings of calligraphy on rock faces or cliffs, often for commemorative or religious purposes, showcasing monumental styles from various dynasties.",
|
| 32 |
+
"文天祥": "Wen Tianxiang (1236–1283), a Southern Song loyalist, poet, and calligrapher, wrote in regular and running scripts. His works expressed moral integrity and patriotism, often in prison before his execution.",
|
| 33 |
+
"文彭": "Wen Peng (1498–1573), a Ming dynasty calligrapher and seal carver, was a pioneer in literati seal carving, also excelling in running script with elegant brushwork.",
|
| 34 |
+
"文征明": "Wen Zhengming (1470–1559), a leading Ming dynasty scholar-painter and calligrapher, mastered multiple scripts. His calligraphy was refined, balanced, and scholarly, influencing generations of literati.",
|
| 35 |
+
"智永": "Zhiyong (dates uncertain, 6th–7th century), a Sui dynasty monk and descendant of Wang Xizhi, was famed for his 'Thousand Character Classic' in regular script, which became a key model for students.",
|
| 36 |
+
"曾巩": "Zeng Gong (1019–1083), a Northern Song scholar and writer, was also skilled in calligraphy, producing dignified regular script reflecting his literary stature.",
|
| 37 |
+
"朱登": "Zhu Deng, a Ming dynasty calligrapher, was known for his structured and disciplined regular script, following Tang exemplars.",
|
| 38 |
+
"朱耷\\八大山人": "Zhu Da (1626–1705), known as Bada Shanren, was a Ming loyalist monk, painter, and calligrapher. His calligraphy was eccentric and bold, matching the unconventional spirit of his paintings.",
|
| 39 |
+
"李倜": "Li Ti, a Tang dynasty calligrapher, was praised for his elegant running script, merging influence from Wang Xizhi and Tang court styles.",
|
| 40 |
+
"李孝光": "Li Xiaoguang, a Ming dynasty calligrapher, was noted for his regular script, showing a balance of precision and grace.",
|
| 41 |
+
"李斯": "Li Si (c. 280–208 BCE), Prime Minister of the Qin dynasty, standardized the Small Seal Script (xiaozhuan) used in the First Emperor’s inscriptions, shaping the foundation of Chinese written form.",
|
| 42 |
+
"李邕": "Li Yong (678–747), a Tang dynasty calligrapher, excelled in regular and running scripts, producing monumental works with strong structure and elegance.",
|
| 43 |
+
"李阳冰": "Li Yangbing, an 8th-century Tang calligrapher and relative of poet Li Bai, was renowned for his mastery of seal script, influencing later revival movements.",
|
| 44 |
+
"杨凝式": "Yang Ningshi (873–954), a Five Dynasties period calligrapher, was celebrated for his running script marked by elegance and individuality, blending Tang and earlier traditions.",
|
| 45 |
+
"杨沂孙": "Yang Yisun (1813–1897), a Qing dynasty calligrapher, specialized in clerical script, studying Han dynasty steles with precision and vigor.",
|
| 46 |
+
"杨秀": "Yang Xiu (d. 219), an Eastern Han court official, was known for his literary talent and calligraphy in early clerical forms.",
|
| 47 |
+
"杨维桢": "Yang Weizhen (1296–1370), a Yuan dynasty poet and calligrapher, favored dramatic, unconventional styles in large characters, reflecting his bold personality.",
|
| 48 |
+
"柯九思": "Ke Jiusi (1290–1343), a Yuan dynasty scholar-painter and calligrapher, excelled in running script and was a close associate of Zhao Mengfu, blending elegance with scholarly restraint.",
|
| 49 |
+
"柳公权": "Liu Gongquan (778–865), a Tang dynasty calligrapher, was a master of regular script. His style, upright and vigorous, became a standard model alongside Yan Zhenqing.",
|
| 50 |
+
"欧阳询": "Ouyang Xun (557–641), a Tang dynasty calligrapher, produced highly disciplined and elegant regular script, influencing generations of scholars.",
|
| 51 |
+
"欧阳通": "Ouyang Tong, son of Ouyang Xun, continued his father’s style in regular script, contributing to the Tang court's refined calligraphic tradition.",
|
| 52 |
+
"沈尹默": "Shen Yinmo (1883–1971), a modern calligrapher, was noted for his running script influenced by Jin and Tang masters, bridging traditional aesthetics with modern sensibility.",
|
| 53 |
+
"王宠": "Wang Chong (1494–1533), a Ming dynasty calligrapher, excelled in small regular script and running script, valued for precision and grace.",
|
| 54 |
+
"王献之": "Wang Xianzhi (344–386), son of Wang Xizhi, was a master of running and cursive scripts. His 'one-stroke' technique in cursive writing was highly influential.",
|
| 55 |
+
"王珣": "Wang Xun (349–400), a Jin dynasty calligrapher and cousin of Wang Xizhi, was famed for his running script letter 'Letter to Boyuan,' an important surviving work.",
|
| 56 |
+
"王福庵": "Wang Fuan (1880–1960), a modern seal script calligrapher, was known for his meticulous study of ancient inscriptions, producing elegant and precise works.",
|
| 57 |
+
"王羲之": "Wang Xizhi (303–361), the 'Sage of Calligraphy,' was a Jin dynasty master whose running script epitomized fluidity and grace. His 'Preface to the Orchid Pavilion' is revered as a pinnacle of the art.",
|
| 58 |
+
"王诜": "Wang Shen (c. 1036–c. 1093), a Northern Song literati and official, was adept in painting and calligraphy, producing refined running script works.",
|
| 59 |
+
"王铎": "Wang Duo (1592–1652), a Ming–Qing transitional calligrapher, was famous for his cursive script full of energy and grandeur, inspired by Zhang Xu and Huai Su.",
|
| 60 |
+
"皇象": "Huang Xiang, a calligrapher of the Three Kingdoms period (Eastern Wu), was known for his elegant official script and is credited with works like the 'Jiucheng Palace Liquan Inscription.'",
|
| 61 |
+
"祝允明": "Zhu Yunming (1460–1526), a Ming dynasty calligrapher, was a master of cursive script, producing works of great speed, energy, and individuality.",
|
| 62 |
+
"空海": "Kūkai (774–835), a Japanese monk and founder of Shingon Buddhism, studied Chinese calligraphy during the Tang dynasty. He excelled in adapting Tang styles to Japanese aesthetics.",
|
| 63 |
+
"简牍帛书盟书": "Bamboo slips, silk manuscripts, and covenant documents (jiandu, boshu, mengshu) are ancient writing media and forms from pre-Qin to Han periods, providing insight into early script forms such as seal and clerical script.",
|
| 64 |
+
"米芾": "Mi Fu (1051–1107), a Northern Song calligrapher, painter, and connoisseur, was famous for his running and cursive scripts, characterized by elegance, speed, and rhythm.",
|
| 65 |
+
"索靖": "Suo Jing (239–303), a Western Jin calligrapher, excelled in cursive script and was influential in the development of early running and cursive styles.",
|
| 66 |
+
"苏轼": "Su Shi (1037–1101), a Northern Song poet, painter, and calligrapher, was known for his free and expressive running script, reflecting his literary genius.",
|
| 67 |
+
"董其昌": "Dong Qichang (1555–1636), a Ming dynasty painter, theorist, and calligrapher, promoted literati aesthetics and excelled in running and cursive scripts with scholarly refinement.",
|
| 68 |
+
"蔡京": "Cai Jing (1047–1126), a Northern Song chancellor, was skilled in calligraphy, especially running script, though his political career was controversial.",
|
| 69 |
+
"蔡襄": "Cai Xiang (1012–1067), a Northern Song calligrapher and statesman, was one of the 'Four Masters of Song,' known for his dignified regular script.",
|
| 70 |
+
"薛绍彭": "Xue Shaopeng (980–1050), a Northern Song calligrapher, was skilled in regular and running scripts, contributing to the refinement of literati styles.",
|
| 71 |
+
"虞世南": "Yu Shinan (558–638), a Tang dynasty calligrapher, was renowned for his graceful regular script, blending strength with delicacy.",
|
| 72 |
+
"褚遂良": "Chu Suiliang (596–658), a Tang dynasty calligrapher, developed a slender, elegant regular script, serving as a bridge between Ouyang Xun and later Tang styles.",
|
| 73 |
+
"赵之谦": "Zhao Zhiqian (1829–1884), a Qing dynasty calligrapher, painter, and seal carver, combined seal script with running script, influencing modern Chinese art.",
|
| 74 |
+
"赵佶\\宋徽宗": "Zhao Ji (1082–1135), Emperor Huizong of the Northern Song, created the slender-gold script (shoujin ti), a distinctive, elegant style marked by fine lines and angular turns.",
|
| 75 |
+
"赵孟頫": "Zhao Mengfu (1254–1322), a Yuan dynasty prince and calligrapher, revived Jin-Tang styles with smooth, rounded brushwork in multiple scripts.",
|
| 76 |
+
"赵构": "Zhao Gou (1107–1187), Emperor Gaozong of the Southern Song, was a skilled calligrapher whose works reflected refinement and scholarly taste.",
|
| 77 |
+
"造像记": "Votive inscriptions (zaoxiang ji) are dedicatory texts carved on Buddhist statues or steles, often in clerical or regular script, providing valuable data on religious and artistic history.",
|
| 78 |
+
"邓文原": "Deng Wenyuan (1258–1328), a Yuan dynasty calligrapher, was skilled in regular and running scripts, known for elegant and restrained brushwork.",
|
| 79 |
+
"邓石如": "Deng Shiru (1743–1805), a Qing dynasty master of seal and clerical scripts, studied ancient inscriptions intensively, producing powerful and scholarly works.",
|
| 80 |
+
"金农": "Jin Nong (1687–1764), a Qing dynasty painter and calligrapher, created the 'lacquer script' style, integrating personal creativity with archaic forms.",
|
| 81 |
+
"金文刻石": "Bronze inscriptions (jinwen) are cast or carved texts on ancient Chinese ritual bronzes, primarily from the Shang and Zhou dynasties, serving as crucial sources for early Chinese writing.",
|
| 82 |
+
"钟繇": "Zhong Yao (151–230), a Wei dynasty calligrapher, is regarded as the 'Father of Regular Script,' developing early forms that influenced generations.",
|
| 83 |
+
"钟绍京": "Zhong Shaojing (fl. early 8th century), a Tang dynasty calligrapher, was famous for his monumental regular script inscriptions, such as the 'Huadu Temple Stele.' His style was upright, clear, and became a key model for stone inscription calligraphy.",
|
| 84 |
+
"陆机": "Lu Ji (261–303), a Western Jin dynasty poet, essayist, and calligrapher, was admired for his refined brushwork in running and cursive scripts. His calligraphy paralleled the elegance of his literary works, though few authentic pieces survive.",
|
| 85 |
+
"陆柬之": "Lu Jianzhi (585–638), a Tang dynasty calligrapher, was a relative of Lu Fayan and known for his formal regular script, contributing to the establishment of Tang calligraphic orthodoxy.",
|
| 86 |
+
"陆游": "Lu You (1125–1210), a Southern Song dynasty poet and patriot, was also skilled in running script. His calligraphy was expressive and vigorous, reflecting his passionate literary voice.",
|
| 87 |
+
"陶弘景": "Tao Hongjing (456–536), a scholar, Daoist master, and calligrapher of the Southern Dynasties, excelled in clerical and early regular scripts, blending elegance with scholarly restraint.",
|
| 88 |
+
"颜真卿": "Yan Zhenqing (709–785), a Tang dynasty calligrapher and loyal official, was a master of regular script, creating a robust, upright style that emphasized moral integrity. Works like 'Yan Qinli Stele' remain exemplary.",
|
| 89 |
+
"鲜于枢": "Xianyu Shu (1257–1302), a Yuan dynasty calligrapher, excelled in regular, running, and cursive scripts. His works were vigorous, angular, and influenced by both Jin and Tang traditions.",
|
| 90 |
+
"黄养正": "Huang Yangzheng, a Ming dynasty calligrapher, specialized in running and clerical scripts, producing works of solid structure and scholarly elegance.",
|
| 91 |
+
"黄庭坚": "Huang Tingjian (1045–1105), a Northern Song poet and calligrapher, was a master of running script, developing a distinctive, elongated style influenced by Su Shi and ancient models.",
|
| 92 |
+
"黄自元": "Huang Ziyuan (1837–1917), a Qing dynasty calligrapher, was noted for his clerical script, compiling influential copybooks that became standard references for learners.",
|
| 93 |
+
"黄道周": "Huang Daozhou (1585–1646), a late Ming scholar-official and calligrapher, excelled in regular and running scripts. His style was upright, forceful, and reflected his loyalist ideals.",
|
| 94 |
+
"龚贤": "Gong Xian (1619–1689), a Ming–Qing transitional painter and calligrapher, was skilled in running script. His brushwork paralleled the dense, layered style of his paintings, embodying literati sensibility."
|
| 95 |
+
}
|
dataset/chirography.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"楷": "symmetrical structure, stable center of gravity, rigorous layout of the frame, clear gaps between strokes, easy to identify the shape of the characters.",
|
| 3 |
+
"篆": "an ancient Chinese calligraphic style marked by uniform and symmetrical strokes, rounded or angular lines, and a highly decorative, formal appearance.",
|
| 4 |
+
"草": "coherent and smooth strokes, sparse and dense structure, and more free structure. Sometimes up, down, left and right compress or stretch to match the strokes.",
|
| 5 |
+
"行": "a calligraphic style that blends structure and freedom, featuring fluid strokes, moderate simplification, and expressive rhythm while preserving readability.",
|
| 6 |
+
"隶": "a calligraphic style characterized by flat, wide strokes, distinctive flared brush endings (“silkworm head and wild goose tail”), and a dignified yet graceful structure."
|
| 7 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Chinese Calligraphy Generation with Flux Model
|
| 4 |
+
Author and font style controllable generation
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
from optimum.quanto import quantize, freeze, qint4
|
| 11 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
+
from typing import Optional, List, Union, Dict, Any
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from pypinyin import lazy_pinyin
|
| 15 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 16 |
+
|
| 17 |
+
from src.flux.util import configs, load_ae, load_clip, load_t5
|
| 18 |
+
from src.flux.model import Flux
|
| 19 |
+
from src.flux.xflux_pipeline import XFluxSampler
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# HuggingFace Hub model IDs
|
| 23 |
+
HF_MODEL_ID = "TSXu/UniCalli-base"
|
| 24 |
+
HF_CHECKPOINT_FILENAME = "unicalli-base_cleaned.bin"
|
| 25 |
+
HF_INTERNVL_ID = "OpenGVLab/InternVL3-1B"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def download_model_from_hf(
|
| 29 |
+
model_id: str = HF_MODEL_ID,
|
| 30 |
+
filename: str = HF_CHECKPOINT_FILENAME,
|
| 31 |
+
local_dir: str = None,
|
| 32 |
+
force_download: bool = False
|
| 33 |
+
) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Download model checkpoint from HuggingFace Hub
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model_id: HuggingFace model repository ID
|
| 39 |
+
filename: Name of the checkpoint file to download
|
| 40 |
+
local_dir: Local directory to save the file (optional)
|
| 41 |
+
force_download: Whether to force re-download
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Path to the downloaded checkpoint file
|
| 45 |
+
"""
|
| 46 |
+
print(f"Downloading {filename} from HuggingFace Hub ({model_id})...")
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
checkpoint_path = hf_hub_download(
|
| 50 |
+
repo_id=model_id,
|
| 51 |
+
filename=filename,
|
| 52 |
+
local_dir=local_dir,
|
| 53 |
+
force_download=force_download
|
| 54 |
+
)
|
| 55 |
+
print(f"Model downloaded to: {checkpoint_path}")
|
| 56 |
+
return checkpoint_path
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error downloading model: {e}")
|
| 59 |
+
raise
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def ensure_checkpoint_exists(checkpoint_path: str) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Ensure checkpoint exists locally, download from HF Hub if not
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
checkpoint_path: Local path or HF model ID
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Path to the local checkpoint file
|
| 71 |
+
"""
|
| 72 |
+
# If it's a local path and exists, return it
|
| 73 |
+
if os.path.exists(checkpoint_path):
|
| 74 |
+
return checkpoint_path
|
| 75 |
+
|
| 76 |
+
# If checkpoint_path looks like a filename (not a full path), try to download
|
| 77 |
+
if not os.path.dirname(checkpoint_path) or checkpoint_path == HF_CHECKPOINT_FILENAME:
|
| 78 |
+
print(f"Checkpoint not found locally, downloading from HuggingFace Hub...")
|
| 79 |
+
return download_model_from_hf(filename=checkpoint_path)
|
| 80 |
+
|
| 81 |
+
# If it looks like a HF repo ID (contains /)
|
| 82 |
+
if '/' in checkpoint_path and not os.path.exists(checkpoint_path):
|
| 83 |
+
print(f"Downloading from HuggingFace Hub: {checkpoint_path}")
|
| 84 |
+
return download_model_from_hf(model_id=checkpoint_path, filename=HF_CHECKPOINT_FILENAME)
|
| 85 |
+
|
| 86 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def convert_to_pinyin(text):
|
| 90 |
+
return ' '.join([item[0] if isinstance(item, list) else item for item in lazy_pinyin(text)])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class CalligraphyGenerator:
|
| 94 |
+
"""
|
| 95 |
+
Chinese Calligraphy Generator using Flux model
|
| 96 |
+
|
| 97 |
+
Attributes:
|
| 98 |
+
device: torch device for computation
|
| 99 |
+
model_name: name of the flux model (flux-dev or flux-schnell)
|
| 100 |
+
font_styles: available font styles for generation
|
| 101 |
+
authors: available calligrapher authors
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
model_name: str = "flux-dev",
|
| 107 |
+
device: str = "cuda",
|
| 108 |
+
offload: bool = True,
|
| 109 |
+
checkpoint_path: Optional[str] = None,
|
| 110 |
+
intern_vlm_path: Optional[str] = None,
|
| 111 |
+
ref_latent_path: Optional[str] = None,
|
| 112 |
+
font_descriptions_path: str = "chirography.json",
|
| 113 |
+
author_descriptions_path: str = "calligraphy_styles_en.json",
|
| 114 |
+
use_deepspeed: bool = False,
|
| 115 |
+
use_4bit_quantization: bool = False,
|
| 116 |
+
deepspeed_config: Optional[str] = None
|
| 117 |
+
):
|
| 118 |
+
"""
|
| 119 |
+
Initialize the calligraphy generator
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
model_name: flux model name (flux-dev or flux-schnell)
|
| 123 |
+
device: device for computation
|
| 124 |
+
offload: whether to offload model to CPU when not in use
|
| 125 |
+
checkpoint_path: path to model checkpoint if using fine-tuned model
|
| 126 |
+
intern_vlm_path: path to InternVLM model for text embedding
|
| 127 |
+
ref_latent_path: path to reference latents for recognition mode
|
| 128 |
+
font_descriptions_path: path to font style descriptions JSON
|
| 129 |
+
author_descriptions_path: path to author style descriptions JSON
|
| 130 |
+
use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
|
| 131 |
+
deepspeed_config: path to DeepSpeed config JSON file
|
| 132 |
+
"""
|
| 133 |
+
self.device = torch.device(device)
|
| 134 |
+
self.model_name = model_name
|
| 135 |
+
self.offload = offload
|
| 136 |
+
self.is_schnell = model_name == "flux-schnell"
|
| 137 |
+
self.use_deepspeed = use_deepspeed
|
| 138 |
+
self.deepspeed_config = deepspeed_config
|
| 139 |
+
self.use_4bit_quantization = use_4bit_quantization
|
| 140 |
+
|
| 141 |
+
# Load font and author style descriptions
|
| 142 |
+
if os.path.exists(font_descriptions_path):
|
| 143 |
+
with open(font_descriptions_path, 'r', encoding='utf-8') as f:
|
| 144 |
+
self.font_style_des = json.load(f)
|
| 145 |
+
else:
|
| 146 |
+
raise FileNotFoundError(f"Font descriptions file not found: {font_descriptions_path}")
|
| 147 |
+
|
| 148 |
+
if os.path.exists(author_descriptions_path):
|
| 149 |
+
with open(author_descriptions_path, 'r', encoding='utf-8') as f:
|
| 150 |
+
self.author_style = json.load(f)
|
| 151 |
+
else:
|
| 152 |
+
raise FileNotFoundError(f"Author descriptions file not found: {author_descriptions_path}")
|
| 153 |
+
|
| 154 |
+
# Load models
|
| 155 |
+
print("Loading models...")
|
| 156 |
+
# When using DeepSpeed, load text encoders on CPU first to save memory during initialization
|
| 157 |
+
# They will be moved to GPU after DeepSpeed initializes the main model
|
| 158 |
+
if self.use_deepspeed:
|
| 159 |
+
text_encoder_device = "cpu"
|
| 160 |
+
elif offload:
|
| 161 |
+
text_encoder_device = "cpu" # Will be moved to GPU during inference
|
| 162 |
+
else:
|
| 163 |
+
text_encoder_device = self.device
|
| 164 |
+
|
| 165 |
+
self.t5 = load_t5(text_encoder_device, max_length=256 if self.is_schnell else 512)
|
| 166 |
+
self.clip = load_clip(text_encoder_device)
|
| 167 |
+
self.clip.requires_grad_(False)
|
| 168 |
+
|
| 169 |
+
# Ensure checkpoint exists (download from HF Hub if needed)
|
| 170 |
+
if checkpoint_path:
|
| 171 |
+
checkpoint_path = ensure_checkpoint_exists(checkpoint_path)
|
| 172 |
+
print(f"Loading model from checkpoint: {checkpoint_path}")
|
| 173 |
+
# When using DeepSpeed, don't move to GPU yet - let DeepSpeed handle it
|
| 174 |
+
self.model = self._load_model_from_checkpoint(
|
| 175 |
+
checkpoint_path, model_name,
|
| 176 |
+
offload=offload,
|
| 177 |
+
use_deepspeed=self.use_deepspeed
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Initialize DeepSpeed if requested
|
| 181 |
+
if self.use_deepspeed:
|
| 182 |
+
self.model = self._init_deepspeed(self.model)
|
| 183 |
+
else:
|
| 184 |
+
# If no checkpoint path provided, download default from HF Hub
|
| 185 |
+
print("No checkpoint path provided, downloading from HuggingFace Hub...")
|
| 186 |
+
checkpoint_path = download_model_from_hf()
|
| 187 |
+
print(f"Loading model from checkpoint: {checkpoint_path}")
|
| 188 |
+
self.model = self._load_model_from_checkpoint(
|
| 189 |
+
checkpoint_path, model_name,
|
| 190 |
+
offload=offload,
|
| 191 |
+
use_deepspeed=self.use_deepspeed
|
| 192 |
+
)
|
| 193 |
+
if self.use_deepspeed:
|
| 194 |
+
self.model = self._init_deepspeed(self.model)
|
| 195 |
+
|
| 196 |
+
# Load VAE
|
| 197 |
+
if self.use_deepspeed or offload:
|
| 198 |
+
vae_device = "cpu"
|
| 199 |
+
else:
|
| 200 |
+
vae_device = self.device
|
| 201 |
+
|
| 202 |
+
self.vae = load_ae(model_name, device=vae_device)
|
| 203 |
+
|
| 204 |
+
# Move VAE to GPU only if offload (not DeepSpeed)
|
| 205 |
+
if offload and not self.use_deepspeed:
|
| 206 |
+
self.vae = self.vae.to(self.device)
|
| 207 |
+
|
| 208 |
+
# After DeepSpeed init, move text encoders to GPU
|
| 209 |
+
if self.use_deepspeed:
|
| 210 |
+
print("Moving text encoders to GPU...")
|
| 211 |
+
self.t5 = self.t5.to(self.device)
|
| 212 |
+
self.clip = self.clip.to(self.device)
|
| 213 |
+
self.vae = self.vae.to(self.device)
|
| 214 |
+
|
| 215 |
+
# Load reference latents if provided
|
| 216 |
+
self.ref_latent = None
|
| 217 |
+
if ref_latent_path and os.path.exists(ref_latent_path):
|
| 218 |
+
print(f"Loading reference latents from {ref_latent_path}")
|
| 219 |
+
self.ref_latent = torch.load(ref_latent_path, map_location='cpu')
|
| 220 |
+
|
| 221 |
+
# Create sampler
|
| 222 |
+
self.sampler = XFluxSampler(
|
| 223 |
+
clip=self.clip,
|
| 224 |
+
t5=self.t5,
|
| 225 |
+
ae=self.vae,
|
| 226 |
+
ref_latent=self.ref_latent,
|
| 227 |
+
model=self.model,
|
| 228 |
+
device=self.device,
|
| 229 |
+
intern_vlm_path=intern_vlm_path
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Font for generating condition images
|
| 233 |
+
self.font_path = self._ensure_font_exists("./FangZhengKaiTiFanTi-1.ttf")
|
| 234 |
+
self.default_font_size = 102 # 128 * 0.8
|
| 235 |
+
|
| 236 |
+
def _ensure_font_exists(self, font_path: str) -> str:
|
| 237 |
+
"""
|
| 238 |
+
Ensure font file exists locally, download from HF Hub if not
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
font_path: Local path to font file
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Path to the local font file
|
| 245 |
+
"""
|
| 246 |
+
if os.path.exists(font_path):
|
| 247 |
+
return font_path
|
| 248 |
+
|
| 249 |
+
# Try to download from HF Hub
|
| 250 |
+
print(f"Font file not found locally, downloading from HuggingFace Hub...")
|
| 251 |
+
try:
|
| 252 |
+
font_path = hf_hub_download(
|
| 253 |
+
repo_id=HF_MODEL_ID,
|
| 254 |
+
filename="FangZhengKaiTiFanTi-1.ttf"
|
| 255 |
+
)
|
| 256 |
+
print(f"Font downloaded to: {font_path}")
|
| 257 |
+
return font_path
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"Warning: Could not download font: {e}")
|
| 260 |
+
return font_path # Return original path, may fail later
|
| 261 |
+
|
| 262 |
+
def _load_model_from_checkpoint(self, checkpoint_path: str, model_name: str, offload: bool, use_deepspeed: bool = False):
|
| 263 |
+
"""
|
| 264 |
+
Load model from checkpoint without loading flux pretrained weights.
|
| 265 |
+
This creates an empty model, initializes module embeddings, then loads your checkpoint.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
checkpoint_path: Path to your checkpoint file
|
| 269 |
+
model_name: flux model name (for config)
|
| 270 |
+
offload: whether to offload to CPU
|
| 271 |
+
use_deepspeed: whether using DeepSpeed (keeps model on CPU)
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
model with loaded checkpoint
|
| 275 |
+
"""
|
| 276 |
+
print(f"Creating empty flux model structure...")
|
| 277 |
+
# Load checkpoint on CPU first to save memory
|
| 278 |
+
# If using DeepSpeed, keep on CPU; otherwise move to GPU after loading
|
| 279 |
+
load_device = "cpu"
|
| 280 |
+
|
| 281 |
+
# Create model structure without loading pretrained weights (using "meta" device)
|
| 282 |
+
with torch.device("meta"):
|
| 283 |
+
model = Flux(configs[model_name].params)
|
| 284 |
+
|
| 285 |
+
# Initialize module embeddings (must be done before loading checkpoint)
|
| 286 |
+
print("Initializing module embeddings...")
|
| 287 |
+
model.init_module_embeddings(tokens_num=320, cond_txt_channel=896)
|
| 288 |
+
|
| 289 |
+
# Move model to loading device
|
| 290 |
+
print(f"Moving model to {load_device} for loading...")
|
| 291 |
+
model = model.to_empty(device=load_device)
|
| 292 |
+
|
| 293 |
+
# Load checkpoint
|
| 294 |
+
print(f"Loading checkpoint from {checkpoint_path}")
|
| 295 |
+
checkpoint = self._load_checkpoint_file(checkpoint_path)
|
| 296 |
+
|
| 297 |
+
# Load weights into model
|
| 298 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 299 |
+
|
| 300 |
+
# Apply 4-bit quantization if requested
|
| 301 |
+
if hasattr(self, 'use_4bit_quantization') and self.use_4bit_quantization:
|
| 302 |
+
print("Applying 4-bit quantization...")
|
| 303 |
+
model = model.float() # 先转为 float32
|
| 304 |
+
quantize(model, weights=qint4)
|
| 305 |
+
freeze(model)
|
| 306 |
+
model._is_quantized = True # 添加标记供 xflux_pipeline 检查
|
| 307 |
+
print("4-bit quantization complete!")
|
| 308 |
+
|
| 309 |
+
# Move to GPU only if NOT using DeepSpeed (DeepSpeed will handle device placement)
|
| 310 |
+
if not use_deepspeed:
|
| 311 |
+
print(f"Moving model to {self.device}...")
|
| 312 |
+
model = model.to(self.device)
|
| 313 |
+
|
| 314 |
+
return model
|
| 315 |
+
|
| 316 |
+
def _init_deepspeed(self, model):
|
| 317 |
+
"""
|
| 318 |
+
Initialize DeepSpeed for the model with ZeRO-3 inference optimization.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
model: PyTorch model to wrap with DeepSpeed
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
DeepSpeed inference engine
|
| 325 |
+
"""
|
| 326 |
+
try:
|
| 327 |
+
import deepspeed
|
| 328 |
+
except ImportError:
|
| 329 |
+
raise ImportError("DeepSpeed is not installed. Install it with: pip install deepspeed")
|
| 330 |
+
|
| 331 |
+
# Load DeepSpeed config
|
| 332 |
+
if self.deepspeed_config is None:
|
| 333 |
+
self.deepspeed_config = "ds_config_zero2.json"
|
| 334 |
+
|
| 335 |
+
if not os.path.exists(self.deepspeed_config):
|
| 336 |
+
raise FileNotFoundError(f"DeepSpeed config not found: {self.deepspeed_config}")
|
| 337 |
+
|
| 338 |
+
print(f"Initializing DeepSpeed Inference with config: {self.deepspeed_config}")
|
| 339 |
+
|
| 340 |
+
# Initialize distributed environment for single GPU if not already initialized
|
| 341 |
+
if not torch.distributed.is_initialized():
|
| 342 |
+
import random
|
| 343 |
+
# Set environment variables for single-process mode
|
| 344 |
+
# Use a random port to avoid conflicts
|
| 345 |
+
port = random.randint(29500, 29600)
|
| 346 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 347 |
+
os.environ['MASTER_PORT'] = str(port)
|
| 348 |
+
os.environ['RANK'] = '0'
|
| 349 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 350 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 351 |
+
|
| 352 |
+
# Initialize process group
|
| 353 |
+
try:
|
| 354 |
+
torch.distributed.init_process_group(
|
| 355 |
+
backend='nccl',
|
| 356 |
+
init_method='env://',
|
| 357 |
+
world_size=1,
|
| 358 |
+
rank=0
|
| 359 |
+
)
|
| 360 |
+
print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}")
|
| 361 |
+
except RuntimeError as e:
|
| 362 |
+
if "address already in use" in str(e):
|
| 363 |
+
print(f"Port {port} in use, trying again...")
|
| 364 |
+
# Try a different port
|
| 365 |
+
port = random.randint(29600, 29700)
|
| 366 |
+
os.environ['MASTER_PORT'] = str(port)
|
| 367 |
+
torch.distributed.init_process_group(
|
| 368 |
+
backend='nccl',
|
| 369 |
+
init_method='env://',
|
| 370 |
+
world_size=1,
|
| 371 |
+
rank=0
|
| 372 |
+
)
|
| 373 |
+
print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}")
|
| 374 |
+
else:
|
| 375 |
+
raise
|
| 376 |
+
|
| 377 |
+
# Use DeepSpeed inference API instead of initialize
|
| 378 |
+
# This doesn't require an optimizer
|
| 379 |
+
with open(self.deepspeed_config) as f:
|
| 380 |
+
ds_config = json.load(f)
|
| 381 |
+
|
| 382 |
+
model_engine = deepspeed.init_inference(
|
| 383 |
+
model=model,
|
| 384 |
+
mp_size=1, # model parallel size
|
| 385 |
+
dtype=torch.bfloat16 if ds_config.get('bf16', {}).get('enabled', False) else torch.float16,
|
| 386 |
+
replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
print("DeepSpeed Inference initialized successfully")
|
| 390 |
+
return model_engine
|
| 391 |
+
|
| 392 |
+
def _load_checkpoint_file(self, checkpoint_path: str) -> dict:
|
| 393 |
+
"""
|
| 394 |
+
Load checkpoint file and extract state dict.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
checkpoint_path: Path to checkpoint file, can be:
|
| 398 |
+
- Full checkpoint with model, optimizer, etc. (from training)
|
| 399 |
+
- State dict only file
|
| 400 |
+
- Directory containing checkpoint files
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
state_dict: model state dictionary
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
# Check if it's a directory containing checkpoint files
|
| 407 |
+
if os.path.isdir(checkpoint_path):
|
| 408 |
+
# Look for common checkpoint filenames
|
| 409 |
+
possible_files = [
|
| 410 |
+
'model.pt', 'model.pth', 'model.bin',
|
| 411 |
+
'checkpoint.pt', 'checkpoint.pth',
|
| 412 |
+
'pytorch_model.bin', 'model_state_dict.pt'
|
| 413 |
+
]
|
| 414 |
+
|
| 415 |
+
checkpoint_file = None
|
| 416 |
+
for filename in possible_files:
|
| 417 |
+
full_path = os.path.join(checkpoint_path, filename)
|
| 418 |
+
if os.path.exists(full_path):
|
| 419 |
+
checkpoint_file = full_path
|
| 420 |
+
print(f"Found checkpoint file: {filename}")
|
| 421 |
+
break
|
| 422 |
+
|
| 423 |
+
if checkpoint_file is None:
|
| 424 |
+
# Try to find any .pt or .pth file
|
| 425 |
+
import glob
|
| 426 |
+
pt_files = glob.glob(os.path.join(checkpoint_path, "*.pt")) + \
|
| 427 |
+
glob.glob(os.path.join(checkpoint_path, "*.pth")) + \
|
| 428 |
+
glob.glob(os.path.join(checkpoint_path, "*.bin"))
|
| 429 |
+
if pt_files:
|
| 430 |
+
checkpoint_file = pt_files[0]
|
| 431 |
+
print(f"Found checkpoint file: {os.path.basename(checkpoint_file)}")
|
| 432 |
+
else:
|
| 433 |
+
raise ValueError(f"No checkpoint files found in directory: {checkpoint_path}")
|
| 434 |
+
|
| 435 |
+
checkpoint_path = checkpoint_file
|
| 436 |
+
|
| 437 |
+
# Load the checkpoint
|
| 438 |
+
print(f"Loading checkpoint file: {checkpoint_path}")
|
| 439 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 440 |
+
|
| 441 |
+
# Handle different checkpoint formats
|
| 442 |
+
if isinstance(checkpoint, dict):
|
| 443 |
+
# Check for different keys that might contain the model
|
| 444 |
+
if 'model' in checkpoint:
|
| 445 |
+
state_dict = checkpoint['model']
|
| 446 |
+
elif 'model_state_dict' in checkpoint:
|
| 447 |
+
state_dict = checkpoint['model_state_dict']
|
| 448 |
+
elif 'state_dict' in checkpoint:
|
| 449 |
+
state_dict = checkpoint['state_dict']
|
| 450 |
+
else:
|
| 451 |
+
# Assume the dict itself is the state dict
|
| 452 |
+
state_dict = checkpoint
|
| 453 |
+
|
| 454 |
+
# Log additional info if available
|
| 455 |
+
if 'epoch' in checkpoint:
|
| 456 |
+
print(f"Checkpoint from epoch: {checkpoint['epoch']}")
|
| 457 |
+
if 'global_step' in checkpoint:
|
| 458 |
+
print(f"Checkpoint from step: {checkpoint['global_step']}")
|
| 459 |
+
if 'loss' in checkpoint:
|
| 460 |
+
print(f"Checkpoint loss: {checkpoint['loss']:.4f}")
|
| 461 |
+
else:
|
| 462 |
+
# If it's not a dict, assume it's directly the state dict
|
| 463 |
+
state_dict = checkpoint
|
| 464 |
+
|
| 465 |
+
# Handle potential prefix mismatches
|
| 466 |
+
# Remove 'module.' prefix if present (from DataParallel/DistributedDataParallel)
|
| 467 |
+
if any(key.startswith('module.') for key in state_dict.keys()):
|
| 468 |
+
state_dict = {key.replace('module.', ''): value
|
| 469 |
+
for key, value in state_dict.items()}
|
| 470 |
+
print("Removed 'module.' prefix from state dict keys")
|
| 471 |
+
|
| 472 |
+
return state_dict
|
| 473 |
+
|
| 474 |
+
def text_to_cond_image(
|
| 475 |
+
self,
|
| 476 |
+
text: str,
|
| 477 |
+
img_size: int = 128,
|
| 478 |
+
font_scale: float = 0.8,
|
| 479 |
+
font_path: Optional[str] = None,
|
| 480 |
+
fixed_chars: int = 7
|
| 481 |
+
) -> Image.Image:
|
| 482 |
+
"""
|
| 483 |
+
Convert text to condition image - always creates image for fixed_chars characters
|
| 484 |
+
Text is arranged from top to bottom.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
text: Chinese text to convert (must be <= fixed_chars characters)
|
| 488 |
+
img_size: size of each character block (default 128)
|
| 489 |
+
font_scale: scale of font relative to image size (default 0.8)
|
| 490 |
+
font_path: path to font file
|
| 491 |
+
fixed_chars: fixed number of character slots (default 7)
|
| 492 |
+
|
| 493 |
+
Returns:
|
| 494 |
+
PIL Image with text rendered (always fixed_chars * img_size height)
|
| 495 |
+
"""
|
| 496 |
+
if len(text) > fixed_chars:
|
| 497 |
+
raise ValueError(f"Text must be at most {fixed_chars} characters, got {len(text)}")
|
| 498 |
+
|
| 499 |
+
if font_path is None:
|
| 500 |
+
font_path = self.font_path
|
| 501 |
+
|
| 502 |
+
# Create font - font size is scaled down from img_size
|
| 503 |
+
font_size_scaled = int(font_scale * img_size)
|
| 504 |
+
font = ImageFont.truetype(font_path, font_size_scaled)
|
| 505 |
+
|
| 506 |
+
# Calculate image dimensions - always fixed_chars height
|
| 507 |
+
img_width = img_size
|
| 508 |
+
img_height = img_size * fixed_chars # Fixed height for 7 characters
|
| 509 |
+
|
| 510 |
+
# Create white background image
|
| 511 |
+
cond_img = Image.new("RGB", (img_width, img_height), (255, 255, 255))
|
| 512 |
+
cond_draw = ImageDraw.Draw(cond_img)
|
| 513 |
+
|
| 514 |
+
# Draw each character from top to bottom
|
| 515 |
+
# Note: font_size for positioning should be img_size, not the scaled font size
|
| 516 |
+
for i, char in enumerate(text):
|
| 517 |
+
font_space = font_size_scaled * (1 - font_scale) // 2
|
| 518 |
+
# Position based on img_size blocks, not scaled font size
|
| 519 |
+
font_position = (font_space, img_size * i + font_space)
|
| 520 |
+
cond_draw.text(font_position, char, font=font, fill=(0, 0, 0))
|
| 521 |
+
|
| 522 |
+
return cond_img
|
| 523 |
+
|
| 524 |
+
def build_prompt(
|
| 525 |
+
self,
|
| 526 |
+
font_style: str = "楷",
|
| 527 |
+
author: str = None,
|
| 528 |
+
is_traditional: bool = True,
|
| 529 |
+
) -> str:
|
| 530 |
+
"""
|
| 531 |
+
Build prompt for generation following dataset.py logic
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
font_style: font style (楷/草/行)
|
| 535 |
+
author: author name (Chinese or None for synthetic)
|
| 536 |
+
is_traditional: whether generating traditional calligraphy
|
| 537 |
+
|
| 538 |
+
Returns:
|
| 539 |
+
formatted prompt string
|
| 540 |
+
"""
|
| 541 |
+
# Validate font style
|
| 542 |
+
if font_style not in self.font_style_des:
|
| 543 |
+
raise ValueError(f"Font style must be one of: {list(self.font_style_des.keys())}")
|
| 544 |
+
|
| 545 |
+
# Convert font style to pinyin
|
| 546 |
+
font_style_pinyin = convert_to_pinyin(font_style)
|
| 547 |
+
|
| 548 |
+
# Build prompt based on traditional or synthetic
|
| 549 |
+
if is_traditional and author and author in self.author_style:
|
| 550 |
+
# Traditional calligraphy with specific author
|
| 551 |
+
prompt = f"Traditional Chinese calligraphy works, background: black, font: {font_style_pinyin}, "
|
| 552 |
+
prompt += self.font_style_des[font_style]
|
| 553 |
+
author_info = self.author_style[author]
|
| 554 |
+
prompt += f" author: {author_info}"
|
| 555 |
+
else:
|
| 556 |
+
# Synthetic calligraphy
|
| 557 |
+
prompt = f"Synthetic calligraphy data, background: black, font: {font_style_pinyin}, "
|
| 558 |
+
prompt += self.font_style_des[font_style]
|
| 559 |
+
|
| 560 |
+
return prompt
|
| 561 |
+
|
| 562 |
+
@torch.no_grad()
|
| 563 |
+
def generate(
|
| 564 |
+
self,
|
| 565 |
+
text: str,
|
| 566 |
+
font_style: str = "楷",
|
| 567 |
+
author: str = None,
|
| 568 |
+
width: int = 128,
|
| 569 |
+
height: int = None, # Fixed to 7 characters height
|
| 570 |
+
num_steps: int = 50,
|
| 571 |
+
guidance: float = 3.5,
|
| 572 |
+
seed: int = None,
|
| 573 |
+
is_traditional: bool = None,
|
| 574 |
+
save_path: Optional[str] = None
|
| 575 |
+
) -> tuple[Image.Image, Image.Image]:
|
| 576 |
+
"""
|
| 577 |
+
Generate calligraphy image from text
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
text: Chinese text to generate (1-7 characters)
|
| 581 |
+
font_style: font style (楷/草/行)
|
| 582 |
+
author: author/calligrapher name from the style list
|
| 583 |
+
width: image width (default 128)
|
| 584 |
+
height: image height (fixed to 7 * width)
|
| 585 |
+
num_steps: number of denoising steps
|
| 586 |
+
guidance: guidance scale
|
| 587 |
+
seed: random seed for generation
|
| 588 |
+
is_traditional: whether generating traditional calligraphy (auto-determined if None)
|
| 589 |
+
save_path: optional path to save the generated image
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
tuple of (generated_image, condition_image)
|
| 593 |
+
"""
|
| 594 |
+
# Fixed number of characters
|
| 595 |
+
FIXED_CHARS = 7
|
| 596 |
+
|
| 597 |
+
# Validate text - must have 1-7 characters
|
| 598 |
+
if len(text) < 1:
|
| 599 |
+
raise ValueError(f"Text must have at least 1 character, got empty string")
|
| 600 |
+
if len(text) > FIXED_CHARS:
|
| 601 |
+
raise ValueError(f"Text must have at most {FIXED_CHARS} characters, got {len(text)}")
|
| 602 |
+
|
| 603 |
+
if seed is None:
|
| 604 |
+
seed = torch.randint(0, 2**32, (1,)).item()
|
| 605 |
+
|
| 606 |
+
# Fixed height for 7 characters
|
| 607 |
+
num_chars = len(text)
|
| 608 |
+
height = width * FIXED_CHARS # Always 7 characters height
|
| 609 |
+
|
| 610 |
+
# Auto-determine traditional vs synthetic
|
| 611 |
+
if is_traditional is None:
|
| 612 |
+
is_traditional = author is not None and author in self.author_style
|
| 613 |
+
|
| 614 |
+
# Generate condition image (fixed size for 7 characters)
|
| 615 |
+
cond_img = self.text_to_cond_image(text, img_size=width, fixed_chars=FIXED_CHARS)
|
| 616 |
+
|
| 617 |
+
# Build prompt
|
| 618 |
+
prompt = self.build_prompt(
|
| 619 |
+
font_style=font_style,
|
| 620 |
+
author=author,
|
| 621 |
+
is_traditional=is_traditional,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
print(f"Generating with prompt: {prompt}")
|
| 625 |
+
print(f"Text: {text} ({num_chars} chars), Seed: {seed}")
|
| 626 |
+
# Generate image
|
| 627 |
+
result_img, recognized_text = self.sampler(
|
| 628 |
+
prompt=prompt,
|
| 629 |
+
width=width,
|
| 630 |
+
height=height,
|
| 631 |
+
num_steps=num_steps,
|
| 632 |
+
controlnet_image=cond_img,
|
| 633 |
+
is_generation=True,
|
| 634 |
+
cond_text=text,
|
| 635 |
+
required_chars=FIXED_CHARS, # Always 7 characters
|
| 636 |
+
seed=seed
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
# Crop to actual text length if less than FIXED_CHARS
|
| 640 |
+
if num_chars < FIXED_CHARS:
|
| 641 |
+
actual_height = width * num_chars
|
| 642 |
+
# Crop result image (top portion only)
|
| 643 |
+
result_img = result_img.crop((0, 0, width, actual_height))
|
| 644 |
+
# Crop condition image as well
|
| 645 |
+
cond_img = cond_img.crop((0, 0, width, actual_height))
|
| 646 |
+
|
| 647 |
+
# Save if path provided
|
| 648 |
+
if save_path:
|
| 649 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 650 |
+
result_img.save(save_path)
|
| 651 |
+
print(f"Image saved to {save_path}")
|
| 652 |
+
|
| 653 |
+
return result_img, cond_img
|
| 654 |
+
|
| 655 |
+
def batch_generate(
|
| 656 |
+
self,
|
| 657 |
+
texts: List[str],
|
| 658 |
+
font_styles: Optional[List[str]] = None,
|
| 659 |
+
authors: Optional[List[str]] = None,
|
| 660 |
+
output_dir: str = "./outputs",
|
| 661 |
+
**kwargs
|
| 662 |
+
) -> List[tuple[Image.Image, Image.Image]]:
|
| 663 |
+
"""
|
| 664 |
+
Batch generate calligraphy images
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
texts: list of texts to generate (1-7 characters each)
|
| 668 |
+
font_styles: list of font styles (if None, use default)
|
| 669 |
+
authors: list of authors (if None, use synthetic)
|
| 670 |
+
output_dir: directory to save outputs
|
| 671 |
+
**kwargs: additional arguments for generate()
|
| 672 |
+
|
| 673 |
+
Returns:
|
| 674 |
+
list of (generated_image, condition_image) tuples
|
| 675 |
+
"""
|
| 676 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 677 |
+
results = []
|
| 678 |
+
|
| 679 |
+
# Default styles and authors if not provided
|
| 680 |
+
if font_styles is None:
|
| 681 |
+
font_styles = ["楷"] * len(texts)
|
| 682 |
+
if authors is None:
|
| 683 |
+
authors = [None] * len(texts)
|
| 684 |
+
|
| 685 |
+
for i, (text, font, author) in enumerate(zip(texts, font_styles, authors)):
|
| 686 |
+
# Clean author name for filename
|
| 687 |
+
author_name = author if author else "synthetic"
|
| 688 |
+
if author and author in self.author_style:
|
| 689 |
+
author_name = convert_to_pinyin(author)
|
| 690 |
+
|
| 691 |
+
save_path = os.path.join(
|
| 692 |
+
output_dir,
|
| 693 |
+
f"{text}_{font}_{author_name}_{i}.png"
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
result_img, cond_img = self.generate(
|
| 697 |
+
text=text,
|
| 698 |
+
font_style=font,
|
| 699 |
+
author=author,
|
| 700 |
+
save_path=save_path,
|
| 701 |
+
**kwargs
|
| 702 |
+
)
|
| 703 |
+
results.append((result_img, cond_img))
|
| 704 |
+
|
| 705 |
+
return results
|
| 706 |
+
|
| 707 |
+
def get_available_authors(self) -> List[str]:
|
| 708 |
+
"""Get list of available author styles"""
|
| 709 |
+
return list(self.author_style.keys())
|
| 710 |
+
|
| 711 |
+
def get_available_fonts(self) -> List[str]:
|
| 712 |
+
"""Get list of available font styles"""
|
| 713 |
+
return list(self.font_style_des.keys())
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
# Hugging Face Pipeline wrapper
|
| 717 |
+
class FluxCalligraphyPipeline:
|
| 718 |
+
"""Hugging Face compatible pipeline for calligraphy generation"""
|
| 719 |
+
|
| 720 |
+
def __init__(
|
| 721 |
+
self,
|
| 722 |
+
model_name: str = "flux-dev",
|
| 723 |
+
device: str = "cuda",
|
| 724 |
+
checkpoint_path: Optional[str] = None,
|
| 725 |
+
**kwargs
|
| 726 |
+
):
|
| 727 |
+
"""Initialize the pipeline"""
|
| 728 |
+
self.generator = CalligraphyGenerator(
|
| 729 |
+
model_name=model_name,
|
| 730 |
+
device=device,
|
| 731 |
+
checkpoint_path=checkpoint_path,
|
| 732 |
+
**kwargs
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
def __call__(
|
| 736 |
+
self,
|
| 737 |
+
text: Union[str, List[str]],
|
| 738 |
+
font_style: Union[str, List[str]] = "楷",
|
| 739 |
+
author: Union[str, List[str]] = None,
|
| 740 |
+
num_inference_steps: int = 50,
|
| 741 |
+
guidance_scale: float = 3.5,
|
| 742 |
+
generator: Optional[torch.Generator] = None,
|
| 743 |
+
**kwargs
|
| 744 |
+
) -> Union[Image.Image, List[Image.Image]]:
|
| 745 |
+
"""
|
| 746 |
+
Generate calligraphy images
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
text: text or list of texts to generate (1-7 characters each)
|
| 750 |
+
font_style: font style(s) (楷/草/行)
|
| 751 |
+
author: author name(s) from the style list
|
| 752 |
+
num_inference_steps: number of denoising steps
|
| 753 |
+
guidance_scale: guidance scale for generation
|
| 754 |
+
generator: torch generator for reproducibility
|
| 755 |
+
|
| 756 |
+
Returns:
|
| 757 |
+
generated image(s)
|
| 758 |
+
"""
|
| 759 |
+
# Handle single text
|
| 760 |
+
if isinstance(text, str):
|
| 761 |
+
seed = None
|
| 762 |
+
if generator is not None:
|
| 763 |
+
seed = generator.initial_seed()
|
| 764 |
+
|
| 765 |
+
result, _ = self.generator.generate(
|
| 766 |
+
text=text,
|
| 767 |
+
font_style=font_style,
|
| 768 |
+
author=author,
|
| 769 |
+
num_steps=num_inference_steps,
|
| 770 |
+
guidance=guidance_scale,
|
| 771 |
+
seed=seed,
|
| 772 |
+
**kwargs
|
| 773 |
+
)
|
| 774 |
+
return result
|
| 775 |
+
|
| 776 |
+
# Handle batch
|
| 777 |
+
else:
|
| 778 |
+
if isinstance(font_style, str):
|
| 779 |
+
font_style = [font_style] * len(text)
|
| 780 |
+
if isinstance(author, str) or author is None:
|
| 781 |
+
author = [author] * len(text)
|
| 782 |
+
|
| 783 |
+
results = []
|
| 784 |
+
for t, f, a in zip(text, font_style, author):
|
| 785 |
+
seed = None
|
| 786 |
+
if generator is not None:
|
| 787 |
+
seed = generator.initial_seed()
|
| 788 |
+
|
| 789 |
+
result, _ = self.generator.generate(
|
| 790 |
+
text=t,
|
| 791 |
+
font_style=f,
|
| 792 |
+
author=a,
|
| 793 |
+
num_steps=num_inference_steps,
|
| 794 |
+
guidance=guidance_scale,
|
| 795 |
+
seed=seed,
|
| 796 |
+
**kwargs
|
| 797 |
+
)
|
| 798 |
+
results.append(result)
|
| 799 |
+
|
| 800 |
+
return results
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
if __name__ == "__main__":
|
| 804 |
+
# Example usage
|
| 805 |
+
import argparse
|
| 806 |
+
|
| 807 |
+
parser = argparse.ArgumentParser(description="Generate Chinese calligraphy")
|
| 808 |
+
parser.add_argument("--text", type=str, default="暴富且平安", help="Text to generate (1-7 characters)")
|
| 809 |
+
parser.add_argument("--font", type=str, default="楷", help="Font style (楷/草/行)")
|
| 810 |
+
parser.add_argument("--author", type=str, default=None, help="Author/calligrapher name")
|
| 811 |
+
parser.add_argument("--steps", type=int, default=50, help="Number of inference steps")
|
| 812 |
+
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
| 813 |
+
parser.add_argument("--output", type=str, default="output.png", help="Output path")
|
| 814 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
|
| 815 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path")
|
| 816 |
+
parser.add_argument("--list-authors", action="store_true", help="List available authors")
|
| 817 |
+
parser.add_argument("--list-fonts", action="store_true", help="List available font styles")
|
| 818 |
+
|
| 819 |
+
args = parser.parse_args()
|
| 820 |
+
|
| 821 |
+
# Initialize generator
|
| 822 |
+
generator = CalligraphyGenerator(
|
| 823 |
+
model_name="flux-dev",
|
| 824 |
+
device=args.device,
|
| 825 |
+
checkpoint_path=args.checkpoint
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
# List available options
|
| 829 |
+
if args.list_authors:
|
| 830 |
+
print("Available authors:")
|
| 831 |
+
for author in generator.get_available_authors()[:20]: # Show first 20
|
| 832 |
+
print(f" - {author}")
|
| 833 |
+
print(f" ... and {len(generator.get_available_authors()) - 20} more")
|
| 834 |
+
exit(0)
|
| 835 |
+
|
| 836 |
+
if args.list_fonts:
|
| 837 |
+
print("Available font styles:")
|
| 838 |
+
for font in generator.get_available_fonts():
|
| 839 |
+
print(f" - {font}: {generator.font_style_des[font]}")
|
| 840 |
+
exit(0)
|
| 841 |
+
|
| 842 |
+
# Validate text - must have 1-7 characters
|
| 843 |
+
if len(args.text) < 1:
|
| 844 |
+
print(f"Error: Text must have at least 1 character")
|
| 845 |
+
exit(1)
|
| 846 |
+
if len(args.text) > 7:
|
| 847 |
+
print(f"Error: Text must have at most 7 characters, got {len(args.text)}")
|
| 848 |
+
exit(1)
|
| 849 |
+
|
| 850 |
+
# Generate
|
| 851 |
+
result_img, cond_img = generator.generate(
|
| 852 |
+
text=args.text,
|
| 853 |
+
font_style=args.font,
|
| 854 |
+
author=args.author,
|
| 855 |
+
num_steps=args.steps,
|
| 856 |
+
seed=args.seed,
|
| 857 |
+
save_path=args.output
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
print(f"Generation complete! Saved to {args.output}")
|
requirements.txt
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
accelerate
|
| 2 |
-
|
| 3 |
-
invisible_watermark
|
| 4 |
-
torch
|
| 5 |
transformers
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
accelerate
|
| 2 |
+
einops
|
|
|
|
|
|
|
| 3 |
transformers
|
| 4 |
+
huggingface-hub
|
| 5 |
+
optimum-quanto
|
| 6 |
+
sentencepiece
|
| 7 |
+
torch
|
| 8 |
+
torchvision
|
| 9 |
+
timm
|
| 10 |
+
pypinyin
|
| 11 |
+
gradio
|
src/__init__.py
ADDED
|
File without changes
|
src/flux/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from ._version import version as __version__ # type: ignore
|
| 3 |
+
from ._version import version_tuple
|
| 4 |
+
except ImportError:
|
| 5 |
+
__version__ = "unknown (no version information available)"
|
| 6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
PACKAGE = __package__.replace("_", "-")
|
| 11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
src/flux/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cli import app
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
app()
|
src/flux/annotator/canny/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CannyDetector:
|
| 5 |
+
def __call__(self, img, low_threshold, high_threshold):
|
| 6 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
src/flux/annotator/ckpts/ckpts.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Weights here.
|
src/flux/annotator/dwpose/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Openpose
|
| 2 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
| 3 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
| 4 |
+
# 3rd Edited by ControlNet
|
| 5 |
+
# 4th Edited by ControlNet (added face and correct hands)
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from . import util
|
| 13 |
+
from .wholebody import Wholebody
|
| 14 |
+
|
| 15 |
+
def draw_pose(pose, H, W):
|
| 16 |
+
bodies = pose['bodies']
|
| 17 |
+
faces = pose['faces']
|
| 18 |
+
hands = pose['hands']
|
| 19 |
+
candidate = bodies['candidate']
|
| 20 |
+
subset = bodies['subset']
|
| 21 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
| 22 |
+
|
| 23 |
+
canvas = util.draw_bodypose(canvas, candidate, subset)
|
| 24 |
+
|
| 25 |
+
canvas = util.draw_handpose(canvas, hands)
|
| 26 |
+
|
| 27 |
+
canvas = util.draw_facepose(canvas, faces)
|
| 28 |
+
|
| 29 |
+
return canvas
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DWposeDetector:
|
| 33 |
+
def __init__(self, device):
|
| 34 |
+
|
| 35 |
+
self.pose_estimation = Wholebody(device)
|
| 36 |
+
|
| 37 |
+
def __call__(self, oriImg):
|
| 38 |
+
oriImg = oriImg.copy()
|
| 39 |
+
H, W, C = oriImg.shape
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
candidate, subset = self.pose_estimation(oriImg)
|
| 42 |
+
nums, keys, locs = candidate.shape
|
| 43 |
+
candidate[..., 0] /= float(W)
|
| 44 |
+
candidate[..., 1] /= float(H)
|
| 45 |
+
body = candidate[:,:18].copy()
|
| 46 |
+
body = body.reshape(nums*18, locs)
|
| 47 |
+
score = subset[:,:18]
|
| 48 |
+
for i in range(len(score)):
|
| 49 |
+
for j in range(len(score[i])):
|
| 50 |
+
if score[i][j] > 0.3:
|
| 51 |
+
score[i][j] = int(18*i+j)
|
| 52 |
+
else:
|
| 53 |
+
score[i][j] = -1
|
| 54 |
+
|
| 55 |
+
un_visible = subset<0.3
|
| 56 |
+
candidate[un_visible] = -1
|
| 57 |
+
|
| 58 |
+
foot = candidate[:,18:24]
|
| 59 |
+
|
| 60 |
+
faces = candidate[:,24:92]
|
| 61 |
+
|
| 62 |
+
hands = candidate[:,92:113]
|
| 63 |
+
hands = np.vstack([hands, candidate[:,113:]])
|
| 64 |
+
|
| 65 |
+
bodies = dict(candidate=body, subset=score)
|
| 66 |
+
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
| 67 |
+
|
| 68 |
+
return draw_pose(pose, H, W)
|
src/flux/annotator/dwpose/onnxdet.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import onnxruntime
|
| 5 |
+
|
| 6 |
+
def nms(boxes, scores, nms_thr):
|
| 7 |
+
"""Single class NMS implemented in Numpy."""
|
| 8 |
+
x1 = boxes[:, 0]
|
| 9 |
+
y1 = boxes[:, 1]
|
| 10 |
+
x2 = boxes[:, 2]
|
| 11 |
+
y2 = boxes[:, 3]
|
| 12 |
+
|
| 13 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 14 |
+
order = scores.argsort()[::-1]
|
| 15 |
+
|
| 16 |
+
keep = []
|
| 17 |
+
while order.size > 0:
|
| 18 |
+
i = order[0]
|
| 19 |
+
keep.append(i)
|
| 20 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 21 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 22 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 23 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 24 |
+
|
| 25 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 26 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 27 |
+
inter = w * h
|
| 28 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 29 |
+
|
| 30 |
+
inds = np.where(ovr <= nms_thr)[0]
|
| 31 |
+
order = order[inds + 1]
|
| 32 |
+
|
| 33 |
+
return keep
|
| 34 |
+
|
| 35 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
| 36 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
| 37 |
+
final_dets = []
|
| 38 |
+
num_classes = scores.shape[1]
|
| 39 |
+
for cls_ind in range(num_classes):
|
| 40 |
+
cls_scores = scores[:, cls_ind]
|
| 41 |
+
valid_score_mask = cls_scores > score_thr
|
| 42 |
+
if valid_score_mask.sum() == 0:
|
| 43 |
+
continue
|
| 44 |
+
else:
|
| 45 |
+
valid_scores = cls_scores[valid_score_mask]
|
| 46 |
+
valid_boxes = boxes[valid_score_mask]
|
| 47 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
| 48 |
+
if len(keep) > 0:
|
| 49 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
| 50 |
+
dets = np.concatenate(
|
| 51 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
| 52 |
+
)
|
| 53 |
+
final_dets.append(dets)
|
| 54 |
+
if len(final_dets) == 0:
|
| 55 |
+
return None
|
| 56 |
+
return np.concatenate(final_dets, 0)
|
| 57 |
+
|
| 58 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
| 59 |
+
grids = []
|
| 60 |
+
expanded_strides = []
|
| 61 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
| 62 |
+
|
| 63 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
| 64 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
| 65 |
+
|
| 66 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
| 67 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
| 68 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
| 69 |
+
grids.append(grid)
|
| 70 |
+
shape = grid.shape[:2]
|
| 71 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
| 72 |
+
|
| 73 |
+
grids = np.concatenate(grids, 1)
|
| 74 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
| 75 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
| 76 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
| 77 |
+
|
| 78 |
+
return outputs
|
| 79 |
+
|
| 80 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
| 81 |
+
if len(img.shape) == 3:
|
| 82 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
| 83 |
+
else:
|
| 84 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
| 85 |
+
|
| 86 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
| 87 |
+
resized_img = cv2.resize(
|
| 88 |
+
img,
|
| 89 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
| 90 |
+
interpolation=cv2.INTER_LINEAR,
|
| 91 |
+
).astype(np.uint8)
|
| 92 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
| 93 |
+
|
| 94 |
+
padded_img = padded_img.transpose(swap)
|
| 95 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
| 96 |
+
return padded_img, r
|
| 97 |
+
|
| 98 |
+
def inference_detector(session, oriImg):
|
| 99 |
+
input_shape = (640,640)
|
| 100 |
+
img, ratio = preprocess(oriImg, input_shape)
|
| 101 |
+
|
| 102 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
| 103 |
+
output = session.run(None, ort_inputs)
|
| 104 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
| 105 |
+
|
| 106 |
+
boxes = predictions[:, :4]
|
| 107 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
| 108 |
+
|
| 109 |
+
boxes_xyxy = np.ones_like(boxes)
|
| 110 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
| 111 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
| 112 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
| 113 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
| 114 |
+
boxes_xyxy /= ratio
|
| 115 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
| 116 |
+
if dets is not None:
|
| 117 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
| 118 |
+
isscore = final_scores>0.3
|
| 119 |
+
iscat = final_cls_inds == 0
|
| 120 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
| 121 |
+
final_boxes = final_boxes[isbbox]
|
| 122 |
+
else:
|
| 123 |
+
final_boxes = np.array([])
|
| 124 |
+
|
| 125 |
+
return final_boxes
|
src/flux/annotator/dwpose/onnxpose.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
|
| 7 |
+
def preprocess(
|
| 8 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
| 9 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 10 |
+
"""Do preprocessing for RTMPose model inference.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
img (np.ndarray): Input image in shape.
|
| 14 |
+
input_size (tuple): Input image size in shape (w, h).
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
tuple:
|
| 18 |
+
- resized_img (np.ndarray): Preprocessed image.
|
| 19 |
+
- center (np.ndarray): Center of image.
|
| 20 |
+
- scale (np.ndarray): Scale of image.
|
| 21 |
+
"""
|
| 22 |
+
# get shape of image
|
| 23 |
+
img_shape = img.shape[:2]
|
| 24 |
+
out_img, out_center, out_scale = [], [], []
|
| 25 |
+
if len(out_bbox) == 0:
|
| 26 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
| 27 |
+
for i in range(len(out_bbox)):
|
| 28 |
+
x0 = out_bbox[i][0]
|
| 29 |
+
y0 = out_bbox[i][1]
|
| 30 |
+
x1 = out_bbox[i][2]
|
| 31 |
+
y1 = out_bbox[i][3]
|
| 32 |
+
bbox = np.array([x0, y0, x1, y1])
|
| 33 |
+
|
| 34 |
+
# get center and scale
|
| 35 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
| 36 |
+
|
| 37 |
+
# do affine transformation
|
| 38 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
| 39 |
+
|
| 40 |
+
# normalize image
|
| 41 |
+
mean = np.array([123.675, 116.28, 103.53])
|
| 42 |
+
std = np.array([58.395, 57.12, 57.375])
|
| 43 |
+
resized_img = (resized_img - mean) / std
|
| 44 |
+
|
| 45 |
+
out_img.append(resized_img)
|
| 46 |
+
out_center.append(center)
|
| 47 |
+
out_scale.append(scale)
|
| 48 |
+
|
| 49 |
+
return out_img, out_center, out_scale
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
| 53 |
+
"""Inference RTMPose model.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
| 57 |
+
img (np.ndarray): Input image in shape.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
outputs (np.ndarray): Output of RTMPose model.
|
| 61 |
+
"""
|
| 62 |
+
all_out = []
|
| 63 |
+
# build input
|
| 64 |
+
for i in range(len(img)):
|
| 65 |
+
input = [img[i].transpose(2, 0, 1)]
|
| 66 |
+
|
| 67 |
+
# build output
|
| 68 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
| 69 |
+
sess_output = []
|
| 70 |
+
for out in sess.get_outputs():
|
| 71 |
+
sess_output.append(out.name)
|
| 72 |
+
|
| 73 |
+
# run model
|
| 74 |
+
outputs = sess.run(sess_output, sess_input)
|
| 75 |
+
all_out.append(outputs)
|
| 76 |
+
|
| 77 |
+
return all_out
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def postprocess(outputs: List[np.ndarray],
|
| 81 |
+
model_input_size: Tuple[int, int],
|
| 82 |
+
center: Tuple[int, int],
|
| 83 |
+
scale: Tuple[int, int],
|
| 84 |
+
simcc_split_ratio: float = 2.0
|
| 85 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 86 |
+
"""Postprocess for RTMPose model output.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
outputs (np.ndarray): Output of RTMPose model.
|
| 90 |
+
model_input_size (tuple): RTMPose model Input image size.
|
| 91 |
+
center (tuple): Center of bbox in shape (x, y).
|
| 92 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
| 93 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
tuple:
|
| 97 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
| 98 |
+
- scores (np.ndarray): Model predict scores.
|
| 99 |
+
"""
|
| 100 |
+
all_key = []
|
| 101 |
+
all_score = []
|
| 102 |
+
for i in range(len(outputs)):
|
| 103 |
+
# use simcc to decode
|
| 104 |
+
simcc_x, simcc_y = outputs[i]
|
| 105 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
| 106 |
+
|
| 107 |
+
# rescale keypoints
|
| 108 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
| 109 |
+
all_key.append(keypoints[0])
|
| 110 |
+
all_score.append(scores[0])
|
| 111 |
+
|
| 112 |
+
return np.array(all_key), np.array(all_score)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
| 116 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
| 117 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
| 121 |
+
as (left, top, right, bottom)
|
| 122 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
| 123 |
+
Default: 1.0
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
tuple: A tuple containing center and scale.
|
| 127 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
| 128 |
+
(n, 2)
|
| 129 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
| 130 |
+
(n, 2)
|
| 131 |
+
"""
|
| 132 |
+
# convert single bbox from (4, ) to (1, 4)
|
| 133 |
+
dim = bbox.ndim
|
| 134 |
+
if dim == 1:
|
| 135 |
+
bbox = bbox[None, :]
|
| 136 |
+
|
| 137 |
+
# get bbox center and scale
|
| 138 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
| 139 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
| 140 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
| 141 |
+
|
| 142 |
+
if dim == 1:
|
| 143 |
+
center = center[0]
|
| 144 |
+
scale = scale[0]
|
| 145 |
+
|
| 146 |
+
return center, scale
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
| 150 |
+
aspect_ratio: float) -> np.ndarray:
|
| 151 |
+
"""Extend the scale to match the given aspect ratio.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
| 155 |
+
aspect_ratio (float): The ratio of ``w/h``
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
np.ndarray: The reshaped image scale in (2, )
|
| 159 |
+
"""
|
| 160 |
+
w, h = np.hsplit(bbox_scale, [1])
|
| 161 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
| 162 |
+
np.hstack([w, w / aspect_ratio]),
|
| 163 |
+
np.hstack([h * aspect_ratio, h]))
|
| 164 |
+
return bbox_scale
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
| 168 |
+
"""Rotate a point by an angle.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
| 172 |
+
angle_rad (float): rotation angle in radian
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
np.ndarray: Rotated point in shape (2, )
|
| 176 |
+
"""
|
| 177 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
| 178 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
| 179 |
+
return rot_mat @ pt
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 183 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
| 184 |
+
function is used to get the 3rd point, given 2D points a & b.
|
| 185 |
+
|
| 186 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
| 187 |
+
anticlockwise, using b as the rotation center.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
| 191 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
np.ndarray: The 3rd point.
|
| 195 |
+
"""
|
| 196 |
+
direction = a - b
|
| 197 |
+
c = b + np.r_[-direction[1], direction[0]]
|
| 198 |
+
return c
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_warp_matrix(center: np.ndarray,
|
| 202 |
+
scale: np.ndarray,
|
| 203 |
+
rot: float,
|
| 204 |
+
output_size: Tuple[int, int],
|
| 205 |
+
shift: Tuple[float, float] = (0., 0.),
|
| 206 |
+
inv: bool = False) -> np.ndarray:
|
| 207 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
| 208 |
+
in the input image to the output size.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
| 212 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
| 213 |
+
wrt [width, height].
|
| 214 |
+
rot (float): Rotation angle (degree).
|
| 215 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
| 216 |
+
destination heatmaps.
|
| 217 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
| 218 |
+
Default (0., 0.).
|
| 219 |
+
inv (bool): Option to inverse the affine transform direction.
|
| 220 |
+
(inv=False: src->dst or inv=True: dst->src)
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
np.ndarray: A 2x3 transformation matrix
|
| 224 |
+
"""
|
| 225 |
+
shift = np.array(shift)
|
| 226 |
+
src_w = scale[0]
|
| 227 |
+
dst_w = output_size[0]
|
| 228 |
+
dst_h = output_size[1]
|
| 229 |
+
|
| 230 |
+
# compute transformation matrix
|
| 231 |
+
rot_rad = np.deg2rad(rot)
|
| 232 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
| 233 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
| 234 |
+
|
| 235 |
+
# get four corners of the src rectangle in the original image
|
| 236 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 237 |
+
src[0, :] = center + scale * shift
|
| 238 |
+
src[1, :] = center + src_dir + scale * shift
|
| 239 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
| 240 |
+
|
| 241 |
+
# get four corners of the dst rectangle in the input image
|
| 242 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 243 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
| 244 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
| 245 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
| 246 |
+
|
| 247 |
+
if inv:
|
| 248 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
| 249 |
+
else:
|
| 250 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 251 |
+
|
| 252 |
+
return warp_mat
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
| 256 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 257 |
+
"""Get the bbox image as the model input by affine transform.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
input_size (dict): The input size of the model.
|
| 261 |
+
bbox_scale (dict): The bbox scale of the img.
|
| 262 |
+
bbox_center (dict): The bbox center of the img.
|
| 263 |
+
img (np.ndarray): The original image.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
tuple: A tuple containing center and scale.
|
| 267 |
+
- np.ndarray[float32]: img after affine transform.
|
| 268 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
| 269 |
+
"""
|
| 270 |
+
w, h = input_size
|
| 271 |
+
warp_size = (int(w), int(h))
|
| 272 |
+
|
| 273 |
+
# reshape bbox to fixed aspect ratio
|
| 274 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
| 275 |
+
|
| 276 |
+
# get the affine matrix
|
| 277 |
+
center = bbox_center
|
| 278 |
+
scale = bbox_scale
|
| 279 |
+
rot = 0
|
| 280 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
| 281 |
+
|
| 282 |
+
# do affine transform
|
| 283 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
| 284 |
+
|
| 285 |
+
return img, bbox_scale
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
| 289 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 290 |
+
"""Get maximum response location and value from simcc representations.
|
| 291 |
+
|
| 292 |
+
Note:
|
| 293 |
+
instance number: N
|
| 294 |
+
num_keypoints: K
|
| 295 |
+
heatmap height: H
|
| 296 |
+
heatmap width: W
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
| 300 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
tuple:
|
| 304 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
| 305 |
+
(K, 2) or (N, K, 2)
|
| 306 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
| 307 |
+
(K,) or (N, K)
|
| 308 |
+
"""
|
| 309 |
+
N, K, Wx = simcc_x.shape
|
| 310 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
| 311 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
| 312 |
+
|
| 313 |
+
# get maximum value locations
|
| 314 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
| 315 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
| 316 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
| 317 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
| 318 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
| 319 |
+
|
| 320 |
+
# get maximum value across x and y axis
|
| 321 |
+
mask = max_val_x > max_val_y
|
| 322 |
+
max_val_x[mask] = max_val_y[mask]
|
| 323 |
+
vals = max_val_x
|
| 324 |
+
locs[vals <= 0.] = -1
|
| 325 |
+
|
| 326 |
+
# reshape
|
| 327 |
+
locs = locs.reshape(N, K, 2)
|
| 328 |
+
vals = vals.reshape(N, K)
|
| 329 |
+
|
| 330 |
+
return locs, vals
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
| 334 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
| 335 |
+
"""Modulate simcc distribution with Gaussian.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
| 339 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
| 340 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
tuple: A tuple containing center and scale.
|
| 344 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
| 345 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
| 346 |
+
"""
|
| 347 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
| 348 |
+
keypoints /= simcc_split_ratio
|
| 349 |
+
|
| 350 |
+
return keypoints, scores
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def inference_pose(session, out_bbox, oriImg):
|
| 354 |
+
h, w = session.get_inputs()[0].shape[2:]
|
| 355 |
+
model_input_size = (w, h)
|
| 356 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
| 357 |
+
outputs = inference(session, resized_img)
|
| 358 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
| 359 |
+
|
| 360 |
+
return keypoints, scores
|
src/flux/annotator/dwpose/util.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
eps = 0.01
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def smart_resize(x, s):
|
| 11 |
+
Ht, Wt = s
|
| 12 |
+
if x.ndim == 2:
|
| 13 |
+
Ho, Wo = x.shape
|
| 14 |
+
Co = 1
|
| 15 |
+
else:
|
| 16 |
+
Ho, Wo, Co = x.shape
|
| 17 |
+
if Co == 3 or Co == 1:
|
| 18 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
| 19 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
| 20 |
+
else:
|
| 21 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def smart_resize_k(x, fx, fy):
|
| 25 |
+
if x.ndim == 2:
|
| 26 |
+
Ho, Wo = x.shape
|
| 27 |
+
Co = 1
|
| 28 |
+
else:
|
| 29 |
+
Ho, Wo, Co = x.shape
|
| 30 |
+
Ht, Wt = Ho * fy, Wo * fx
|
| 31 |
+
if Co == 3 or Co == 1:
|
| 32 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
| 33 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
| 34 |
+
else:
|
| 35 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def padRightDownCorner(img, stride, padValue):
|
| 39 |
+
h = img.shape[0]
|
| 40 |
+
w = img.shape[1]
|
| 41 |
+
|
| 42 |
+
pad = 4 * [None]
|
| 43 |
+
pad[0] = 0 # up
|
| 44 |
+
pad[1] = 0 # left
|
| 45 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
| 46 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
| 47 |
+
|
| 48 |
+
img_padded = img
|
| 49 |
+
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
| 50 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
| 51 |
+
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
| 52 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
| 53 |
+
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
| 54 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
| 55 |
+
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
| 56 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
| 57 |
+
|
| 58 |
+
return img_padded, pad
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def transfer(model, model_weights):
|
| 62 |
+
transfered_model_weights = {}
|
| 63 |
+
for weights_name in model.state_dict().keys():
|
| 64 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
| 65 |
+
return transfered_model_weights
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def draw_bodypose(canvas, candidate, subset):
|
| 69 |
+
H, W, C = canvas.shape
|
| 70 |
+
candidate = np.array(candidate)
|
| 71 |
+
subset = np.array(subset)
|
| 72 |
+
|
| 73 |
+
stickwidth = 4
|
| 74 |
+
|
| 75 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
| 76 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
| 77 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
| 78 |
+
|
| 79 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
| 80 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
| 81 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
| 82 |
+
|
| 83 |
+
for i in range(17):
|
| 84 |
+
for n in range(len(subset)):
|
| 85 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
| 86 |
+
if -1 in index:
|
| 87 |
+
continue
|
| 88 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
| 89 |
+
X = candidate[index.astype(int), 1] * float(H)
|
| 90 |
+
mX = np.mean(X)
|
| 91 |
+
mY = np.mean(Y)
|
| 92 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 93 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 94 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 95 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
| 96 |
+
|
| 97 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
| 98 |
+
|
| 99 |
+
for i in range(18):
|
| 100 |
+
for n in range(len(subset)):
|
| 101 |
+
index = int(subset[n][i])
|
| 102 |
+
if index == -1:
|
| 103 |
+
continue
|
| 104 |
+
x, y = candidate[index][0:2]
|
| 105 |
+
x = int(x * W)
|
| 106 |
+
y = int(y * H)
|
| 107 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
| 108 |
+
|
| 109 |
+
return canvas
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def draw_handpose(canvas, all_hand_peaks):
|
| 113 |
+
H, W, C = canvas.shape
|
| 114 |
+
|
| 115 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
| 116 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
| 117 |
+
|
| 118 |
+
for peaks in all_hand_peaks:
|
| 119 |
+
peaks = np.array(peaks)
|
| 120 |
+
|
| 121 |
+
for ie, e in enumerate(edges):
|
| 122 |
+
x1, y1 = peaks[e[0]]
|
| 123 |
+
x2, y2 = peaks[e[1]]
|
| 124 |
+
x1 = int(x1 * W)
|
| 125 |
+
y1 = int(y1 * H)
|
| 126 |
+
x2 = int(x2 * W)
|
| 127 |
+
y2 = int(y2 * H)
|
| 128 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
| 129 |
+
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
| 130 |
+
|
| 131 |
+
for i, keyponit in enumerate(peaks):
|
| 132 |
+
x, y = keyponit
|
| 133 |
+
x = int(x * W)
|
| 134 |
+
y = int(y * H)
|
| 135 |
+
if x > eps and y > eps:
|
| 136 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
| 137 |
+
return canvas
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def draw_facepose(canvas, all_lmks):
|
| 141 |
+
H, W, C = canvas.shape
|
| 142 |
+
for lmks in all_lmks:
|
| 143 |
+
lmks = np.array(lmks)
|
| 144 |
+
for lmk in lmks:
|
| 145 |
+
x, y = lmk
|
| 146 |
+
x = int(x * W)
|
| 147 |
+
y = int(y * H)
|
| 148 |
+
if x > eps and y > eps:
|
| 149 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
| 150 |
+
return canvas
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# detect hand according to body pose keypoints
|
| 154 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
| 155 |
+
def handDetect(candidate, subset, oriImg):
|
| 156 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
| 157 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
| 158 |
+
ratioWristElbow = 0.33
|
| 159 |
+
detect_result = []
|
| 160 |
+
image_height, image_width = oriImg.shape[0:2]
|
| 161 |
+
for person in subset.astype(int):
|
| 162 |
+
# if any of three not detected
|
| 163 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
| 164 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
| 165 |
+
if not (has_left or has_right):
|
| 166 |
+
continue
|
| 167 |
+
hands = []
|
| 168 |
+
#left hand
|
| 169 |
+
if has_left:
|
| 170 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
| 171 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
| 172 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
| 173 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
| 174 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
| 175 |
+
# right hand
|
| 176 |
+
if has_right:
|
| 177 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
| 178 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
| 179 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
| 180 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
| 181 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
| 182 |
+
|
| 183 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
| 184 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
| 185 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
| 186 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
| 187 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
| 188 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
| 189 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
| 190 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
| 191 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
| 192 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
| 193 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
| 194 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
| 195 |
+
# x-y refers to the center --> offset to topLeft point
|
| 196 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
| 197 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
| 198 |
+
x -= width / 2
|
| 199 |
+
y -= width / 2 # width = height
|
| 200 |
+
# overflow the image
|
| 201 |
+
if x < 0: x = 0
|
| 202 |
+
if y < 0: y = 0
|
| 203 |
+
width1 = width
|
| 204 |
+
width2 = width
|
| 205 |
+
if x + width > image_width: width1 = image_width - x
|
| 206 |
+
if y + width > image_height: width2 = image_height - y
|
| 207 |
+
width = min(width1, width2)
|
| 208 |
+
# the max hand box value is 20 pixels
|
| 209 |
+
if width >= 20:
|
| 210 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
| 211 |
+
|
| 212 |
+
'''
|
| 213 |
+
return value: [[x, y, w, True if left hand else False]].
|
| 214 |
+
width=height since the network require squared input.
|
| 215 |
+
x, y is the coordinate of top left
|
| 216 |
+
'''
|
| 217 |
+
return detect_result
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Written by Lvmin
|
| 221 |
+
def faceDetect(candidate, subset, oriImg):
|
| 222 |
+
# left right eye ear 14 15 16 17
|
| 223 |
+
detect_result = []
|
| 224 |
+
image_height, image_width = oriImg.shape[0:2]
|
| 225 |
+
for person in subset.astype(int):
|
| 226 |
+
has_head = person[0] > -1
|
| 227 |
+
if not has_head:
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
has_left_eye = person[14] > -1
|
| 231 |
+
has_right_eye = person[15] > -1
|
| 232 |
+
has_left_ear = person[16] > -1
|
| 233 |
+
has_right_ear = person[17] > -1
|
| 234 |
+
|
| 235 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
| 239 |
+
|
| 240 |
+
width = 0.0
|
| 241 |
+
x0, y0 = candidate[head][:2]
|
| 242 |
+
|
| 243 |
+
if has_left_eye:
|
| 244 |
+
x1, y1 = candidate[left_eye][:2]
|
| 245 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 246 |
+
width = max(width, d * 3.0)
|
| 247 |
+
|
| 248 |
+
if has_right_eye:
|
| 249 |
+
x1, y1 = candidate[right_eye][:2]
|
| 250 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 251 |
+
width = max(width, d * 3.0)
|
| 252 |
+
|
| 253 |
+
if has_left_ear:
|
| 254 |
+
x1, y1 = candidate[left_ear][:2]
|
| 255 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 256 |
+
width = max(width, d * 1.5)
|
| 257 |
+
|
| 258 |
+
if has_right_ear:
|
| 259 |
+
x1, y1 = candidate[right_ear][:2]
|
| 260 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 261 |
+
width = max(width, d * 1.5)
|
| 262 |
+
|
| 263 |
+
x, y = x0, y0
|
| 264 |
+
|
| 265 |
+
x -= width
|
| 266 |
+
y -= width
|
| 267 |
+
|
| 268 |
+
if x < 0:
|
| 269 |
+
x = 0
|
| 270 |
+
|
| 271 |
+
if y < 0:
|
| 272 |
+
y = 0
|
| 273 |
+
|
| 274 |
+
width1 = width * 2
|
| 275 |
+
width2 = width * 2
|
| 276 |
+
|
| 277 |
+
if x + width > image_width:
|
| 278 |
+
width1 = image_width - x
|
| 279 |
+
|
| 280 |
+
if y + width > image_height:
|
| 281 |
+
width2 = image_height - y
|
| 282 |
+
|
| 283 |
+
width = min(width1, width2)
|
| 284 |
+
|
| 285 |
+
if width >= 20:
|
| 286 |
+
detect_result.append([int(x), int(y), int(width)])
|
| 287 |
+
|
| 288 |
+
return detect_result
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# get max index of 2d array
|
| 292 |
+
def npmax(array):
|
| 293 |
+
arrayindex = array.argmax(1)
|
| 294 |
+
arrayvalue = array.max(1)
|
| 295 |
+
i = arrayvalue.argmax()
|
| 296 |
+
j = arrayindex[i]
|
| 297 |
+
return i, j
|
src/flux/annotator/dwpose/wholebody.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import onnxruntime as ort
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from .onnxdet import inference_detector
|
| 7 |
+
from .onnxpose import inference_pose
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Wholebody:
|
| 11 |
+
def __init__(self, device="cuda:0"):
|
| 12 |
+
providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
|
| 13 |
+
onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx")
|
| 14 |
+
onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx")
|
| 15 |
+
|
| 16 |
+
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
| 17 |
+
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
| 18 |
+
|
| 19 |
+
def __call__(self, oriImg):
|
| 20 |
+
det_result = inference_detector(self.session_det, oriImg)
|
| 21 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
| 22 |
+
|
| 23 |
+
keypoints_info = np.concatenate(
|
| 24 |
+
(keypoints, scores[..., None]), axis=-1)
|
| 25 |
+
# compute neck joint
|
| 26 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
| 27 |
+
# neck score when visualizing pred
|
| 28 |
+
neck[:, 2:4] = np.logical_and(
|
| 29 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
| 30 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
| 31 |
+
new_keypoints_info = np.insert(
|
| 32 |
+
keypoints_info, 17, neck, axis=1)
|
| 33 |
+
mmpose_idx = [
|
| 34 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
| 35 |
+
]
|
| 36 |
+
openpose_idx = [
|
| 37 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
| 38 |
+
]
|
| 39 |
+
new_keypoints_info[:, openpose_idx] = \
|
| 40 |
+
new_keypoints_info[:, mmpose_idx]
|
| 41 |
+
keypoints_info = new_keypoints_info
|
| 42 |
+
|
| 43 |
+
keypoints, scores = keypoints_info[
|
| 44 |
+
..., :2], keypoints_info[..., 2]
|
| 45 |
+
|
| 46 |
+
return keypoints, scores
|
| 47 |
+
|
| 48 |
+
|
src/flux/annotator/hed/__init__.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
| 2 |
+
# Please use this implementation in your products
|
| 3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
| 4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
| 5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
| 6 |
+
# and in this way it works better for gradio's RGB protocol
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import cv2
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from ...annotator.util import annotator_ckpts_path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DoubleConvBlock(torch.nn.Module):
|
| 19 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.convs = torch.nn.Sequential()
|
| 22 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
| 23 |
+
for i in range(1, layer_number):
|
| 24 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
| 25 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
| 26 |
+
|
| 27 |
+
def __call__(self, x, down_sampling=False):
|
| 28 |
+
h = x
|
| 29 |
+
if down_sampling:
|
| 30 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
| 31 |
+
for conv in self.convs:
|
| 32 |
+
h = conv(h)
|
| 33 |
+
h = torch.nn.functional.relu(h)
|
| 34 |
+
return h, self.projection(h)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
| 38 |
+
def __init__(self):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
| 41 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
| 42 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
| 43 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
| 44 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
| 45 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
| 46 |
+
|
| 47 |
+
def __call__(self, x):
|
| 48 |
+
h = x - self.norm
|
| 49 |
+
h, projection1 = self.block1(h)
|
| 50 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
| 51 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
| 52 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
| 53 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
| 54 |
+
return projection1, projection2, projection3, projection4, projection5
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class HEDdetector:
|
| 58 |
+
def __init__(self):
|
| 59 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
| 60 |
+
if not os.path.exists(modelpath):
|
| 61 |
+
modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
|
| 62 |
+
self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
|
| 63 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
| 64 |
+
|
| 65 |
+
def __call__(self, input_image):
|
| 66 |
+
assert input_image.ndim == 3
|
| 67 |
+
H, W, C = input_image.shape
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
image_hed = torch.from_numpy(input_image.copy()).float().cuda()
|
| 70 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
| 71 |
+
edges = self.netNetwork(image_hed)
|
| 72 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
| 73 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
| 74 |
+
edges = np.stack(edges, axis=2)
|
| 75 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
| 76 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
| 77 |
+
return edge
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def nms(x, t, s):
|
| 81 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
| 82 |
+
|
| 83 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
| 84 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
| 85 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
| 86 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
| 87 |
+
|
| 88 |
+
y = np.zeros_like(x)
|
| 89 |
+
|
| 90 |
+
for f in [f1, f2, f3, f4]:
|
| 91 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
| 92 |
+
|
| 93 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
| 94 |
+
z[y > t] = 255
|
| 95 |
+
return z
|
src/flux/annotator/midas/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
src/flux/annotator/midas/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Midas Depth Estimation
|
| 2 |
+
# From https://github.com/isl-org/MiDaS
|
| 3 |
+
# MIT LICENSE
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from .api import MiDaSInference
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MidasDetector:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
| 16 |
+
|
| 17 |
+
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
| 18 |
+
assert input_image.ndim == 3
|
| 19 |
+
image_depth = input_image
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
| 22 |
+
image_depth = image_depth / 127.5 - 1.0
|
| 23 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
| 24 |
+
depth = self.model(image_depth)[0]
|
| 25 |
+
|
| 26 |
+
depth_pt = depth.clone()
|
| 27 |
+
depth_pt -= torch.min(depth_pt)
|
| 28 |
+
depth_pt /= torch.max(depth_pt)
|
| 29 |
+
depth_pt = depth_pt.cpu().numpy()
|
| 30 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
| 31 |
+
|
| 32 |
+
depth_np = depth.cpu().numpy()
|
| 33 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
| 34 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
| 35 |
+
z = np.ones_like(x) * a
|
| 36 |
+
x[depth_pt < bg_th] = 0
|
| 37 |
+
y[depth_pt < bg_th] = 0
|
| 38 |
+
normal = np.stack([x, y, z], axis=2)
|
| 39 |
+
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
| 40 |
+
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
| 41 |
+
|
| 42 |
+
return depth_image, normal_image
|
src/flux/annotator/midas/api.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# based on https://github.com/isl-org/MiDaS
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torchvision.transforms import Compose
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
from .midas.dpt_depth import DPTDepthModel
|
| 12 |
+
from .midas.midas_net import MidasNet
|
| 13 |
+
from .midas.midas_net_custom import MidasNet_small
|
| 14 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
| 15 |
+
from ...annotator.util import annotator_ckpts_path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ISL_PATHS = {
|
| 19 |
+
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
|
| 20 |
+
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
|
| 21 |
+
"midas_v21": "",
|
| 22 |
+
"midas_v21_small": "",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def disabled_train(self, mode=True):
|
| 27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 28 |
+
does not change anymore."""
|
| 29 |
+
return self
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_midas_transform(model_type):
|
| 33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
| 34 |
+
# load transform only
|
| 35 |
+
if model_type == "dpt_large": # DPT-Large
|
| 36 |
+
net_w, net_h = 384, 384
|
| 37 |
+
resize_mode = "minimal"
|
| 38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 39 |
+
|
| 40 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
| 41 |
+
net_w, net_h = 384, 384
|
| 42 |
+
resize_mode = "minimal"
|
| 43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 44 |
+
|
| 45 |
+
elif model_type == "midas_v21":
|
| 46 |
+
net_w, net_h = 384, 384
|
| 47 |
+
resize_mode = "upper_bound"
|
| 48 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 49 |
+
|
| 50 |
+
elif model_type == "midas_v21_small":
|
| 51 |
+
net_w, net_h = 256, 256
|
| 52 |
+
resize_mode = "upper_bound"
|
| 53 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 54 |
+
|
| 55 |
+
else:
|
| 56 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
| 57 |
+
|
| 58 |
+
transform = Compose(
|
| 59 |
+
[
|
| 60 |
+
Resize(
|
| 61 |
+
net_w,
|
| 62 |
+
net_h,
|
| 63 |
+
resize_target=None,
|
| 64 |
+
keep_aspect_ratio=True,
|
| 65 |
+
ensure_multiple_of=32,
|
| 66 |
+
resize_method=resize_mode,
|
| 67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
| 68 |
+
),
|
| 69 |
+
normalization,
|
| 70 |
+
PrepareForNet(),
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return transform
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_model(model_type):
|
| 78 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
| 79 |
+
# load network
|
| 80 |
+
model_path = ISL_PATHS[model_type]
|
| 81 |
+
if model_type == "dpt_large": # DPT-Large
|
| 82 |
+
model = DPTDepthModel(
|
| 83 |
+
path=model_path,
|
| 84 |
+
backbone="vitl16_384",
|
| 85 |
+
non_negative=True,
|
| 86 |
+
)
|
| 87 |
+
net_w, net_h = 384, 384
|
| 88 |
+
resize_mode = "minimal"
|
| 89 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 90 |
+
|
| 91 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
| 92 |
+
if not os.path.exists(model_path):
|
| 93 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt")
|
| 94 |
+
|
| 95 |
+
model = DPTDepthModel(
|
| 96 |
+
path=model_path,
|
| 97 |
+
backbone="vitb_rn50_384",
|
| 98 |
+
non_negative=True,
|
| 99 |
+
)
|
| 100 |
+
net_w, net_h = 384, 384
|
| 101 |
+
resize_mode = "minimal"
|
| 102 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 103 |
+
|
| 104 |
+
elif model_type == "midas_v21":
|
| 105 |
+
model = MidasNet(model_path, non_negative=True)
|
| 106 |
+
net_w, net_h = 384, 384
|
| 107 |
+
resize_mode = "upper_bound"
|
| 108 |
+
normalization = NormalizeImage(
|
| 109 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
elif model_type == "midas_v21_small":
|
| 113 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
| 114 |
+
non_negative=True, blocks={'expand': True})
|
| 115 |
+
net_w, net_h = 256, 256
|
| 116 |
+
resize_mode = "upper_bound"
|
| 117 |
+
normalization = NormalizeImage(
|
| 118 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
| 123 |
+
assert False
|
| 124 |
+
|
| 125 |
+
transform = Compose(
|
| 126 |
+
[
|
| 127 |
+
Resize(
|
| 128 |
+
net_w,
|
| 129 |
+
net_h,
|
| 130 |
+
resize_target=None,
|
| 131 |
+
keep_aspect_ratio=True,
|
| 132 |
+
ensure_multiple_of=32,
|
| 133 |
+
resize_method=resize_mode,
|
| 134 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
| 135 |
+
),
|
| 136 |
+
normalization,
|
| 137 |
+
PrepareForNet(),
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return model.eval(), transform
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class MiDaSInference(nn.Module):
|
| 145 |
+
MODEL_TYPES_TORCH_HUB = [
|
| 146 |
+
"DPT_Large",
|
| 147 |
+
"DPT_Hybrid",
|
| 148 |
+
"MiDaS_small"
|
| 149 |
+
]
|
| 150 |
+
MODEL_TYPES_ISL = [
|
| 151 |
+
"dpt_large",
|
| 152 |
+
"dpt_hybrid",
|
| 153 |
+
"midas_v21",
|
| 154 |
+
"midas_v21_small",
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
def __init__(self, model_type):
|
| 158 |
+
super().__init__()
|
| 159 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
| 160 |
+
model, _ = load_model(model_type)
|
| 161 |
+
self.model = model
|
| 162 |
+
self.model.train = disabled_train
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
prediction = self.model(x)
|
| 167 |
+
return prediction
|
| 168 |
+
|
src/flux/annotator/midas/midas/__init__.py
ADDED
|
File without changes
|
src/flux/annotator/midas/midas/base_model.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BaseModel(torch.nn.Module):
|
| 5 |
+
def load(self, path):
|
| 6 |
+
"""Load model from file.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
path (str): file path
|
| 10 |
+
"""
|
| 11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
| 12 |
+
|
| 13 |
+
if "optimizer" in parameters:
|
| 14 |
+
parameters = parameters["model"]
|
| 15 |
+
|
| 16 |
+
self.load_state_dict(parameters)
|
src/flux/annotator/midas/midas/blocks.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .vit import (
|
| 5 |
+
_make_pretrained_vitb_rn50_384,
|
| 6 |
+
_make_pretrained_vitl16_384,
|
| 7 |
+
_make_pretrained_vitb16_384,
|
| 8 |
+
forward_vit,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
| 12 |
+
if backbone == "vitl16_384":
|
| 13 |
+
pretrained = _make_pretrained_vitl16_384(
|
| 14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
| 15 |
+
)
|
| 16 |
+
scratch = _make_scratch(
|
| 17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
| 18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
| 19 |
+
elif backbone == "vitb_rn50_384":
|
| 20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
| 21 |
+
use_pretrained,
|
| 22 |
+
hooks=hooks,
|
| 23 |
+
use_vit_only=use_vit_only,
|
| 24 |
+
use_readout=use_readout,
|
| 25 |
+
)
|
| 26 |
+
scratch = _make_scratch(
|
| 27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
| 28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
| 29 |
+
elif backbone == "vitb16_384":
|
| 30 |
+
pretrained = _make_pretrained_vitb16_384(
|
| 31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
| 32 |
+
)
|
| 33 |
+
scratch = _make_scratch(
|
| 34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
| 35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
| 36 |
+
elif backbone == "resnext101_wsl":
|
| 37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
| 38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
| 39 |
+
elif backbone == "efficientnet_lite3":
|
| 40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
| 41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
| 42 |
+
else:
|
| 43 |
+
print(f"Backbone '{backbone}' not implemented")
|
| 44 |
+
assert False
|
| 45 |
+
|
| 46 |
+
return pretrained, scratch
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
| 50 |
+
scratch = nn.Module()
|
| 51 |
+
|
| 52 |
+
out_shape1 = out_shape
|
| 53 |
+
out_shape2 = out_shape
|
| 54 |
+
out_shape3 = out_shape
|
| 55 |
+
out_shape4 = out_shape
|
| 56 |
+
if expand==True:
|
| 57 |
+
out_shape1 = out_shape
|
| 58 |
+
out_shape2 = out_shape*2
|
| 59 |
+
out_shape3 = out_shape*4
|
| 60 |
+
out_shape4 = out_shape*8
|
| 61 |
+
|
| 62 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 64 |
+
)
|
| 65 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 67 |
+
)
|
| 68 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 70 |
+
)
|
| 71 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return scratch
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
| 79 |
+
efficientnet = torch.hub.load(
|
| 80 |
+
"rwightman/gen-efficientnet-pytorch",
|
| 81 |
+
"tf_efficientnet_lite3",
|
| 82 |
+
pretrained=use_pretrained,
|
| 83 |
+
exportable=exportable
|
| 84 |
+
)
|
| 85 |
+
return _make_efficientnet_backbone(efficientnet)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _make_efficientnet_backbone(effnet):
|
| 89 |
+
pretrained = nn.Module()
|
| 90 |
+
|
| 91 |
+
pretrained.layer1 = nn.Sequential(
|
| 92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
| 93 |
+
)
|
| 94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
| 95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
| 96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
| 97 |
+
|
| 98 |
+
return pretrained
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _make_resnet_backbone(resnet):
|
| 102 |
+
pretrained = nn.Module()
|
| 103 |
+
pretrained.layer1 = nn.Sequential(
|
| 104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
pretrained.layer2 = resnet.layer2
|
| 108 |
+
pretrained.layer3 = resnet.layer3
|
| 109 |
+
pretrained.layer4 = resnet.layer4
|
| 110 |
+
|
| 111 |
+
return pretrained
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
| 115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
| 116 |
+
return _make_resnet_backbone(resnet)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Interpolate(nn.Module):
|
| 121 |
+
"""Interpolation module.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
| 125 |
+
"""Init.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
scale_factor (float): scaling
|
| 129 |
+
mode (str): interpolation mode
|
| 130 |
+
"""
|
| 131 |
+
super(Interpolate, self).__init__()
|
| 132 |
+
|
| 133 |
+
self.interp = nn.functional.interpolate
|
| 134 |
+
self.scale_factor = scale_factor
|
| 135 |
+
self.mode = mode
|
| 136 |
+
self.align_corners = align_corners
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
"""Forward pass.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
x (tensor): input
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
tensor: interpolated data
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
x = self.interp(
|
| 149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ResidualConvUnit(nn.Module):
|
| 156 |
+
"""Residual convolution module.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, features):
|
| 160 |
+
"""Init.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
features (int): number of features
|
| 164 |
+
"""
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
self.conv1 = nn.Conv2d(
|
| 168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.conv2 = nn.Conv2d(
|
| 172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.relu = nn.ReLU(inplace=True)
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
"""Forward pass.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
x (tensor): input
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
tensor: output
|
| 185 |
+
"""
|
| 186 |
+
out = self.relu(x)
|
| 187 |
+
out = self.conv1(out)
|
| 188 |
+
out = self.relu(out)
|
| 189 |
+
out = self.conv2(out)
|
| 190 |
+
|
| 191 |
+
return out + x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class FeatureFusionBlock(nn.Module):
|
| 195 |
+
"""Feature fusion block.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, features):
|
| 199 |
+
"""Init.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
features (int): number of features
|
| 203 |
+
"""
|
| 204 |
+
super(FeatureFusionBlock, self).__init__()
|
| 205 |
+
|
| 206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
| 207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
| 208 |
+
|
| 209 |
+
def forward(self, *xs):
|
| 210 |
+
"""Forward pass.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
tensor: output
|
| 214 |
+
"""
|
| 215 |
+
output = xs[0]
|
| 216 |
+
|
| 217 |
+
if len(xs) == 2:
|
| 218 |
+
output += self.resConfUnit1(xs[1])
|
| 219 |
+
|
| 220 |
+
output = self.resConfUnit2(output)
|
| 221 |
+
|
| 222 |
+
output = nn.functional.interpolate(
|
| 223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return output
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ResidualConvUnit_custom(nn.Module):
|
| 232 |
+
"""Residual convolution module.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
def __init__(self, features, activation, bn):
|
| 236 |
+
"""Init.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
features (int): number of features
|
| 240 |
+
"""
|
| 241 |
+
super().__init__()
|
| 242 |
+
|
| 243 |
+
self.bn = bn
|
| 244 |
+
|
| 245 |
+
self.groups=1
|
| 246 |
+
|
| 247 |
+
self.conv1 = nn.Conv2d(
|
| 248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
self.conv2 = nn.Conv2d(
|
| 252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if self.bn==True:
|
| 256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
| 257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
| 258 |
+
|
| 259 |
+
self.activation = activation
|
| 260 |
+
|
| 261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 262 |
+
|
| 263 |
+
def forward(self, x):
|
| 264 |
+
"""Forward pass.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
x (tensor): input
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
tensor: output
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
out = self.activation(x)
|
| 274 |
+
out = self.conv1(out)
|
| 275 |
+
if self.bn==True:
|
| 276 |
+
out = self.bn1(out)
|
| 277 |
+
|
| 278 |
+
out = self.activation(out)
|
| 279 |
+
out = self.conv2(out)
|
| 280 |
+
if self.bn==True:
|
| 281 |
+
out = self.bn2(out)
|
| 282 |
+
|
| 283 |
+
if self.groups > 1:
|
| 284 |
+
out = self.conv_merge(out)
|
| 285 |
+
|
| 286 |
+
return self.skip_add.add(out, x)
|
| 287 |
+
|
| 288 |
+
# return out + x
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
| 292 |
+
"""Feature fusion block.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
| 296 |
+
"""Init.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
features (int): number of features
|
| 300 |
+
"""
|
| 301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
| 302 |
+
|
| 303 |
+
self.deconv = deconv
|
| 304 |
+
self.align_corners = align_corners
|
| 305 |
+
|
| 306 |
+
self.groups=1
|
| 307 |
+
|
| 308 |
+
self.expand = expand
|
| 309 |
+
out_features = features
|
| 310 |
+
if self.expand==True:
|
| 311 |
+
out_features = features//2
|
| 312 |
+
|
| 313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
| 314 |
+
|
| 315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
| 316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
| 317 |
+
|
| 318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 319 |
+
|
| 320 |
+
def forward(self, *xs):
|
| 321 |
+
"""Forward pass.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
tensor: output
|
| 325 |
+
"""
|
| 326 |
+
output = xs[0]
|
| 327 |
+
|
| 328 |
+
if len(xs) == 2:
|
| 329 |
+
res = self.resConfUnit1(xs[1])
|
| 330 |
+
output = self.skip_add.add(output, res)
|
| 331 |
+
# output += res
|
| 332 |
+
|
| 333 |
+
output = self.resConfUnit2(output)
|
| 334 |
+
|
| 335 |
+
output = nn.functional.interpolate(
|
| 336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
output = self.out_conv(output)
|
| 340 |
+
|
| 341 |
+
return output
|
| 342 |
+
|
src/flux/annotator/midas/midas/dpt_depth.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .base_model import BaseModel
|
| 6 |
+
from .blocks import (
|
| 7 |
+
FeatureFusionBlock,
|
| 8 |
+
FeatureFusionBlock_custom,
|
| 9 |
+
Interpolate,
|
| 10 |
+
_make_encoder,
|
| 11 |
+
forward_vit,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _make_fusion_block(features, use_bn):
|
| 16 |
+
return FeatureFusionBlock_custom(
|
| 17 |
+
features,
|
| 18 |
+
nn.ReLU(False),
|
| 19 |
+
deconv=False,
|
| 20 |
+
bn=use_bn,
|
| 21 |
+
expand=False,
|
| 22 |
+
align_corners=True,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DPT(BaseModel):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
head,
|
| 30 |
+
features=256,
|
| 31 |
+
backbone="vitb_rn50_384",
|
| 32 |
+
readout="project",
|
| 33 |
+
channels_last=False,
|
| 34 |
+
use_bn=False,
|
| 35 |
+
):
|
| 36 |
+
|
| 37 |
+
super(DPT, self).__init__()
|
| 38 |
+
|
| 39 |
+
self.channels_last = channels_last
|
| 40 |
+
|
| 41 |
+
hooks = {
|
| 42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
| 43 |
+
"vitb16_384": [2, 5, 8, 11],
|
| 44 |
+
"vitl16_384": [5, 11, 17, 23],
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
# Instantiate backbone and reassemble blocks
|
| 48 |
+
self.pretrained, self.scratch = _make_encoder(
|
| 49 |
+
backbone,
|
| 50 |
+
features,
|
| 51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
| 52 |
+
groups=1,
|
| 53 |
+
expand=False,
|
| 54 |
+
exportable=False,
|
| 55 |
+
hooks=hooks[backbone],
|
| 56 |
+
use_readout=readout,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
| 60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
| 61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
| 62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
| 63 |
+
|
| 64 |
+
self.scratch.output_conv = head
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
if self.channels_last == True:
|
| 69 |
+
x.contiguous(memory_format=torch.channels_last)
|
| 70 |
+
|
| 71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
| 72 |
+
|
| 73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 77 |
+
|
| 78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
| 79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
| 80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
| 81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 82 |
+
|
| 83 |
+
out = self.scratch.output_conv(path_1)
|
| 84 |
+
|
| 85 |
+
return out
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DPTDepthModel(DPT):
|
| 89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
| 90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
| 91 |
+
|
| 92 |
+
head = nn.Sequential(
|
| 93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 96 |
+
nn.ReLU(True),
|
| 97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
| 99 |
+
nn.Identity(),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
super().__init__(head, **kwargs)
|
| 103 |
+
|
| 104 |
+
if path is not None:
|
| 105 |
+
self.load(path)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
return super().forward(x).squeeze(dim=1)
|
| 109 |
+
|
src/flux/annotator/midas/midas/midas_net.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
| 2 |
+
This file contains code that is adapted from
|
| 3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .base_model import BaseModel
|
| 9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MidasNet(BaseModel):
|
| 13 |
+
"""Network for monocular depth estimation.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
| 17 |
+
"""Init.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
| 21 |
+
features (int, optional): Number of features. Defaults to 256.
|
| 22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
| 23 |
+
"""
|
| 24 |
+
print("Loading weights: ", path)
|
| 25 |
+
|
| 26 |
+
super(MidasNet, self).__init__()
|
| 27 |
+
|
| 28 |
+
use_pretrained = False if path is None else True
|
| 29 |
+
|
| 30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
| 31 |
+
|
| 32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
| 33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
| 34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
| 35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
| 36 |
+
|
| 37 |
+
self.scratch.output_conv = nn.Sequential(
|
| 38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
| 39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
| 40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
| 41 |
+
nn.ReLU(True),
|
| 42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if path:
|
| 47 |
+
self.load(path)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
"""Forward pass.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
x (tensor): input data (image)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
tensor: depth
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
layer_1 = self.pretrained.layer1(x)
|
| 60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
| 61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
| 62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
| 63 |
+
|
| 64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 68 |
+
|
| 69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
| 70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
| 71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
| 72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 73 |
+
|
| 74 |
+
out = self.scratch.output_conv(path_1)
|
| 75 |
+
|
| 76 |
+
return torch.squeeze(out, dim=1)
|
src/flux/annotator/midas/midas/midas_net_custom.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
| 2 |
+
This file contains code that is adapted from
|
| 3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .base_model import BaseModel
|
| 9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MidasNet_small(BaseModel):
|
| 13 |
+
"""Network for monocular depth estimation.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
| 17 |
+
blocks={'expand': True}):
|
| 18 |
+
"""Init.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
| 22 |
+
features (int, optional): Number of features. Defaults to 256.
|
| 23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
| 24 |
+
"""
|
| 25 |
+
print("Loading weights: ", path)
|
| 26 |
+
|
| 27 |
+
super(MidasNet_small, self).__init__()
|
| 28 |
+
|
| 29 |
+
use_pretrained = False if path else True
|
| 30 |
+
|
| 31 |
+
self.channels_last = channels_last
|
| 32 |
+
self.blocks = blocks
|
| 33 |
+
self.backbone = backbone
|
| 34 |
+
|
| 35 |
+
self.groups = 1
|
| 36 |
+
|
| 37 |
+
features1=features
|
| 38 |
+
features2=features
|
| 39 |
+
features3=features
|
| 40 |
+
features4=features
|
| 41 |
+
self.expand = False
|
| 42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
| 43 |
+
self.expand = True
|
| 44 |
+
features1=features
|
| 45 |
+
features2=features*2
|
| 46 |
+
features3=features*4
|
| 47 |
+
features4=features*8
|
| 48 |
+
|
| 49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
| 50 |
+
|
| 51 |
+
self.scratch.activation = nn.ReLU(False)
|
| 52 |
+
|
| 53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
| 54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
| 55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
| 56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
self.scratch.output_conv = nn.Sequential(
|
| 60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
| 61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
| 62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
| 63 |
+
self.scratch.activation,
|
| 64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
| 66 |
+
nn.Identity(),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if path:
|
| 70 |
+
self.load(path)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
"""Forward pass.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
x (tensor): input data (image)
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
tensor: depth
|
| 81 |
+
"""
|
| 82 |
+
if self.channels_last==True:
|
| 83 |
+
print("self.channels_last = ", self.channels_last)
|
| 84 |
+
x.contiguous(memory_format=torch.channels_last)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
layer_1 = self.pretrained.layer1(x)
|
| 88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
| 89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
| 90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
| 91 |
+
|
| 92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
| 99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
| 100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
| 101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 102 |
+
|
| 103 |
+
out = self.scratch.output_conv(path_1)
|
| 104 |
+
|
| 105 |
+
return torch.squeeze(out, dim=1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def fuse_model(m):
|
| 110 |
+
prev_previous_type = nn.Identity()
|
| 111 |
+
prev_previous_name = ''
|
| 112 |
+
previous_type = nn.Identity()
|
| 113 |
+
previous_name = ''
|
| 114 |
+
for name, module in m.named_modules():
|
| 115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
| 116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
| 117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
| 118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
| 119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
| 120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
| 121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
| 122 |
+
# print("FUSED ", previous_name, name)
|
| 123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
| 124 |
+
|
| 125 |
+
prev_previous_type = previous_type
|
| 126 |
+
prev_previous_name = previous_name
|
| 127 |
+
previous_type = type(module)
|
| 128 |
+
previous_name = name
|
src/flux/annotator/midas/midas/transforms.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
| 7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
sample (dict): sample
|
| 11 |
+
size (tuple): image size
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
tuple: new size
|
| 15 |
+
"""
|
| 16 |
+
shape = list(sample["disparity"].shape)
|
| 17 |
+
|
| 18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
| 19 |
+
return sample
|
| 20 |
+
|
| 21 |
+
scale = [0, 0]
|
| 22 |
+
scale[0] = size[0] / shape[0]
|
| 23 |
+
scale[1] = size[1] / shape[1]
|
| 24 |
+
|
| 25 |
+
scale = max(scale)
|
| 26 |
+
|
| 27 |
+
shape[0] = math.ceil(scale * shape[0])
|
| 28 |
+
shape[1] = math.ceil(scale * shape[1])
|
| 29 |
+
|
| 30 |
+
# resize
|
| 31 |
+
sample["image"] = cv2.resize(
|
| 32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
sample["disparity"] = cv2.resize(
|
| 36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
| 37 |
+
)
|
| 38 |
+
sample["mask"] = cv2.resize(
|
| 39 |
+
sample["mask"].astype(np.float32),
|
| 40 |
+
tuple(shape[::-1]),
|
| 41 |
+
interpolation=cv2.INTER_NEAREST,
|
| 42 |
+
)
|
| 43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
| 44 |
+
|
| 45 |
+
return tuple(shape)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Resize(object):
|
| 49 |
+
"""Resize sample to given size (width, height).
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
width,
|
| 55 |
+
height,
|
| 56 |
+
resize_target=True,
|
| 57 |
+
keep_aspect_ratio=False,
|
| 58 |
+
ensure_multiple_of=1,
|
| 59 |
+
resize_method="lower_bound",
|
| 60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
| 61 |
+
):
|
| 62 |
+
"""Init.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
width (int): desired output width
|
| 66 |
+
height (int): desired output height
|
| 67 |
+
resize_target (bool, optional):
|
| 68 |
+
True: Resize the full sample (image, mask, target).
|
| 69 |
+
False: Resize image only.
|
| 70 |
+
Defaults to True.
|
| 71 |
+
keep_aspect_ratio (bool, optional):
|
| 72 |
+
True: Keep the aspect ratio of the input sample.
|
| 73 |
+
Output sample might not have the given width and height, and
|
| 74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
| 75 |
+
Defaults to False.
|
| 76 |
+
ensure_multiple_of (int, optional):
|
| 77 |
+
Output width and height is constrained to be multiple of this parameter.
|
| 78 |
+
Defaults to 1.
|
| 79 |
+
resize_method (str, optional):
|
| 80 |
+
"lower_bound": Output will be at least as large as the given size.
|
| 81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
| 82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
| 83 |
+
Defaults to "lower_bound".
|
| 84 |
+
"""
|
| 85 |
+
self.__width = width
|
| 86 |
+
self.__height = height
|
| 87 |
+
|
| 88 |
+
self.__resize_target = resize_target
|
| 89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
| 90 |
+
self.__multiple_of = ensure_multiple_of
|
| 91 |
+
self.__resize_method = resize_method
|
| 92 |
+
self.__image_interpolation_method = image_interpolation_method
|
| 93 |
+
|
| 94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
| 95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 96 |
+
|
| 97 |
+
if max_val is not None and y > max_val:
|
| 98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 99 |
+
|
| 100 |
+
if y < min_val:
|
| 101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 102 |
+
|
| 103 |
+
return y
|
| 104 |
+
|
| 105 |
+
def get_size(self, width, height):
|
| 106 |
+
# determine new height and width
|
| 107 |
+
scale_height = self.__height / height
|
| 108 |
+
scale_width = self.__width / width
|
| 109 |
+
|
| 110 |
+
if self.__keep_aspect_ratio:
|
| 111 |
+
if self.__resize_method == "lower_bound":
|
| 112 |
+
# scale such that output size is lower bound
|
| 113 |
+
if scale_width > scale_height:
|
| 114 |
+
# fit width
|
| 115 |
+
scale_height = scale_width
|
| 116 |
+
else:
|
| 117 |
+
# fit height
|
| 118 |
+
scale_width = scale_height
|
| 119 |
+
elif self.__resize_method == "upper_bound":
|
| 120 |
+
# scale such that output size is upper bound
|
| 121 |
+
if scale_width < scale_height:
|
| 122 |
+
# fit width
|
| 123 |
+
scale_height = scale_width
|
| 124 |
+
else:
|
| 125 |
+
# fit height
|
| 126 |
+
scale_width = scale_height
|
| 127 |
+
elif self.__resize_method == "minimal":
|
| 128 |
+
# scale as least as possbile
|
| 129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
| 130 |
+
# fit width
|
| 131 |
+
scale_height = scale_width
|
| 132 |
+
else:
|
| 133 |
+
# fit height
|
| 134 |
+
scale_width = scale_height
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"resize_method {self.__resize_method} not implemented"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if self.__resize_method == "lower_bound":
|
| 141 |
+
new_height = self.constrain_to_multiple_of(
|
| 142 |
+
scale_height * height, min_val=self.__height
|
| 143 |
+
)
|
| 144 |
+
new_width = self.constrain_to_multiple_of(
|
| 145 |
+
scale_width * width, min_val=self.__width
|
| 146 |
+
)
|
| 147 |
+
elif self.__resize_method == "upper_bound":
|
| 148 |
+
new_height = self.constrain_to_multiple_of(
|
| 149 |
+
scale_height * height, max_val=self.__height
|
| 150 |
+
)
|
| 151 |
+
new_width = self.constrain_to_multiple_of(
|
| 152 |
+
scale_width * width, max_val=self.__width
|
| 153 |
+
)
|
| 154 |
+
elif self.__resize_method == "minimal":
|
| 155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
| 156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 159 |
+
|
| 160 |
+
return (new_width, new_height)
|
| 161 |
+
|
| 162 |
+
def __call__(self, sample):
|
| 163 |
+
width, height = self.get_size(
|
| 164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# resize sample
|
| 168 |
+
sample["image"] = cv2.resize(
|
| 169 |
+
sample["image"],
|
| 170 |
+
(width, height),
|
| 171 |
+
interpolation=self.__image_interpolation_method,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if self.__resize_target:
|
| 175 |
+
if "disparity" in sample:
|
| 176 |
+
sample["disparity"] = cv2.resize(
|
| 177 |
+
sample["disparity"],
|
| 178 |
+
(width, height),
|
| 179 |
+
interpolation=cv2.INTER_NEAREST,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if "depth" in sample:
|
| 183 |
+
sample["depth"] = cv2.resize(
|
| 184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
sample["mask"] = cv2.resize(
|
| 188 |
+
sample["mask"].astype(np.float32),
|
| 189 |
+
(width, height),
|
| 190 |
+
interpolation=cv2.INTER_NEAREST,
|
| 191 |
+
)
|
| 192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
| 193 |
+
|
| 194 |
+
return sample
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class NormalizeImage(object):
|
| 198 |
+
"""Normlize image by given mean and std.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self, mean, std):
|
| 202 |
+
self.__mean = mean
|
| 203 |
+
self.__std = std
|
| 204 |
+
|
| 205 |
+
def __call__(self, sample):
|
| 206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
| 207 |
+
|
| 208 |
+
return sample
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class PrepareForNet(object):
|
| 212 |
+
"""Prepare sample for usage as network input.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(self):
|
| 216 |
+
pass
|
| 217 |
+
|
| 218 |
+
def __call__(self, sample):
|
| 219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
| 220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
| 221 |
+
|
| 222 |
+
if "mask" in sample:
|
| 223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
| 224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
| 225 |
+
|
| 226 |
+
if "disparity" in sample:
|
| 227 |
+
disparity = sample["disparity"].astype(np.float32)
|
| 228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
| 229 |
+
|
| 230 |
+
if "depth" in sample:
|
| 231 |
+
depth = sample["depth"].astype(np.float32)
|
| 232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
| 233 |
+
|
| 234 |
+
return sample
|
src/flux/annotator/midas/midas/vit.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import timm
|
| 4 |
+
import types
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Slice(nn.Module):
|
| 10 |
+
def __init__(self, start_index=1):
|
| 11 |
+
super(Slice, self).__init__()
|
| 12 |
+
self.start_index = start_index
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return x[:, self.start_index :]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AddReadout(nn.Module):
|
| 19 |
+
def __init__(self, start_index=1):
|
| 20 |
+
super(AddReadout, self).__init__()
|
| 21 |
+
self.start_index = start_index
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
if self.start_index == 2:
|
| 25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
| 26 |
+
else:
|
| 27 |
+
readout = x[:, 0]
|
| 28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ProjectReadout(nn.Module):
|
| 32 |
+
def __init__(self, in_features, start_index=1):
|
| 33 |
+
super(ProjectReadout, self).__init__()
|
| 34 |
+
self.start_index = start_index
|
| 35 |
+
|
| 36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
| 40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
| 41 |
+
|
| 42 |
+
return self.project(features)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Transpose(nn.Module):
|
| 46 |
+
def __init__(self, dim0, dim1):
|
| 47 |
+
super(Transpose, self).__init__()
|
| 48 |
+
self.dim0 = dim0
|
| 49 |
+
self.dim1 = dim1
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = x.transpose(self.dim0, self.dim1)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def forward_vit(pretrained, x):
|
| 57 |
+
b, c, h, w = x.shape
|
| 58 |
+
|
| 59 |
+
glob = pretrained.model.forward_flex(x)
|
| 60 |
+
|
| 61 |
+
layer_1 = pretrained.activations["1"]
|
| 62 |
+
layer_2 = pretrained.activations["2"]
|
| 63 |
+
layer_3 = pretrained.activations["3"]
|
| 64 |
+
layer_4 = pretrained.activations["4"]
|
| 65 |
+
|
| 66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
| 67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
| 68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
| 69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
| 70 |
+
|
| 71 |
+
unflatten = nn.Sequential(
|
| 72 |
+
nn.Unflatten(
|
| 73 |
+
2,
|
| 74 |
+
torch.Size(
|
| 75 |
+
[
|
| 76 |
+
h // pretrained.model.patch_size[1],
|
| 77 |
+
w // pretrained.model.patch_size[0],
|
| 78 |
+
]
|
| 79 |
+
),
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if layer_1.ndim == 3:
|
| 84 |
+
layer_1 = unflatten(layer_1)
|
| 85 |
+
if layer_2.ndim == 3:
|
| 86 |
+
layer_2 = unflatten(layer_2)
|
| 87 |
+
if layer_3.ndim == 3:
|
| 88 |
+
layer_3 = unflatten(layer_3)
|
| 89 |
+
if layer_4.ndim == 3:
|
| 90 |
+
layer_4 = unflatten(layer_4)
|
| 91 |
+
|
| 92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
| 93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
| 94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
| 95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
| 96 |
+
|
| 97 |
+
return layer_1, layer_2, layer_3, layer_4
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
| 101 |
+
posemb_tok, posemb_grid = (
|
| 102 |
+
posemb[:, : self.start_index],
|
| 103 |
+
posemb[0, self.start_index :],
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 107 |
+
|
| 108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
| 110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
| 111 |
+
|
| 112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 113 |
+
|
| 114 |
+
return posemb
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def forward_flex(self, x):
|
| 118 |
+
b, c, h, w = x.shape
|
| 119 |
+
|
| 120 |
+
pos_embed = self._resize_pos_embed(
|
| 121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
B = x.shape[0]
|
| 125 |
+
|
| 126 |
+
if hasattr(self.patch_embed, "backbone"):
|
| 127 |
+
x = self.patch_embed.backbone(x)
|
| 128 |
+
if isinstance(x, (list, tuple)):
|
| 129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
| 130 |
+
|
| 131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
if getattr(self, "dist_token", None) is not None:
|
| 134 |
+
cls_tokens = self.cls_token.expand(
|
| 135 |
+
B, -1, -1
|
| 136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
| 137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
| 138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
| 139 |
+
else:
|
| 140 |
+
cls_tokens = self.cls_token.expand(
|
| 141 |
+
B, -1, -1
|
| 142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
| 143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 144 |
+
|
| 145 |
+
x = x + pos_embed
|
| 146 |
+
x = self.pos_drop(x)
|
| 147 |
+
|
| 148 |
+
for blk in self.blocks:
|
| 149 |
+
x = blk(x)
|
| 150 |
+
|
| 151 |
+
x = self.norm(x)
|
| 152 |
+
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
activations = {}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_activation(name):
|
| 160 |
+
def hook(model, input, output):
|
| 161 |
+
activations[name] = output
|
| 162 |
+
|
| 163 |
+
return hook
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
| 167 |
+
if use_readout == "ignore":
|
| 168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
| 169 |
+
elif use_readout == "add":
|
| 170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
| 171 |
+
elif use_readout == "project":
|
| 172 |
+
readout_oper = [
|
| 173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
| 174 |
+
]
|
| 175 |
+
else:
|
| 176 |
+
assert (
|
| 177 |
+
False
|
| 178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
| 179 |
+
|
| 180 |
+
return readout_oper
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _make_vit_b16_backbone(
|
| 184 |
+
model,
|
| 185 |
+
features=[96, 192, 384, 768],
|
| 186 |
+
size=[384, 384],
|
| 187 |
+
hooks=[2, 5, 8, 11],
|
| 188 |
+
vit_features=768,
|
| 189 |
+
use_readout="ignore",
|
| 190 |
+
start_index=1,
|
| 191 |
+
):
|
| 192 |
+
pretrained = nn.Module()
|
| 193 |
+
|
| 194 |
+
pretrained.model = model
|
| 195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
| 196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
| 197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
| 198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
| 199 |
+
|
| 200 |
+
pretrained.activations = activations
|
| 201 |
+
|
| 202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
| 203 |
+
|
| 204 |
+
# 32, 48, 136, 384
|
| 205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
| 206 |
+
readout_oper[0],
|
| 207 |
+
Transpose(1, 2),
|
| 208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 209 |
+
nn.Conv2d(
|
| 210 |
+
in_channels=vit_features,
|
| 211 |
+
out_channels=features[0],
|
| 212 |
+
kernel_size=1,
|
| 213 |
+
stride=1,
|
| 214 |
+
padding=0,
|
| 215 |
+
),
|
| 216 |
+
nn.ConvTranspose2d(
|
| 217 |
+
in_channels=features[0],
|
| 218 |
+
out_channels=features[0],
|
| 219 |
+
kernel_size=4,
|
| 220 |
+
stride=4,
|
| 221 |
+
padding=0,
|
| 222 |
+
bias=True,
|
| 223 |
+
dilation=1,
|
| 224 |
+
groups=1,
|
| 225 |
+
),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
| 229 |
+
readout_oper[1],
|
| 230 |
+
Transpose(1, 2),
|
| 231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 232 |
+
nn.Conv2d(
|
| 233 |
+
in_channels=vit_features,
|
| 234 |
+
out_channels=features[1],
|
| 235 |
+
kernel_size=1,
|
| 236 |
+
stride=1,
|
| 237 |
+
padding=0,
|
| 238 |
+
),
|
| 239 |
+
nn.ConvTranspose2d(
|
| 240 |
+
in_channels=features[1],
|
| 241 |
+
out_channels=features[1],
|
| 242 |
+
kernel_size=2,
|
| 243 |
+
stride=2,
|
| 244 |
+
padding=0,
|
| 245 |
+
bias=True,
|
| 246 |
+
dilation=1,
|
| 247 |
+
groups=1,
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
| 252 |
+
readout_oper[2],
|
| 253 |
+
Transpose(1, 2),
|
| 254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 255 |
+
nn.Conv2d(
|
| 256 |
+
in_channels=vit_features,
|
| 257 |
+
out_channels=features[2],
|
| 258 |
+
kernel_size=1,
|
| 259 |
+
stride=1,
|
| 260 |
+
padding=0,
|
| 261 |
+
),
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
| 265 |
+
readout_oper[3],
|
| 266 |
+
Transpose(1, 2),
|
| 267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 268 |
+
nn.Conv2d(
|
| 269 |
+
in_channels=vit_features,
|
| 270 |
+
out_channels=features[3],
|
| 271 |
+
kernel_size=1,
|
| 272 |
+
stride=1,
|
| 273 |
+
padding=0,
|
| 274 |
+
),
|
| 275 |
+
nn.Conv2d(
|
| 276 |
+
in_channels=features[3],
|
| 277 |
+
out_channels=features[3],
|
| 278 |
+
kernel_size=3,
|
| 279 |
+
stride=2,
|
| 280 |
+
padding=1,
|
| 281 |
+
),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
pretrained.model.start_index = start_index
|
| 285 |
+
pretrained.model.patch_size = [16, 16]
|
| 286 |
+
|
| 287 |
+
# We inject this function into the VisionTransformer instances so that
|
| 288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
| 289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
| 290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
| 291 |
+
_resize_pos_embed, pretrained.model
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return pretrained
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
| 298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
| 299 |
+
|
| 300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
| 301 |
+
return _make_vit_b16_backbone(
|
| 302 |
+
model,
|
| 303 |
+
features=[256, 512, 1024, 1024],
|
| 304 |
+
hooks=hooks,
|
| 305 |
+
vit_features=1024,
|
| 306 |
+
use_readout=use_readout,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
| 311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
| 312 |
+
|
| 313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
| 314 |
+
return _make_vit_b16_backbone(
|
| 315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
| 320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
| 321 |
+
|
| 322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
| 323 |
+
return _make_vit_b16_backbone(
|
| 324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
| 329 |
+
model = timm.create_model(
|
| 330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
| 334 |
+
return _make_vit_b16_backbone(
|
| 335 |
+
model,
|
| 336 |
+
features=[96, 192, 384, 768],
|
| 337 |
+
hooks=hooks,
|
| 338 |
+
use_readout=use_readout,
|
| 339 |
+
start_index=2,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _make_vit_b_rn50_backbone(
|
| 344 |
+
model,
|
| 345 |
+
features=[256, 512, 768, 768],
|
| 346 |
+
size=[384, 384],
|
| 347 |
+
hooks=[0, 1, 8, 11],
|
| 348 |
+
vit_features=768,
|
| 349 |
+
use_vit_only=False,
|
| 350 |
+
use_readout="ignore",
|
| 351 |
+
start_index=1,
|
| 352 |
+
):
|
| 353 |
+
pretrained = nn.Module()
|
| 354 |
+
|
| 355 |
+
pretrained.model = model
|
| 356 |
+
|
| 357 |
+
if use_vit_only == True:
|
| 358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
| 359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
| 360 |
+
else:
|
| 361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
| 362 |
+
get_activation("1")
|
| 363 |
+
)
|
| 364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
| 365 |
+
get_activation("2")
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
| 369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
| 370 |
+
|
| 371 |
+
pretrained.activations = activations
|
| 372 |
+
|
| 373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
| 374 |
+
|
| 375 |
+
if use_vit_only == True:
|
| 376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
| 377 |
+
readout_oper[0],
|
| 378 |
+
Transpose(1, 2),
|
| 379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 380 |
+
nn.Conv2d(
|
| 381 |
+
in_channels=vit_features,
|
| 382 |
+
out_channels=features[0],
|
| 383 |
+
kernel_size=1,
|
| 384 |
+
stride=1,
|
| 385 |
+
padding=0,
|
| 386 |
+
),
|
| 387 |
+
nn.ConvTranspose2d(
|
| 388 |
+
in_channels=features[0],
|
| 389 |
+
out_channels=features[0],
|
| 390 |
+
kernel_size=4,
|
| 391 |
+
stride=4,
|
| 392 |
+
padding=0,
|
| 393 |
+
bias=True,
|
| 394 |
+
dilation=1,
|
| 395 |
+
groups=1,
|
| 396 |
+
),
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
| 400 |
+
readout_oper[1],
|
| 401 |
+
Transpose(1, 2),
|
| 402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 403 |
+
nn.Conv2d(
|
| 404 |
+
in_channels=vit_features,
|
| 405 |
+
out_channels=features[1],
|
| 406 |
+
kernel_size=1,
|
| 407 |
+
stride=1,
|
| 408 |
+
padding=0,
|
| 409 |
+
),
|
| 410 |
+
nn.ConvTranspose2d(
|
| 411 |
+
in_channels=features[1],
|
| 412 |
+
out_channels=features[1],
|
| 413 |
+
kernel_size=2,
|
| 414 |
+
stride=2,
|
| 415 |
+
padding=0,
|
| 416 |
+
bias=True,
|
| 417 |
+
dilation=1,
|
| 418 |
+
groups=1,
|
| 419 |
+
),
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
| 423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
| 424 |
+
)
|
| 425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
| 426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
| 430 |
+
readout_oper[2],
|
| 431 |
+
Transpose(1, 2),
|
| 432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 433 |
+
nn.Conv2d(
|
| 434 |
+
in_channels=vit_features,
|
| 435 |
+
out_channels=features[2],
|
| 436 |
+
kernel_size=1,
|
| 437 |
+
stride=1,
|
| 438 |
+
padding=0,
|
| 439 |
+
),
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
| 443 |
+
readout_oper[3],
|
| 444 |
+
Transpose(1, 2),
|
| 445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
| 446 |
+
nn.Conv2d(
|
| 447 |
+
in_channels=vit_features,
|
| 448 |
+
out_channels=features[3],
|
| 449 |
+
kernel_size=1,
|
| 450 |
+
stride=1,
|
| 451 |
+
padding=0,
|
| 452 |
+
),
|
| 453 |
+
nn.Conv2d(
|
| 454 |
+
in_channels=features[3],
|
| 455 |
+
out_channels=features[3],
|
| 456 |
+
kernel_size=3,
|
| 457 |
+
stride=2,
|
| 458 |
+
padding=1,
|
| 459 |
+
),
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
pretrained.model.start_index = start_index
|
| 463 |
+
pretrained.model.patch_size = [16, 16]
|
| 464 |
+
|
| 465 |
+
# We inject this function into the VisionTransformer instances so that
|
| 466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
| 467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
| 468 |
+
|
| 469 |
+
# We inject this function into the VisionTransformer instances so that
|
| 470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
| 471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
| 472 |
+
_resize_pos_embed, pretrained.model
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
return pretrained
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _make_pretrained_vitb_rn50_384(
|
| 479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
| 480 |
+
):
|
| 481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
| 482 |
+
|
| 483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
| 484 |
+
return _make_vit_b_rn50_backbone(
|
| 485 |
+
model,
|
| 486 |
+
features=[256, 512, 768, 768],
|
| 487 |
+
size=[384, 384],
|
| 488 |
+
hooks=hooks,
|
| 489 |
+
use_vit_only=use_vit_only,
|
| 490 |
+
use_readout=use_readout,
|
| 491 |
+
)
|
src/flux/annotator/midas/utils.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utils for monoDepth."""
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def read_pfm(path):
|
| 10 |
+
"""Read pfm file.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
path (str): path to file
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
tuple: (data, scale)
|
| 17 |
+
"""
|
| 18 |
+
with open(path, "rb") as file:
|
| 19 |
+
|
| 20 |
+
color = None
|
| 21 |
+
width = None
|
| 22 |
+
height = None
|
| 23 |
+
scale = None
|
| 24 |
+
endian = None
|
| 25 |
+
|
| 26 |
+
header = file.readline().rstrip()
|
| 27 |
+
if header.decode("ascii") == "PF":
|
| 28 |
+
color = True
|
| 29 |
+
elif header.decode("ascii") == "Pf":
|
| 30 |
+
color = False
|
| 31 |
+
else:
|
| 32 |
+
raise Exception("Not a PFM file: " + path)
|
| 33 |
+
|
| 34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
| 35 |
+
if dim_match:
|
| 36 |
+
width, height = list(map(int, dim_match.groups()))
|
| 37 |
+
else:
|
| 38 |
+
raise Exception("Malformed PFM header.")
|
| 39 |
+
|
| 40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
| 41 |
+
if scale < 0:
|
| 42 |
+
# little-endian
|
| 43 |
+
endian = "<"
|
| 44 |
+
scale = -scale
|
| 45 |
+
else:
|
| 46 |
+
# big-endian
|
| 47 |
+
endian = ">"
|
| 48 |
+
|
| 49 |
+
data = np.fromfile(file, endian + "f")
|
| 50 |
+
shape = (height, width, 3) if color else (height, width)
|
| 51 |
+
|
| 52 |
+
data = np.reshape(data, shape)
|
| 53 |
+
data = np.flipud(data)
|
| 54 |
+
|
| 55 |
+
return data, scale
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def write_pfm(path, image, scale=1):
|
| 59 |
+
"""Write pfm file.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
path (str): pathto file
|
| 63 |
+
image (array): data
|
| 64 |
+
scale (int, optional): Scale. Defaults to 1.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
with open(path, "wb") as file:
|
| 68 |
+
color = None
|
| 69 |
+
|
| 70 |
+
if image.dtype.name != "float32":
|
| 71 |
+
raise Exception("Image dtype must be float32.")
|
| 72 |
+
|
| 73 |
+
image = np.flipud(image)
|
| 74 |
+
|
| 75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
| 76 |
+
color = True
|
| 77 |
+
elif (
|
| 78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
| 79 |
+
): # greyscale
|
| 80 |
+
color = False
|
| 81 |
+
else:
|
| 82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
| 83 |
+
|
| 84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
| 85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
| 86 |
+
|
| 87 |
+
endian = image.dtype.byteorder
|
| 88 |
+
|
| 89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
| 90 |
+
scale = -scale
|
| 91 |
+
|
| 92 |
+
file.write("%f\n".encode() % scale)
|
| 93 |
+
|
| 94 |
+
image.tofile(file)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def read_image(path):
|
| 98 |
+
"""Read image and output RGB image (0-1).
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
path (str): path to file
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
array: RGB image (0-1)
|
| 105 |
+
"""
|
| 106 |
+
img = cv2.imread(path)
|
| 107 |
+
|
| 108 |
+
if img.ndim == 2:
|
| 109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 110 |
+
|
| 111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
| 112 |
+
|
| 113 |
+
return img
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def resize_image(img):
|
| 117 |
+
"""Resize image and make it fit for network.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
img (array): image
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
tensor: data ready for network
|
| 124 |
+
"""
|
| 125 |
+
height_orig = img.shape[0]
|
| 126 |
+
width_orig = img.shape[1]
|
| 127 |
+
|
| 128 |
+
if width_orig > height_orig:
|
| 129 |
+
scale = width_orig / 384
|
| 130 |
+
else:
|
| 131 |
+
scale = height_orig / 384
|
| 132 |
+
|
| 133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
| 134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
| 135 |
+
|
| 136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
| 137 |
+
|
| 138 |
+
img_resized = (
|
| 139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
| 140 |
+
)
|
| 141 |
+
img_resized = img_resized.unsqueeze(0)
|
| 142 |
+
|
| 143 |
+
return img_resized
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def resize_depth(depth, width, height):
|
| 147 |
+
"""Resize depth map and bring to CPU (numpy).
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
depth (tensor): depth
|
| 151 |
+
width (int): image width
|
| 152 |
+
height (int): image height
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
array: processed depth
|
| 156 |
+
"""
|
| 157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
| 158 |
+
|
| 159 |
+
depth_resized = cv2.resize(
|
| 160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return depth_resized
|
| 164 |
+
|
| 165 |
+
def write_depth(path, depth, bits=1):
|
| 166 |
+
"""Write depth map to pfm and png file.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
path (str): filepath without extension
|
| 170 |
+
depth (array): depth
|
| 171 |
+
"""
|
| 172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
| 173 |
+
|
| 174 |
+
depth_min = depth.min()
|
| 175 |
+
depth_max = depth.max()
|
| 176 |
+
|
| 177 |
+
max_val = (2**(8*bits))-1
|
| 178 |
+
|
| 179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
| 180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
| 181 |
+
else:
|
| 182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
| 183 |
+
|
| 184 |
+
if bits == 1:
|
| 185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
| 186 |
+
elif bits == 2:
|
| 187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
| 188 |
+
|
| 189 |
+
return
|
src/flux/annotator/mlsd/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2021-present NAVER Corp.
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
src/flux/annotator/mlsd/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MLSD Line Detection
|
| 2 |
+
# From https://github.com/navervision/mlsd
|
| 3 |
+
# Apache-2.0 license
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
|
| 13 |
+
from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
| 14 |
+
from .utils import pred_lines
|
| 15 |
+
|
| 16 |
+
from ...annotator.util import annotator_ckpts_path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MLSDdetector:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
|
| 22 |
+
if not os.path.exists(model_path):
|
| 23 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth")
|
| 24 |
+
model = MobileV2_MLSD_Large()
|
| 25 |
+
model.load_state_dict(torch.load(model_path), strict=True)
|
| 26 |
+
self.model = model.cuda().eval()
|
| 27 |
+
|
| 28 |
+
def __call__(self, input_image, thr_v, thr_d):
|
| 29 |
+
assert input_image.ndim == 3
|
| 30 |
+
img = input_image
|
| 31 |
+
img_output = np.zeros_like(img)
|
| 32 |
+
try:
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
|
| 35 |
+
for line in lines:
|
| 36 |
+
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
| 37 |
+
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
pass
|
| 40 |
+
return img_output[:, :, 0]
|
src/flux/annotator/mlsd/models/mbv2_mlsd_large.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.utils.model_zoo as model_zoo
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BlockTypeA(nn.Module):
|
| 10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
| 11 |
+
super(BlockTypeA, self).__init__()
|
| 12 |
+
self.conv1 = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
| 14 |
+
nn.BatchNorm2d(out_c2),
|
| 15 |
+
nn.ReLU(inplace=True)
|
| 16 |
+
)
|
| 17 |
+
self.conv2 = nn.Sequential(
|
| 18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
| 19 |
+
nn.BatchNorm2d(out_c1),
|
| 20 |
+
nn.ReLU(inplace=True)
|
| 21 |
+
)
|
| 22 |
+
self.upscale = upscale
|
| 23 |
+
|
| 24 |
+
def forward(self, a, b):
|
| 25 |
+
b = self.conv1(b)
|
| 26 |
+
a = self.conv2(a)
|
| 27 |
+
if self.upscale:
|
| 28 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
| 29 |
+
return torch.cat((a, b), dim=1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BlockTypeB(nn.Module):
|
| 33 |
+
def __init__(self, in_c, out_c):
|
| 34 |
+
super(BlockTypeB, self).__init__()
|
| 35 |
+
self.conv1 = nn.Sequential(
|
| 36 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
| 37 |
+
nn.BatchNorm2d(in_c),
|
| 38 |
+
nn.ReLU()
|
| 39 |
+
)
|
| 40 |
+
self.conv2 = nn.Sequential(
|
| 41 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
| 42 |
+
nn.BatchNorm2d(out_c),
|
| 43 |
+
nn.ReLU()
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.conv1(x) + x
|
| 48 |
+
x = self.conv2(x)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class BlockTypeC(nn.Module):
|
| 52 |
+
def __init__(self, in_c, out_c):
|
| 53 |
+
super(BlockTypeC, self).__init__()
|
| 54 |
+
self.conv1 = nn.Sequential(
|
| 55 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
| 56 |
+
nn.BatchNorm2d(in_c),
|
| 57 |
+
nn.ReLU()
|
| 58 |
+
)
|
| 59 |
+
self.conv2 = nn.Sequential(
|
| 60 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
| 61 |
+
nn.BatchNorm2d(in_c),
|
| 62 |
+
nn.ReLU()
|
| 63 |
+
)
|
| 64 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = self.conv1(x)
|
| 68 |
+
x = self.conv2(x)
|
| 69 |
+
x = self.conv3(x)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
def _make_divisible(v, divisor, min_value=None):
|
| 73 |
+
"""
|
| 74 |
+
This function is taken from the original tf repo.
|
| 75 |
+
It ensures that all layers have a channel number that is divisible by 8
|
| 76 |
+
It can be seen here:
|
| 77 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
| 78 |
+
:param v:
|
| 79 |
+
:param divisor:
|
| 80 |
+
:param min_value:
|
| 81 |
+
:return:
|
| 82 |
+
"""
|
| 83 |
+
if min_value is None:
|
| 84 |
+
min_value = divisor
|
| 85 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 86 |
+
# Make sure that round down does not go down by more than 10%.
|
| 87 |
+
if new_v < 0.9 * v:
|
| 88 |
+
new_v += divisor
|
| 89 |
+
return new_v
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ConvBNReLU(nn.Sequential):
|
| 93 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
| 94 |
+
self.channel_pad = out_planes - in_planes
|
| 95 |
+
self.stride = stride
|
| 96 |
+
#padding = (kernel_size - 1) // 2
|
| 97 |
+
|
| 98 |
+
# TFLite uses slightly different padding than PyTorch
|
| 99 |
+
if stride == 2:
|
| 100 |
+
padding = 0
|
| 101 |
+
else:
|
| 102 |
+
padding = (kernel_size - 1) // 2
|
| 103 |
+
|
| 104 |
+
super(ConvBNReLU, self).__init__(
|
| 105 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
| 106 |
+
nn.BatchNorm2d(out_planes),
|
| 107 |
+
nn.ReLU6(inplace=True)
|
| 108 |
+
)
|
| 109 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
# TFLite uses different padding
|
| 114 |
+
if self.stride == 2:
|
| 115 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
| 116 |
+
#print(x.shape)
|
| 117 |
+
|
| 118 |
+
for module in self:
|
| 119 |
+
if not isinstance(module, nn.MaxPool2d):
|
| 120 |
+
x = module(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class InvertedResidual(nn.Module):
|
| 125 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
| 126 |
+
super(InvertedResidual, self).__init__()
|
| 127 |
+
self.stride = stride
|
| 128 |
+
assert stride in [1, 2]
|
| 129 |
+
|
| 130 |
+
hidden_dim = int(round(inp * expand_ratio))
|
| 131 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
| 132 |
+
|
| 133 |
+
layers = []
|
| 134 |
+
if expand_ratio != 1:
|
| 135 |
+
# pw
|
| 136 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
| 137 |
+
layers.extend([
|
| 138 |
+
# dw
|
| 139 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
| 140 |
+
# pw-linear
|
| 141 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 142 |
+
nn.BatchNorm2d(oup),
|
| 143 |
+
])
|
| 144 |
+
self.conv = nn.Sequential(*layers)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
if self.use_res_connect:
|
| 148 |
+
return x + self.conv(x)
|
| 149 |
+
else:
|
| 150 |
+
return self.conv(x)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MobileNetV2(nn.Module):
|
| 154 |
+
def __init__(self, pretrained=True):
|
| 155 |
+
"""
|
| 156 |
+
MobileNet V2 main class
|
| 157 |
+
Args:
|
| 158 |
+
num_classes (int): Number of classes
|
| 159 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
| 160 |
+
inverted_residual_setting: Network structure
|
| 161 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
| 162 |
+
Set to 1 to turn off rounding
|
| 163 |
+
block: Module specifying inverted residual building block for mobilenet
|
| 164 |
+
"""
|
| 165 |
+
super(MobileNetV2, self).__init__()
|
| 166 |
+
|
| 167 |
+
block = InvertedResidual
|
| 168 |
+
input_channel = 32
|
| 169 |
+
last_channel = 1280
|
| 170 |
+
width_mult = 1.0
|
| 171 |
+
round_nearest = 8
|
| 172 |
+
|
| 173 |
+
inverted_residual_setting = [
|
| 174 |
+
# t, c, n, s
|
| 175 |
+
[1, 16, 1, 1],
|
| 176 |
+
[6, 24, 2, 2],
|
| 177 |
+
[6, 32, 3, 2],
|
| 178 |
+
[6, 64, 4, 2],
|
| 179 |
+
[6, 96, 3, 1],
|
| 180 |
+
#[6, 160, 3, 2],
|
| 181 |
+
#[6, 320, 1, 1],
|
| 182 |
+
]
|
| 183 |
+
|
| 184 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
| 185 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
| 186 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
| 187 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
| 188 |
+
|
| 189 |
+
# building first layer
|
| 190 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
| 191 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
| 192 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
| 193 |
+
# building inverted residual blocks
|
| 194 |
+
for t, c, n, s in inverted_residual_setting:
|
| 195 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
| 196 |
+
for i in range(n):
|
| 197 |
+
stride = s if i == 0 else 1
|
| 198 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
| 199 |
+
input_channel = output_channel
|
| 200 |
+
|
| 201 |
+
self.features = nn.Sequential(*features)
|
| 202 |
+
self.fpn_selected = [1, 3, 6, 10, 13]
|
| 203 |
+
# weight initialization
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv2d):
|
| 206 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
| 207 |
+
if m.bias is not None:
|
| 208 |
+
nn.init.zeros_(m.bias)
|
| 209 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 210 |
+
nn.init.ones_(m.weight)
|
| 211 |
+
nn.init.zeros_(m.bias)
|
| 212 |
+
elif isinstance(m, nn.Linear):
|
| 213 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 214 |
+
nn.init.zeros_(m.bias)
|
| 215 |
+
if pretrained:
|
| 216 |
+
self._load_pretrained_model()
|
| 217 |
+
|
| 218 |
+
def _forward_impl(self, x):
|
| 219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
| 220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
| 221 |
+
fpn_features = []
|
| 222 |
+
for i, f in enumerate(self.features):
|
| 223 |
+
if i > self.fpn_selected[-1]:
|
| 224 |
+
break
|
| 225 |
+
x = f(x)
|
| 226 |
+
if i in self.fpn_selected:
|
| 227 |
+
fpn_features.append(x)
|
| 228 |
+
|
| 229 |
+
c1, c2, c3, c4, c5 = fpn_features
|
| 230 |
+
return c1, c2, c3, c4, c5
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
return self._forward_impl(x)
|
| 235 |
+
|
| 236 |
+
def _load_pretrained_model(self):
|
| 237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
| 238 |
+
model_dict = {}
|
| 239 |
+
state_dict = self.state_dict()
|
| 240 |
+
for k, v in pretrain_dict.items():
|
| 241 |
+
if k in state_dict:
|
| 242 |
+
model_dict[k] = v
|
| 243 |
+
state_dict.update(model_dict)
|
| 244 |
+
self.load_state_dict(state_dict)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class MobileV2_MLSD_Large(nn.Module):
|
| 248 |
+
def __init__(self):
|
| 249 |
+
super(MobileV2_MLSD_Large, self).__init__()
|
| 250 |
+
|
| 251 |
+
self.backbone = MobileNetV2(pretrained=False)
|
| 252 |
+
## A, B
|
| 253 |
+
self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
|
| 254 |
+
out_c1= 64, out_c2=64,
|
| 255 |
+
upscale=False)
|
| 256 |
+
self.block16 = BlockTypeB(128, 64)
|
| 257 |
+
|
| 258 |
+
## A, B
|
| 259 |
+
self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
|
| 260 |
+
out_c1= 64, out_c2= 64)
|
| 261 |
+
self.block18 = BlockTypeB(128, 64)
|
| 262 |
+
|
| 263 |
+
## A, B
|
| 264 |
+
self.block19 = BlockTypeA(in_c1=24, in_c2=64,
|
| 265 |
+
out_c1=64, out_c2=64)
|
| 266 |
+
self.block20 = BlockTypeB(128, 64)
|
| 267 |
+
|
| 268 |
+
## A, B, C
|
| 269 |
+
self.block21 = BlockTypeA(in_c1=16, in_c2=64,
|
| 270 |
+
out_c1=64, out_c2=64)
|
| 271 |
+
self.block22 = BlockTypeB(128, 64)
|
| 272 |
+
|
| 273 |
+
self.block23 = BlockTypeC(64, 16)
|
| 274 |
+
|
| 275 |
+
def forward(self, x):
|
| 276 |
+
c1, c2, c3, c4, c5 = self.backbone(x)
|
| 277 |
+
|
| 278 |
+
x = self.block15(c4, c5)
|
| 279 |
+
x = self.block16(x)
|
| 280 |
+
|
| 281 |
+
x = self.block17(c3, x)
|
| 282 |
+
x = self.block18(x)
|
| 283 |
+
|
| 284 |
+
x = self.block19(c2, x)
|
| 285 |
+
x = self.block20(x)
|
| 286 |
+
|
| 287 |
+
x = self.block21(c1, x)
|
| 288 |
+
x = self.block22(x)
|
| 289 |
+
x = self.block23(x)
|
| 290 |
+
x = x[:, 7:, :, :]
|
| 291 |
+
|
| 292 |
+
return x
|
src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.utils.model_zoo as model_zoo
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BlockTypeA(nn.Module):
|
| 10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
| 11 |
+
super(BlockTypeA, self).__init__()
|
| 12 |
+
self.conv1 = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
| 14 |
+
nn.BatchNorm2d(out_c2),
|
| 15 |
+
nn.ReLU(inplace=True)
|
| 16 |
+
)
|
| 17 |
+
self.conv2 = nn.Sequential(
|
| 18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
| 19 |
+
nn.BatchNorm2d(out_c1),
|
| 20 |
+
nn.ReLU(inplace=True)
|
| 21 |
+
)
|
| 22 |
+
self.upscale = upscale
|
| 23 |
+
|
| 24 |
+
def forward(self, a, b):
|
| 25 |
+
b = self.conv1(b)
|
| 26 |
+
a = self.conv2(a)
|
| 27 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
| 28 |
+
return torch.cat((a, b), dim=1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BlockTypeB(nn.Module):
|
| 32 |
+
def __init__(self, in_c, out_c):
|
| 33 |
+
super(BlockTypeB, self).__init__()
|
| 34 |
+
self.conv1 = nn.Sequential(
|
| 35 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
| 36 |
+
nn.BatchNorm2d(in_c),
|
| 37 |
+
nn.ReLU()
|
| 38 |
+
)
|
| 39 |
+
self.conv2 = nn.Sequential(
|
| 40 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
| 41 |
+
nn.BatchNorm2d(out_c),
|
| 42 |
+
nn.ReLU()
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.conv1(x) + x
|
| 47 |
+
x = self.conv2(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
class BlockTypeC(nn.Module):
|
| 51 |
+
def __init__(self, in_c, out_c):
|
| 52 |
+
super(BlockTypeC, self).__init__()
|
| 53 |
+
self.conv1 = nn.Sequential(
|
| 54 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
| 55 |
+
nn.BatchNorm2d(in_c),
|
| 56 |
+
nn.ReLU()
|
| 57 |
+
)
|
| 58 |
+
self.conv2 = nn.Sequential(
|
| 59 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
| 60 |
+
nn.BatchNorm2d(in_c),
|
| 61 |
+
nn.ReLU()
|
| 62 |
+
)
|
| 63 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
x = self.conv1(x)
|
| 67 |
+
x = self.conv2(x)
|
| 68 |
+
x = self.conv3(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
def _make_divisible(v, divisor, min_value=None):
|
| 72 |
+
"""
|
| 73 |
+
This function is taken from the original tf repo.
|
| 74 |
+
It ensures that all layers have a channel number that is divisible by 8
|
| 75 |
+
It can be seen here:
|
| 76 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
| 77 |
+
:param v:
|
| 78 |
+
:param divisor:
|
| 79 |
+
:param min_value:
|
| 80 |
+
:return:
|
| 81 |
+
"""
|
| 82 |
+
if min_value is None:
|
| 83 |
+
min_value = divisor
|
| 84 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 85 |
+
# Make sure that round down does not go down by more than 10%.
|
| 86 |
+
if new_v < 0.9 * v:
|
| 87 |
+
new_v += divisor
|
| 88 |
+
return new_v
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ConvBNReLU(nn.Sequential):
|
| 92 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
| 93 |
+
self.channel_pad = out_planes - in_planes
|
| 94 |
+
self.stride = stride
|
| 95 |
+
#padding = (kernel_size - 1) // 2
|
| 96 |
+
|
| 97 |
+
# TFLite uses slightly different padding than PyTorch
|
| 98 |
+
if stride == 2:
|
| 99 |
+
padding = 0
|
| 100 |
+
else:
|
| 101 |
+
padding = (kernel_size - 1) // 2
|
| 102 |
+
|
| 103 |
+
super(ConvBNReLU, self).__init__(
|
| 104 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
| 105 |
+
nn.BatchNorm2d(out_planes),
|
| 106 |
+
nn.ReLU6(inplace=True)
|
| 107 |
+
)
|
| 108 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
# TFLite uses different padding
|
| 113 |
+
if self.stride == 2:
|
| 114 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
| 115 |
+
#print(x.shape)
|
| 116 |
+
|
| 117 |
+
for module in self:
|
| 118 |
+
if not isinstance(module, nn.MaxPool2d):
|
| 119 |
+
x = module(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class InvertedResidual(nn.Module):
|
| 124 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
| 125 |
+
super(InvertedResidual, self).__init__()
|
| 126 |
+
self.stride = stride
|
| 127 |
+
assert stride in [1, 2]
|
| 128 |
+
|
| 129 |
+
hidden_dim = int(round(inp * expand_ratio))
|
| 130 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
| 131 |
+
|
| 132 |
+
layers = []
|
| 133 |
+
if expand_ratio != 1:
|
| 134 |
+
# pw
|
| 135 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
| 136 |
+
layers.extend([
|
| 137 |
+
# dw
|
| 138 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
| 139 |
+
# pw-linear
|
| 140 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 141 |
+
nn.BatchNorm2d(oup),
|
| 142 |
+
])
|
| 143 |
+
self.conv = nn.Sequential(*layers)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
if self.use_res_connect:
|
| 147 |
+
return x + self.conv(x)
|
| 148 |
+
else:
|
| 149 |
+
return self.conv(x)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class MobileNetV2(nn.Module):
|
| 153 |
+
def __init__(self, pretrained=True):
|
| 154 |
+
"""
|
| 155 |
+
MobileNet V2 main class
|
| 156 |
+
Args:
|
| 157 |
+
num_classes (int): Number of classes
|
| 158 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
| 159 |
+
inverted_residual_setting: Network structure
|
| 160 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
| 161 |
+
Set to 1 to turn off rounding
|
| 162 |
+
block: Module specifying inverted residual building block for mobilenet
|
| 163 |
+
"""
|
| 164 |
+
super(MobileNetV2, self).__init__()
|
| 165 |
+
|
| 166 |
+
block = InvertedResidual
|
| 167 |
+
input_channel = 32
|
| 168 |
+
last_channel = 1280
|
| 169 |
+
width_mult = 1.0
|
| 170 |
+
round_nearest = 8
|
| 171 |
+
|
| 172 |
+
inverted_residual_setting = [
|
| 173 |
+
# t, c, n, s
|
| 174 |
+
[1, 16, 1, 1],
|
| 175 |
+
[6, 24, 2, 2],
|
| 176 |
+
[6, 32, 3, 2],
|
| 177 |
+
[6, 64, 4, 2],
|
| 178 |
+
#[6, 96, 3, 1],
|
| 179 |
+
#[6, 160, 3, 2],
|
| 180 |
+
#[6, 320, 1, 1],
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
| 184 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
| 185 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
| 186 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
| 187 |
+
|
| 188 |
+
# building first layer
|
| 189 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
| 190 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
| 191 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
| 192 |
+
# building inverted residual blocks
|
| 193 |
+
for t, c, n, s in inverted_residual_setting:
|
| 194 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
| 195 |
+
for i in range(n):
|
| 196 |
+
stride = s if i == 0 else 1
|
| 197 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
| 198 |
+
input_channel = output_channel
|
| 199 |
+
self.features = nn.Sequential(*features)
|
| 200 |
+
|
| 201 |
+
self.fpn_selected = [3, 6, 10]
|
| 202 |
+
# weight initialization
|
| 203 |
+
for m in self.modules():
|
| 204 |
+
if isinstance(m, nn.Conv2d):
|
| 205 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
| 206 |
+
if m.bias is not None:
|
| 207 |
+
nn.init.zeros_(m.bias)
|
| 208 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 209 |
+
nn.init.ones_(m.weight)
|
| 210 |
+
nn.init.zeros_(m.bias)
|
| 211 |
+
elif isinstance(m, nn.Linear):
|
| 212 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 213 |
+
nn.init.zeros_(m.bias)
|
| 214 |
+
|
| 215 |
+
#if pretrained:
|
| 216 |
+
# self._load_pretrained_model()
|
| 217 |
+
|
| 218 |
+
def _forward_impl(self, x):
|
| 219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
| 220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
| 221 |
+
fpn_features = []
|
| 222 |
+
for i, f in enumerate(self.features):
|
| 223 |
+
if i > self.fpn_selected[-1]:
|
| 224 |
+
break
|
| 225 |
+
x = f(x)
|
| 226 |
+
if i in self.fpn_selected:
|
| 227 |
+
fpn_features.append(x)
|
| 228 |
+
|
| 229 |
+
c2, c3, c4 = fpn_features
|
| 230 |
+
return c2, c3, c4
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
return self._forward_impl(x)
|
| 235 |
+
|
| 236 |
+
def _load_pretrained_model(self):
|
| 237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
| 238 |
+
model_dict = {}
|
| 239 |
+
state_dict = self.state_dict()
|
| 240 |
+
for k, v in pretrain_dict.items():
|
| 241 |
+
if k in state_dict:
|
| 242 |
+
model_dict[k] = v
|
| 243 |
+
state_dict.update(model_dict)
|
| 244 |
+
self.load_state_dict(state_dict)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class MobileV2_MLSD_Tiny(nn.Module):
|
| 248 |
+
def __init__(self):
|
| 249 |
+
super(MobileV2_MLSD_Tiny, self).__init__()
|
| 250 |
+
|
| 251 |
+
self.backbone = MobileNetV2(pretrained=True)
|
| 252 |
+
|
| 253 |
+
self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
|
| 254 |
+
out_c1= 64, out_c2=64)
|
| 255 |
+
self.block13 = BlockTypeB(128, 64)
|
| 256 |
+
|
| 257 |
+
self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
|
| 258 |
+
out_c1= 32, out_c2= 32)
|
| 259 |
+
self.block15 = BlockTypeB(64, 64)
|
| 260 |
+
|
| 261 |
+
self.block16 = BlockTypeC(64, 16)
|
| 262 |
+
|
| 263 |
+
def forward(self, x):
|
| 264 |
+
c2, c3, c4 = self.backbone(x)
|
| 265 |
+
|
| 266 |
+
x = self.block12(c3, c4)
|
| 267 |
+
x = self.block13(x)
|
| 268 |
+
x = self.block14(c2, x)
|
| 269 |
+
x = self.block15(x)
|
| 270 |
+
x = self.block16(x)
|
| 271 |
+
x = x[:, 7:, :, :]
|
| 272 |
+
#print(x.shape)
|
| 273 |
+
x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
|
| 274 |
+
|
| 275 |
+
return x
|
src/flux/annotator/mlsd/utils.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
modified by lihaoweicv
|
| 3 |
+
pytorch version
|
| 4 |
+
'''
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
M-LSD
|
| 8 |
+
Copyright 2021-present NAVER Corp.
|
| 9 |
+
Apache License v2.0
|
| 10 |
+
'''
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import numpy as np
|
| 14 |
+
import cv2
|
| 15 |
+
import torch
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
| 20 |
+
'''
|
| 21 |
+
tpMap:
|
| 22 |
+
center: tpMap[1, 0, :, :]
|
| 23 |
+
displacement: tpMap[1, 1:5, :, :]
|
| 24 |
+
'''
|
| 25 |
+
b, c, h, w = tpMap.shape
|
| 26 |
+
assert b==1, 'only support bsize==1'
|
| 27 |
+
displacement = tpMap[:, 1:5, :, :][0]
|
| 28 |
+
center = tpMap[:, 0, :, :]
|
| 29 |
+
heat = torch.sigmoid(center)
|
| 30 |
+
hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
|
| 31 |
+
keep = (hmax == heat).float()
|
| 32 |
+
heat = heat * keep
|
| 33 |
+
heat = heat.reshape(-1, )
|
| 34 |
+
|
| 35 |
+
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
|
| 36 |
+
yy = torch.floor_divide(indices, w).unsqueeze(-1)
|
| 37 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
| 38 |
+
ptss = torch.cat((yy, xx),dim=-1)
|
| 39 |
+
|
| 40 |
+
ptss = ptss.detach().cpu().numpy()
|
| 41 |
+
scores = scores.detach().cpu().numpy()
|
| 42 |
+
displacement = displacement.detach().cpu().numpy()
|
| 43 |
+
displacement = displacement.transpose((1,2,0))
|
| 44 |
+
return ptss, scores, displacement
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def pred_lines(image, model,
|
| 48 |
+
input_shape=[512, 512],
|
| 49 |
+
score_thr=0.10,
|
| 50 |
+
dist_thr=20.0):
|
| 51 |
+
h, w, _ = image.shape
|
| 52 |
+
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
| 53 |
+
|
| 54 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
| 55 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
| 56 |
+
|
| 57 |
+
resized_image = resized_image.transpose((2,0,1))
|
| 58 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
| 59 |
+
batch_image = (batch_image / 127.5) - 1.0
|
| 60 |
+
|
| 61 |
+
batch_image = torch.from_numpy(batch_image).float().to("cuda:4")
|
| 62 |
+
outputs = model(batch_image)
|
| 63 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
| 64 |
+
start = vmap[:, :, :2]
|
| 65 |
+
end = vmap[:, :, 2:]
|
| 66 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
| 67 |
+
|
| 68 |
+
segments_list = []
|
| 69 |
+
for center, score in zip(pts, pts_score):
|
| 70 |
+
y, x = center
|
| 71 |
+
distance = dist_map[y, x]
|
| 72 |
+
if score > score_thr and distance > dist_thr:
|
| 73 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
| 74 |
+
x_start = x + disp_x_start
|
| 75 |
+
y_start = y + disp_y_start
|
| 76 |
+
x_end = x + disp_x_end
|
| 77 |
+
y_end = y + disp_y_end
|
| 78 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
| 79 |
+
|
| 80 |
+
lines = 2 * np.array(segments_list) # 256 > 512
|
| 81 |
+
lines[:, 0] = lines[:, 0] * w_ratio
|
| 82 |
+
lines[:, 1] = lines[:, 1] * h_ratio
|
| 83 |
+
lines[:, 2] = lines[:, 2] * w_ratio
|
| 84 |
+
lines[:, 3] = lines[:, 3] * h_ratio
|
| 85 |
+
|
| 86 |
+
return lines
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def pred_squares(image,
|
| 90 |
+
model,
|
| 91 |
+
input_shape=[512, 512],
|
| 92 |
+
params={'score': 0.06,
|
| 93 |
+
'outside_ratio': 0.28,
|
| 94 |
+
'inside_ratio': 0.45,
|
| 95 |
+
'w_overlap': 0.0,
|
| 96 |
+
'w_degree': 1.95,
|
| 97 |
+
'w_length': 0.0,
|
| 98 |
+
'w_area': 1.86,
|
| 99 |
+
'w_center': 0.14}):
|
| 100 |
+
'''
|
| 101 |
+
shape = [height, width]
|
| 102 |
+
'''
|
| 103 |
+
h, w, _ = image.shape
|
| 104 |
+
original_shape = [h, w]
|
| 105 |
+
|
| 106 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
| 107 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
| 108 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
| 109 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
| 110 |
+
batch_image = (batch_image / 127.5) - 1.0
|
| 111 |
+
|
| 112 |
+
batch_image = torch.from_numpy(batch_image).float().cuda()
|
| 113 |
+
outputs = model(batch_image)
|
| 114 |
+
|
| 115 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
| 116 |
+
start = vmap[:, :, :2] # (x, y)
|
| 117 |
+
end = vmap[:, :, 2:] # (x, y)
|
| 118 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
| 119 |
+
|
| 120 |
+
junc_list = []
|
| 121 |
+
segments_list = []
|
| 122 |
+
for junc, score in zip(pts, pts_score):
|
| 123 |
+
y, x = junc
|
| 124 |
+
distance = dist_map[y, x]
|
| 125 |
+
if score > params['score'] and distance > 20.0:
|
| 126 |
+
junc_list.append([x, y])
|
| 127 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
| 128 |
+
d_arrow = 1.0
|
| 129 |
+
x_start = x + d_arrow * disp_x_start
|
| 130 |
+
y_start = y + d_arrow * disp_y_start
|
| 131 |
+
x_end = x + d_arrow * disp_x_end
|
| 132 |
+
y_end = y + d_arrow * disp_y_end
|
| 133 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
| 134 |
+
|
| 135 |
+
segments = np.array(segments_list)
|
| 136 |
+
|
| 137 |
+
####### post processing for squares
|
| 138 |
+
# 1. get unique lines
|
| 139 |
+
point = np.array([[0, 0]])
|
| 140 |
+
point = point[0]
|
| 141 |
+
start = segments[:, :2]
|
| 142 |
+
end = segments[:, 2:]
|
| 143 |
+
diff = start - end
|
| 144 |
+
a = diff[:, 1]
|
| 145 |
+
b = -diff[:, 0]
|
| 146 |
+
c = a * start[:, 0] + b * start[:, 1]
|
| 147 |
+
|
| 148 |
+
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
|
| 149 |
+
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
|
| 150 |
+
theta[theta < 0.0] += 180
|
| 151 |
+
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
|
| 152 |
+
|
| 153 |
+
d_quant = 1
|
| 154 |
+
theta_quant = 2
|
| 155 |
+
hough[:, 0] //= d_quant
|
| 156 |
+
hough[:, 1] //= theta_quant
|
| 157 |
+
_, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
|
| 158 |
+
|
| 159 |
+
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
|
| 160 |
+
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
|
| 161 |
+
yx_indices = hough[indices, :].astype('int32')
|
| 162 |
+
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
|
| 163 |
+
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
|
| 164 |
+
|
| 165 |
+
acc_map_np = acc_map
|
| 166 |
+
# acc_map = acc_map[None, :, :, None]
|
| 167 |
+
#
|
| 168 |
+
# ### fast suppression using tensorflow op
|
| 169 |
+
# acc_map = tf.constant(acc_map, dtype=tf.float32)
|
| 170 |
+
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
|
| 171 |
+
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
|
| 172 |
+
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
|
| 173 |
+
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
|
| 174 |
+
# _, h, w, _ = acc_map.shape
|
| 175 |
+
# y = tf.expand_dims(topk_indices // w, axis=-1)
|
| 176 |
+
# x = tf.expand_dims(topk_indices % w, axis=-1)
|
| 177 |
+
# yx = tf.concat([y, x], axis=-1)
|
| 178 |
+
|
| 179 |
+
### fast suppression using pytorch op
|
| 180 |
+
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
|
| 181 |
+
_,_, h, w = acc_map.shape
|
| 182 |
+
max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
|
| 183 |
+
acc_map = acc_map * ( (acc_map == max_acc_map).float() )
|
| 184 |
+
flatten_acc_map = acc_map.reshape([-1, ])
|
| 185 |
+
|
| 186 |
+
scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
|
| 187 |
+
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
|
| 188 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
| 189 |
+
yx = torch.cat((yy, xx), dim=-1)
|
| 190 |
+
|
| 191 |
+
yx = yx.detach().cpu().numpy()
|
| 192 |
+
|
| 193 |
+
topk_values = scores.detach().cpu().numpy()
|
| 194 |
+
indices = idx_map[yx[:, 0], yx[:, 1]]
|
| 195 |
+
basis = 5 // 2
|
| 196 |
+
|
| 197 |
+
merged_segments = []
|
| 198 |
+
for yx_pt, max_indice, value in zip(yx, indices, topk_values):
|
| 199 |
+
y, x = yx_pt
|
| 200 |
+
if max_indice == -1 or value == 0:
|
| 201 |
+
continue
|
| 202 |
+
segment_list = []
|
| 203 |
+
for y_offset in range(-basis, basis + 1):
|
| 204 |
+
for x_offset in range(-basis, basis + 1):
|
| 205 |
+
indice = idx_map[y + y_offset, x + x_offset]
|
| 206 |
+
cnt = int(acc_map_np[y + y_offset, x + x_offset])
|
| 207 |
+
if indice != -1:
|
| 208 |
+
segment_list.append(segments[indice])
|
| 209 |
+
if cnt > 1:
|
| 210 |
+
check_cnt = 1
|
| 211 |
+
current_hough = hough[indice]
|
| 212 |
+
for new_indice, new_hough in enumerate(hough):
|
| 213 |
+
if (current_hough == new_hough).all() and indice != new_indice:
|
| 214 |
+
segment_list.append(segments[new_indice])
|
| 215 |
+
check_cnt += 1
|
| 216 |
+
if check_cnt == cnt:
|
| 217 |
+
break
|
| 218 |
+
group_segments = np.array(segment_list).reshape([-1, 2])
|
| 219 |
+
sorted_group_segments = np.sort(group_segments, axis=0)
|
| 220 |
+
x_min, y_min = sorted_group_segments[0, :]
|
| 221 |
+
x_max, y_max = sorted_group_segments[-1, :]
|
| 222 |
+
|
| 223 |
+
deg = theta[max_indice]
|
| 224 |
+
if deg >= 90:
|
| 225 |
+
merged_segments.append([x_min, y_max, x_max, y_min])
|
| 226 |
+
else:
|
| 227 |
+
merged_segments.append([x_min, y_min, x_max, y_max])
|
| 228 |
+
|
| 229 |
+
# 2. get intersections
|
| 230 |
+
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
|
| 231 |
+
start = new_segments[:, :2] # (x1, y1)
|
| 232 |
+
end = new_segments[:, 2:] # (x2, y2)
|
| 233 |
+
new_centers = (start + end) / 2.0
|
| 234 |
+
diff = start - end
|
| 235 |
+
dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
|
| 236 |
+
|
| 237 |
+
# ax + by = c
|
| 238 |
+
a = diff[:, 1]
|
| 239 |
+
b = -diff[:, 0]
|
| 240 |
+
c = a * start[:, 0] + b * start[:, 1]
|
| 241 |
+
pre_det = a[:, None] * b[None, :]
|
| 242 |
+
det = pre_det - np.transpose(pre_det)
|
| 243 |
+
|
| 244 |
+
pre_inter_y = a[:, None] * c[None, :]
|
| 245 |
+
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
|
| 246 |
+
pre_inter_x = c[:, None] * b[None, :]
|
| 247 |
+
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
|
| 248 |
+
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
|
| 249 |
+
|
| 250 |
+
# 3. get corner information
|
| 251 |
+
# 3.1 get distance
|
| 252 |
+
'''
|
| 253 |
+
dist_segments:
|
| 254 |
+
| dist(0), dist(1), dist(2), ...|
|
| 255 |
+
dist_inter_to_segment1:
|
| 256 |
+
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
|
| 257 |
+
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
|
| 258 |
+
...
|
| 259 |
+
dist_inter_to_semgnet2:
|
| 260 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
| 261 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
| 262 |
+
...
|
| 263 |
+
'''
|
| 264 |
+
|
| 265 |
+
dist_inter_to_segment1_start = np.sqrt(
|
| 266 |
+
np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
| 267 |
+
dist_inter_to_segment1_end = np.sqrt(
|
| 268 |
+
np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
| 269 |
+
dist_inter_to_segment2_start = np.sqrt(
|
| 270 |
+
np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
| 271 |
+
dist_inter_to_segment2_end = np.sqrt(
|
| 272 |
+
np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
| 273 |
+
|
| 274 |
+
# sort ascending
|
| 275 |
+
dist_inter_to_segment1 = np.sort(
|
| 276 |
+
np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
|
| 277 |
+
axis=-1) # [n_batch, n_batch, 2]
|
| 278 |
+
dist_inter_to_segment2 = np.sort(
|
| 279 |
+
np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
|
| 280 |
+
axis=-1) # [n_batch, n_batch, 2]
|
| 281 |
+
|
| 282 |
+
# 3.2 get degree
|
| 283 |
+
inter_to_start = new_centers[:, None, :] - inter_pts
|
| 284 |
+
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
|
| 285 |
+
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
|
| 286 |
+
inter_to_end = new_centers[None, :, :] - inter_pts
|
| 287 |
+
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
|
| 288 |
+
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
|
| 289 |
+
|
| 290 |
+
'''
|
| 291 |
+
B -- G
|
| 292 |
+
| |
|
| 293 |
+
C -- R
|
| 294 |
+
B : blue / G: green / C: cyan / R: red
|
| 295 |
+
|
| 296 |
+
0 -- 1
|
| 297 |
+
| |
|
| 298 |
+
3 -- 2
|
| 299 |
+
'''
|
| 300 |
+
# rename variables
|
| 301 |
+
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
|
| 302 |
+
# sort deg ascending
|
| 303 |
+
deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
|
| 304 |
+
|
| 305 |
+
deg_diff_map = np.abs(deg1_map - deg2_map)
|
| 306 |
+
# we only consider the smallest degree of intersect
|
| 307 |
+
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
|
| 308 |
+
|
| 309 |
+
# define available degree range
|
| 310 |
+
deg_range = [60, 120]
|
| 311 |
+
|
| 312 |
+
corner_dict = {corner_info: [] for corner_info in range(4)}
|
| 313 |
+
inter_points = []
|
| 314 |
+
for i in range(inter_pts.shape[0]):
|
| 315 |
+
for j in range(i + 1, inter_pts.shape[1]):
|
| 316 |
+
# i, j > line index, always i < j
|
| 317 |
+
x, y = inter_pts[i, j, :]
|
| 318 |
+
deg1, deg2 = deg_sort[i, j, :]
|
| 319 |
+
deg_diff = deg_diff_map[i, j]
|
| 320 |
+
|
| 321 |
+
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
|
| 322 |
+
|
| 323 |
+
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
|
| 324 |
+
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
|
| 325 |
+
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
|
| 326 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
|
| 327 |
+
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
|
| 328 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
|
| 329 |
+
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
|
| 330 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
|
| 331 |
+
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
|
| 332 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
|
| 333 |
+
|
| 334 |
+
if check_degree and check_distance:
|
| 335 |
+
corner_info = None
|
| 336 |
+
|
| 337 |
+
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
|
| 338 |
+
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
|
| 339 |
+
corner_info, color_info = 0, 'blue'
|
| 340 |
+
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
|
| 341 |
+
corner_info, color_info = 1, 'green'
|
| 342 |
+
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
|
| 343 |
+
corner_info, color_info = 2, 'black'
|
| 344 |
+
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
|
| 345 |
+
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
|
| 346 |
+
corner_info, color_info = 3, 'cyan'
|
| 347 |
+
else:
|
| 348 |
+
corner_info, color_info = 4, 'red' # we don't use it
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
corner_dict[corner_info].append([x, y, i, j])
|
| 352 |
+
inter_points.append([x, y])
|
| 353 |
+
|
| 354 |
+
square_list = []
|
| 355 |
+
connect_list = []
|
| 356 |
+
segments_list = []
|
| 357 |
+
for corner0 in corner_dict[0]:
|
| 358 |
+
for corner1 in corner_dict[1]:
|
| 359 |
+
connect01 = False
|
| 360 |
+
for corner0_line in corner0[2:]:
|
| 361 |
+
if corner0_line in corner1[2:]:
|
| 362 |
+
connect01 = True
|
| 363 |
+
break
|
| 364 |
+
if connect01:
|
| 365 |
+
for corner2 in corner_dict[2]:
|
| 366 |
+
connect12 = False
|
| 367 |
+
for corner1_line in corner1[2:]:
|
| 368 |
+
if corner1_line in corner2[2:]:
|
| 369 |
+
connect12 = True
|
| 370 |
+
break
|
| 371 |
+
if connect12:
|
| 372 |
+
for corner3 in corner_dict[3]:
|
| 373 |
+
connect23 = False
|
| 374 |
+
for corner2_line in corner2[2:]:
|
| 375 |
+
if corner2_line in corner3[2:]:
|
| 376 |
+
connect23 = True
|
| 377 |
+
break
|
| 378 |
+
if connect23:
|
| 379 |
+
for corner3_line in corner3[2:]:
|
| 380 |
+
if corner3_line in corner0[2:]:
|
| 381 |
+
# SQUARE!!!
|
| 382 |
+
'''
|
| 383 |
+
0 -- 1
|
| 384 |
+
| |
|
| 385 |
+
3 -- 2
|
| 386 |
+
square_list:
|
| 387 |
+
order: 0 > 1 > 2 > 3
|
| 388 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
| 389 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
| 390 |
+
...
|
| 391 |
+
connect_list:
|
| 392 |
+
order: 01 > 12 > 23 > 30
|
| 393 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
| 394 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
| 395 |
+
...
|
| 396 |
+
segments_list:
|
| 397 |
+
order: 0 > 1 > 2 > 3
|
| 398 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
| 399 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
| 400 |
+
...
|
| 401 |
+
'''
|
| 402 |
+
square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
|
| 403 |
+
connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
|
| 404 |
+
segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
|
| 405 |
+
|
| 406 |
+
def check_outside_inside(segments_info, connect_idx):
|
| 407 |
+
# return 'outside or inside', min distance, cover_param, peri_param
|
| 408 |
+
if connect_idx == segments_info[0]:
|
| 409 |
+
check_dist_mat = dist_inter_to_segment1
|
| 410 |
+
else:
|
| 411 |
+
check_dist_mat = dist_inter_to_segment2
|
| 412 |
+
|
| 413 |
+
i, j = segments_info
|
| 414 |
+
min_dist, max_dist = check_dist_mat[i, j, :]
|
| 415 |
+
connect_dist = dist_segments[connect_idx]
|
| 416 |
+
if max_dist > connect_dist:
|
| 417 |
+
return 'outside', min_dist, 0, 1
|
| 418 |
+
else:
|
| 419 |
+
return 'inside', min_dist, -1, -1
|
| 420 |
+
|
| 421 |
+
top_square = None
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
map_size = input_shape[0] / 2
|
| 425 |
+
squares = np.array(square_list).reshape([-1, 4, 2])
|
| 426 |
+
score_array = []
|
| 427 |
+
connect_array = np.array(connect_list)
|
| 428 |
+
segments_array = np.array(segments_list).reshape([-1, 4, 2])
|
| 429 |
+
|
| 430 |
+
# get degree of corners:
|
| 431 |
+
squares_rollup = np.roll(squares, 1, axis=1)
|
| 432 |
+
squares_rolldown = np.roll(squares, -1, axis=1)
|
| 433 |
+
vec1 = squares_rollup - squares
|
| 434 |
+
normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
|
| 435 |
+
vec2 = squares_rolldown - squares
|
| 436 |
+
normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
|
| 437 |
+
inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
|
| 438 |
+
squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
|
| 439 |
+
|
| 440 |
+
# get square score
|
| 441 |
+
overlap_scores = []
|
| 442 |
+
degree_scores = []
|
| 443 |
+
length_scores = []
|
| 444 |
+
|
| 445 |
+
for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
|
| 446 |
+
'''
|
| 447 |
+
0 -- 1
|
| 448 |
+
| |
|
| 449 |
+
3 -- 2
|
| 450 |
+
|
| 451 |
+
# segments: [4, 2]
|
| 452 |
+
# connects: [4]
|
| 453 |
+
'''
|
| 454 |
+
|
| 455 |
+
###################################### OVERLAP SCORES
|
| 456 |
+
cover = 0
|
| 457 |
+
perimeter = 0
|
| 458 |
+
# check 0 > 1 > 2 > 3
|
| 459 |
+
square_length = []
|
| 460 |
+
|
| 461 |
+
for start_idx in range(4):
|
| 462 |
+
end_idx = (start_idx + 1) % 4
|
| 463 |
+
|
| 464 |
+
connect_idx = connects[start_idx] # segment idx of segment01
|
| 465 |
+
start_segments = segments[start_idx]
|
| 466 |
+
end_segments = segments[end_idx]
|
| 467 |
+
|
| 468 |
+
start_point = square[start_idx]
|
| 469 |
+
end_point = square[end_idx]
|
| 470 |
+
|
| 471 |
+
# check whether outside or inside
|
| 472 |
+
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
|
| 473 |
+
connect_idx)
|
| 474 |
+
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
|
| 475 |
+
|
| 476 |
+
cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
|
| 477 |
+
perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
|
| 478 |
+
|
| 479 |
+
square_length.append(
|
| 480 |
+
dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
|
| 481 |
+
|
| 482 |
+
overlap_scores.append(cover / perimeter)
|
| 483 |
+
######################################
|
| 484 |
+
###################################### DEGREE SCORES
|
| 485 |
+
'''
|
| 486 |
+
deg0 vs deg2
|
| 487 |
+
deg1 vs deg3
|
| 488 |
+
'''
|
| 489 |
+
deg0, deg1, deg2, deg3 = degree
|
| 490 |
+
deg_ratio1 = deg0 / deg2
|
| 491 |
+
if deg_ratio1 > 1.0:
|
| 492 |
+
deg_ratio1 = 1 / deg_ratio1
|
| 493 |
+
deg_ratio2 = deg1 / deg3
|
| 494 |
+
if deg_ratio2 > 1.0:
|
| 495 |
+
deg_ratio2 = 1 / deg_ratio2
|
| 496 |
+
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
|
| 497 |
+
######################################
|
| 498 |
+
###################################### LENGTH SCORES
|
| 499 |
+
'''
|
| 500 |
+
len0 vs len2
|
| 501 |
+
len1 vs len3
|
| 502 |
+
'''
|
| 503 |
+
len0, len1, len2, len3 = square_length
|
| 504 |
+
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
|
| 505 |
+
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
|
| 506 |
+
length_scores.append((len_ratio1 + len_ratio2) / 2)
|
| 507 |
+
|
| 508 |
+
######################################
|
| 509 |
+
|
| 510 |
+
overlap_scores = np.array(overlap_scores)
|
| 511 |
+
overlap_scores /= np.max(overlap_scores)
|
| 512 |
+
|
| 513 |
+
degree_scores = np.array(degree_scores)
|
| 514 |
+
# degree_scores /= np.max(degree_scores)
|
| 515 |
+
|
| 516 |
+
length_scores = np.array(length_scores)
|
| 517 |
+
|
| 518 |
+
###################################### AREA SCORES
|
| 519 |
+
area_scores = np.reshape(squares, [-1, 4, 2])
|
| 520 |
+
area_x = area_scores[:, :, 0]
|
| 521 |
+
area_y = area_scores[:, :, 1]
|
| 522 |
+
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
|
| 523 |
+
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
|
| 524 |
+
area_scores = 0.5 * np.abs(area_scores + correction)
|
| 525 |
+
area_scores /= (map_size * map_size) # np.max(area_scores)
|
| 526 |
+
######################################
|
| 527 |
+
|
| 528 |
+
###################################### CENTER SCORES
|
| 529 |
+
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
|
| 530 |
+
# squares: [n, 4, 2]
|
| 531 |
+
square_centers = np.mean(squares, axis=1) # [n, 2]
|
| 532 |
+
center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
|
| 533 |
+
center_scores = center2center / (map_size / np.sqrt(2.0))
|
| 534 |
+
|
| 535 |
+
'''
|
| 536 |
+
score_w = [overlap, degree, area, center, length]
|
| 537 |
+
'''
|
| 538 |
+
score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
|
| 539 |
+
score_array = params['w_overlap'] * overlap_scores \
|
| 540 |
+
+ params['w_degree'] * degree_scores \
|
| 541 |
+
+ params['w_area'] * area_scores \
|
| 542 |
+
- params['w_center'] * center_scores \
|
| 543 |
+
+ params['w_length'] * length_scores
|
| 544 |
+
|
| 545 |
+
best_square = []
|
| 546 |
+
|
| 547 |
+
sorted_idx = np.argsort(score_array)[::-1]
|
| 548 |
+
score_array = score_array[sorted_idx]
|
| 549 |
+
squares = squares[sorted_idx]
|
| 550 |
+
|
| 551 |
+
except Exception as e:
|
| 552 |
+
pass
|
| 553 |
+
|
| 554 |
+
'''return list
|
| 555 |
+
merged_lines, squares, scores
|
| 556 |
+
'''
|
| 557 |
+
|
| 558 |
+
try:
|
| 559 |
+
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
|
| 560 |
+
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
|
| 561 |
+
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
|
| 562 |
+
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
|
| 563 |
+
except:
|
| 564 |
+
new_segments = []
|
| 565 |
+
|
| 566 |
+
try:
|
| 567 |
+
squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
|
| 568 |
+
squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
|
| 569 |
+
except:
|
| 570 |
+
squares = []
|
| 571 |
+
score_array = []
|
| 572 |
+
|
| 573 |
+
try:
|
| 574 |
+
inter_points = np.array(inter_points)
|
| 575 |
+
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
|
| 576 |
+
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
|
| 577 |
+
except:
|
| 578 |
+
inter_points = []
|
| 579 |
+
|
| 580 |
+
return new_segments, squares, score_array, inter_points
|
src/flux/annotator/tile/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import cv2
|
| 3 |
+
from .guided_filter import FastGuidedFilter
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TileDetector:
|
| 7 |
+
# https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0
|
| 8 |
+
def __init__(self):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def __call__(self, image):
|
| 12 |
+
blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0]
|
| 13 |
+
radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]
|
| 14 |
+
eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0]
|
| 15 |
+
scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0]
|
| 16 |
+
|
| 17 |
+
ksize = int(blur_strength)
|
| 18 |
+
if ksize % 2 == 0:
|
| 19 |
+
ksize += 1
|
| 20 |
+
|
| 21 |
+
if random.random() > 0.5:
|
| 22 |
+
image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2)
|
| 23 |
+
if random.random() > 0.5:
|
| 24 |
+
filter = FastGuidedFilter(image, radius, eps, scale_factor)
|
| 25 |
+
image = filter.filter(image)
|
| 26 |
+
return image
|
src/flux/annotator/tile/guided_filter.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
## @package guided_filter.core.filters
|
| 3 |
+
#
|
| 4 |
+
# Implementation of guided filter.
|
| 5 |
+
# * GuidedFilter: Original guided filter.
|
| 6 |
+
# * FastGuidedFilter: Fast version of the guided filter.
|
| 7 |
+
# @author tody
|
| 8 |
+
# @date 2015/08/26
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
## Convert image into float32 type.
|
| 14 |
+
def to32F(img):
|
| 15 |
+
if img.dtype == np.float32:
|
| 16 |
+
return img
|
| 17 |
+
return (1.0 / 255.0) * np.float32(img)
|
| 18 |
+
|
| 19 |
+
## Convert image into uint8 type.
|
| 20 |
+
def to8U(img):
|
| 21 |
+
if img.dtype == np.uint8:
|
| 22 |
+
return img
|
| 23 |
+
return np.clip(np.uint8(255.0 * img), 0, 255)
|
| 24 |
+
|
| 25 |
+
## Return if the input image is gray or not.
|
| 26 |
+
def _isGray(I):
|
| 27 |
+
return len(I.shape) == 2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## Return down sampled image.
|
| 31 |
+
# @param scale (w/s, h/s) image will be created.
|
| 32 |
+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
|
| 33 |
+
def _downSample(I, scale=4, shape=None):
|
| 34 |
+
if shape is not None:
|
| 35 |
+
h, w = shape
|
| 36 |
+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 37 |
+
|
| 38 |
+
h, w = I.shape[:2]
|
| 39 |
+
return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## Return up sampled image.
|
| 43 |
+
# @param scale (w*s, h*s) image will be created.
|
| 44 |
+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
|
| 45 |
+
def _upSample(I, scale=2, shape=None):
|
| 46 |
+
if shape is not None:
|
| 47 |
+
h, w = shape
|
| 48 |
+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 49 |
+
|
| 50 |
+
h, w = I.shape[:2]
|
| 51 |
+
return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
|
| 52 |
+
|
| 53 |
+
## Fast guide filter.
|
| 54 |
+
class FastGuidedFilter:
|
| 55 |
+
## Constructor.
|
| 56 |
+
# @param I Input guidance image. Color or gray.
|
| 57 |
+
# @param radius Radius of Guided Filter.
|
| 58 |
+
# @param epsilon Regularization term of Guided Filter.
|
| 59 |
+
# @param scale Down sampled scale.
|
| 60 |
+
def __init__(self, I, radius=5, epsilon=0.4, scale=4):
|
| 61 |
+
I_32F = to32F(I)
|
| 62 |
+
self._I = I_32F
|
| 63 |
+
h, w = I.shape[:2]
|
| 64 |
+
|
| 65 |
+
I_sub = _downSample(I_32F, scale)
|
| 66 |
+
|
| 67 |
+
self._I_sub = I_sub
|
| 68 |
+
radius = int(radius / scale)
|
| 69 |
+
|
| 70 |
+
if _isGray(I):
|
| 71 |
+
self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
|
| 72 |
+
else:
|
| 73 |
+
self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
|
| 74 |
+
|
| 75 |
+
## Apply filter for the input image.
|
| 76 |
+
# @param p Input image for the filtering.
|
| 77 |
+
def filter(self, p):
|
| 78 |
+
p_32F = to32F(p)
|
| 79 |
+
shape_original = p.shape[:2]
|
| 80 |
+
|
| 81 |
+
p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
|
| 82 |
+
|
| 83 |
+
if _isGray(p_sub):
|
| 84 |
+
return self._filterGray(p_sub, shape_original)
|
| 85 |
+
|
| 86 |
+
cs = p.shape[2]
|
| 87 |
+
q = np.array(p_32F)
|
| 88 |
+
|
| 89 |
+
for ci in range(cs):
|
| 90 |
+
q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
|
| 91 |
+
return to8U(q)
|
| 92 |
+
|
| 93 |
+
def _filterGray(self, p_sub, shape_original):
|
| 94 |
+
ab_sub = self._guided_filter._computeCoefficients(p_sub)
|
| 95 |
+
ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
|
| 96 |
+
return self._guided_filter._computeOutput(ab, self._I)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
## Guide filter.
|
| 100 |
+
class GuidedFilter:
|
| 101 |
+
## Constructor.
|
| 102 |
+
# @param I Input guidance image. Color or gray.
|
| 103 |
+
# @param radius Radius of Guided Filter.
|
| 104 |
+
# @param epsilon Regularization term of Guided Filter.
|
| 105 |
+
def __init__(self, I, radius=5, epsilon=0.4):
|
| 106 |
+
I_32F = to32F(I)
|
| 107 |
+
|
| 108 |
+
if _isGray(I):
|
| 109 |
+
self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
|
| 110 |
+
else:
|
| 111 |
+
self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
|
| 112 |
+
|
| 113 |
+
## Apply filter for the input image.
|
| 114 |
+
# @param p Input image for the filtering.
|
| 115 |
+
def filter(self, p):
|
| 116 |
+
return to8U(self._guided_filter.filter(p))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
## Common parts of guided filter.
|
| 120 |
+
#
|
| 121 |
+
# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
|
| 122 |
+
# Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
|
| 123 |
+
# GuidedFilterCommon.filter computes filtered image for color and gray.
|
| 124 |
+
class GuidedFilterCommon:
|
| 125 |
+
def __init__(self, guided_filter):
|
| 126 |
+
self._guided_filter = guided_filter
|
| 127 |
+
|
| 128 |
+
## Apply filter for the input image.
|
| 129 |
+
# @param p Input image for the filtering.
|
| 130 |
+
def filter(self, p):
|
| 131 |
+
p_32F = to32F(p)
|
| 132 |
+
if _isGray(p_32F):
|
| 133 |
+
return self._filterGray(p_32F)
|
| 134 |
+
|
| 135 |
+
cs = p.shape[2]
|
| 136 |
+
q = np.array(p_32F)
|
| 137 |
+
|
| 138 |
+
for ci in range(cs):
|
| 139 |
+
q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
|
| 140 |
+
return q
|
| 141 |
+
|
| 142 |
+
def _filterGray(self, p):
|
| 143 |
+
ab = self._guided_filter._computeCoefficients(p)
|
| 144 |
+
return self._guided_filter._computeOutput(ab, self._guided_filter._I)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
## Guided filter for gray guidance image.
|
| 148 |
+
class GuidedFilterGray:
|
| 149 |
+
# @param I Input gray guidance image.
|
| 150 |
+
# @param radius Radius of Guided Filter.
|
| 151 |
+
# @param epsilon Regularization term of Guided Filter.
|
| 152 |
+
def __init__(self, I, radius=5, epsilon=0.4):
|
| 153 |
+
self._radius = 2 * radius + 1
|
| 154 |
+
self._epsilon = epsilon
|
| 155 |
+
self._I = to32F(I)
|
| 156 |
+
self._initFilter()
|
| 157 |
+
self._filter_common = GuidedFilterCommon(self)
|
| 158 |
+
|
| 159 |
+
## Apply filter for the input image.
|
| 160 |
+
# @param p Input image for the filtering.
|
| 161 |
+
def filter(self, p):
|
| 162 |
+
return self._filter_common.filter(p)
|
| 163 |
+
|
| 164 |
+
def _initFilter(self):
|
| 165 |
+
I = self._I
|
| 166 |
+
r = self._radius
|
| 167 |
+
self._I_mean = cv2.blur(I, (r, r))
|
| 168 |
+
I_mean_sq = cv2.blur(I ** 2, (r, r))
|
| 169 |
+
self._I_var = I_mean_sq - self._I_mean ** 2
|
| 170 |
+
|
| 171 |
+
def _computeCoefficients(self, p):
|
| 172 |
+
r = self._radius
|
| 173 |
+
p_mean = cv2.blur(p, (r, r))
|
| 174 |
+
p_cov = p_mean - self._I_mean * p_mean
|
| 175 |
+
a = p_cov / (self._I_var + self._epsilon)
|
| 176 |
+
b = p_mean - a * self._I_mean
|
| 177 |
+
a_mean = cv2.blur(a, (r, r))
|
| 178 |
+
b_mean = cv2.blur(b, (r, r))
|
| 179 |
+
return a_mean, b_mean
|
| 180 |
+
|
| 181 |
+
def _computeOutput(self, ab, I):
|
| 182 |
+
a_mean, b_mean = ab
|
| 183 |
+
return a_mean * I + b_mean
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
## Guided filter for color guidance image.
|
| 187 |
+
class GuidedFilterColor:
|
| 188 |
+
# @param I Input color guidance image.
|
| 189 |
+
# @param radius Radius of Guided Filter.
|
| 190 |
+
# @param epsilon Regularization term of Guided Filter.
|
| 191 |
+
def __init__(self, I, radius=5, epsilon=0.2):
|
| 192 |
+
self._radius = 2 * radius + 1
|
| 193 |
+
self._epsilon = epsilon
|
| 194 |
+
self._I = to32F(I)
|
| 195 |
+
self._initFilter()
|
| 196 |
+
self._filter_common = GuidedFilterCommon(self)
|
| 197 |
+
|
| 198 |
+
## Apply filter for the input image.
|
| 199 |
+
# @param p Input image for the filtering.
|
| 200 |
+
def filter(self, p):
|
| 201 |
+
return self._filter_common.filter(p)
|
| 202 |
+
|
| 203 |
+
def _initFilter(self):
|
| 204 |
+
I = self._I
|
| 205 |
+
r = self._radius
|
| 206 |
+
eps = self._epsilon
|
| 207 |
+
|
| 208 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
| 209 |
+
|
| 210 |
+
self._Ir_mean = cv2.blur(Ir, (r, r))
|
| 211 |
+
self._Ig_mean = cv2.blur(Ig, (r, r))
|
| 212 |
+
self._Ib_mean = cv2.blur(Ib, (r, r))
|
| 213 |
+
|
| 214 |
+
Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
|
| 215 |
+
Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
|
| 216 |
+
Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
|
| 217 |
+
Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
|
| 218 |
+
Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
|
| 219 |
+
Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
|
| 220 |
+
|
| 221 |
+
Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
|
| 222 |
+
Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
|
| 223 |
+
Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
|
| 224 |
+
Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
|
| 225 |
+
Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
|
| 226 |
+
Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
|
| 227 |
+
|
| 228 |
+
I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
|
| 229 |
+
Irr_inv /= I_cov
|
| 230 |
+
Irg_inv /= I_cov
|
| 231 |
+
Irb_inv /= I_cov
|
| 232 |
+
Igg_inv /= I_cov
|
| 233 |
+
Igb_inv /= I_cov
|
| 234 |
+
Ibb_inv /= I_cov
|
| 235 |
+
|
| 236 |
+
self._Irr_inv = Irr_inv
|
| 237 |
+
self._Irg_inv = Irg_inv
|
| 238 |
+
self._Irb_inv = Irb_inv
|
| 239 |
+
self._Igg_inv = Igg_inv
|
| 240 |
+
self._Igb_inv = Igb_inv
|
| 241 |
+
self._Ibb_inv = Ibb_inv
|
| 242 |
+
|
| 243 |
+
def _computeCoefficients(self, p):
|
| 244 |
+
r = self._radius
|
| 245 |
+
I = self._I
|
| 246 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
| 247 |
+
|
| 248 |
+
p_mean = cv2.blur(p, (r, r))
|
| 249 |
+
|
| 250 |
+
Ipr_mean = cv2.blur(Ir * p, (r, r))
|
| 251 |
+
Ipg_mean = cv2.blur(Ig * p, (r, r))
|
| 252 |
+
Ipb_mean = cv2.blur(Ib * p, (r, r))
|
| 253 |
+
|
| 254 |
+
Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
|
| 255 |
+
Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
|
| 256 |
+
Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
|
| 257 |
+
|
| 258 |
+
ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
|
| 259 |
+
ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
|
| 260 |
+
ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
|
| 261 |
+
b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
|
| 262 |
+
|
| 263 |
+
ar_mean = cv2.blur(ar, (r, r))
|
| 264 |
+
ag_mean = cv2.blur(ag, (r, r))
|
| 265 |
+
ab_mean = cv2.blur(ab, (r, r))
|
| 266 |
+
b_mean = cv2.blur(b, (r, r))
|
| 267 |
+
|
| 268 |
+
return ar_mean, ag_mean, ab_mean, b_mean
|
| 269 |
+
|
| 270 |
+
def _computeOutput(self, ab, I):
|
| 271 |
+
ar_mean, ag_mean, ab_mean, b_mean = ab
|
| 272 |
+
|
| 273 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
| 274 |
+
|
| 275 |
+
q = (ar_mean * Ir +
|
| 276 |
+
ag_mean * Ig +
|
| 277 |
+
ab_mean * Ib +
|
| 278 |
+
b_mean)
|
| 279 |
+
|
| 280 |
+
return q
|
src/flux/annotator/util.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def HWC3(x):
|
| 10 |
+
assert x.dtype == np.uint8
|
| 11 |
+
if x.ndim == 2:
|
| 12 |
+
x = x[:, :, None]
|
| 13 |
+
assert x.ndim == 3
|
| 14 |
+
H, W, C = x.shape
|
| 15 |
+
assert C == 1 or C == 3 or C == 4
|
| 16 |
+
if C == 3:
|
| 17 |
+
return x
|
| 18 |
+
if C == 1:
|
| 19 |
+
return np.concatenate([x, x, x], axis=2)
|
| 20 |
+
if C == 4:
|
| 21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
| 22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
| 25 |
+
return y
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def resize_image(input_image, resolution):
|
| 29 |
+
H, W, C = input_image.shape
|
| 30 |
+
H = float(H)
|
| 31 |
+
W = float(W)
|
| 32 |
+
k = float(resolution) / min(H, W)
|
| 33 |
+
H *= k
|
| 34 |
+
W *= k
|
| 35 |
+
H = int(np.round(H / 64.0)) * 64
|
| 36 |
+
W = int(np.round(W / 64.0)) * 64
|
| 37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
| 38 |
+
return img
|
src/flux/annotator/zoe/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
src/flux/annotator/zoe/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ZoeDepth
|
| 2 |
+
# https://github.com/isl-org/ZoeDepth
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth
|
| 11 |
+
from .zoedepth.utils.config import get_config
|
| 12 |
+
from ...annotator.util import annotator_ckpts_path
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ZoeDetector:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt")
|
| 19 |
+
if not os.path.exists(model_path):
|
| 20 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt")
|
| 21 |
+
conf = get_config("zoedepth", "infer")
|
| 22 |
+
model = ZoeDepth.build_from_config(conf)
|
| 23 |
+
model.load_state_dict(torch.load(model_path)['model'], strict=False)
|
| 24 |
+
model = model.cuda()
|
| 25 |
+
model.device = 'cuda'
|
| 26 |
+
model.eval()
|
| 27 |
+
self.model = model
|
| 28 |
+
|
| 29 |
+
def __call__(self, input_image):
|
| 30 |
+
assert input_image.ndim == 3
|
| 31 |
+
image_depth = input_image
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
| 34 |
+
image_depth = image_depth / 255.0
|
| 35 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
| 36 |
+
depth = self.model.infer(image_depth)
|
| 37 |
+
|
| 38 |
+
depth = depth[0, 0].cpu().numpy()
|
| 39 |
+
|
| 40 |
+
vmin = np.percentile(depth, 2)
|
| 41 |
+
vmax = np.percentile(depth, 85)
|
| 42 |
+
|
| 43 |
+
depth -= vmin
|
| 44 |
+
depth /= vmax - vmin
|
| 45 |
+
depth = 1.0 - depth
|
| 46 |
+
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
| 47 |
+
|
| 48 |
+
return depth_image
|
src/flux/annotator/zoe/zoedepth/data/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
src/flux/annotator/zoe/zoedepth/data/data_mono.py
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
|
| 26 |
+
|
| 27 |
+
import itertools
|
| 28 |
+
import os
|
| 29 |
+
import random
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import cv2
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.utils.data.distributed
|
| 36 |
+
from zoedepth.utils.easydict import EasyDict as edict
|
| 37 |
+
from PIL import Image, ImageOps
|
| 38 |
+
from torch.utils.data import DataLoader, Dataset
|
| 39 |
+
from torchvision import transforms
|
| 40 |
+
|
| 41 |
+
from zoedepth.utils.config import change_dataset
|
| 42 |
+
|
| 43 |
+
from .ddad import get_ddad_loader
|
| 44 |
+
from .diml_indoor_test import get_diml_indoor_loader
|
| 45 |
+
from .diml_outdoor_test import get_diml_outdoor_loader
|
| 46 |
+
from .diode import get_diode_loader
|
| 47 |
+
from .hypersim import get_hypersim_loader
|
| 48 |
+
from .ibims import get_ibims_loader
|
| 49 |
+
from .sun_rgbd_loader import get_sunrgbd_loader
|
| 50 |
+
from .vkitti import get_vkitti_loader
|
| 51 |
+
from .vkitti2 import get_vkitti2_loader
|
| 52 |
+
|
| 53 |
+
from .preprocess import CropParams, get_white_border, get_black_border
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _is_pil_image(img):
|
| 57 |
+
return isinstance(img, Image.Image)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _is_numpy_image(img):
|
| 61 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def preprocessing_transforms(mode, **kwargs):
|
| 65 |
+
return transforms.Compose([
|
| 66 |
+
ToTensor(mode=mode, **kwargs)
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class DepthDataLoader(object):
|
| 71 |
+
def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
Data loader for depth datasets
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
config (dict): Config dictionary. Refer to utils/config.py
|
| 77 |
+
mode (str): "train" or "online_eval"
|
| 78 |
+
device (str, optional): Device to load the data on. Defaults to 'cpu'.
|
| 79 |
+
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
self.config = config
|
| 83 |
+
|
| 84 |
+
if config.dataset == 'ibims':
|
| 85 |
+
self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
if config.dataset == 'sunrgbd':
|
| 89 |
+
self.data = get_sunrgbd_loader(
|
| 90 |
+
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
if config.dataset == 'diml_indoor':
|
| 94 |
+
self.data = get_diml_indoor_loader(
|
| 95 |
+
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
if config.dataset == 'diml_outdoor':
|
| 99 |
+
self.data = get_diml_outdoor_loader(
|
| 100 |
+
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
if "diode" in config.dataset:
|
| 104 |
+
self.data = get_diode_loader(
|
| 105 |
+
config[config.dataset+"_root"], batch_size=1, num_workers=1)
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
if config.dataset == 'hypersim_test':
|
| 109 |
+
self.data = get_hypersim_loader(
|
| 110 |
+
config.hypersim_test_root, batch_size=1, num_workers=1)
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
if config.dataset == 'vkitti':
|
| 114 |
+
self.data = get_vkitti_loader(
|
| 115 |
+
config.vkitti_root, batch_size=1, num_workers=1)
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
if config.dataset == 'vkitti2':
|
| 119 |
+
self.data = get_vkitti2_loader(
|
| 120 |
+
config.vkitti2_root, batch_size=1, num_workers=1)
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
if config.dataset == 'ddad':
|
| 124 |
+
self.data = get_ddad_loader(config.ddad_root, resize_shape=(
|
| 125 |
+
352, 1216), batch_size=1, num_workers=1)
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
img_size = self.config.get("img_size", None)
|
| 129 |
+
img_size = img_size if self.config.get(
|
| 130 |
+
"do_input_resize", False) else None
|
| 131 |
+
|
| 132 |
+
if transform is None:
|
| 133 |
+
transform = preprocessing_transforms(mode, size=img_size)
|
| 134 |
+
|
| 135 |
+
if mode == 'train':
|
| 136 |
+
|
| 137 |
+
Dataset = DataLoadPreprocess
|
| 138 |
+
self.training_samples = Dataset(
|
| 139 |
+
config, mode, transform=transform, device=device)
|
| 140 |
+
|
| 141 |
+
if config.distributed:
|
| 142 |
+
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 143 |
+
self.training_samples)
|
| 144 |
+
else:
|
| 145 |
+
self.train_sampler = None
|
| 146 |
+
|
| 147 |
+
self.data = DataLoader(self.training_samples,
|
| 148 |
+
batch_size=config.batch_size,
|
| 149 |
+
shuffle=(self.train_sampler is None),
|
| 150 |
+
num_workers=config.workers,
|
| 151 |
+
pin_memory=True,
|
| 152 |
+
persistent_workers=True,
|
| 153 |
+
# prefetch_factor=2,
|
| 154 |
+
sampler=self.train_sampler)
|
| 155 |
+
|
| 156 |
+
elif mode == 'online_eval':
|
| 157 |
+
self.testing_samples = DataLoadPreprocess(
|
| 158 |
+
config, mode, transform=transform)
|
| 159 |
+
if config.distributed: # redundant. here only for readability and to be more explicit
|
| 160 |
+
# Give whole test set to all processes (and report evaluation only on one) regardless
|
| 161 |
+
self.eval_sampler = None
|
| 162 |
+
else:
|
| 163 |
+
self.eval_sampler = None
|
| 164 |
+
self.data = DataLoader(self.testing_samples, 1,
|
| 165 |
+
shuffle=kwargs.get("shuffle_test", False),
|
| 166 |
+
num_workers=1,
|
| 167 |
+
pin_memory=False,
|
| 168 |
+
sampler=self.eval_sampler)
|
| 169 |
+
|
| 170 |
+
elif mode == 'test':
|
| 171 |
+
self.testing_samples = DataLoadPreprocess(
|
| 172 |
+
config, mode, transform=transform)
|
| 173 |
+
self.data = DataLoader(self.testing_samples,
|
| 174 |
+
1, shuffle=False, num_workers=1)
|
| 175 |
+
|
| 176 |
+
else:
|
| 177 |
+
print(
|
| 178 |
+
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def repetitive_roundrobin(*iterables):
|
| 182 |
+
"""
|
| 183 |
+
cycles through iterables but sample wise
|
| 184 |
+
first yield first sample from first iterable then first sample from second iterable and so on
|
| 185 |
+
then second sample from first iterable then second sample from second iterable and so on
|
| 186 |
+
|
| 187 |
+
If one iterable is shorter than the others, it is repeated until all iterables are exhausted
|
| 188 |
+
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
|
| 189 |
+
"""
|
| 190 |
+
# Repetitive roundrobin
|
| 191 |
+
iterables_ = [iter(it) for it in iterables]
|
| 192 |
+
exhausted = [False] * len(iterables)
|
| 193 |
+
while not all(exhausted):
|
| 194 |
+
for i, it in enumerate(iterables_):
|
| 195 |
+
try:
|
| 196 |
+
yield next(it)
|
| 197 |
+
except StopIteration:
|
| 198 |
+
exhausted[i] = True
|
| 199 |
+
iterables_[i] = itertools.cycle(iterables[i])
|
| 200 |
+
# First elements may get repeated if one iterable is shorter than the others
|
| 201 |
+
yield next(iterables_[i])
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class RepetitiveRoundRobinDataLoader(object):
|
| 205 |
+
def __init__(self, *dataloaders):
|
| 206 |
+
self.dataloaders = dataloaders
|
| 207 |
+
|
| 208 |
+
def __iter__(self):
|
| 209 |
+
return repetitive_roundrobin(*self.dataloaders)
|
| 210 |
+
|
| 211 |
+
def __len__(self):
|
| 212 |
+
# First samples get repeated, thats why the plus one
|
| 213 |
+
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class MixedNYUKITTI(object):
|
| 217 |
+
def __init__(self, config, mode, device='cpu', **kwargs):
|
| 218 |
+
config = edict(config)
|
| 219 |
+
config.workers = config.workers // 2
|
| 220 |
+
self.config = config
|
| 221 |
+
nyu_conf = change_dataset(edict(config), 'nyu')
|
| 222 |
+
kitti_conf = change_dataset(edict(config), 'kitti')
|
| 223 |
+
|
| 224 |
+
# make nyu default for testing
|
| 225 |
+
self.config = config = nyu_conf
|
| 226 |
+
img_size = self.config.get("img_size", None)
|
| 227 |
+
img_size = img_size if self.config.get(
|
| 228 |
+
"do_input_resize", False) else None
|
| 229 |
+
if mode == 'train':
|
| 230 |
+
nyu_loader = DepthDataLoader(
|
| 231 |
+
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
| 232 |
+
kitti_loader = DepthDataLoader(
|
| 233 |
+
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
| 234 |
+
# It has been changed to repetitive roundrobin
|
| 235 |
+
self.data = RepetitiveRoundRobinDataLoader(
|
| 236 |
+
nyu_loader, kitti_loader)
|
| 237 |
+
else:
|
| 238 |
+
self.data = DepthDataLoader(nyu_conf, mode, device=device).data
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def remove_leading_slash(s):
|
| 242 |
+
if s[0] == '/' or s[0] == '\\':
|
| 243 |
+
return s[1:]
|
| 244 |
+
return s
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class CachedReader:
|
| 248 |
+
def __init__(self, shared_dict=None):
|
| 249 |
+
if shared_dict:
|
| 250 |
+
self._cache = shared_dict
|
| 251 |
+
else:
|
| 252 |
+
self._cache = {}
|
| 253 |
+
|
| 254 |
+
def open(self, fpath):
|
| 255 |
+
im = self._cache.get(fpath, None)
|
| 256 |
+
if im is None:
|
| 257 |
+
im = self._cache[fpath] = Image.open(fpath)
|
| 258 |
+
return im
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class ImReader:
|
| 262 |
+
def __init__(self):
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
# @cache
|
| 266 |
+
def open(self, fpath):
|
| 267 |
+
return Image.open(fpath)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class DataLoadPreprocess(Dataset):
|
| 271 |
+
def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs):
|
| 272 |
+
self.config = config
|
| 273 |
+
if mode == 'online_eval':
|
| 274 |
+
with open(config.filenames_file_eval, 'r') as f:
|
| 275 |
+
self.filenames = f.readlines()
|
| 276 |
+
else:
|
| 277 |
+
with open(config.filenames_file, 'r') as f:
|
| 278 |
+
self.filenames = f.readlines()
|
| 279 |
+
|
| 280 |
+
self.mode = mode
|
| 281 |
+
self.transform = transform
|
| 282 |
+
self.to_tensor = ToTensor(mode)
|
| 283 |
+
self.is_for_online_eval = is_for_online_eval
|
| 284 |
+
if config.use_shared_dict:
|
| 285 |
+
self.reader = CachedReader(config.shared_dict)
|
| 286 |
+
else:
|
| 287 |
+
self.reader = ImReader()
|
| 288 |
+
|
| 289 |
+
def postprocess(self, sample):
|
| 290 |
+
return sample
|
| 291 |
+
|
| 292 |
+
def __getitem__(self, idx):
|
| 293 |
+
sample_path = self.filenames[idx]
|
| 294 |
+
focal = float(sample_path.split()[2])
|
| 295 |
+
sample = {}
|
| 296 |
+
|
| 297 |
+
if self.mode == 'train':
|
| 298 |
+
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
|
| 299 |
+
image_path = os.path.join(
|
| 300 |
+
self.config.data_path, remove_leading_slash(sample_path.split()[3]))
|
| 301 |
+
depth_path = os.path.join(
|
| 302 |
+
self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
|
| 303 |
+
else:
|
| 304 |
+
image_path = os.path.join(
|
| 305 |
+
self.config.data_path, remove_leading_slash(sample_path.split()[0]))
|
| 306 |
+
depth_path = os.path.join(
|
| 307 |
+
self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
|
| 308 |
+
|
| 309 |
+
image = self.reader.open(image_path)
|
| 310 |
+
depth_gt = self.reader.open(depth_path)
|
| 311 |
+
w, h = image.size
|
| 312 |
+
|
| 313 |
+
if self.config.do_kb_crop:
|
| 314 |
+
height = image.height
|
| 315 |
+
width = image.width
|
| 316 |
+
top_margin = int(height - 352)
|
| 317 |
+
left_margin = int((width - 1216) / 2)
|
| 318 |
+
depth_gt = depth_gt.crop(
|
| 319 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
| 320 |
+
image = image.crop(
|
| 321 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
| 322 |
+
|
| 323 |
+
# Avoid blank boundaries due to pixel registration?
|
| 324 |
+
# Train images have white border. Test images have black border.
|
| 325 |
+
if self.config.dataset == 'nyu' and self.config.avoid_boundary:
|
| 326 |
+
# print("Avoiding Blank Boundaries!")
|
| 327 |
+
# We just crop and pad again with reflect padding to original size
|
| 328 |
+
# original_size = image.size
|
| 329 |
+
crop_params = get_white_border(np.array(image, dtype=np.uint8))
|
| 330 |
+
image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
| 331 |
+
depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
| 332 |
+
|
| 333 |
+
# Use reflect padding to fill the blank
|
| 334 |
+
image = np.array(image)
|
| 335 |
+
image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
|
| 336 |
+
image = Image.fromarray(image)
|
| 337 |
+
|
| 338 |
+
depth_gt = np.array(depth_gt)
|
| 339 |
+
depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0)
|
| 340 |
+
depth_gt = Image.fromarray(depth_gt)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if self.config.do_random_rotate and (self.config.aug):
|
| 344 |
+
random_angle = (random.random() - 0.5) * 2 * self.config.degree
|
| 345 |
+
image = self.rotate_image(image, random_angle)
|
| 346 |
+
depth_gt = self.rotate_image(
|
| 347 |
+
depth_gt, random_angle, flag=Image.NEAREST)
|
| 348 |
+
|
| 349 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
| 350 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
| 351 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
| 352 |
+
|
| 353 |
+
if self.config.dataset == 'nyu':
|
| 354 |
+
depth_gt = depth_gt / 1000.0
|
| 355 |
+
else:
|
| 356 |
+
depth_gt = depth_gt / 256.0
|
| 357 |
+
|
| 358 |
+
if self.config.aug and (self.config.random_crop):
|
| 359 |
+
image, depth_gt = self.random_crop(
|
| 360 |
+
image, depth_gt, self.config.input_height, self.config.input_width)
|
| 361 |
+
|
| 362 |
+
if self.config.aug and self.config.random_translate:
|
| 363 |
+
# print("Random Translation!")
|
| 364 |
+
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
|
| 365 |
+
|
| 366 |
+
image, depth_gt = self.train_preprocess(image, depth_gt)
|
| 367 |
+
mask = np.logical_and(depth_gt > self.config.min_depth,
|
| 368 |
+
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
| 369 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal,
|
| 370 |
+
'mask': mask, **sample}
|
| 371 |
+
|
| 372 |
+
else:
|
| 373 |
+
if self.mode == 'online_eval':
|
| 374 |
+
data_path = self.config.data_path_eval
|
| 375 |
+
else:
|
| 376 |
+
data_path = self.config.data_path
|
| 377 |
+
|
| 378 |
+
image_path = os.path.join(
|
| 379 |
+
data_path, remove_leading_slash(sample_path.split()[0]))
|
| 380 |
+
image = np.asarray(self.reader.open(image_path),
|
| 381 |
+
dtype=np.float32) / 255.0
|
| 382 |
+
|
| 383 |
+
if self.mode == 'online_eval':
|
| 384 |
+
gt_path = self.config.gt_path_eval
|
| 385 |
+
depth_path = os.path.join(
|
| 386 |
+
gt_path, remove_leading_slash(sample_path.split()[1]))
|
| 387 |
+
has_valid_depth = False
|
| 388 |
+
try:
|
| 389 |
+
depth_gt = self.reader.open(depth_path)
|
| 390 |
+
has_valid_depth = True
|
| 391 |
+
except IOError:
|
| 392 |
+
depth_gt = False
|
| 393 |
+
# print('Missing gt for {}'.format(image_path))
|
| 394 |
+
|
| 395 |
+
if has_valid_depth:
|
| 396 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
| 397 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
| 398 |
+
if self.config.dataset == 'nyu':
|
| 399 |
+
depth_gt = depth_gt / 1000.0
|
| 400 |
+
else:
|
| 401 |
+
depth_gt = depth_gt / 256.0
|
| 402 |
+
|
| 403 |
+
mask = np.logical_and(
|
| 404 |
+
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
|
| 405 |
+
else:
|
| 406 |
+
mask = False
|
| 407 |
+
|
| 408 |
+
if self.config.do_kb_crop:
|
| 409 |
+
height = image.shape[0]
|
| 410 |
+
width = image.shape[1]
|
| 411 |
+
top_margin = int(height - 352)
|
| 412 |
+
left_margin = int((width - 1216) / 2)
|
| 413 |
+
image = image[top_margin:top_margin + 352,
|
| 414 |
+
left_margin:left_margin + 1216, :]
|
| 415 |
+
if self.mode == 'online_eval' and has_valid_depth:
|
| 416 |
+
depth_gt = depth_gt[top_margin:top_margin +
|
| 417 |
+
352, left_margin:left_margin + 1216, :]
|
| 418 |
+
|
| 419 |
+
if self.mode == 'online_eval':
|
| 420 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
| 421 |
+
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
|
| 422 |
+
'mask': mask}
|
| 423 |
+
else:
|
| 424 |
+
sample = {'image': image, 'focal': focal}
|
| 425 |
+
|
| 426 |
+
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
|
| 427 |
+
mask = np.logical_and(depth_gt > self.config.min_depth,
|
| 428 |
+
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
| 429 |
+
sample['mask'] = mask
|
| 430 |
+
|
| 431 |
+
if self.transform:
|
| 432 |
+
sample = self.transform(sample)
|
| 433 |
+
|
| 434 |
+
sample = self.postprocess(sample)
|
| 435 |
+
sample['dataset'] = self.config.dataset
|
| 436 |
+
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
|
| 437 |
+
|
| 438 |
+
return sample
|
| 439 |
+
|
| 440 |
+
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
| 441 |
+
result = image.rotate(angle, resample=flag)
|
| 442 |
+
return result
|
| 443 |
+
|
| 444 |
+
def random_crop(self, img, depth, height, width):
|
| 445 |
+
assert img.shape[0] >= height
|
| 446 |
+
assert img.shape[1] >= width
|
| 447 |
+
assert img.shape[0] == depth.shape[0]
|
| 448 |
+
assert img.shape[1] == depth.shape[1]
|
| 449 |
+
x = random.randint(0, img.shape[1] - width)
|
| 450 |
+
y = random.randint(0, img.shape[0] - height)
|
| 451 |
+
img = img[y:y + height, x:x + width, :]
|
| 452 |
+
depth = depth[y:y + height, x:x + width, :]
|
| 453 |
+
|
| 454 |
+
return img, depth
|
| 455 |
+
|
| 456 |
+
def random_translate(self, img, depth, max_t=20):
|
| 457 |
+
assert img.shape[0] == depth.shape[0]
|
| 458 |
+
assert img.shape[1] == depth.shape[1]
|
| 459 |
+
p = self.config.translate_prob
|
| 460 |
+
do_translate = random.random()
|
| 461 |
+
if do_translate > p:
|
| 462 |
+
return img, depth
|
| 463 |
+
x = random.randint(-max_t, max_t)
|
| 464 |
+
y = random.randint(-max_t, max_t)
|
| 465 |
+
M = np.float32([[1, 0, x], [0, 1, y]])
|
| 466 |
+
# print(img.shape, depth.shape)
|
| 467 |
+
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
|
| 468 |
+
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
|
| 469 |
+
depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
|
| 470 |
+
# print("after", img.shape, depth.shape)
|
| 471 |
+
return img, depth
|
| 472 |
+
|
| 473 |
+
def train_preprocess(self, image, depth_gt):
|
| 474 |
+
if self.config.aug:
|
| 475 |
+
# Random flipping
|
| 476 |
+
do_flip = random.random()
|
| 477 |
+
if do_flip > 0.5:
|
| 478 |
+
image = (image[:, ::-1, :]).copy()
|
| 479 |
+
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
| 480 |
+
|
| 481 |
+
# Random gamma, brightness, color augmentation
|
| 482 |
+
do_augment = random.random()
|
| 483 |
+
if do_augment > 0.5:
|
| 484 |
+
image = self.augment_image(image)
|
| 485 |
+
|
| 486 |
+
return image, depth_gt
|
| 487 |
+
|
| 488 |
+
def augment_image(self, image):
|
| 489 |
+
# gamma augmentation
|
| 490 |
+
gamma = random.uniform(0.9, 1.1)
|
| 491 |
+
image_aug = image ** gamma
|
| 492 |
+
|
| 493 |
+
# brightness augmentation
|
| 494 |
+
if self.config.dataset == 'nyu':
|
| 495 |
+
brightness = random.uniform(0.75, 1.25)
|
| 496 |
+
else:
|
| 497 |
+
brightness = random.uniform(0.9, 1.1)
|
| 498 |
+
image_aug = image_aug * brightness
|
| 499 |
+
|
| 500 |
+
# color augmentation
|
| 501 |
+
colors = np.random.uniform(0.9, 1.1, size=3)
|
| 502 |
+
white = np.ones((image.shape[0], image.shape[1]))
|
| 503 |
+
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
| 504 |
+
image_aug *= color_image
|
| 505 |
+
image_aug = np.clip(image_aug, 0, 1)
|
| 506 |
+
|
| 507 |
+
return image_aug
|
| 508 |
+
|
| 509 |
+
def __len__(self):
|
| 510 |
+
return len(self.filenames)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
class ToTensor(object):
|
| 514 |
+
def __init__(self, mode, do_normalize=False, size=None):
|
| 515 |
+
self.mode = mode
|
| 516 |
+
self.normalize = transforms.Normalize(
|
| 517 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
|
| 518 |
+
self.size = size
|
| 519 |
+
if size is not None:
|
| 520 |
+
self.resize = transforms.Resize(size=size)
|
| 521 |
+
else:
|
| 522 |
+
self.resize = nn.Identity()
|
| 523 |
+
|
| 524 |
+
def __call__(self, sample):
|
| 525 |
+
image, focal = sample['image'], sample['focal']
|
| 526 |
+
image = self.to_tensor(image)
|
| 527 |
+
image = self.normalize(image)
|
| 528 |
+
image = self.resize(image)
|
| 529 |
+
|
| 530 |
+
if self.mode == 'test':
|
| 531 |
+
return {'image': image, 'focal': focal}
|
| 532 |
+
|
| 533 |
+
depth = sample['depth']
|
| 534 |
+
if self.mode == 'train':
|
| 535 |
+
depth = self.to_tensor(depth)
|
| 536 |
+
return {**sample, 'image': image, 'depth': depth, 'focal': focal}
|
| 537 |
+
else:
|
| 538 |
+
has_valid_depth = sample['has_valid_depth']
|
| 539 |
+
image = self.resize(image)
|
| 540 |
+
return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
| 541 |
+
'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
|
| 542 |
+
|
| 543 |
+
def to_tensor(self, pic):
|
| 544 |
+
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
| 545 |
+
raise TypeError(
|
| 546 |
+
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
| 547 |
+
|
| 548 |
+
if isinstance(pic, np.ndarray):
|
| 549 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 550 |
+
return img
|
| 551 |
+
|
| 552 |
+
# handle PIL Image
|
| 553 |
+
if pic.mode == 'I':
|
| 554 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
| 555 |
+
elif pic.mode == 'I;16':
|
| 556 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
| 557 |
+
else:
|
| 558 |
+
img = torch.ByteTensor(
|
| 559 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 560 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
| 561 |
+
if pic.mode == 'YCbCr':
|
| 562 |
+
nchannel = 3
|
| 563 |
+
elif pic.mode == 'I;16':
|
| 564 |
+
nchannel = 1
|
| 565 |
+
else:
|
| 566 |
+
nchannel = len(pic.mode)
|
| 567 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
| 568 |
+
|
| 569 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 570 |
+
if isinstance(img, torch.ByteTensor):
|
| 571 |
+
return img.float()
|
| 572 |
+
else:
|
| 573 |
+
return img
|
src/flux/annotator/zoe/zoedepth/data/ddad.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch.utils.data import DataLoader, Dataset
|
| 31 |
+
from torchvision import transforms
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ToTensor(object):
|
| 35 |
+
def __init__(self, resize_shape):
|
| 36 |
+
# self.normalize = transforms.Normalize(
|
| 37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 38 |
+
self.normalize = lambda x : x
|
| 39 |
+
self.resize = transforms.Resize(resize_shape)
|
| 40 |
+
|
| 41 |
+
def __call__(self, sample):
|
| 42 |
+
image, depth = sample['image'], sample['depth']
|
| 43 |
+
image = self.to_tensor(image)
|
| 44 |
+
image = self.normalize(image)
|
| 45 |
+
depth = self.to_tensor(depth)
|
| 46 |
+
|
| 47 |
+
image = self.resize(image)
|
| 48 |
+
|
| 49 |
+
return {'image': image, 'depth': depth, 'dataset': "ddad"}
|
| 50 |
+
|
| 51 |
+
def to_tensor(self, pic):
|
| 52 |
+
|
| 53 |
+
if isinstance(pic, np.ndarray):
|
| 54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 55 |
+
return img
|
| 56 |
+
|
| 57 |
+
# # handle PIL Image
|
| 58 |
+
if pic.mode == 'I':
|
| 59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
| 60 |
+
elif pic.mode == 'I;16':
|
| 61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
| 62 |
+
else:
|
| 63 |
+
img = torch.ByteTensor(
|
| 64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
| 66 |
+
if pic.mode == 'YCbCr':
|
| 67 |
+
nchannel = 3
|
| 68 |
+
elif pic.mode == 'I;16':
|
| 69 |
+
nchannel = 1
|
| 70 |
+
else:
|
| 71 |
+
nchannel = len(pic.mode)
|
| 72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
| 73 |
+
|
| 74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 75 |
+
|
| 76 |
+
if isinstance(img, torch.ByteTensor):
|
| 77 |
+
return img.float()
|
| 78 |
+
else:
|
| 79 |
+
return img
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class DDAD(Dataset):
|
| 83 |
+
def __init__(self, data_dir_root, resize_shape):
|
| 84 |
+
import glob
|
| 85 |
+
|
| 86 |
+
# image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
|
| 87 |
+
self.image_files = glob.glob(os.path.join(data_dir_root, '*.png'))
|
| 88 |
+
self.depth_files = [r.replace("_rgb.png", "_depth.npy")
|
| 89 |
+
for r in self.image_files]
|
| 90 |
+
self.transform = ToTensor(resize_shape)
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, idx):
|
| 93 |
+
|
| 94 |
+
image_path = self.image_files[idx]
|
| 95 |
+
depth_path = self.depth_files[idx]
|
| 96 |
+
|
| 97 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
| 98 |
+
depth = np.load(depth_path) # meters
|
| 99 |
+
|
| 100 |
+
# depth[depth > 8] = -1
|
| 101 |
+
depth = depth[..., None]
|
| 102 |
+
|
| 103 |
+
sample = dict(image=image, depth=depth)
|
| 104 |
+
sample = self.transform(sample)
|
| 105 |
+
|
| 106 |
+
if idx == 0:
|
| 107 |
+
print(sample["image"].shape)
|
| 108 |
+
|
| 109 |
+
return sample
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return len(self.image_files)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs):
|
| 116 |
+
dataset = DDAD(data_dir_root, resize_shape)
|
| 117 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch.utils.data import DataLoader, Dataset
|
| 31 |
+
from torchvision import transforms
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ToTensor(object):
|
| 35 |
+
def __init__(self):
|
| 36 |
+
# self.normalize = transforms.Normalize(
|
| 37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 38 |
+
self.normalize = lambda x : x
|
| 39 |
+
self.resize = transforms.Resize((480, 640))
|
| 40 |
+
|
| 41 |
+
def __call__(self, sample):
|
| 42 |
+
image, depth = sample['image'], sample['depth']
|
| 43 |
+
image = self.to_tensor(image)
|
| 44 |
+
image = self.normalize(image)
|
| 45 |
+
depth = self.to_tensor(depth)
|
| 46 |
+
|
| 47 |
+
image = self.resize(image)
|
| 48 |
+
|
| 49 |
+
return {'image': image, 'depth': depth, 'dataset': "diml_indoor"}
|
| 50 |
+
|
| 51 |
+
def to_tensor(self, pic):
|
| 52 |
+
|
| 53 |
+
if isinstance(pic, np.ndarray):
|
| 54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 55 |
+
return img
|
| 56 |
+
|
| 57 |
+
# # handle PIL Image
|
| 58 |
+
if pic.mode == 'I':
|
| 59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
| 60 |
+
elif pic.mode == 'I;16':
|
| 61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
| 62 |
+
else:
|
| 63 |
+
img = torch.ByteTensor(
|
| 64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
| 66 |
+
if pic.mode == 'YCbCr':
|
| 67 |
+
nchannel = 3
|
| 68 |
+
elif pic.mode == 'I;16':
|
| 69 |
+
nchannel = 1
|
| 70 |
+
else:
|
| 71 |
+
nchannel = len(pic.mode)
|
| 72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
| 73 |
+
|
| 74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 75 |
+
if isinstance(img, torch.ByteTensor):
|
| 76 |
+
return img.float()
|
| 77 |
+
else:
|
| 78 |
+
return img
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class DIML_Indoor(Dataset):
|
| 82 |
+
def __init__(self, data_dir_root):
|
| 83 |
+
import glob
|
| 84 |
+
|
| 85 |
+
# image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
|
| 86 |
+
self.image_files = glob.glob(os.path.join(
|
| 87 |
+
data_dir_root, "LR", '*', 'color', '*.png'))
|
| 88 |
+
self.depth_files = [r.replace("color", "depth_filled").replace(
|
| 89 |
+
"_c.png", "_depth_filled.png") for r in self.image_files]
|
| 90 |
+
self.transform = ToTensor()
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, idx):
|
| 93 |
+
image_path = self.image_files[idx]
|
| 94 |
+
depth_path = self.depth_files[idx]
|
| 95 |
+
|
| 96 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
| 97 |
+
depth = np.asarray(Image.open(depth_path),
|
| 98 |
+
dtype='uint16') / 1000.0 # mm to meters
|
| 99 |
+
|
| 100 |
+
# print(np.shape(image))
|
| 101 |
+
# print(np.shape(depth))
|
| 102 |
+
|
| 103 |
+
# depth[depth > 8] = -1
|
| 104 |
+
depth = depth[..., None]
|
| 105 |
+
|
| 106 |
+
sample = dict(image=image, depth=depth)
|
| 107 |
+
|
| 108 |
+
# return sample
|
| 109 |
+
sample = self.transform(sample)
|
| 110 |
+
|
| 111 |
+
if idx == 0:
|
| 112 |
+
print(sample["image"].shape)
|
| 113 |
+
|
| 114 |
+
return sample
|
| 115 |
+
|
| 116 |
+
def __len__(self):
|
| 117 |
+
return len(self.image_files)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs):
|
| 121 |
+
dataset = DIML_Indoor(data_dir_root)
|
| 122 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
| 123 |
+
|
| 124 |
+
# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR")
|
| 125 |
+
# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR")
|
src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch.utils.data import DataLoader, Dataset
|
| 31 |
+
from torchvision import transforms
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ToTensor(object):
|
| 35 |
+
def __init__(self):
|
| 36 |
+
# self.normalize = transforms.Normalize(
|
| 37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 38 |
+
self.normalize = lambda x : x
|
| 39 |
+
|
| 40 |
+
def __call__(self, sample):
|
| 41 |
+
image, depth = sample['image'], sample['depth']
|
| 42 |
+
image = self.to_tensor(image)
|
| 43 |
+
image = self.normalize(image)
|
| 44 |
+
depth = self.to_tensor(depth)
|
| 45 |
+
|
| 46 |
+
return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"}
|
| 47 |
+
|
| 48 |
+
def to_tensor(self, pic):
|
| 49 |
+
|
| 50 |
+
if isinstance(pic, np.ndarray):
|
| 51 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 52 |
+
return img
|
| 53 |
+
|
| 54 |
+
# # handle PIL Image
|
| 55 |
+
if pic.mode == 'I':
|
| 56 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
| 57 |
+
elif pic.mode == 'I;16':
|
| 58 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
| 59 |
+
else:
|
| 60 |
+
img = torch.ByteTensor(
|
| 61 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 62 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
| 63 |
+
if pic.mode == 'YCbCr':
|
| 64 |
+
nchannel = 3
|
| 65 |
+
elif pic.mode == 'I;16':
|
| 66 |
+
nchannel = 1
|
| 67 |
+
else:
|
| 68 |
+
nchannel = len(pic.mode)
|
| 69 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
| 70 |
+
|
| 71 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 72 |
+
if isinstance(img, torch.ByteTensor):
|
| 73 |
+
return img.float()
|
| 74 |
+
else:
|
| 75 |
+
return img
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DIML_Outdoor(Dataset):
|
| 79 |
+
def __init__(self, data_dir_root):
|
| 80 |
+
import glob
|
| 81 |
+
|
| 82 |
+
# image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
|
| 83 |
+
self.image_files = glob.glob(os.path.join(
|
| 84 |
+
data_dir_root, "*", 'outleft', '*.png'))
|
| 85 |
+
self.depth_files = [r.replace("outleft", "depthmap")
|
| 86 |
+
for r in self.image_files]
|
| 87 |
+
self.transform = ToTensor()
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx):
|
| 90 |
+
image_path = self.image_files[idx]
|
| 91 |
+
depth_path = self.depth_files[idx]
|
| 92 |
+
|
| 93 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
| 94 |
+
depth = np.asarray(Image.open(depth_path),
|
| 95 |
+
dtype='uint16') / 1000.0 # mm to meters
|
| 96 |
+
|
| 97 |
+
# depth[depth > 8] = -1
|
| 98 |
+
depth = depth[..., None]
|
| 99 |
+
|
| 100 |
+
sample = dict(image=image, depth=depth, dataset="diml_outdoor")
|
| 101 |
+
|
| 102 |
+
# return sample
|
| 103 |
+
return self.transform(sample)
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return len(self.image_files)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs):
|
| 110 |
+
dataset = DIML_Outdoor(data_dir_root)
|
| 111 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
| 112 |
+
|
| 113 |
+
# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR")
|
| 114 |
+
# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR")
|
src/flux/annotator/zoe/zoedepth/data/diode.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch.utils.data import DataLoader, Dataset
|
| 31 |
+
from torchvision import transforms
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ToTensor(object):
|
| 35 |
+
def __init__(self):
|
| 36 |
+
# self.normalize = transforms.Normalize(
|
| 37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 38 |
+
self.normalize = lambda x : x
|
| 39 |
+
self.resize = transforms.Resize(480)
|
| 40 |
+
|
| 41 |
+
def __call__(self, sample):
|
| 42 |
+
image, depth = sample['image'], sample['depth']
|
| 43 |
+
image = self.to_tensor(image)
|
| 44 |
+
image = self.normalize(image)
|
| 45 |
+
depth = self.to_tensor(depth)
|
| 46 |
+
|
| 47 |
+
image = self.resize(image)
|
| 48 |
+
|
| 49 |
+
return {'image': image, 'depth': depth, 'dataset': "diode"}
|
| 50 |
+
|
| 51 |
+
def to_tensor(self, pic):
|
| 52 |
+
|
| 53 |
+
if isinstance(pic, np.ndarray):
|
| 54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 55 |
+
return img
|
| 56 |
+
|
| 57 |
+
# # handle PIL Image
|
| 58 |
+
if pic.mode == 'I':
|
| 59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
| 60 |
+
elif pic.mode == 'I;16':
|
| 61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
| 62 |
+
else:
|
| 63 |
+
img = torch.ByteTensor(
|
| 64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
| 66 |
+
if pic.mode == 'YCbCr':
|
| 67 |
+
nchannel = 3
|
| 68 |
+
elif pic.mode == 'I;16':
|
| 69 |
+
nchannel = 1
|
| 70 |
+
else:
|
| 71 |
+
nchannel = len(pic.mode)
|
| 72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
| 73 |
+
|
| 74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 75 |
+
|
| 76 |
+
if isinstance(img, torch.ByteTensor):
|
| 77 |
+
return img.float()
|
| 78 |
+
else:
|
| 79 |
+
return img
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class DIODE(Dataset):
|
| 83 |
+
def __init__(self, data_dir_root):
|
| 84 |
+
import glob
|
| 85 |
+
|
| 86 |
+
# image paths are of the form <data_dir_root>/scene_#/scan_#/*.png
|
| 87 |
+
self.image_files = glob.glob(
|
| 88 |
+
os.path.join(data_dir_root, '*', '*', '*.png'))
|
| 89 |
+
self.depth_files = [r.replace(".png", "_depth.npy")
|
| 90 |
+
for r in self.image_files]
|
| 91 |
+
self.depth_mask_files = [
|
| 92 |
+
r.replace(".png", "_depth_mask.npy") for r in self.image_files]
|
| 93 |
+
self.transform = ToTensor()
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, idx):
|
| 96 |
+
image_path = self.image_files[idx]
|
| 97 |
+
depth_path = self.depth_files[idx]
|
| 98 |
+
depth_mask_path = self.depth_mask_files[idx]
|
| 99 |
+
|
| 100 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
| 101 |
+
depth = np.load(depth_path) # in meters
|
| 102 |
+
valid = np.load(depth_mask_path) # binary
|
| 103 |
+
|
| 104 |
+
# depth[depth > 8] = -1
|
| 105 |
+
# depth = depth[..., None]
|
| 106 |
+
|
| 107 |
+
sample = dict(image=image, depth=depth, valid=valid)
|
| 108 |
+
|
| 109 |
+
# return sample
|
| 110 |
+
sample = self.transform(sample)
|
| 111 |
+
|
| 112 |
+
if idx == 0:
|
| 113 |
+
print(sample["image"].shape)
|
| 114 |
+
|
| 115 |
+
return sample
|
| 116 |
+
|
| 117 |
+
def __len__(self):
|
| 118 |
+
return len(self.image_files)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_diode_loader(data_dir_root, batch_size=1, **kwargs):
|
| 122 |
+
dataset = DIODE(data_dir_root)
|
| 123 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
| 124 |
+
|
| 125 |
+
# get_diode_loader(data_dir_root="datasets/diode/val/outdoor")
|
src/flux/annotator/zoe/zoedepth/data/hypersim.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import glob
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
import h5py
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from torch.utils.data import DataLoader, Dataset
|
| 33 |
+
from torchvision import transforms
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def hypersim_distance_to_depth(npyDistance):
|
| 37 |
+
intWidth, intHeight, fltFocal = 1024, 768, 886.81
|
| 38 |
+
|
| 39 |
+
npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
|
| 40 |
+
1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
|
| 41 |
+
npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
|
| 42 |
+
intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
|
| 43 |
+
npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
|
| 44 |
+
npyImageplane = np.concatenate(
|
| 45 |
+
[npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
|
| 46 |
+
|
| 47 |
+
npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
|
| 48 |
+
return npyDepth
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ToTensor(object):
|
| 52 |
+
def __init__(self):
|
| 53 |
+
# self.normalize = transforms.Normalize(
|
| 54 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 55 |
+
self.normalize = lambda x: x
|
| 56 |
+
self.resize = transforms.Resize((480, 640))
|
| 57 |
+
|
| 58 |
+
def __call__(self, sample):
|
| 59 |
+
image, depth = sample['image'], sample['depth']
|
| 60 |
+
image = self.to_tensor(image)
|
| 61 |
+
image = self.normalize(image)
|
| 62 |
+
depth = self.to_tensor(depth)
|
| 63 |
+
|
| 64 |
+
image = self.resize(image)
|
| 65 |
+
|
| 66 |
+
return {'image': image, 'depth': depth, 'dataset': "hypersim"}
|
| 67 |
+
|
| 68 |
+
def to_tensor(self, pic):
|
| 69 |
+
|
| 70 |
+
if isinstance(pic, np.ndarray):
|
| 71 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
| 72 |
+
return img
|
| 73 |
+
|
| 74 |
+
# # handle PIL Image
|
| 75 |
+
if pic.mode == 'I':
|
| 76 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
| 77 |
+
elif pic.mode == 'I;16':
|
| 78 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
| 79 |
+
else:
|
| 80 |
+
img = torch.ByteTensor(
|
| 81 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 82 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
| 83 |
+
if pic.mode == 'YCbCr':
|
| 84 |
+
nchannel = 3
|
| 85 |
+
elif pic.mode == 'I;16':
|
| 86 |
+
nchannel = 1
|
| 87 |
+
else:
|
| 88 |
+
nchannel = len(pic.mode)
|
| 89 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
| 90 |
+
|
| 91 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 92 |
+
if isinstance(img, torch.ByteTensor):
|
| 93 |
+
return img.float()
|
| 94 |
+
else:
|
| 95 |
+
return img
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class HyperSim(Dataset):
|
| 99 |
+
def __init__(self, data_dir_root):
|
| 100 |
+
# image paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.tonemap.jpg
|
| 101 |
+
# depth paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.depth_meters.hdf5
|
| 102 |
+
self.image_files = glob.glob(os.path.join(
|
| 103 |
+
data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg'))
|
| 104 |
+
self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace(
|
| 105 |
+
".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files]
|
| 106 |
+
self.transform = ToTensor()
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, idx):
|
| 109 |
+
image_path = self.image_files[idx]
|
| 110 |
+
depth_path = self.depth_files[idx]
|
| 111 |
+
|
| 112 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
| 113 |
+
|
| 114 |
+
# depth from hdf5
|
| 115 |
+
depth_fd = h5py.File(depth_path, "r")
|
| 116 |
+
# in meters (Euclidean distance)
|
| 117 |
+
distance_meters = np.array(depth_fd['dataset'])
|
| 118 |
+
depth = hypersim_distance_to_depth(
|
| 119 |
+
distance_meters) # in meters (planar depth)
|
| 120 |
+
|
| 121 |
+
# depth[depth > 8] = -1
|
| 122 |
+
depth = depth[..., None]
|
| 123 |
+
|
| 124 |
+
sample = dict(image=image, depth=depth)
|
| 125 |
+
sample = self.transform(sample)
|
| 126 |
+
|
| 127 |
+
if idx == 0:
|
| 128 |
+
print(sample["image"].shape)
|
| 129 |
+
|
| 130 |
+
return sample
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return len(self.image_files)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs):
|
| 137 |
+
dataset = HyperSim(data_dir_root)
|
| 138 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
src/flux/annotator/zoe/zoedepth/data/ibims.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch.utils.data import DataLoader, Dataset
|
| 31 |
+
from torchvision import transforms as T
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class iBims(Dataset):
|
| 35 |
+
def __init__(self, config):
|
| 36 |
+
root_folder = config.ibims_root
|
| 37 |
+
with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f:
|
| 38 |
+
imglist = f.read().split()
|
| 39 |
+
|
| 40 |
+
samples = []
|
| 41 |
+
for basename in imglist:
|
| 42 |
+
img_path = os.path.join(root_folder, 'rgb', basename + ".png")
|
| 43 |
+
depth_path = os.path.join(root_folder, 'depth', basename + ".png")
|
| 44 |
+
valid_mask_path = os.path.join(
|
| 45 |
+
root_folder, 'mask_invalid', basename+".png")
|
| 46 |
+
transp_mask_path = os.path.join(
|
| 47 |
+
root_folder, 'mask_transp', basename+".png")
|
| 48 |
+
|
| 49 |
+
samples.append(
|
| 50 |
+
(img_path, depth_path, valid_mask_path, transp_mask_path))
|
| 51 |
+
|
| 52 |
+
self.samples = samples
|
| 53 |
+
# self.normalize = T.Normalize(
|
| 54 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 55 |
+
self.normalize = lambda x : x
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx]
|
| 59 |
+
|
| 60 |
+
img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0
|
| 61 |
+
depth = np.asarray(Image.open(depth_path),
|
| 62 |
+
dtype=np.uint16).astype('float')*50.0/65535
|
| 63 |
+
|
| 64 |
+
mask_valid = np.asarray(Image.open(valid_mask_path))
|
| 65 |
+
mask_transp = np.asarray(Image.open(transp_mask_path))
|
| 66 |
+
|
| 67 |
+
# depth = depth * mask_valid * mask_transp
|
| 68 |
+
depth = np.where(mask_valid * mask_transp, depth, -1)
|
| 69 |
+
|
| 70 |
+
img = torch.from_numpy(img).permute(2, 0, 1)
|
| 71 |
+
img = self.normalize(img)
|
| 72 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
| 73 |
+
return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims')
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.samples)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_ibims_loader(config, batch_size=1, **kwargs):
|
| 80 |
+
dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs)
|
| 81 |
+
return dataloader
|
src/flux/annotator/zoe/zoedepth/data/preprocess.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
from dataclasses import dataclass
|
| 27 |
+
from typing import Tuple, List
|
| 28 |
+
|
| 29 |
+
# dataclass to store the crop parameters
|
| 30 |
+
@dataclass
|
| 31 |
+
class CropParams:
|
| 32 |
+
top: int
|
| 33 |
+
bottom: int
|
| 34 |
+
left: int
|
| 35 |
+
right: int
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams:
|
| 40 |
+
gray_image = np.mean(rgb_image, axis=channel_axis)
|
| 41 |
+
h, w = gray_image.shape
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def num_value_pixels(arr):
|
| 45 |
+
return np.sum(np.abs(arr - value) < level_diff_threshold)
|
| 46 |
+
|
| 47 |
+
def is_above_tolerance(arr, total_pixels):
|
| 48 |
+
return (num_value_pixels(arr) / total_pixels) > tolerance
|
| 49 |
+
|
| 50 |
+
# Crop top border until number of value pixels become below tolerance
|
| 51 |
+
top = min_border
|
| 52 |
+
while is_above_tolerance(gray_image[top, :], w) and top < h-1:
|
| 53 |
+
top += 1
|
| 54 |
+
if top > cut_off:
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
# Crop bottom border until number of value pixels become below tolerance
|
| 58 |
+
bottom = h - min_border
|
| 59 |
+
while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0:
|
| 60 |
+
bottom -= 1
|
| 61 |
+
if h - bottom > cut_off:
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
# Crop left border until number of value pixels become below tolerance
|
| 65 |
+
left = min_border
|
| 66 |
+
while is_above_tolerance(gray_image[:, left], h) and left < w-1:
|
| 67 |
+
left += 1
|
| 68 |
+
if left > cut_off:
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
# Crop right border until number of value pixels become below tolerance
|
| 72 |
+
right = w - min_border
|
| 73 |
+
while is_above_tolerance(gray_image[:, right], h) and right > 0:
|
| 74 |
+
right -= 1
|
| 75 |
+
if w - right > cut_off:
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
return CropParams(top, bottom, left, right)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_white_border(rgb_image, value=255, **kwargs) -> CropParams:
|
| 83 |
+
"""Crops the white border of the RGB.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
rgb: RGB image, shape (H, W, 3).
|
| 87 |
+
Returns:
|
| 88 |
+
Crop parameters.
|
| 89 |
+
"""
|
| 90 |
+
if value == 255:
|
| 91 |
+
# assert range of values in rgb image is [0, 255]
|
| 92 |
+
assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]."
|
| 93 |
+
assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]."
|
| 94 |
+
elif value == 1:
|
| 95 |
+
# assert range of values in rgb image is [0, 1]
|
| 96 |
+
assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]."
|
| 97 |
+
|
| 98 |
+
return get_border_params(rgb_image, value=value, **kwargs)
|
| 99 |
+
|
| 100 |
+
def get_black_border(rgb_image, **kwargs) -> CropParams:
|
| 101 |
+
"""Crops the black border of the RGB.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
rgb: RGB image, shape (H, W, 3).
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Crop parameters.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
return get_border_params(rgb_image, value=0, **kwargs)
|
| 111 |
+
|
| 112 |
+
def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray:
|
| 113 |
+
"""Crops the image according to the crop parameters.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
image: RGB or depth image, shape (H, W, 3) or (H, W).
|
| 117 |
+
crop_params: Crop parameters.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Cropped image.
|
| 121 |
+
"""
|
| 122 |
+
return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
|
| 123 |
+
|
| 124 |
+
def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]:
|
| 125 |
+
"""Crops the images according to the crop parameters.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
images: RGB or depth images, shape (H, W, 3) or (H, W).
|
| 129 |
+
crop_params: Crop parameters.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Cropped images.
|
| 133 |
+
"""
|
| 134 |
+
return tuple(crop_image(image, crop_params) for image in images)
|
| 135 |
+
|
| 136 |
+
def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]:
|
| 137 |
+
"""Crops the white and black border of the RGB and depth images.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
rgb: RGB image, shape (H, W, 3). This image is used to determine the border.
|
| 141 |
+
other_images: The other images to crop according to the border of the RGB image.
|
| 142 |
+
Returns:
|
| 143 |
+
Cropped RGB and other images.
|
| 144 |
+
"""
|
| 145 |
+
# crop black border
|
| 146 |
+
crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
|
| 147 |
+
cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params)
|
| 148 |
+
|
| 149 |
+
# crop white border
|
| 150 |
+
crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
|
| 151 |
+
cropped_images = crop_images(*cropped_images, crop_params=crop_params)
|
| 152 |
+
|
| 153 |
+
return cropped_images
|
| 154 |
+
|