Spaces:
Running
Running
1st commit
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitattributes +2 -0
- app.py +1786 -0
- depth_anything_v2/__pycache__/dinov2.cpython-310.pyc +0 -0
- depth_anything_v2/__pycache__/dpt.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2.py +415 -0
- depth_anything_v2/dinov2_layers/__init__.py +11 -0
- depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/attention.py +83 -0
- depth_anything_v2/dinov2_layers/block.py +252 -0
- depth_anything_v2/dinov2_layers/drop_path.py +35 -0
- depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
- depth_anything_v2/dinov2_layers/mlp.py +41 -0
- depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
- depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
- depth_anything_v2/dpt.py +221 -0
- depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc +0 -0
- depth_anything_v2/util/__pycache__/transform.cpython-310.pyc +0 -0
- depth_anything_v2/util/blocks.py +148 -0
- depth_anything_v2/util/transform.py +158 -0
- models/FCN.py +55 -0
- models/SegNet.py +33 -0
- models/__pycache__/FCN.cpython-37.pyc +0 -0
- models/__pycache__/FCN.cpython-39.pyc +0 -0
- models/__pycache__/SegNet.cpython-37.pyc +0 -0
- models/__pycache__/SegNet.cpython-39.pyc +0 -0
- models/__pycache__/deeplab.cpython-310.pyc +0 -0
- models/__pycache__/deeplab.cpython-313.pyc +0 -0
- models/__pycache__/deeplab.cpython-37.pyc +0 -0
- models/__pycache__/deeplab.cpython-39.pyc +0 -0
- models/__pycache__/unets.cpython-37.pyc +0 -0
- models/__pycache__/unets.cpython-39.pyc +0 -0
- models/deeplab.py +539 -0
- models/unets.py +171 -0
- requirements.txt +151 -0
- temp_files/Final_workig_cpu.txt +1000 -0
- temp_files/README.md +12 -0
- temp_files/fw2.txt +1175 -0
- temp_files/predict.py +64 -0
- temp_files/requirements.txt +109 -0
- temp_files/run_gradio_app.py +92 -0
- temp_files/segmentation_app.py +222 -0
- temp_files/test1.txt +843 -0
- temp_files/test2.txt +1063 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
training_history/2019-12-19[[:space:]]01%3A53%3A15.480800.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
training_history/2025-08-07_16-25-27.hdf5 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,1786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import tempfile
|
| 8 |
+
from gradio_imageslider import ImageSlider
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
| 13 |
+
import os
|
| 14 |
+
import tensorflow as tf
|
| 15 |
+
from tensorflow.keras.models import load_model
|
| 16 |
+
|
| 17 |
+
# Classification imports
|
| 18 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 19 |
+
import google.generativeai as genai
|
| 20 |
+
|
| 21 |
+
import gdown
|
| 22 |
+
import spaces
|
| 23 |
+
import cv2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Import actual segmentation model components
|
| 27 |
+
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
|
| 28 |
+
from utils.learning.metrics import dice_coef, precision, recall
|
| 29 |
+
from utils.io.data import normalize
|
| 30 |
+
|
| 31 |
+
# --- Classification Model Setup ---
|
| 32 |
+
# Load classification model and processor
|
| 33 |
+
classification_processor = AutoImageProcessor.from_pretrained("Hemg/Wound-classification")
|
| 34 |
+
classification_model = AutoModelForImageClassification.from_pretrained("Hemg/Wound-classification")
|
| 35 |
+
|
| 36 |
+
# Configure Gemini AI
|
| 37 |
+
try:
|
| 38 |
+
# Try to get API key from Hugging Face secrets
|
| 39 |
+
gemini_api_key = os.getenv("GOOGLE_API_KEY")
|
| 40 |
+
if not gemini_api_key:
|
| 41 |
+
raise ValueError("GEMINI_API_KEY not found in environment variables")
|
| 42 |
+
|
| 43 |
+
genai.configure(api_key=gemini_api_key)
|
| 44 |
+
gemini_model = genai.GenerativeModel("gemini-2.5-pro")
|
| 45 |
+
print("β
Gemini AI configured successfully with API key from secrets")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"β Error configuring Gemini AI: {e}")
|
| 48 |
+
print("Please make sure GEMINI_API_KEY is set in your Hugging Face Space secrets")
|
| 49 |
+
gemini_model = None
|
| 50 |
+
|
| 51 |
+
# --- Classification Functions ---
|
| 52 |
+
def analyze_wound_with_gemini(image, predicted_label):
|
| 53 |
+
"""
|
| 54 |
+
Analyze wound image using Gemini AI with classification context
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
image: PIL Image
|
| 58 |
+
predicted_label: The predicted wound type from classification model
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
str: Gemini AI analysis
|
| 62 |
+
"""
|
| 63 |
+
if image is None:
|
| 64 |
+
return "No image provided for analysis."
|
| 65 |
+
|
| 66 |
+
if gemini_model is None:
|
| 67 |
+
return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets."
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Ensure image is in RGB format
|
| 71 |
+
if image.mode != 'RGB':
|
| 72 |
+
image = image.convert('RGB')
|
| 73 |
+
|
| 74 |
+
# Create prompt that includes the classification result
|
| 75 |
+
prompt = f"""You are assisting in a medical education and research task.
|
| 76 |
+
|
| 77 |
+
Based on the wound classification model, this image has been identified as: {predicted_label}
|
| 78 |
+
|
| 79 |
+
Please provide an educational analysis of this wound image focusing on:
|
| 80 |
+
1. Visible characteristics of the wound (size, color, texture, edges, surrounding tissue)
|
| 81 |
+
2. Educational explanation about this type of wound based on the classification: {predicted_label}
|
| 82 |
+
3. General wound healing stages if applicable
|
| 83 |
+
4. Key features that are typically associated with this wound type
|
| 84 |
+
|
| 85 |
+
Important guidelines:
|
| 86 |
+
- This is for educational and research purposes only
|
| 87 |
+
- Do not provide medical advice or diagnosis
|
| 88 |
+
- Keep the analysis objective and educational
|
| 89 |
+
- Focus on visible features and general wound characteristics
|
| 90 |
+
- Do not recommend treatments or medical interventions
|
| 91 |
+
|
| 92 |
+
Please provide a comprehensive educational analysis."""
|
| 93 |
+
|
| 94 |
+
response = gemini_model.generate_content([prompt, image])
|
| 95 |
+
return response.text
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
return f"Error analyzing image with Gemini: {str(e)}"
|
| 99 |
+
|
| 100 |
+
def analyze_wound_depth_with_gemini(image, depth_map, depth_stats):
|
| 101 |
+
"""
|
| 102 |
+
Analyze wound depth and severity using Gemini AI with depth analysis context
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
image: Original wound image (PIL Image or numpy array)
|
| 106 |
+
depth_map: Depth map (numpy array)
|
| 107 |
+
depth_stats: Dictionary containing depth analysis statistics
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
str: Gemini AI medical assessment based on depth analysis
|
| 111 |
+
"""
|
| 112 |
+
if image is None or depth_map is None:
|
| 113 |
+
return "No image or depth map provided for analysis."
|
| 114 |
+
|
| 115 |
+
if gemini_model is None:
|
| 116 |
+
return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets."
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
# Convert numpy array to PIL Image if needed
|
| 120 |
+
if isinstance(image, np.ndarray):
|
| 121 |
+
image = Image.fromarray(image)
|
| 122 |
+
|
| 123 |
+
# Ensure image is in RGB format
|
| 124 |
+
if image.mode != 'RGB':
|
| 125 |
+
image = image.convert('RGB')
|
| 126 |
+
|
| 127 |
+
# Convert depth map to PIL Image for Gemini
|
| 128 |
+
if isinstance(depth_map, np.ndarray):
|
| 129 |
+
# Normalize depth map for visualization
|
| 130 |
+
norm_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
|
| 131 |
+
depth_image = Image.fromarray(norm_depth.astype(np.uint8))
|
| 132 |
+
else:
|
| 133 |
+
depth_image = depth_map
|
| 134 |
+
|
| 135 |
+
# Create detailed prompt with depth statistics
|
| 136 |
+
prompt = f"""You are a medical AI assistant specializing in wound assessment. Analyze this wound using both the original image and depth map data.
|
| 137 |
+
|
| 138 |
+
DEPTH ANALYSIS DATA PROVIDED:
|
| 139 |
+
- Total Wound Area: {depth_stats['total_area_cm2']:.2f} cmΒ²
|
| 140 |
+
- Mean Depth: {depth_stats['mean_depth_mm']:.1f} mm
|
| 141 |
+
- Maximum Depth: {depth_stats['max_depth_mm']:.1f} mm
|
| 142 |
+
- Depth Standard Deviation: {depth_stats['depth_std_mm']:.1f} mm
|
| 143 |
+
- Wound Volume: {depth_stats['wound_volume_cm3']:.2f} cmΒ³
|
| 144 |
+
- Deep Tissue Involvement: {depth_stats['deep_ratio']*100:.1f}%
|
| 145 |
+
- Analysis Quality: {depth_stats['analysis_quality']}
|
| 146 |
+
- Depth Consistency: {depth_stats['depth_consistency']}
|
| 147 |
+
|
| 148 |
+
TISSUE DEPTH DISTRIBUTION:
|
| 149 |
+
- Superficial Areas (0-2mm): {depth_stats['superficial_area_cm2']:.2f} cmΒ²
|
| 150 |
+
- Partial Thickness (2-4mm): {depth_stats['partial_thickness_area_cm2']:.2f} cmΒ²
|
| 151 |
+
- Full Thickness (4-6mm): {depth_stats['full_thickness_area_cm2']:.2f} cmΒ²
|
| 152 |
+
- Deep Areas (>6mm): {depth_stats['deep_area_cm2']:.2f} cmΒ²
|
| 153 |
+
|
| 154 |
+
STATISTICAL DEPTH ANALYSIS:
|
| 155 |
+
- 25th Percentile Depth: {depth_stats['depth_percentiles']['25']:.1f} mm
|
| 156 |
+
- Median Depth: {depth_stats['depth_percentiles']['50']:.1f} mm
|
| 157 |
+
- 75th Percentile Depth: {depth_stats['depth_percentiles']['75']:.1f} mm
|
| 158 |
+
|
| 159 |
+
Please provide a comprehensive medical assessment focusing on:
|
| 160 |
+
|
| 161 |
+
1. **WOUND CHARACTERISTICS ANALYSIS**
|
| 162 |
+
- Visible wound features from the original image
|
| 163 |
+
- Correlation between visual appearance and depth measurements
|
| 164 |
+
- Tissue quality assessment based on color, texture, and depth data
|
| 165 |
+
|
| 166 |
+
2. **DEPTH-BASED SEVERITY ASSESSMENT**
|
| 167 |
+
- Clinical significance of the measured depths
|
| 168 |
+
- Tissue layer involvement based on depth measurements
|
| 169 |
+
- Risk assessment based on deep tissue involvement percentage
|
| 170 |
+
|
| 171 |
+
3. **HEALING PROGNOSIS**
|
| 172 |
+
- Expected healing timeline based on depth and area measurements
|
| 173 |
+
- Factors that may affect healing based on depth distribution
|
| 174 |
+
- Complexity assessment based on wound volume and depth variation
|
| 175 |
+
|
| 176 |
+
4. **CLINICAL CONSIDERATIONS**
|
| 177 |
+
- Significance of depth consistency/inconsistency
|
| 178 |
+
- Areas of particular concern based on depth analysis
|
| 179 |
+
- Educational insights about this type of wound presentation
|
| 180 |
+
|
| 181 |
+
5. **MEASUREMENT INTERPRETATION**
|
| 182 |
+
- Clinical relevance of the statistical depth measurements
|
| 183 |
+
- What the depth distribution tells us about wound progression
|
| 184 |
+
- Comparison to typical wound depth classifications
|
| 185 |
+
|
| 186 |
+
IMPORTANT GUIDELINES:
|
| 187 |
+
- This is for educational and research purposes only
|
| 188 |
+
- Do not provide specific medical advice or treatment recommendations
|
| 189 |
+
- Focus on objective analysis of the provided measurements
|
| 190 |
+
- Correlate visual findings with quantitative depth data
|
| 191 |
+
- Maintain educational and clinical terminology
|
| 192 |
+
- Emphasize the relationship between depth measurements and clinical significance
|
| 193 |
+
|
| 194 |
+
Provide a detailed, structured medical assessment that integrates both visual and quantitative depth analysis."""
|
| 195 |
+
|
| 196 |
+
# Send both images to Gemini for analysis
|
| 197 |
+
response = gemini_model.generate_content([prompt, image, depth_image])
|
| 198 |
+
return response.text
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
return f"Error analyzing wound with Gemini AI: {str(e)}"
|
| 202 |
+
|
| 203 |
+
def classify_wound(image):
|
| 204 |
+
"""
|
| 205 |
+
Classify wound type from uploaded image
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
image: PIL Image or numpy array
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
dict: Classification results with confidence scores
|
| 212 |
+
"""
|
| 213 |
+
if image is None:
|
| 214 |
+
return "Please upload an image"
|
| 215 |
+
|
| 216 |
+
# Convert to PIL Image if needed
|
| 217 |
+
if isinstance(image, np.ndarray):
|
| 218 |
+
image = Image.fromarray(image)
|
| 219 |
+
|
| 220 |
+
# Ensure image is in RGB format
|
| 221 |
+
if image.mode != 'RGB':
|
| 222 |
+
image = image.convert('RGB')
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
# Process the image
|
| 226 |
+
inputs = classification_processor(images=image, return_tensors="pt")
|
| 227 |
+
|
| 228 |
+
# Get model predictions
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
outputs = classification_model(**inputs)
|
| 231 |
+
predictions = torch.nn.functional.softmax(outputs.logits[0], dim=-1)
|
| 232 |
+
|
| 233 |
+
# Get the predicted class labels and confidence scores
|
| 234 |
+
confidence_scores = predictions.numpy()
|
| 235 |
+
|
| 236 |
+
# Create results dictionary
|
| 237 |
+
results = {}
|
| 238 |
+
for i, score in enumerate(confidence_scores):
|
| 239 |
+
# Get class name from model config
|
| 240 |
+
class_name = classification_model.config.id2label[i] if hasattr(classification_model.config, 'id2label') else f"Class {i}"
|
| 241 |
+
results[class_name] = float(score)
|
| 242 |
+
|
| 243 |
+
return results
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
return f"Error processing image: {str(e)}"
|
| 247 |
+
|
| 248 |
+
def classify_and_analyze_wound(image):
|
| 249 |
+
"""
|
| 250 |
+
Combined function to classify wound and get Gemini analysis
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
image: PIL Image or numpy array
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
tuple: (classification_results, gemini_analysis)
|
| 257 |
+
"""
|
| 258 |
+
if image is None:
|
| 259 |
+
return "Please upload an image", "Please upload an image for analysis"
|
| 260 |
+
|
| 261 |
+
# Get classification results
|
| 262 |
+
classification_results = classify_wound(image)
|
| 263 |
+
|
| 264 |
+
# Get the top predicted label for Gemini analysis
|
| 265 |
+
if isinstance(classification_results, dict) and classification_results:
|
| 266 |
+
# Get the label with highest confidence
|
| 267 |
+
top_label = max(classification_results.items(), key=lambda x: x[1])[0]
|
| 268 |
+
|
| 269 |
+
# Get Gemini analysis
|
| 270 |
+
gemini_analysis = analyze_wound_with_gemini(image, top_label)
|
| 271 |
+
else:
|
| 272 |
+
top_label = "Unknown"
|
| 273 |
+
gemini_analysis = "Unable to analyze due to classification error"
|
| 274 |
+
|
| 275 |
+
return classification_results, gemini_analysis
|
| 276 |
+
|
| 277 |
+
def format_gemini_analysis(analysis):
|
| 278 |
+
"""Format Gemini analysis as properly structured HTML"""
|
| 279 |
+
if not analysis or "Error" in analysis:
|
| 280 |
+
return f"""
|
| 281 |
+
<div style="
|
| 282 |
+
background-color: #fee2e2;
|
| 283 |
+
border-radius: 12px;
|
| 284 |
+
padding: 16px;
|
| 285 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 286 |
+
font-family: Arial, sans-serif;
|
| 287 |
+
min-height: 300px;
|
| 288 |
+
border-left: 4px solid #ef4444;
|
| 289 |
+
">
|
| 290 |
+
<h4 style="color: #dc2626; margin-top: 0;">Analysis Error</h4>
|
| 291 |
+
<p style="color: #991b1b;">{analysis}</p>
|
| 292 |
+
</div>
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
# Parse the markdown-style response and convert to HTML
|
| 296 |
+
formatted_analysis = parse_markdown_to_html(analysis)
|
| 297 |
+
|
| 298 |
+
return f"""
|
| 299 |
+
<div style="
|
| 300 |
+
border-radius: 12px;
|
| 301 |
+
padding: 25px;
|
| 302 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 303 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 304 |
+
min-height: 300px;
|
| 305 |
+
border-left: 4px solid #d97706;
|
| 306 |
+
max-height: 600px;
|
| 307 |
+
overflow-y: auto;
|
| 308 |
+
">
|
| 309 |
+
<h3 style="color: #d97706; margin-top: 0; margin-bottom: 20px; display: flex; align-items: center; gap: 8px;">
|
| 310 |
+
Initial Wound Analysis
|
| 311 |
+
</h3>
|
| 312 |
+
<div style="color: white; line-height: 1.7;">
|
| 313 |
+
{formatted_analysis}
|
| 314 |
+
</div>
|
| 315 |
+
</div>
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def format_gemini_depth_analysis(analysis):
|
| 319 |
+
"""Format Gemini depth analysis as properly structured HTML for medical assessment"""
|
| 320 |
+
if not analysis or "Error" in analysis:
|
| 321 |
+
return f"""
|
| 322 |
+
<div style="color: #ffffff; line-height: 1.6;">
|
| 323 |
+
<div style="font-size: 16px; font-weight: bold; margin-bottom: 10px; color: #f44336;">
|
| 324 |
+
β AI Analysis Error
|
| 325 |
+
</div>
|
| 326 |
+
<div style="color: #cccccc;">
|
| 327 |
+
{analysis}
|
| 328 |
+
</div>
|
| 329 |
+
</div>
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
# Parse the markdown-style response and convert to HTML
|
| 333 |
+
formatted_analysis = parse_markdown_to_html(analysis)
|
| 334 |
+
|
| 335 |
+
return f"""
|
| 336 |
+
<div style="color: #ffffff; line-height: 1.6;">
|
| 337 |
+
<div style="font-size: 16px; font-weight: bold; margin-bottom: 15px; color: #4CAF50;">
|
| 338 |
+
π€ AI-Powered Medical Assessment
|
| 339 |
+
</div>
|
| 340 |
+
<div style="color: #cccccc; max-height: 400px; overflow-y: auto; padding-right: 10px;">
|
| 341 |
+
{formatted_analysis}
|
| 342 |
+
</div>
|
| 343 |
+
</div>
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def parse_markdown_to_html(text):
|
| 347 |
+
"""Convert markdown-style text to HTML"""
|
| 348 |
+
import re
|
| 349 |
+
|
| 350 |
+
# Replace markdown headers
|
| 351 |
+
text = re.sub(r'^### \*\*(.*?)\*\*$', r'<h4 style="color: #d97706; margin: 20px 0 10px 0; font-weight: bold;">\1</h4>', text, flags=re.MULTILINE)
|
| 352 |
+
text = re.sub(r'^#### \*\*(.*?)\*\*$', r'<h5 style="color: #f59e0b; margin: 15px 0 8px 0; font-weight: bold;">\1</h5>', text, flags=re.MULTILINE)
|
| 353 |
+
text = re.sub(r'^### (.*?)$', r'<h4 style="color: #d97706; margin: 20px 0 10px 0; font-weight: bold;">\1</h4>', text, flags=re.MULTILINE)
|
| 354 |
+
text = re.sub(r'^#### (.*?)$', r'<h5 style="color: #f59e0b; margin: 15px 0 8px 0; font-weight: bold;">\1</h5>', text, flags=re.MULTILINE)
|
| 355 |
+
|
| 356 |
+
# Replace bold text
|
| 357 |
+
text = re.sub(r'\*\*(.*?)\*\*', r'<strong style="color: #fbbf24;">\1</strong>', text)
|
| 358 |
+
|
| 359 |
+
# Replace italic text
|
| 360 |
+
text = re.sub(r'\*(.*?)\*', r'<em style="color: #fde68a;">\1</em>', text)
|
| 361 |
+
|
| 362 |
+
# Replace bullet points
|
| 363 |
+
text = re.sub(r'^\* (.*?)$', r'<li style="margin: 5px 0; color: white;">\1</li>', text, flags=re.MULTILINE)
|
| 364 |
+
text = re.sub(r'^ \* (.*?)$', r'<li style="margin: 3px 0; margin-left: 20px; color: white;">\1</li>', text, flags=re.MULTILINE)
|
| 365 |
+
|
| 366 |
+
# Wrap consecutive list items in ul tags
|
| 367 |
+
text = re.sub(r'(<li.*?</li>(?:\s*<li.*?</li>)*)', r'<ul style="margin: 10px 0; padding-left: 20px;">\1</ul>', text, flags=re.DOTALL)
|
| 368 |
+
|
| 369 |
+
# Replace numbered lists
|
| 370 |
+
text = re.sub(r'^(\d+)\.\s+(.*?)$', r'<div style="margin: 8px 0; color: white;"><strong style="color: #d97706;">\1.</strong> \2</div>', text, flags=re.MULTILINE)
|
| 371 |
+
|
| 372 |
+
# Convert paragraphs (double newlines)
|
| 373 |
+
paragraphs = text.split('\n\n')
|
| 374 |
+
formatted_paragraphs = []
|
| 375 |
+
|
| 376 |
+
for para in paragraphs:
|
| 377 |
+
para = para.strip()
|
| 378 |
+
if para:
|
| 379 |
+
# Skip if it's already wrapped in HTML tags
|
| 380 |
+
if not (para.startswith('<') or para.endswith('>')):
|
| 381 |
+
para = f'<p style="margin: 12px 0; color: white; text-align: justify;">{para}</p>'
|
| 382 |
+
formatted_paragraphs.append(para)
|
| 383 |
+
|
| 384 |
+
return '\n'.join(formatted_paragraphs)
|
| 385 |
+
|
| 386 |
+
def combined_analysis(image):
|
| 387 |
+
"""Combined function for UI that returns both outputs"""
|
| 388 |
+
classification, gemini_analysis = classify_and_analyze_wound(image)
|
| 389 |
+
formatted_analysis = format_gemini_analysis(gemini_analysis)
|
| 390 |
+
return classification, formatted_analysis
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# Define path and file ID
|
| 397 |
+
checkpoint_dir = "checkpoints"
|
| 398 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 399 |
+
|
| 400 |
+
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
|
| 401 |
+
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
|
| 402 |
+
|
| 403 |
+
# Download if not already present
|
| 404 |
+
if not os.path.exists(model_file):
|
| 405 |
+
print("Downloading model from Google Drive...")
|
| 406 |
+
gdown.download(gdrive_url, model_file, quiet=False)
|
| 407 |
+
|
| 408 |
+
# --- TensorFlow: Check GPU Availability ---
|
| 409 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 410 |
+
if gpus:
|
| 411 |
+
print("TensorFlow is using GPU")
|
| 412 |
+
else:
|
| 413 |
+
print("TensorFlow is using CPU")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# --- Load Actual Wound Segmentation Model ---
|
| 418 |
+
class WoundSegmentationModel:
|
| 419 |
+
def __init__(self):
|
| 420 |
+
self.input_dim_x = 224
|
| 421 |
+
self.input_dim_y = 224
|
| 422 |
+
self.model = None
|
| 423 |
+
self.load_model()
|
| 424 |
+
|
| 425 |
+
def load_model(self):
|
| 426 |
+
"""Load the trained wound segmentation model"""
|
| 427 |
+
try:
|
| 428 |
+
# Try to load the most recent model
|
| 429 |
+
weight_file_name = '2025-08-07_16-25-27.hdf5'
|
| 430 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 431 |
+
|
| 432 |
+
self.model = load_model(model_path,
|
| 433 |
+
custom_objects={
|
| 434 |
+
'recall': recall,
|
| 435 |
+
'precision': precision,
|
| 436 |
+
'dice_coef': dice_coef,
|
| 437 |
+
'relu6': relu6,
|
| 438 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 439 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 440 |
+
})
|
| 441 |
+
print(f"Segmentation model loaded successfully from {model_path}")
|
| 442 |
+
except Exception as e:
|
| 443 |
+
print(f"Error loading segmentation model: {e}")
|
| 444 |
+
# Fallback to the older model
|
| 445 |
+
try:
|
| 446 |
+
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
|
| 447 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 448 |
+
|
| 449 |
+
self.model = load_model(model_path,
|
| 450 |
+
custom_objects={
|
| 451 |
+
'recall': recall,
|
| 452 |
+
'precision': precision,
|
| 453 |
+
'dice_coef': dice_coef,
|
| 454 |
+
'relu6': relu6,
|
| 455 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 456 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 457 |
+
})
|
| 458 |
+
print(f"Segmentation model loaded successfully from {model_path}")
|
| 459 |
+
except Exception as e2:
|
| 460 |
+
print(f"Error loading fallback segmentation model: {e2}")
|
| 461 |
+
self.model = None
|
| 462 |
+
|
| 463 |
+
def preprocess_image(self, image):
|
| 464 |
+
"""Preprocess the uploaded image for model input"""
|
| 465 |
+
if image is None:
|
| 466 |
+
return None
|
| 467 |
+
|
| 468 |
+
# Convert to RGB if needed
|
| 469 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 470 |
+
# Convert BGR to RGB if needed
|
| 471 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 472 |
+
|
| 473 |
+
# Resize to model input size
|
| 474 |
+
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
|
| 475 |
+
|
| 476 |
+
# Normalize the image
|
| 477 |
+
image = image.astype(np.float32) / 255.0
|
| 478 |
+
|
| 479 |
+
# Add batch dimension
|
| 480 |
+
image = np.expand_dims(image, axis=0)
|
| 481 |
+
|
| 482 |
+
return image
|
| 483 |
+
|
| 484 |
+
def postprocess_prediction(self, prediction):
|
| 485 |
+
"""Postprocess the model prediction"""
|
| 486 |
+
# Remove batch dimension
|
| 487 |
+
prediction = prediction[0]
|
| 488 |
+
|
| 489 |
+
# Apply threshold to get binary mask
|
| 490 |
+
threshold = 0.5
|
| 491 |
+
binary_mask = (prediction > threshold).astype(np.uint8) * 255
|
| 492 |
+
|
| 493 |
+
return binary_mask
|
| 494 |
+
|
| 495 |
+
def segment_wound(self, input_image):
|
| 496 |
+
"""Main function to segment wound from uploaded image"""
|
| 497 |
+
if self.model is None:
|
| 498 |
+
return None, "Error: Segmentation model not loaded. Please check the model files."
|
| 499 |
+
|
| 500 |
+
if input_image is None:
|
| 501 |
+
return None, "Please upload an image."
|
| 502 |
+
|
| 503 |
+
try:
|
| 504 |
+
# Preprocess the image
|
| 505 |
+
processed_image = self.preprocess_image(input_image)
|
| 506 |
+
|
| 507 |
+
if processed_image is None:
|
| 508 |
+
return None, "Error processing image."
|
| 509 |
+
|
| 510 |
+
# Make prediction
|
| 511 |
+
prediction = self.model.predict(processed_image, verbose=0)
|
| 512 |
+
|
| 513 |
+
# Postprocess the prediction
|
| 514 |
+
segmented_mask = self.postprocess_prediction(prediction)
|
| 515 |
+
|
| 516 |
+
return segmented_mask, "Segmentation completed successfully!"
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
return None, f"Error during segmentation: {str(e)}"
|
| 520 |
+
|
| 521 |
+
# Initialize the segmentation model
|
| 522 |
+
segmentation_model = WoundSegmentationModel()
|
| 523 |
+
|
| 524 |
+
# --- PyTorch: Set Device and Load Depth Model ---
|
| 525 |
+
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
|
| 526 |
+
print(f"Using PyTorch device: {map_device}")
|
| 527 |
+
|
| 528 |
+
model_configs = {
|
| 529 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 530 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 531 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 532 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
| 533 |
+
}
|
| 534 |
+
encoder = 'vitl'
|
| 535 |
+
depth_model = DepthAnythingV2(**model_configs[encoder])
|
| 536 |
+
state_dict = torch.load(
|
| 537 |
+
f'checkpoints/depth_anything_v2_{encoder}.pth',
|
| 538 |
+
map_location=map_device
|
| 539 |
+
)
|
| 540 |
+
depth_model.load_state_dict(state_dict)
|
| 541 |
+
depth_model = depth_model.to(map_device).eval()
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# --- Custom CSS for unified dark theme ---
|
| 545 |
+
css = """
|
| 546 |
+
.gradio-container {
|
| 547 |
+
font-family: 'Segoe UI', sans-serif;
|
| 548 |
+
background-color: #121212;
|
| 549 |
+
color: #ffffff;
|
| 550 |
+
padding: 20px;
|
| 551 |
+
}
|
| 552 |
+
.gr-button {
|
| 553 |
+
background-color: #2c3e50;
|
| 554 |
+
color: white;
|
| 555 |
+
border-radius: 10px;
|
| 556 |
+
}
|
| 557 |
+
.gr-button:hover {
|
| 558 |
+
background-color: #34495e;
|
| 559 |
+
}
|
| 560 |
+
.gr-html, .gr-html div {
|
| 561 |
+
white-space: normal !important;
|
| 562 |
+
overflow: visible !important;
|
| 563 |
+
text-overflow: unset !important;
|
| 564 |
+
word-break: break-word !important;
|
| 565 |
+
}
|
| 566 |
+
#img-display-container {
|
| 567 |
+
max-height: 100vh;
|
| 568 |
+
}
|
| 569 |
+
#img-display-input {
|
| 570 |
+
max-height: 80vh;
|
| 571 |
+
}
|
| 572 |
+
#img-display-output {
|
| 573 |
+
max-height: 80vh;
|
| 574 |
+
}
|
| 575 |
+
#download {
|
| 576 |
+
height: 62px;
|
| 577 |
+
}
|
| 578 |
+
h1 {
|
| 579 |
+
text-align: center;
|
| 580 |
+
font-size: 3rem;
|
| 581 |
+
font-weight: bold;
|
| 582 |
+
margin: 2rem 0;
|
| 583 |
+
color: #ffffff;
|
| 584 |
+
}
|
| 585 |
+
h2 {
|
| 586 |
+
color: #ffffff;
|
| 587 |
+
text-align: center;
|
| 588 |
+
margin: 1rem 0;
|
| 589 |
+
}
|
| 590 |
+
.gr-tabs {
|
| 591 |
+
background-color: #1e1e1e;
|
| 592 |
+
border-radius: 10px;
|
| 593 |
+
padding: 10px;
|
| 594 |
+
}
|
| 595 |
+
.gr-tab-nav {
|
| 596 |
+
background-color: #2c3e50;
|
| 597 |
+
border-radius: 8px;
|
| 598 |
+
}
|
| 599 |
+
.gr-tab-nav button {
|
| 600 |
+
color: #ffffff !important;
|
| 601 |
+
}
|
| 602 |
+
.gr-tab-nav button.selected {
|
| 603 |
+
background-color: #34495e !important;
|
| 604 |
+
}
|
| 605 |
+
/* Card styling for consistent heights */
|
| 606 |
+
.wound-card {
|
| 607 |
+
min-height: 200px !important;
|
| 608 |
+
display: flex !important;
|
| 609 |
+
flex-direction: column !important;
|
| 610 |
+
justify-content: space-between !important;
|
| 611 |
+
}
|
| 612 |
+
.wound-card-content {
|
| 613 |
+
flex-grow: 1 !important;
|
| 614 |
+
display: flex !important;
|
| 615 |
+
flex-direction: column !important;
|
| 616 |
+
justify-content: center !important;
|
| 617 |
+
}
|
| 618 |
+
/* Loading animation */
|
| 619 |
+
.loading-spinner {
|
| 620 |
+
display: inline-block;
|
| 621 |
+
width: 20px;
|
| 622 |
+
height: 20px;
|
| 623 |
+
border: 3px solid #f3f3f3;
|
| 624 |
+
border-top: 3px solid #3498db;
|
| 625 |
+
border-radius: 50%;
|
| 626 |
+
animation: spin 1s linear infinite;
|
| 627 |
+
}
|
| 628 |
+
@keyframes spin {
|
| 629 |
+
0% { transform: rotate(0deg); }
|
| 630 |
+
100% { transform: rotate(360deg); }
|
| 631 |
+
}
|
| 632 |
+
"""
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# --- Enhanced Wound Severity Estimation Functions ---
|
| 639 |
+
|
| 640 |
+
def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
|
| 641 |
+
"""
|
| 642 |
+
Enhanced depth analysis with proper calibration and medical standards
|
| 643 |
+
Based on wound depth classification standards:
|
| 644 |
+
- Superficial: 0-2mm (epidermis only)
|
| 645 |
+
- Partial thickness: 2-4mm (epidermis + partial dermis)
|
| 646 |
+
- Full thickness: 4-6mm (epidermis + full dermis)
|
| 647 |
+
- Deep: >6mm (involving subcutaneous tissue)
|
| 648 |
+
"""
|
| 649 |
+
# Convert pixel spacing to mm
|
| 650 |
+
pixel_spacing_mm = float(pixel_spacing_mm)
|
| 651 |
+
|
| 652 |
+
# Calculate pixel area in cmΒ²
|
| 653 |
+
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
|
| 654 |
+
|
| 655 |
+
# Extract wound region (binary mask)
|
| 656 |
+
wound_mask = (mask > 127).astype(np.uint8)
|
| 657 |
+
|
| 658 |
+
# Apply morphological operations to clean the mask
|
| 659 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 660 |
+
wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel)
|
| 661 |
+
|
| 662 |
+
# Get depth values only for wound region
|
| 663 |
+
wound_depths = depth_map[wound_mask > 0]
|
| 664 |
+
|
| 665 |
+
if len(wound_depths) == 0:
|
| 666 |
+
return {
|
| 667 |
+
'total_area_cm2': 0,
|
| 668 |
+
'superficial_area_cm2': 0,
|
| 669 |
+
'partial_thickness_area_cm2': 0,
|
| 670 |
+
'full_thickness_area_cm2': 0,
|
| 671 |
+
'deep_area_cm2': 0,
|
| 672 |
+
'mean_depth_mm': 0,
|
| 673 |
+
'max_depth_mm': 0,
|
| 674 |
+
'depth_std_mm': 0,
|
| 675 |
+
'deep_ratio': 0,
|
| 676 |
+
'wound_volume_cm3': 0,
|
| 677 |
+
'depth_percentiles': {'25': 0, '50': 0, '75': 0}
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
# Normalize depth relative to nearest point in wound area
|
| 681 |
+
normalized_depth_map, nearest_point_coords, max_relative_depth = normalize_depth_relative_to_nearest_point(depth_map, wound_mask)
|
| 682 |
+
|
| 683 |
+
# Calibrate the normalized depth map for more accurate measurements
|
| 684 |
+
calibrated_depth_map = calibrate_depth_map(normalized_depth_map, reference_depth_mm=depth_calibration_mm)
|
| 685 |
+
|
| 686 |
+
# Get calibrated depth values for wound region
|
| 687 |
+
wound_depths_mm = calibrated_depth_map[wound_mask > 0]
|
| 688 |
+
|
| 689 |
+
# Medical depth classification
|
| 690 |
+
superficial_mask = wound_depths_mm < 2.0
|
| 691 |
+
partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0)
|
| 692 |
+
full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0)
|
| 693 |
+
deep_mask = wound_depths_mm >= 6.0
|
| 694 |
+
|
| 695 |
+
# Calculate areas
|
| 696 |
+
total_pixels = np.sum(wound_mask > 0)
|
| 697 |
+
total_area_cm2 = total_pixels * pixel_area_cm2
|
| 698 |
+
|
| 699 |
+
superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2
|
| 700 |
+
partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2
|
| 701 |
+
full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2
|
| 702 |
+
deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2
|
| 703 |
+
|
| 704 |
+
# Calculate depth statistics
|
| 705 |
+
mean_depth_mm = np.mean(wound_depths_mm)
|
| 706 |
+
max_depth_mm = np.max(wound_depths_mm)
|
| 707 |
+
depth_std_mm = np.std(wound_depths_mm)
|
| 708 |
+
|
| 709 |
+
# Calculate depth percentiles
|
| 710 |
+
depth_percentiles = {
|
| 711 |
+
'25': np.percentile(wound_depths_mm, 25),
|
| 712 |
+
'50': np.percentile(wound_depths_mm, 50),
|
| 713 |
+
'75': np.percentile(wound_depths_mm, 75)
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
# Calculate depth distribution statistics
|
| 717 |
+
depth_distribution = {
|
| 718 |
+
'shallow_ratio': np.sum(wound_depths_mm < 2.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0,
|
| 719 |
+
'moderate_ratio': np.sum((wound_depths_mm >= 2.0) & (wound_depths_mm < 5.0)) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0,
|
| 720 |
+
'deep_ratio': np.sum(wound_depths_mm >= 5.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
# Calculate wound volume (approximate)
|
| 724 |
+
# Volume = area * average depth
|
| 725 |
+
wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0)
|
| 726 |
+
|
| 727 |
+
# Deep tissue ratio
|
| 728 |
+
deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0
|
| 729 |
+
|
| 730 |
+
# Calculate analysis quality metrics
|
| 731 |
+
wound_pixel_count = len(wound_depths_mm)
|
| 732 |
+
analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low"
|
| 733 |
+
|
| 734 |
+
# Calculate depth consistency (lower std dev = more consistent)
|
| 735 |
+
depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low"
|
| 736 |
+
|
| 737 |
+
return {
|
| 738 |
+
'total_area_cm2': total_area_cm2,
|
| 739 |
+
'superficial_area_cm2': superficial_area_cm2,
|
| 740 |
+
'partial_thickness_area_cm2': partial_thickness_area_cm2,
|
| 741 |
+
'full_thickness_area_cm2': full_thickness_area_cm2,
|
| 742 |
+
'deep_area_cm2': deep_area_cm2,
|
| 743 |
+
'mean_depth_mm': mean_depth_mm,
|
| 744 |
+
'max_depth_mm': max_depth_mm,
|
| 745 |
+
'depth_std_mm': depth_std_mm,
|
| 746 |
+
'deep_ratio': deep_ratio,
|
| 747 |
+
'wound_volume_cm3': wound_volume_cm3,
|
| 748 |
+
'depth_percentiles': depth_percentiles,
|
| 749 |
+
'depth_distribution': depth_distribution,
|
| 750 |
+
'analysis_quality': analysis_quality,
|
| 751 |
+
'depth_consistency': depth_consistency,
|
| 752 |
+
'wound_pixel_count': wound_pixel_count,
|
| 753 |
+
'nearest_point_coords': nearest_point_coords,
|
| 754 |
+
'max_relative_depth': max_relative_depth,
|
| 755 |
+
'normalized_depth_map': normalized_depth_map
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
def classify_wound_severity_by_enhanced_metrics(depth_stats):
|
| 759 |
+
"""
|
| 760 |
+
Enhanced wound severity classification based on medical standards
|
| 761 |
+
Uses multiple criteria: depth, area, volume, and tissue involvement
|
| 762 |
+
"""
|
| 763 |
+
if depth_stats['total_area_cm2'] == 0:
|
| 764 |
+
return "Unknown"
|
| 765 |
+
|
| 766 |
+
# Extract key metrics
|
| 767 |
+
total_area = depth_stats['total_area_cm2']
|
| 768 |
+
deep_area = depth_stats['deep_area_cm2']
|
| 769 |
+
full_thickness_area = depth_stats['full_thickness_area_cm2']
|
| 770 |
+
mean_depth = depth_stats['mean_depth_mm']
|
| 771 |
+
max_depth = depth_stats['max_depth_mm']
|
| 772 |
+
wound_volume = depth_stats['wound_volume_cm3']
|
| 773 |
+
deep_ratio = depth_stats['deep_ratio']
|
| 774 |
+
|
| 775 |
+
# Medical severity classification criteria
|
| 776 |
+
severity_score = 0
|
| 777 |
+
|
| 778 |
+
# Criterion 1: Maximum depth
|
| 779 |
+
if max_depth >= 10.0:
|
| 780 |
+
severity_score += 3 # Very severe
|
| 781 |
+
elif max_depth >= 6.0:
|
| 782 |
+
severity_score += 2 # Severe
|
| 783 |
+
elif max_depth >= 4.0:
|
| 784 |
+
severity_score += 1 # Moderate
|
| 785 |
+
|
| 786 |
+
# Criterion 2: Mean depth
|
| 787 |
+
if mean_depth >= 5.0:
|
| 788 |
+
severity_score += 2
|
| 789 |
+
elif mean_depth >= 3.0:
|
| 790 |
+
severity_score += 1
|
| 791 |
+
|
| 792 |
+
# Criterion 3: Deep tissue involvement ratio
|
| 793 |
+
if deep_ratio >= 0.5:
|
| 794 |
+
severity_score += 3 # More than 50% deep tissue
|
| 795 |
+
elif deep_ratio >= 0.25:
|
| 796 |
+
severity_score += 2 # 25-50% deep tissue
|
| 797 |
+
elif deep_ratio >= 0.1:
|
| 798 |
+
severity_score += 1 # 10-25% deep tissue
|
| 799 |
+
|
| 800 |
+
# Criterion 4: Total wound area
|
| 801 |
+
if total_area >= 10.0:
|
| 802 |
+
severity_score += 2 # Large wound (>10 cmΒ²)
|
| 803 |
+
elif total_area >= 5.0:
|
| 804 |
+
severity_score += 1 # Medium wound (5-10 cmΒ²)
|
| 805 |
+
|
| 806 |
+
# Criterion 5: Wound volume
|
| 807 |
+
if wound_volume >= 5.0:
|
| 808 |
+
severity_score += 2 # High volume
|
| 809 |
+
elif wound_volume >= 2.0:
|
| 810 |
+
severity_score += 1 # Medium volume
|
| 811 |
+
|
| 812 |
+
# Determine severity based on total score
|
| 813 |
+
if severity_score >= 8:
|
| 814 |
+
return "Very Severe"
|
| 815 |
+
elif severity_score >= 6:
|
| 816 |
+
return "Severe"
|
| 817 |
+
elif severity_score >= 4:
|
| 818 |
+
return "Moderate"
|
| 819 |
+
elif severity_score >= 2:
|
| 820 |
+
return "Mild"
|
| 821 |
+
else:
|
| 822 |
+
return "Superficial"
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
|
| 829 |
+
"""Enhanced wound severity analysis based on depth measurements"""
|
| 830 |
+
if image is None or depth_map is None or wound_mask is None:
|
| 831 |
+
return "β Please upload image, depth map, and wound mask."
|
| 832 |
+
|
| 833 |
+
# Convert wound mask to grayscale if needed
|
| 834 |
+
if len(wound_mask.shape) == 3:
|
| 835 |
+
wound_mask = np.mean(wound_mask, axis=2)
|
| 836 |
+
|
| 837 |
+
# Ensure depth map and mask have same dimensions
|
| 838 |
+
if depth_map.shape[:2] != wound_mask.shape[:2]:
|
| 839 |
+
# Resize mask to match depth map
|
| 840 |
+
from PIL import Image
|
| 841 |
+
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
|
| 842 |
+
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
|
| 843 |
+
wound_mask = np.array(mask_pil)
|
| 844 |
+
|
| 845 |
+
# Compute enhanced statistics with relative depth normalization
|
| 846 |
+
stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm)
|
| 847 |
+
|
| 848 |
+
# Get severity based on enhanced metrics
|
| 849 |
+
severity_level = classify_wound_severity_by_enhanced_metrics(stats)
|
| 850 |
+
severity_description = get_enhanced_severity_description(severity_level)
|
| 851 |
+
|
| 852 |
+
# Get Gemini AI analysis based on depth data
|
| 853 |
+
gemini_analysis = analyze_wound_depth_with_gemini(image, depth_map, stats)
|
| 854 |
+
|
| 855 |
+
# Format Gemini analysis for display
|
| 856 |
+
formatted_gemini_analysis = format_gemini_depth_analysis(gemini_analysis)
|
| 857 |
+
|
| 858 |
+
# Create depth analysis visualization
|
| 859 |
+
depth_visualization = create_depth_analysis_visualization(
|
| 860 |
+
stats['normalized_depth_map'], wound_mask,
|
| 861 |
+
stats['nearest_point_coords'], stats['max_relative_depth']
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
# Enhanced severity color coding
|
| 865 |
+
severity_color = {
|
| 866 |
+
"Superficial": "#4CAF50", # Green
|
| 867 |
+
"Mild": "#8BC34A", # Light Green
|
| 868 |
+
"Moderate": "#FF9800", # Orange
|
| 869 |
+
"Severe": "#F44336", # Red
|
| 870 |
+
"Very Severe": "#9C27B0" # Purple
|
| 871 |
+
}.get(severity_level, "#9E9E9E") # Gray for unknown
|
| 872 |
+
|
| 873 |
+
# Create comprehensive medical report
|
| 874 |
+
report = f"""
|
| 875 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 876 |
+
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
|
| 877 |
+
π©Ή Enhanced Wound Severity Analysis
|
| 878 |
+
</div>
|
| 879 |
+
|
| 880 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px;'>
|
| 881 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 15px; text-align: center;'>
|
| 882 |
+
π Depth & Quality Analysis
|
| 883 |
+
</div>
|
| 884 |
+
<div style='color: #cccccc; line-height: 1.6; display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 20px;'>
|
| 885 |
+
<div>
|
| 886 |
+
<div style='font-size: 16px; font-weight: bold; color: #ff9800; margin-bottom: 8px;'>οΏ½ Basic Measurements</div>
|
| 887 |
+
<div>οΏ½π <b>Mean Relative Depth:</b> {stats['mean_depth_mm']:.1f} mm</div>
|
| 888 |
+
<div>π <b>Max Relative Depth:</b> {stats['max_depth_mm']:.1f} mm</div>
|
| 889 |
+
<div>π <b>Depth Std Dev:</b> {stats['depth_std_mm']:.1f} mm</div>
|
| 890 |
+
<div>π¦ <b>Wound Volume:</b> {stats['wound_volume_cm3']:.2f} cmΒ³</div>
|
| 891 |
+
<div>π₯ <b>Deep Tissue Ratio:</b> {stats['deep_ratio']*100:.1f}%</div>
|
| 892 |
+
</div>
|
| 893 |
+
<div>
|
| 894 |
+
<div style='font-size: 16px; font-weight: bold; color: #4CAF50; margin-bottom: 8px;'>π Statistical Analysis</div>
|
| 895 |
+
<div>οΏ½ <b>25th Percentile:</b> {stats['depth_percentiles']['25']:.1f} mm</div>
|
| 896 |
+
<div>π <b>Median (50th):</b> {stats['depth_percentiles']['50']:.1f} mm</div>
|
| 897 |
+
<div>π <b>75th Percentile:</b> {stats['depth_percentiles']['75']:.1f} mm</div>
|
| 898 |
+
<div>π <b>Shallow Areas:</b> {stats['depth_distribution']['shallow_ratio']*100:.1f}%</div>
|
| 899 |
+
<div>π <b>Moderate Areas:</b> {stats['depth_distribution']['moderate_ratio']*100:.1f}%</div>
|
| 900 |
+
</div>
|
| 901 |
+
<div>
|
| 902 |
+
<div style='font-size: 16px; font-weight: bold; color: #2196F3; margin-bottom: 8px;'>π Quality Metrics</div>
|
| 903 |
+
<div>π <b>Analysis Quality:</b> {stats['analysis_quality']}</div>
|
| 904 |
+
<div>π <b>Depth Consistency:</b> {stats['depth_consistency']}</div>
|
| 905 |
+
<div>π <b>Data Points:</b> {stats['wound_pixel_count']:,}</div>
|
| 906 |
+
<div>π <b>Deep Areas:</b> {stats['depth_distribution']['deep_ratio']*100:.1f}%</div>
|
| 907 |
+
<div>π― <b>Reference Point:</b> Nearest to camera</div>
|
| 908 |
+
</div>
|
| 909 |
+
</div>
|
| 910 |
+
</div>
|
| 911 |
+
|
| 912 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid {severity_color};'>
|
| 913 |
+
<div style='font-size: 18px; font-weight: bold; color: {severity_color}; margin-bottom: 10px;'>
|
| 914 |
+
π Medical Assessment Based on Depth Analysis
|
| 915 |
+
</div>
|
| 916 |
+
{formatted_gemini_analysis}
|
| 917 |
+
</div>
|
| 918 |
+
</div>
|
| 919 |
+
"""
|
| 920 |
+
|
| 921 |
+
return report
|
| 922 |
+
|
| 923 |
+
def normalize_depth_relative_to_nearest_point(depth_map, wound_mask):
|
| 924 |
+
"""
|
| 925 |
+
Normalize depth map relative to the nearest point in the wound area
|
| 926 |
+
This assumes a top-down camera perspective where the closest point to camera = 0 depth
|
| 927 |
+
|
| 928 |
+
Args:
|
| 929 |
+
depth_map: Raw depth map
|
| 930 |
+
wound_mask: Binary mask of wound region
|
| 931 |
+
|
| 932 |
+
Returns:
|
| 933 |
+
normalized_depth: Depth values relative to nearest point (0 = nearest, positive = deeper)
|
| 934 |
+
nearest_point_coords: Coordinates of the nearest point
|
| 935 |
+
max_relative_depth: Maximum relative depth in the wound
|
| 936 |
+
"""
|
| 937 |
+
if depth_map is None or wound_mask is None:
|
| 938 |
+
return depth_map, None, 0
|
| 939 |
+
|
| 940 |
+
# Convert mask to binary
|
| 941 |
+
binary_mask = (wound_mask > 127).astype(np.uint8)
|
| 942 |
+
|
| 943 |
+
# Find wound region coordinates
|
| 944 |
+
wound_coords = np.where(binary_mask > 0)
|
| 945 |
+
|
| 946 |
+
if len(wound_coords[0]) == 0:
|
| 947 |
+
return depth_map, None, 0
|
| 948 |
+
|
| 949 |
+
# Get depth values only for wound region
|
| 950 |
+
wound_depths = depth_map[wound_coords]
|
| 951 |
+
|
| 952 |
+
# Find the nearest point (minimum depth value in wound region)
|
| 953 |
+
nearest_depth = np.min(wound_depths)
|
| 954 |
+
nearest_indices = np.where(wound_depths == nearest_depth)
|
| 955 |
+
|
| 956 |
+
# Get coordinates of the nearest point(s)
|
| 957 |
+
nearest_point_coords = (wound_coords[0][nearest_indices[0][0]],
|
| 958 |
+
wound_coords[1][nearest_indices[0][0]])
|
| 959 |
+
|
| 960 |
+
# Create normalized depth map (relative to nearest point)
|
| 961 |
+
normalized_depth = depth_map.copy()
|
| 962 |
+
normalized_depth = normalized_depth - nearest_depth
|
| 963 |
+
|
| 964 |
+
# Ensure all values are non-negative (nearest point = 0, others = positive)
|
| 965 |
+
normalized_depth = np.maximum(normalized_depth, 0)
|
| 966 |
+
|
| 967 |
+
# Calculate maximum relative depth in wound region
|
| 968 |
+
wound_normalized_depths = normalized_depth[wound_coords]
|
| 969 |
+
max_relative_depth = np.max(wound_normalized_depths)
|
| 970 |
+
|
| 971 |
+
return normalized_depth, nearest_point_coords, max_relative_depth
|
| 972 |
+
|
| 973 |
+
def calibrate_depth_map(depth_map, reference_depth_mm=10.0):
|
| 974 |
+
"""
|
| 975 |
+
Calibrate depth map to real-world measurements using reference depth
|
| 976 |
+
This helps convert normalized depth values to actual millimeters
|
| 977 |
+
"""
|
| 978 |
+
if depth_map is None:
|
| 979 |
+
return depth_map
|
| 980 |
+
|
| 981 |
+
# Find the maximum depth value in the depth map
|
| 982 |
+
max_depth_value = np.max(depth_map)
|
| 983 |
+
min_depth_value = np.min(depth_map)
|
| 984 |
+
|
| 985 |
+
if max_depth_value == min_depth_value:
|
| 986 |
+
return depth_map
|
| 987 |
+
|
| 988 |
+
# Apply calibration to convert to millimeters
|
| 989 |
+
# Assuming the maximum depth in the map corresponds to reference_depth_mm
|
| 990 |
+
calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm
|
| 991 |
+
|
| 992 |
+
return calibrated_depth
|
| 993 |
+
|
| 994 |
+
def create_depth_analysis_visualization(depth_map, wound_mask, nearest_point_coords, max_relative_depth):
|
| 995 |
+
"""
|
| 996 |
+
Create a visualization showing the depth analysis with nearest point and deepest point highlighted
|
| 997 |
+
"""
|
| 998 |
+
if depth_map is None or wound_mask is None:
|
| 999 |
+
return None
|
| 1000 |
+
|
| 1001 |
+
# Create a copy of the depth map for visualization
|
| 1002 |
+
vis_depth = depth_map.copy()
|
| 1003 |
+
|
| 1004 |
+
# Apply colormap for better visualization
|
| 1005 |
+
normalized_depth = (vis_depth - np.min(vis_depth)) / (np.max(vis_depth) - np.min(vis_depth))
|
| 1006 |
+
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(normalized_depth)[:, :, :3] * 255).astype(np.uint8)
|
| 1007 |
+
|
| 1008 |
+
# Convert to RGB if grayscale
|
| 1009 |
+
if len(colored_depth.shape) == 3 and colored_depth.shape[2] == 1:
|
| 1010 |
+
colored_depth = cv2.cvtColor(colored_depth, cv2.COLOR_GRAY2RGB)
|
| 1011 |
+
|
| 1012 |
+
# Highlight the nearest point (reference point) with a red circle
|
| 1013 |
+
if nearest_point_coords is not None:
|
| 1014 |
+
y, x = nearest_point_coords
|
| 1015 |
+
cv2.circle(colored_depth, (x, y), 10, (255, 0, 0), 2) # Red circle for nearest point
|
| 1016 |
+
cv2.putText(colored_depth, "REF", (x+15, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
| 1017 |
+
|
| 1018 |
+
# Find and highlight the deepest point
|
| 1019 |
+
binary_mask = (wound_mask > 127).astype(np.uint8)
|
| 1020 |
+
wound_coords = np.where(binary_mask > 0)
|
| 1021 |
+
|
| 1022 |
+
if len(wound_coords[0]) > 0:
|
| 1023 |
+
# Get depth values for wound region
|
| 1024 |
+
wound_depths = vis_depth[wound_coords]
|
| 1025 |
+
max_depth_idx = np.argmax(wound_depths)
|
| 1026 |
+
deepest_point_coords = (wound_coords[0][max_depth_idx], wound_coords[1][max_depth_idx])
|
| 1027 |
+
|
| 1028 |
+
# Highlight the deepest point with a blue circle
|
| 1029 |
+
y, x = deepest_point_coords
|
| 1030 |
+
cv2.circle(colored_depth, (x, y), 12, (0, 0, 255), 3) # Blue circle for deepest point
|
| 1031 |
+
cv2.putText(colored_depth, "DEEP", (x+15, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
|
| 1032 |
+
|
| 1033 |
+
# Overlay wound mask outline
|
| 1034 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 1035 |
+
cv2.drawContours(colored_depth, contours, -1, (0, 255, 0), 2) # Green outline for wound boundary
|
| 1036 |
+
|
| 1037 |
+
return colored_depth
|
| 1038 |
+
|
| 1039 |
+
def get_enhanced_severity_description(severity):
|
| 1040 |
+
"""Get comprehensive medical description for severity level"""
|
| 1041 |
+
descriptions = {
|
| 1042 |
+
"Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.",
|
| 1043 |
+
"Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.",
|
| 1044 |
+
"Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.",
|
| 1045 |
+
"Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.",
|
| 1046 |
+
"Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.",
|
| 1047 |
+
"Unknown": "Unable to determine severity due to insufficient data or poor image quality."
|
| 1048 |
+
}
|
| 1049 |
+
return descriptions.get(severity, "Severity assessment unavailable.")
|
| 1050 |
+
|
| 1051 |
+
def create_sample_wound_mask(image_shape, center=None, radius=50):
|
| 1052 |
+
"""Create a sample circular wound mask for testing"""
|
| 1053 |
+
if center is None:
|
| 1054 |
+
center = (image_shape[1] // 2, image_shape[0] // 2)
|
| 1055 |
+
|
| 1056 |
+
mask = np.zeros(image_shape[:2], dtype=np.uint8)
|
| 1057 |
+
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
|
| 1058 |
+
|
| 1059 |
+
# Create circular mask
|
| 1060 |
+
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
|
| 1061 |
+
mask[dist_from_center <= radius] = 255
|
| 1062 |
+
|
| 1063 |
+
return mask
|
| 1064 |
+
|
| 1065 |
+
def create_realistic_wound_mask(image_shape, method='elliptical'):
|
| 1066 |
+
"""Create a more realistic wound mask with irregular shapes"""
|
| 1067 |
+
h, w = image_shape[:2]
|
| 1068 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 1069 |
+
|
| 1070 |
+
if method == 'elliptical':
|
| 1071 |
+
# Create elliptical wound mask
|
| 1072 |
+
center = (w // 2, h // 2)
|
| 1073 |
+
radius_x = min(w, h) // 3
|
| 1074 |
+
radius_y = min(w, h) // 4
|
| 1075 |
+
|
| 1076 |
+
y, x = np.ogrid[:h, :w]
|
| 1077 |
+
# Add some irregularity to make it more realistic
|
| 1078 |
+
ellipse = ((x - center[0])**2 / (radius_x**2) +
|
| 1079 |
+
(y - center[1])**2 / (radius_y**2)) <= 1
|
| 1080 |
+
|
| 1081 |
+
# Add some noise and irregularity
|
| 1082 |
+
noise = np.random.random((h, w)) > 0.8
|
| 1083 |
+
mask = (ellipse | noise).astype(np.uint8) * 255
|
| 1084 |
+
|
| 1085 |
+
elif method == 'irregular':
|
| 1086 |
+
# Create irregular wound mask
|
| 1087 |
+
center = (w // 2, h // 2)
|
| 1088 |
+
radius = min(w, h) // 4
|
| 1089 |
+
|
| 1090 |
+
y, x = np.ogrid[:h, :w]
|
| 1091 |
+
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
|
| 1092 |
+
|
| 1093 |
+
# Add irregular extensions
|
| 1094 |
+
extensions = np.zeros_like(base_circle)
|
| 1095 |
+
for i in range(3):
|
| 1096 |
+
angle = i * 2 * np.pi / 3
|
| 1097 |
+
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
|
| 1098 |
+
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
|
| 1099 |
+
ext_radius = radius // 3
|
| 1100 |
+
|
| 1101 |
+
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
|
| 1102 |
+
extensions = extensions | ext_circle
|
| 1103 |
+
|
| 1104 |
+
mask = (base_circle | extensions).astype(np.uint8) * 255
|
| 1105 |
+
|
| 1106 |
+
# Apply morphological operations to smooth the mask
|
| 1107 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 1108 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 1109 |
+
|
| 1110 |
+
return mask
|
| 1111 |
+
|
| 1112 |
+
# --- Depth Estimation Functions ---
|
| 1113 |
+
|
| 1114 |
+
def predict_depth(image):
|
| 1115 |
+
return depth_model.infer_image(image)
|
| 1116 |
+
|
| 1117 |
+
def calculate_max_points(image):
|
| 1118 |
+
"""Calculate maximum points based on image dimensions (3x pixel count)"""
|
| 1119 |
+
if image is None:
|
| 1120 |
+
return 10000 # Default value
|
| 1121 |
+
h, w = image.shape[:2]
|
| 1122 |
+
max_points = h * w * 3
|
| 1123 |
+
# Ensure minimum and reasonable maximum values
|
| 1124 |
+
return max(1000, min(max_points, 300000))
|
| 1125 |
+
|
| 1126 |
+
def update_slider_on_image_upload(image):
|
| 1127 |
+
"""Update the points slider when an image is uploaded"""
|
| 1128 |
+
max_points = calculate_max_points(image)
|
| 1129 |
+
default_value = min(10000, max_points // 10) # 10% of max points as default
|
| 1130 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
| 1131 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
|
| 1135 |
+
"""Create a point cloud from depth map using camera intrinsics with high detail"""
|
| 1136 |
+
h, w = depth_map.shape
|
| 1137 |
+
|
| 1138 |
+
# Use smaller step for higher detail (reduced downsampling)
|
| 1139 |
+
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
|
| 1140 |
+
|
| 1141 |
+
# Create mesh grid for camera coordinates
|
| 1142 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 1143 |
+
|
| 1144 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 1145 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
| 1146 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
| 1147 |
+
|
| 1148 |
+
# Get depth values
|
| 1149 |
+
depth_values = depth_map[::step, ::step]
|
| 1150 |
+
|
| 1151 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 1152 |
+
x_3d = x_cam * depth_values
|
| 1153 |
+
y_3d = y_cam * depth_values
|
| 1154 |
+
z_3d = depth_values
|
| 1155 |
+
|
| 1156 |
+
# Flatten arrays
|
| 1157 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
| 1158 |
+
|
| 1159 |
+
# Get corresponding image colors
|
| 1160 |
+
image_colors = image[::step, ::step, :]
|
| 1161 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
| 1162 |
+
|
| 1163 |
+
# Create Open3D point cloud
|
| 1164 |
+
pcd = o3d.geometry.PointCloud()
|
| 1165 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 1166 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 1167 |
+
|
| 1168 |
+
return pcd
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
def reconstruct_surface_mesh_from_point_cloud(pcd):
|
| 1172 |
+
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
|
| 1173 |
+
# Estimate and orient normals with high precision
|
| 1174 |
+
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
|
| 1175 |
+
pcd.orient_normals_consistent_tangent_plane(k=50)
|
| 1176 |
+
|
| 1177 |
+
# Create surface mesh with maximum detail (depth=12 for very high resolution)
|
| 1178 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
|
| 1179 |
+
|
| 1180 |
+
# Return mesh without filtering low-density vertices
|
| 1181 |
+
return mesh
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
| 1185 |
+
"""Create an enhanced 3D visualization using proper camera projection"""
|
| 1186 |
+
h, w = depth_map.shape
|
| 1187 |
+
|
| 1188 |
+
# Downsample to avoid too many points for performance
|
| 1189 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
| 1190 |
+
|
| 1191 |
+
# Create mesh grid for camera coordinates
|
| 1192 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 1193 |
+
|
| 1194 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 1195 |
+
focal_length = 470.4 # Default focal length
|
| 1196 |
+
x_cam = (x_coords - w / 2) / focal_length
|
| 1197 |
+
y_cam = (y_coords - h / 2) / focal_length
|
| 1198 |
+
|
| 1199 |
+
# Get depth values
|
| 1200 |
+
depth_values = depth_map[::step, ::step]
|
| 1201 |
+
|
| 1202 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 1203 |
+
x_3d = x_cam * depth_values
|
| 1204 |
+
y_3d = y_cam * depth_values
|
| 1205 |
+
z_3d = depth_values
|
| 1206 |
+
|
| 1207 |
+
# Flatten arrays
|
| 1208 |
+
x_flat = x_3d.flatten()
|
| 1209 |
+
y_flat = y_3d.flatten()
|
| 1210 |
+
z_flat = z_3d.flatten()
|
| 1211 |
+
|
| 1212 |
+
# Get corresponding image colors
|
| 1213 |
+
image_colors = image[::step, ::step, :]
|
| 1214 |
+
colors_flat = image_colors.reshape(-1, 3)
|
| 1215 |
+
|
| 1216 |
+
# Create 3D scatter plot with proper camera projection
|
| 1217 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 1218 |
+
x=x_flat,
|
| 1219 |
+
y=y_flat,
|
| 1220 |
+
z=z_flat,
|
| 1221 |
+
mode='markers',
|
| 1222 |
+
marker=dict(
|
| 1223 |
+
size=1.5,
|
| 1224 |
+
color=colors_flat,
|
| 1225 |
+
opacity=0.9
|
| 1226 |
+
),
|
| 1227 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
| 1228 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
| 1229 |
+
'<extra></extra>'
|
| 1230 |
+
)])
|
| 1231 |
+
|
| 1232 |
+
fig.update_layout(
|
| 1233 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
| 1234 |
+
scene=dict(
|
| 1235 |
+
xaxis_title="X (meters)",
|
| 1236 |
+
yaxis_title="Y (meters)",
|
| 1237 |
+
zaxis_title="Z (meters)",
|
| 1238 |
+
camera=dict(
|
| 1239 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
| 1240 |
+
center=dict(x=0, y=0, z=0),
|
| 1241 |
+
up=dict(x=0, y=0, z=1)
|
| 1242 |
+
),
|
| 1243 |
+
aspectmode='data'
|
| 1244 |
+
),
|
| 1245 |
+
width=700,
|
| 1246 |
+
height=600
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
return fig
|
| 1250 |
+
|
| 1251 |
+
def on_depth_submit(image, num_points, focal_x, focal_y):
|
| 1252 |
+
original_image = image.copy()
|
| 1253 |
+
|
| 1254 |
+
h, w = image.shape[:2]
|
| 1255 |
+
|
| 1256 |
+
# Predict depth using the model
|
| 1257 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 1258 |
+
|
| 1259 |
+
# Save raw 16-bit depth
|
| 1260 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
| 1261 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 1262 |
+
raw_depth.save(tmp_raw_depth.name)
|
| 1263 |
+
|
| 1264 |
+
# Normalize and convert to grayscale for display
|
| 1265 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 1266 |
+
norm_depth = norm_depth.astype(np.uint8)
|
| 1267 |
+
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
|
| 1268 |
+
|
| 1269 |
+
gray_depth = Image.fromarray(norm_depth)
|
| 1270 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 1271 |
+
gray_depth.save(tmp_gray_depth.name)
|
| 1272 |
+
|
| 1273 |
+
# Create point cloud
|
| 1274 |
+
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
|
| 1275 |
+
|
| 1276 |
+
# Reconstruct mesh from point cloud
|
| 1277 |
+
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
|
| 1278 |
+
|
| 1279 |
+
# Save mesh with faces as .ply
|
| 1280 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
| 1281 |
+
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
|
| 1282 |
+
|
| 1283 |
+
# Create enhanced 3D scatter plot visualization
|
| 1284 |
+
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
|
| 1285 |
+
|
| 1286 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
| 1287 |
+
|
| 1288 |
+
# --- Actual Wound Segmentation Functions ---
|
| 1289 |
+
def create_automatic_wound_mask(image, method='deep_learning'):
|
| 1290 |
+
"""
|
| 1291 |
+
Automatically generate wound mask from image using the actual deep learning model
|
| 1292 |
+
|
| 1293 |
+
Args:
|
| 1294 |
+
image: Input image (numpy array)
|
| 1295 |
+
method: Segmentation method (currently only 'deep_learning' supported)
|
| 1296 |
+
|
| 1297 |
+
Returns:
|
| 1298 |
+
mask: Binary wound mask
|
| 1299 |
+
"""
|
| 1300 |
+
if image is None:
|
| 1301 |
+
return None
|
| 1302 |
+
|
| 1303 |
+
# Use the actual deep learning model for segmentation
|
| 1304 |
+
if method == 'deep_learning':
|
| 1305 |
+
mask, _ = segmentation_model.segment_wound(image)
|
| 1306 |
+
return mask
|
| 1307 |
+
else:
|
| 1308 |
+
# Fallback to deep learning if method not recognized
|
| 1309 |
+
mask, _ = segmentation_model.segment_wound(image)
|
| 1310 |
+
return mask
|
| 1311 |
+
|
| 1312 |
+
def post_process_wound_mask(mask, min_area=100):
|
| 1313 |
+
"""Post-process the wound mask to remove noise and small objects"""
|
| 1314 |
+
if mask is None:
|
| 1315 |
+
return None
|
| 1316 |
+
|
| 1317 |
+
# Convert to binary if needed
|
| 1318 |
+
if mask.dtype != np.uint8:
|
| 1319 |
+
mask = mask.astype(np.uint8)
|
| 1320 |
+
|
| 1321 |
+
# Apply morphological operations to clean up
|
| 1322 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 1323 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 1324 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 1325 |
+
|
| 1326 |
+
# Remove small objects using OpenCV
|
| 1327 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 1328 |
+
mask_clean = np.zeros_like(mask)
|
| 1329 |
+
|
| 1330 |
+
for contour in contours:
|
| 1331 |
+
area = cv2.contourArea(contour)
|
| 1332 |
+
if area >= min_area:
|
| 1333 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 1334 |
+
|
| 1335 |
+
# Fill holes
|
| 1336 |
+
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
|
| 1337 |
+
|
| 1338 |
+
return mask_clean
|
| 1339 |
+
|
| 1340 |
+
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
|
| 1341 |
+
"""Analyze wound severity with automatic mask generation using actual segmentation model"""
|
| 1342 |
+
if image is None or depth_map is None:
|
| 1343 |
+
return "β Please provide both image and depth map."
|
| 1344 |
+
|
| 1345 |
+
# Generate automatic wound mask using the actual model
|
| 1346 |
+
auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
|
| 1347 |
+
|
| 1348 |
+
if auto_mask is None:
|
| 1349 |
+
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
|
| 1350 |
+
|
| 1351 |
+
# Post-process the mask
|
| 1352 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1353 |
+
|
| 1354 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 1355 |
+
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
|
| 1356 |
+
|
| 1357 |
+
# Analyze severity using the automatic mask
|
| 1358 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
|
| 1359 |
+
|
| 1360 |
+
# --- Main Gradio Interface ---
|
| 1361 |
+
with gr.Blocks(css=css, title="Wound Analysis System") as demo:
|
| 1362 |
+
gr.HTML("<h1>Wound Analysis System</h1>")
|
| 1363 |
+
#gr.Markdown("### Complete workflow: Classification β Depth Estimation β Wound Severity Analysis")
|
| 1364 |
+
|
| 1365 |
+
# Shared states
|
| 1366 |
+
shared_image = gr.State()
|
| 1367 |
+
shared_depth_map = gr.State()
|
| 1368 |
+
|
| 1369 |
+
with gr.Tabs():
|
| 1370 |
+
|
| 1371 |
+
# Tab 1: Wound Classification
|
| 1372 |
+
with gr.Tab("1. π Wound Classification & Initial Analysis"):
|
| 1373 |
+
gr.Markdown("### Step 1: Classify wound type and get initial AI analysis")
|
| 1374 |
+
#gr.Markdown("Upload an image to identify the wound type and receive detailed analysis from our Vision AI.")
|
| 1375 |
+
|
| 1376 |
+
|
| 1377 |
+
with gr.Row():
|
| 1378 |
+
# Left Column - Image Upload
|
| 1379 |
+
with gr.Column(scale=1):
|
| 1380 |
+
gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 0; font-weight: bold; font-size: 1.8rem;">Upload Wound Image</h2>')
|
| 1381 |
+
classification_image_input = gr.Image(
|
| 1382 |
+
label="",
|
| 1383 |
+
type="pil",
|
| 1384 |
+
height=400
|
| 1385 |
+
)
|
| 1386 |
+
# Place Clear and Analyse buttons side by side
|
| 1387 |
+
with gr.Row():
|
| 1388 |
+
classify_clear_btn = gr.Button(
|
| 1389 |
+
"Clear",
|
| 1390 |
+
variant="secondary",
|
| 1391 |
+
size="lg",
|
| 1392 |
+
scale=1
|
| 1393 |
+
)
|
| 1394 |
+
analyse_btn = gr.Button(
|
| 1395 |
+
"Analyse",
|
| 1396 |
+
variant="primary",
|
| 1397 |
+
size="lg",
|
| 1398 |
+
scale=1
|
| 1399 |
+
)
|
| 1400 |
+
# Right Column - Classification Results
|
| 1401 |
+
with gr.Column(scale=1):
|
| 1402 |
+
gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 0; font-weight: bold; font-size: 1.8rem;">Classification Results</h2>')
|
| 1403 |
+
classification_output = gr.Label(
|
| 1404 |
+
label="",
|
| 1405 |
+
num_top_classes=5,
|
| 1406 |
+
show_label=False
|
| 1407 |
+
)
|
| 1408 |
+
|
| 1409 |
+
# Second Row - Full Width AI Analysis
|
| 1410 |
+
with gr.Row():
|
| 1411 |
+
with gr.Column(scale=1):
|
| 1412 |
+
gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 2rem; margin-bottom: 1rem; font-weight: bold; font-size: 1.8rem;">Wound Visual Analysis</h2>')
|
| 1413 |
+
gemini_output = gr.HTML(
|
| 1414 |
+
value="""
|
| 1415 |
+
<div style="
|
| 1416 |
+
border-radius: 12px;
|
| 1417 |
+
padding: 20px;
|
| 1418 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 1419 |
+
font-family: Arial, sans-serif;
|
| 1420 |
+
min-height: 200px;
|
| 1421 |
+
display: flex;
|
| 1422 |
+
align-items: center;
|
| 1423 |
+
justify-content: center;
|
| 1424 |
+
color: white;
|
| 1425 |
+
width: 100%;
|
| 1426 |
+
border-left: 4px solid #d97706;
|
| 1427 |
+
font-weight: bold;
|
| 1428 |
+
">
|
| 1429 |
+
Upload an image to get AI-powered wound analysis
|
| 1430 |
+
</div>
|
| 1431 |
+
"""
|
| 1432 |
+
)
|
| 1433 |
+
|
| 1434 |
+
# Event handlers for classification tab
|
| 1435 |
+
classify_clear_btn.click(
|
| 1436 |
+
fn=lambda: (None, None, """
|
| 1437 |
+
<div style="
|
| 1438 |
+
border-radius: 12px;
|
| 1439 |
+
padding: 20px;
|
| 1440 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 1441 |
+
font-family: Arial, sans-serif;
|
| 1442 |
+
min-height: 200px;
|
| 1443 |
+
display: flex;
|
| 1444 |
+
align-items: center;
|
| 1445 |
+
justify-content: center;
|
| 1446 |
+
color: white;
|
| 1447 |
+
width: 100%;
|
| 1448 |
+
border-left: 4px solid #d97706;
|
| 1449 |
+
font-weight: bold;
|
| 1450 |
+
">
|
| 1451 |
+
Upload an image to get AI-powered wound analysis
|
| 1452 |
+
</div>
|
| 1453 |
+
"""),
|
| 1454 |
+
inputs=None,
|
| 1455 |
+
outputs=[classification_image_input, classification_output, gemini_output]
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
# Only run classification on image upload
|
| 1459 |
+
def classify_and_store(image):
|
| 1460 |
+
result = classify_wound(image)
|
| 1461 |
+
return result
|
| 1462 |
+
|
| 1463 |
+
classification_image_input.change(
|
| 1464 |
+
fn=classify_and_store,
|
| 1465 |
+
inputs=classification_image_input,
|
| 1466 |
+
outputs=classification_output
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
# Store image in shared state for next tabs
|
| 1470 |
+
def store_shared_image(image):
|
| 1471 |
+
return image
|
| 1472 |
+
|
| 1473 |
+
classification_image_input.change(
|
| 1474 |
+
fn=store_shared_image,
|
| 1475 |
+
inputs=classification_image_input,
|
| 1476 |
+
outputs=shared_image
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
# Run Gemini analysis only when Analyse button is clicked
|
| 1480 |
+
def run_gemini_on_click(image, classification):
|
| 1481 |
+
# Get top label
|
| 1482 |
+
if isinstance(classification, dict) and classification:
|
| 1483 |
+
top_label = max(classification.items(), key=lambda x: x[1])[0]
|
| 1484 |
+
else:
|
| 1485 |
+
top_label = "Unknown"
|
| 1486 |
+
gemini_analysis = analyze_wound_with_gemini(image, top_label)
|
| 1487 |
+
formatted_analysis = format_gemini_analysis(gemini_analysis)
|
| 1488 |
+
return formatted_analysis
|
| 1489 |
+
|
| 1490 |
+
analyse_btn.click(
|
| 1491 |
+
fn=run_gemini_on_click,
|
| 1492 |
+
inputs=[classification_image_input, classification_output],
|
| 1493 |
+
outputs=gemini_output
|
| 1494 |
+
)
|
| 1495 |
+
|
| 1496 |
+
# Tab 2: Depth Estimation
|
| 1497 |
+
with gr.Tab("2. π Depth Estimation & 3D Visualization"):
|
| 1498 |
+
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
|
| 1499 |
+
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
|
| 1500 |
+
|
| 1501 |
+
with gr.Row():
|
| 1502 |
+
load_from_classification_btn = gr.Button("π Load Image from Classification Tab", variant="secondary")
|
| 1503 |
+
|
| 1504 |
+
with gr.Row():
|
| 1505 |
+
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
| 1506 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
| 1507 |
+
|
| 1508 |
+
with gr.Row():
|
| 1509 |
+
depth_submit = gr.Button(value="Compute Depth", variant="primary")
|
| 1510 |
+
|
| 1511 |
+
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
| 1512 |
+
label="Number of 3D points (upload image to update max)")
|
| 1513 |
+
|
| 1514 |
+
with gr.Row():
|
| 1515 |
+
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 1516 |
+
label="Focal Length X (pixels)")
|
| 1517 |
+
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 1518 |
+
label="Focal Length Y (pixels)")
|
| 1519 |
+
|
| 1520 |
+
# Reorganized layout: 2 columns - 3D visualization on left, file outputs stacked on right
|
| 1521 |
+
with gr.Row():
|
| 1522 |
+
with gr.Column(scale=2):
|
| 1523 |
+
# 3D Visualization
|
| 1524 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
| 1525 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
| 1526 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
| 1527 |
+
|
| 1528 |
+
with gr.Column(scale=1):
|
| 1529 |
+
gr.Markdown("### Download Files")
|
| 1530 |
+
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
| 1531 |
+
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
| 1532 |
+
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
|
| 1536 |
+
# Tab 3: Wound Severity Analysis
|
| 1537 |
+
with gr.Tab("3. π©Ή Wound Severity Analysis"):
|
| 1538 |
+
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
|
| 1539 |
+
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
|
| 1540 |
+
|
| 1541 |
+
with gr.Row():
|
| 1542 |
+
# Load depth map from previous tab
|
| 1543 |
+
load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary")
|
| 1544 |
+
|
| 1545 |
+
with gr.Row():
|
| 1546 |
+
severity_input_image = gr.Image(label="Original Image", type='numpy')
|
| 1547 |
+
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
|
| 1548 |
+
|
| 1549 |
+
with gr.Row():
|
| 1550 |
+
wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
|
| 1551 |
+
|
| 1552 |
+
with gr.Row():
|
| 1553 |
+
severity_output = gr.HTML(
|
| 1554 |
+
label="π€ AI-Powered Medical Assessment",
|
| 1555 |
+
value="""
|
| 1556 |
+
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
|
| 1557 |
+
<div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'>
|
| 1558 |
+
π©Ή Wound Severity Analysis
|
| 1559 |
+
</div>
|
| 1560 |
+
<div style='font-size: 18px; color: #cccccc; margin-bottom: 20px;'>
|
| 1561 |
+
β³ Waiting for Input...
|
| 1562 |
+
</div>
|
| 1563 |
+
<div style='color: #888888; font-size: 14px;'>
|
| 1564 |
+
Please upload an image and depth map, then click "π€ Analyze Severity with Auto-Generated Mask" to begin AI-powered medical assessment.
|
| 1565 |
+
</div>
|
| 1566 |
+
</div>
|
| 1567 |
+
"""
|
| 1568 |
+
)
|
| 1569 |
+
|
| 1570 |
+
gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
|
| 1571 |
+
|
| 1572 |
+
with gr.Row():
|
| 1573 |
+
auto_severity_button = gr.Button("π€ Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
|
| 1574 |
+
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
|
| 1575 |
+
label="Pixel Spacing (mm/pixel)")
|
| 1576 |
+
depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0,
|
| 1577 |
+
label="Depth Calibration (mm)",
|
| 1578 |
+
info="Adjust based on expected maximum wound depth")
|
| 1579 |
+
|
| 1580 |
+
#gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
|
| 1581 |
+
#gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.")
|
| 1582 |
+
|
| 1583 |
+
#gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
|
| 1584 |
+
|
| 1585 |
+
# Update slider when image is uploaded
|
| 1586 |
+
depth_input_image.change(
|
| 1587 |
+
fn=update_slider_on_image_upload,
|
| 1588 |
+
inputs=[depth_input_image],
|
| 1589 |
+
outputs=[points_slider]
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
# Modified depth submit function to store depth map
|
| 1593 |
+
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
|
| 1594 |
+
results = on_depth_submit(image, num_points, focal_x, focal_y)
|
| 1595 |
+
# Extract depth map from results for severity analysis
|
| 1596 |
+
depth_map = None
|
| 1597 |
+
if image is not None:
|
| 1598 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 1599 |
+
# Normalize depth for severity analysis
|
| 1600 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 1601 |
+
depth_map = norm_depth.astype(np.uint8)
|
| 1602 |
+
return results + [depth_map]
|
| 1603 |
+
|
| 1604 |
+
depth_submit.click(on_depth_submit_with_state,
|
| 1605 |
+
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
|
| 1606 |
+
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, shared_depth_map])
|
| 1607 |
+
|
| 1608 |
+
# Function to load image from classification to depth tab
|
| 1609 |
+
def load_image_from_classification(shared_img):
|
| 1610 |
+
if shared_img is None:
|
| 1611 |
+
return None, "β No image available from classification tab. Please upload an image in Tab 1 first."
|
| 1612 |
+
|
| 1613 |
+
# Convert PIL image to numpy array for depth estimation
|
| 1614 |
+
if hasattr(shared_img, 'convert'):
|
| 1615 |
+
# It's a PIL image, convert to numpy
|
| 1616 |
+
img_array = np.array(shared_img)
|
| 1617 |
+
return img_array, "β
Image loaded from classification tab successfully!"
|
| 1618 |
+
else:
|
| 1619 |
+
# Already numpy array
|
| 1620 |
+
return shared_img, "β
Image loaded from classification tab successfully!"
|
| 1621 |
+
|
| 1622 |
+
# Connect the load button
|
| 1623 |
+
load_from_classification_btn.click(
|
| 1624 |
+
fn=load_image_from_classification,
|
| 1625 |
+
inputs=shared_image,
|
| 1626 |
+
outputs=[depth_input_image, gr.HTML()]
|
| 1627 |
+
)
|
| 1628 |
+
|
| 1629 |
+
# Load depth map to severity tab and auto-generate mask
|
| 1630 |
+
def load_depth_to_severity(depth_map, original_image):
|
| 1631 |
+
if depth_map is None:
|
| 1632 |
+
return None, None, None, "β No depth map available. Please compute depth in Tab 2 first."
|
| 1633 |
+
|
| 1634 |
+
# Auto-generate wound mask using segmentation model
|
| 1635 |
+
if original_image is not None:
|
| 1636 |
+
auto_mask, _ = segmentation_model.segment_wound(original_image)
|
| 1637 |
+
if auto_mask is not None:
|
| 1638 |
+
# Post-process the mask
|
| 1639 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1640 |
+
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
|
| 1641 |
+
return depth_map, original_image, processed_mask, "β
Depth map loaded and wound mask auto-generated!"
|
| 1642 |
+
else:
|
| 1643 |
+
return depth_map, original_image, None, "β
Depth map loaded but no wound detected. Try uploading a different image."
|
| 1644 |
+
else:
|
| 1645 |
+
return depth_map, original_image, None, "β
Depth map loaded but segmentation failed. Try uploading a different image."
|
| 1646 |
+
else:
|
| 1647 |
+
return depth_map, original_image, None, "β
Depth map loaded successfully!"
|
| 1648 |
+
|
| 1649 |
+
load_depth_btn.click(
|
| 1650 |
+
fn=load_depth_to_severity,
|
| 1651 |
+
inputs=[shared_depth_map, depth_input_image],
|
| 1652 |
+
outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
|
| 1653 |
+
)
|
| 1654 |
+
|
| 1655 |
+
# Loading state function
|
| 1656 |
+
def show_loading_state():
|
| 1657 |
+
return """
|
| 1658 |
+
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
|
| 1659 |
+
<div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'>
|
| 1660 |
+
π©Ή Wound Severity Analysis
|
| 1661 |
+
</div>
|
| 1662 |
+
<div style='font-size: 18px; color: #4CAF50; margin-bottom: 20px;'>
|
| 1663 |
+
π AI Analysis in Progress...
|
| 1664 |
+
</div>
|
| 1665 |
+
<div style='color: #cccccc; font-size: 14px; margin-bottom: 15px;'>
|
| 1666 |
+
β’ Generating wound mask with deep learning model<br>
|
| 1667 |
+
β’ Computing depth measurements and statistics<br>
|
| 1668 |
+
β’ Analyzing wound characteristics with Gemini AI<br>
|
| 1669 |
+
β’ Preparing comprehensive medical assessment
|
| 1670 |
+
</div>
|
| 1671 |
+
<div style='display: inline-block; width: 30px; height: 30px; border: 3px solid #f3f3f3; border-top: 3px solid #4CAF50; border-radius: 50%; animation: spin 1s linear infinite;'></div>
|
| 1672 |
+
<style>
|
| 1673 |
+
@keyframes spin {
|
| 1674 |
+
0% { transform: rotate(0deg); }
|
| 1675 |
+
100% { transform: rotate(360deg); }
|
| 1676 |
+
}
|
| 1677 |
+
</style>
|
| 1678 |
+
</div>
|
| 1679 |
+
"""
|
| 1680 |
+
|
| 1681 |
+
# Automatic severity analysis function
|
| 1682 |
+
def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration):
|
| 1683 |
+
if depth_map is None:
|
| 1684 |
+
return """
|
| 1685 |
+
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
|
| 1686 |
+
<div style='font-size: 24px; font-weight: bold; color: #f44336; margin-bottom: 15px;'>
|
| 1687 |
+
β Error
|
| 1688 |
+
</div>
|
| 1689 |
+
<div style='font-size: 16px; color: #cccccc;'>
|
| 1690 |
+
Please load depth map from Tab 1 first.
|
| 1691 |
+
</div>
|
| 1692 |
+
</div>
|
| 1693 |
+
"""
|
| 1694 |
+
|
| 1695 |
+
# Generate automatic wound mask using the actual model
|
| 1696 |
+
auto_mask = create_automatic_wound_mask(image, method='deep_learning')
|
| 1697 |
+
|
| 1698 |
+
if auto_mask is None:
|
| 1699 |
+
return """
|
| 1700 |
+
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
|
| 1701 |
+
<div style='font-size: 24px; font-weight: bold; color: #f44336; margin-bottom: 15px;'>
|
| 1702 |
+
β Error
|
| 1703 |
+
</div>
|
| 1704 |
+
<div style='font-size: 16px; color: #cccccc;'>
|
| 1705 |
+
Failed to generate automatic wound mask. Please check if the segmentation model is loaded.
|
| 1706 |
+
</div>
|
| 1707 |
+
</div>
|
| 1708 |
+
"""
|
| 1709 |
+
|
| 1710 |
+
# Post-process the mask with fixed minimum area
|
| 1711 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1712 |
+
|
| 1713 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 1714 |
+
return """
|
| 1715 |
+
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
|
| 1716 |
+
<div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'>
|
| 1717 |
+
β οΈ No Wound Detected
|
| 1718 |
+
</div>
|
| 1719 |
+
<div style='font-size: 16px; color: #cccccc;'>
|
| 1720 |
+
No wound region detected by the segmentation model. Try uploading a different image or use manual mask.
|
| 1721 |
+
</div>
|
| 1722 |
+
</div>
|
| 1723 |
+
"""
|
| 1724 |
+
|
| 1725 |
+
# Analyze severity using the automatic mask
|
| 1726 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration)
|
| 1727 |
+
|
| 1728 |
+
# Connect event handler with loading state
|
| 1729 |
+
auto_severity_button.click(
|
| 1730 |
+
fn=show_loading_state,
|
| 1731 |
+
inputs=[],
|
| 1732 |
+
outputs=[severity_output]
|
| 1733 |
+
).then(
|
| 1734 |
+
fn=run_auto_severity_analysis,
|
| 1735 |
+
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider],
|
| 1736 |
+
outputs=[severity_output]
|
| 1737 |
+
)
|
| 1738 |
+
|
| 1739 |
+
|
| 1740 |
+
|
| 1741 |
+
# Auto-generate mask when image is uploaded
|
| 1742 |
+
def auto_generate_mask_on_image_upload(image):
|
| 1743 |
+
if image is None:
|
| 1744 |
+
return None, "β No image uploaded."
|
| 1745 |
+
|
| 1746 |
+
# Generate automatic wound mask using segmentation model
|
| 1747 |
+
auto_mask, _ = segmentation_model.segment_wound(image)
|
| 1748 |
+
if auto_mask is not None:
|
| 1749 |
+
# Post-process the mask
|
| 1750 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1751 |
+
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
|
| 1752 |
+
return processed_mask, "β
Wound mask auto-generated using deep learning model!"
|
| 1753 |
+
else:
|
| 1754 |
+
return None, "β
Image uploaded but no wound detected. Try uploading a different image."
|
| 1755 |
+
else:
|
| 1756 |
+
return None, "β
Image uploaded but segmentation failed. Try uploading a different image."
|
| 1757 |
+
|
| 1758 |
+
# Load shared image from classification tab
|
| 1759 |
+
def load_shared_image(shared_img):
|
| 1760 |
+
if shared_img is None:
|
| 1761 |
+
return gr.Image(), "β No image available from classification tab"
|
| 1762 |
+
|
| 1763 |
+
# Convert PIL image to numpy array for depth estimation
|
| 1764 |
+
if hasattr(shared_img, 'convert'):
|
| 1765 |
+
# It's a PIL image, convert to numpy
|
| 1766 |
+
img_array = np.array(shared_img)
|
| 1767 |
+
return img_array, "β
Image loaded from classification tab"
|
| 1768 |
+
else:
|
| 1769 |
+
# Already numpy array
|
| 1770 |
+
return shared_img, "β
Image loaded from classification tab"
|
| 1771 |
+
|
| 1772 |
+
# Auto-generate mask when image is uploaded to severity tab
|
| 1773 |
+
severity_input_image.change(
|
| 1774 |
+
fn=auto_generate_mask_on_image_upload,
|
| 1775 |
+
inputs=[severity_input_image],
|
| 1776 |
+
outputs=[wound_mask_input, gr.HTML()]
|
| 1777 |
+
)
|
| 1778 |
+
|
| 1779 |
+
|
| 1780 |
+
|
| 1781 |
+
if __name__ == '__main__':
|
| 1782 |
+
demo.queue().launch(
|
| 1783 |
+
server_name="0.0.0.0",
|
| 1784 |
+
server_port=7860,
|
| 1785 |
+
share=True
|
| 1786 |
+
)
|
depth_anything_v2/__pycache__/dinov2.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
depth_anything_v2/__pycache__/dpt.cpython-310.pyc
ADDED
|
Binary file (5.97 kB). View file
|
|
|
depth_anything_v2/dinov2.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 27 |
+
if not depth_first and include_root:
|
| 28 |
+
fn(module=module, name=name)
|
| 29 |
+
for child_name, child_module in module.named_children():
|
| 30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 32 |
+
if depth_first and include_root:
|
| 33 |
+
fn(module=module, name=name)
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BlockChunk(nn.ModuleList):
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
for b in self:
|
| 40 |
+
x = b(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DinoVisionTransformer(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
img_size=224,
|
| 48 |
+
patch_size=16,
|
| 49 |
+
in_chans=3,
|
| 50 |
+
embed_dim=768,
|
| 51 |
+
depth=12,
|
| 52 |
+
num_heads=12,
|
| 53 |
+
mlp_ratio=4.0,
|
| 54 |
+
qkv_bias=True,
|
| 55 |
+
ffn_bias=True,
|
| 56 |
+
proj_bias=True,
|
| 57 |
+
drop_path_rate=0.0,
|
| 58 |
+
drop_path_uniform=False,
|
| 59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 60 |
+
embed_layer=PatchEmbed,
|
| 61 |
+
act_layer=nn.GELU,
|
| 62 |
+
block_fn=Block,
|
| 63 |
+
ffn_layer="mlp",
|
| 64 |
+
block_chunks=1,
|
| 65 |
+
num_register_tokens=0,
|
| 66 |
+
interpolate_antialias=False,
|
| 67 |
+
interpolate_offset=0.1,
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
img_size (int, tuple): input image size
|
| 72 |
+
patch_size (int, tuple): patch size
|
| 73 |
+
in_chans (int): number of input channels
|
| 74 |
+
embed_dim (int): embedding dimension
|
| 75 |
+
depth (int): depth of transformer
|
| 76 |
+
num_heads (int): number of attention heads
|
| 77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 78 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 80 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 81 |
+
drop_path_rate (float): stochastic depth rate
|
| 82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 83 |
+
weight_init (str): weight init scheme
|
| 84 |
+
init_values (float): layer-scale init values
|
| 85 |
+
embed_layer (nn.Module): patch embedding layer
|
| 86 |
+
act_layer (nn.Module): MLP activation layer
|
| 87 |
+
block_fn (nn.Module): transformer block class
|
| 88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 96 |
+
|
| 97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 98 |
+
self.num_tokens = 1
|
| 99 |
+
self.n_blocks = depth
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.num_register_tokens = num_register_tokens
|
| 103 |
+
self.interpolate_antialias = interpolate_antialias
|
| 104 |
+
self.interpolate_offset = interpolate_offset
|
| 105 |
+
|
| 106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 107 |
+
num_patches = self.patch_embed.num_patches
|
| 108 |
+
|
| 109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 111 |
+
assert num_register_tokens >= 0
|
| 112 |
+
self.register_tokens = (
|
| 113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if drop_path_uniform is True:
|
| 117 |
+
dpr = [drop_path_rate] * depth
|
| 118 |
+
else:
|
| 119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 120 |
+
|
| 121 |
+
if ffn_layer == "mlp":
|
| 122 |
+
logger.info("using MLP layer as FFN")
|
| 123 |
+
ffn_layer = Mlp
|
| 124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 125 |
+
logger.info("using SwiGLU layer as FFN")
|
| 126 |
+
ffn_layer = SwiGLUFFNFused
|
| 127 |
+
elif ffn_layer == "identity":
|
| 128 |
+
logger.info("using Identity layer as FFN")
|
| 129 |
+
|
| 130 |
+
def f(*args, **kwargs):
|
| 131 |
+
return nn.Identity()
|
| 132 |
+
|
| 133 |
+
ffn_layer = f
|
| 134 |
+
else:
|
| 135 |
+
raise NotImplementedError
|
| 136 |
+
|
| 137 |
+
blocks_list = [
|
| 138 |
+
block_fn(
|
| 139 |
+
dim=embed_dim,
|
| 140 |
+
num_heads=num_heads,
|
| 141 |
+
mlp_ratio=mlp_ratio,
|
| 142 |
+
qkv_bias=qkv_bias,
|
| 143 |
+
proj_bias=proj_bias,
|
| 144 |
+
ffn_bias=ffn_bias,
|
| 145 |
+
drop_path=dpr[i],
|
| 146 |
+
norm_layer=norm_layer,
|
| 147 |
+
act_layer=act_layer,
|
| 148 |
+
ffn_layer=ffn_layer,
|
| 149 |
+
init_values=init_values,
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)
|
| 152 |
+
]
|
| 153 |
+
if block_chunks > 0:
|
| 154 |
+
self.chunked_blocks = True
|
| 155 |
+
chunked_blocks = []
|
| 156 |
+
chunksize = depth // block_chunks
|
| 157 |
+
for i in range(0, depth, chunksize):
|
| 158 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 161 |
+
else:
|
| 162 |
+
self.chunked_blocks = False
|
| 163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 164 |
+
|
| 165 |
+
self.norm = norm_layer(embed_dim)
|
| 166 |
+
self.head = nn.Identity()
|
| 167 |
+
|
| 168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 169 |
+
|
| 170 |
+
self.init_weights()
|
| 171 |
+
|
| 172 |
+
def init_weights(self):
|
| 173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 175 |
+
if self.register_tokens is not None:
|
| 176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 177 |
+
named_apply(init_weights_vit_timm, self)
|
| 178 |
+
|
| 179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 180 |
+
previous_dtype = x.dtype
|
| 181 |
+
npatch = x.shape[1] - 1
|
| 182 |
+
N = self.pos_embed.shape[1] - 1
|
| 183 |
+
if npatch == N and w == h:
|
| 184 |
+
return self.pos_embed
|
| 185 |
+
pos_embed = self.pos_embed.float()
|
| 186 |
+
class_pos_embed = pos_embed[:, 0]
|
| 187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 188 |
+
dim = x.shape[-1]
|
| 189 |
+
w0 = w // self.patch_size
|
| 190 |
+
h0 = h // self.patch_size
|
| 191 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
| 194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
| 195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
| 196 |
+
|
| 197 |
+
sqrt_N = math.sqrt(N)
|
| 198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
| 199 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
| 201 |
+
scale_factor=(sx, sy),
|
| 202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
| 203 |
+
mode="bicubic",
|
| 204 |
+
antialias=self.interpolate_antialias
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
| 208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
| 209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 211 |
+
|
| 212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 213 |
+
B, nc, w, h = x.shape
|
| 214 |
+
x = self.patch_embed(x)
|
| 215 |
+
if masks is not None:
|
| 216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 217 |
+
|
| 218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 220 |
+
|
| 221 |
+
if self.register_tokens is not None:
|
| 222 |
+
x = torch.cat(
|
| 223 |
+
(
|
| 224 |
+
x[:, :1],
|
| 225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 226 |
+
x[:, 1:],
|
| 227 |
+
),
|
| 228 |
+
dim=1,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
def forward_features_list(self, x_list, masks_list):
|
| 234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 235 |
+
for blk in self.blocks:
|
| 236 |
+
x = blk(x)
|
| 237 |
+
|
| 238 |
+
all_x = x
|
| 239 |
+
output = []
|
| 240 |
+
for x, masks in zip(all_x, masks_list):
|
| 241 |
+
x_norm = self.norm(x)
|
| 242 |
+
output.append(
|
| 243 |
+
{
|
| 244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 247 |
+
"x_prenorm": x,
|
| 248 |
+
"masks": masks,
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
return output
|
| 252 |
+
|
| 253 |
+
def forward_features(self, x, masks=None):
|
| 254 |
+
if isinstance(x, list):
|
| 255 |
+
return self.forward_features_list(x, masks)
|
| 256 |
+
|
| 257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 258 |
+
|
| 259 |
+
for blk in self.blocks:
|
| 260 |
+
x = blk(x)
|
| 261 |
+
|
| 262 |
+
x_norm = self.norm(x)
|
| 263 |
+
return {
|
| 264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 267 |
+
"x_prenorm": x,
|
| 268 |
+
"masks": masks,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 272 |
+
x = self.prepare_tokens_with_masks(x)
|
| 273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 274 |
+
output, total_block_len = [], len(self.blocks)
|
| 275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 276 |
+
for i, blk in enumerate(self.blocks):
|
| 277 |
+
x = blk(x)
|
| 278 |
+
if i in blocks_to_take:
|
| 279 |
+
output.append(x)
|
| 280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 281 |
+
return output
|
| 282 |
+
|
| 283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 284 |
+
x = self.prepare_tokens_with_masks(x)
|
| 285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 288 |
+
for block_chunk in self.blocks:
|
| 289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 290 |
+
x = blk(x)
|
| 291 |
+
if i in blocks_to_take:
|
| 292 |
+
output.append(x)
|
| 293 |
+
i += 1
|
| 294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 295 |
+
return output
|
| 296 |
+
|
| 297 |
+
def get_intermediate_layers(
|
| 298 |
+
self,
|
| 299 |
+
x: torch.Tensor,
|
| 300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 301 |
+
reshape: bool = False,
|
| 302 |
+
return_class_token: bool = False,
|
| 303 |
+
norm=True
|
| 304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 305 |
+
if self.chunked_blocks:
|
| 306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 307 |
+
else:
|
| 308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 309 |
+
if norm:
|
| 310 |
+
outputs = [self.norm(out) for out in outputs]
|
| 311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
| 313 |
+
if reshape:
|
| 314 |
+
B, _, w, h = x.shape
|
| 315 |
+
outputs = [
|
| 316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 317 |
+
for out in outputs
|
| 318 |
+
]
|
| 319 |
+
if return_class_token:
|
| 320 |
+
return tuple(zip(outputs, class_tokens))
|
| 321 |
+
return tuple(outputs)
|
| 322 |
+
|
| 323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 324 |
+
ret = self.forward_features(*args, **kwargs)
|
| 325 |
+
if is_training:
|
| 326 |
+
return ret
|
| 327 |
+
else:
|
| 328 |
+
return self.head(ret["x_norm_clstoken"])
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 333 |
+
if isinstance(module, nn.Linear):
|
| 334 |
+
trunc_normal_(module.weight, std=0.02)
|
| 335 |
+
if module.bias is not None:
|
| 336 |
+
nn.init.zeros_(module.bias)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 340 |
+
model = DinoVisionTransformer(
|
| 341 |
+
patch_size=patch_size,
|
| 342 |
+
embed_dim=384,
|
| 343 |
+
depth=12,
|
| 344 |
+
num_heads=6,
|
| 345 |
+
mlp_ratio=4,
|
| 346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 347 |
+
num_register_tokens=num_register_tokens,
|
| 348 |
+
**kwargs,
|
| 349 |
+
)
|
| 350 |
+
return model
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 354 |
+
model = DinoVisionTransformer(
|
| 355 |
+
patch_size=patch_size,
|
| 356 |
+
embed_dim=768,
|
| 357 |
+
depth=12,
|
| 358 |
+
num_heads=12,
|
| 359 |
+
mlp_ratio=4,
|
| 360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 361 |
+
num_register_tokens=num_register_tokens,
|
| 362 |
+
**kwargs,
|
| 363 |
+
)
|
| 364 |
+
return model
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 368 |
+
model = DinoVisionTransformer(
|
| 369 |
+
patch_size=patch_size,
|
| 370 |
+
embed_dim=1024,
|
| 371 |
+
depth=24,
|
| 372 |
+
num_heads=16,
|
| 373 |
+
mlp_ratio=4,
|
| 374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 375 |
+
num_register_tokens=num_register_tokens,
|
| 376 |
+
**kwargs,
|
| 377 |
+
)
|
| 378 |
+
return model
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 382 |
+
"""
|
| 383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 384 |
+
"""
|
| 385 |
+
model = DinoVisionTransformer(
|
| 386 |
+
patch_size=patch_size,
|
| 387 |
+
embed_dim=1536,
|
| 388 |
+
depth=40,
|
| 389 |
+
num_heads=24,
|
| 390 |
+
mlp_ratio=4,
|
| 391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 392 |
+
num_register_tokens=num_register_tokens,
|
| 393 |
+
**kwargs,
|
| 394 |
+
)
|
| 395 |
+
return model
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def DINOv2(model_name):
|
| 399 |
+
model_zoo = {
|
| 400 |
+
"vits": vit_small,
|
| 401 |
+
"vitb": vit_base,
|
| 402 |
+
"vitl": vit_large,
|
| 403 |
+
"vitg": vit_giant2
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
return model_zoo[model_name](
|
| 407 |
+
img_size=518,
|
| 408 |
+
patch_size=14,
|
| 409 |
+
init_values=1.0,
|
| 410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
| 411 |
+
block_chunks=0,
|
| 412 |
+
num_register_tokens=0,
|
| 413 |
+
interpolate_antialias=False,
|
| 414 |
+
interpolate_offset=0.1
|
| 415 |
+
)
|
depth_anything_v2/dinov2_layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (429 Bytes). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc
ADDED
|
Binary file (8 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
depth_anything_v2/dinov2_layers/attention.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
logger.warning("xFormers not available")
|
| 26 |
+
XFORMERS_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int,
|
| 33 |
+
num_heads: int = 8,
|
| 34 |
+
qkv_bias: bool = False,
|
| 35 |
+
proj_bias: bool = True,
|
| 36 |
+
attn_drop: float = 0.0,
|
| 37 |
+
proj_drop: float = 0.0,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
head_dim = dim // num_heads
|
| 42 |
+
self.scale = head_dim**-0.5
|
| 43 |
+
|
| 44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 50 |
+
B, N, C = x.shape
|
| 51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 52 |
+
|
| 53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
|
| 56 |
+
attn = attn.softmax(dim=-1)
|
| 57 |
+
attn = self.attn_drop(attn)
|
| 58 |
+
|
| 59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MemEffAttention(Attention):
|
| 66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 67 |
+
if not XFORMERS_AVAILABLE:
|
| 68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 69 |
+
return super().forward(x)
|
| 70 |
+
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 73 |
+
|
| 74 |
+
q, k, v = unbind(qkv, 2)
|
| 75 |
+
|
| 76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 77 |
+
x = x.reshape([B, N, C])
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
depth_anything_v2/dinov2_layers/block.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
|
| 17 |
+
from .attention import Attention, MemEffAttention
|
| 18 |
+
from .drop_path import DropPath
|
| 19 |
+
from .layer_scale import LayerScale
|
| 20 |
+
from .mlp import Mlp
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from xformers.ops import fmha
|
| 28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
| 29 |
+
|
| 30 |
+
XFORMERS_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("xFormers not available")
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Block(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int,
|
| 41 |
+
mlp_ratio: float = 4.0,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
ffn_bias: bool = True,
|
| 45 |
+
drop: float = 0.0,
|
| 46 |
+
attn_drop: float = 0.0,
|
| 47 |
+
init_values=None,
|
| 48 |
+
drop_path: float = 0.0,
|
| 49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 53 |
+
) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 56 |
+
self.norm1 = norm_layer(dim)
|
| 57 |
+
self.attn = attn_class(
|
| 58 |
+
dim,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
qkv_bias=qkv_bias,
|
| 61 |
+
proj_bias=proj_bias,
|
| 62 |
+
attn_drop=attn_drop,
|
| 63 |
+
proj_drop=drop,
|
| 64 |
+
)
|
| 65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 67 |
+
|
| 68 |
+
self.norm2 = norm_layer(dim)
|
| 69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 70 |
+
self.mlp = ffn_layer(
|
| 71 |
+
in_features=dim,
|
| 72 |
+
hidden_features=mlp_hidden_dim,
|
| 73 |
+
act_layer=act_layer,
|
| 74 |
+
drop=drop,
|
| 75 |
+
bias=ffn_bias,
|
| 76 |
+
)
|
| 77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 79 |
+
|
| 80 |
+
self.sample_drop_ratio = drop_path
|
| 81 |
+
|
| 82 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 85 |
+
|
| 86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 88 |
+
|
| 89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 91 |
+
x = drop_add_residual_stochastic_depth(
|
| 92 |
+
x,
|
| 93 |
+
residual_func=attn_residual_func,
|
| 94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 95 |
+
)
|
| 96 |
+
x = drop_add_residual_stochastic_depth(
|
| 97 |
+
x,
|
| 98 |
+
residual_func=ffn_residual_func,
|
| 99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 100 |
+
)
|
| 101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 104 |
+
else:
|
| 105 |
+
x = x + attn_residual_func(x)
|
| 106 |
+
x = x + ffn_residual_func(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def drop_add_residual_stochastic_depth(
|
| 111 |
+
x: Tensor,
|
| 112 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 113 |
+
sample_drop_ratio: float = 0.0,
|
| 114 |
+
) -> Tensor:
|
| 115 |
+
# 1) extract subset using permutation
|
| 116 |
+
b, n, d = x.shape
|
| 117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 119 |
+
x_subset = x[brange]
|
| 120 |
+
|
| 121 |
+
# 2) apply residual_func to get residual
|
| 122 |
+
residual = residual_func(x_subset)
|
| 123 |
+
|
| 124 |
+
x_flat = x.flatten(1)
|
| 125 |
+
residual = residual.flatten(1)
|
| 126 |
+
|
| 127 |
+
residual_scale_factor = b / sample_subset_size
|
| 128 |
+
|
| 129 |
+
# 3) add the residual
|
| 130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 131 |
+
return x_plus_residual.view_as(x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 135 |
+
b, n, d = x.shape
|
| 136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 138 |
+
residual_scale_factor = b / sample_subset_size
|
| 139 |
+
return brange, residual_scale_factor
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 143 |
+
if scaling_vector is None:
|
| 144 |
+
x_flat = x.flatten(1)
|
| 145 |
+
residual = residual.flatten(1)
|
| 146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 147 |
+
else:
|
| 148 |
+
x_plus_residual = scaled_index_add(
|
| 149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 150 |
+
)
|
| 151 |
+
return x_plus_residual
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 158 |
+
"""
|
| 159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 160 |
+
"""
|
| 161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 163 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 164 |
+
seqlens = []
|
| 165 |
+
for b, x in zip(batch_sizes, x_list):
|
| 166 |
+
for _ in range(b):
|
| 167 |
+
seqlens.append(x.shape[1])
|
| 168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 169 |
+
attn_bias._batch_sizes = batch_sizes
|
| 170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 171 |
+
|
| 172 |
+
if branges is not None:
|
| 173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 174 |
+
else:
|
| 175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 177 |
+
|
| 178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def drop_add_residual_stochastic_depth_list(
|
| 182 |
+
x_list: List[Tensor],
|
| 183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 184 |
+
sample_drop_ratio: float = 0.0,
|
| 185 |
+
scaling_vector=None,
|
| 186 |
+
) -> Tensor:
|
| 187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 189 |
+
branges = [s[0] for s in branges_scales]
|
| 190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 191 |
+
|
| 192 |
+
# 2) get attention bias and index+concat the tensors
|
| 193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 194 |
+
|
| 195 |
+
# 3) apply residual_func to get residual, and split the result
|
| 196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 197 |
+
|
| 198 |
+
outputs = []
|
| 199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NestedTensorBlock(Block):
|
| 205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 206 |
+
"""
|
| 207 |
+
x_list contains a list of tensors to nest together and run
|
| 208 |
+
"""
|
| 209 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 210 |
+
|
| 211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 212 |
+
|
| 213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 215 |
+
|
| 216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 217 |
+
return self.mlp(self.norm2(x))
|
| 218 |
+
|
| 219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 220 |
+
x_list,
|
| 221 |
+
residual_func=attn_residual_func,
|
| 222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 224 |
+
)
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=ffn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
return x_list
|
| 232 |
+
else:
|
| 233 |
+
|
| 234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 236 |
+
|
| 237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 239 |
+
|
| 240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 242 |
+
x = x + ffn_residual_func(x)
|
| 243 |
+
return attn_bias.split(x)
|
| 244 |
+
|
| 245 |
+
def forward(self, x_or_x_list):
|
| 246 |
+
if isinstance(x_or_x_list, Tensor):
|
| 247 |
+
return super().forward(x_or_x_list)
|
| 248 |
+
elif isinstance(x_or_x_list, list):
|
| 249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
| 250 |
+
return self.forward_nested(x_or_x_list)
|
| 251 |
+
else:
|
| 252 |
+
raise AssertionError
|
depth_anything_v2/dinov2_layers/drop_path.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 16 |
+
if drop_prob == 0.0 or not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 21 |
+
if keep_prob > 0.0:
|
| 22 |
+
random_tensor.div_(keep_prob)
|
| 23 |
+
output = x * random_tensor
|
| 24 |
+
return output
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DropPath(nn.Module):
|
| 28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, drop_prob=None):
|
| 31 |
+
super(DropPath, self).__init__()
|
| 32 |
+
self.drop_prob = drop_prob
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return drop_path(x, self.drop_prob, self.training)
|
depth_anything_v2/dinov2_layers/layer_scale.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 8 |
+
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LayerScale(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 21 |
+
inplace: bool = False,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.inplace = inplace
|
| 25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
depth_anything_v2/dinov2_layers/mlp.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_features: int,
|
| 21 |
+
hidden_features: Optional[int] = None,
|
| 22 |
+
out_features: Optional[int] = None,
|
| 23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 24 |
+
drop: float = 0.0,
|
| 25 |
+
bias: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
depth_anything_v2/dinov2_layers/patch_embed.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PatchEmbed(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size: Image size.
|
| 32 |
+
patch_size: Patch token size.
|
| 33 |
+
in_chans: Number of input image channels.
|
| 34 |
+
embed_dim: Number of linear projection output channels.
|
| 35 |
+
norm_layer: Normalization layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 42 |
+
in_chans: int = 3,
|
| 43 |
+
embed_dim: int = 768,
|
| 44 |
+
norm_layer: Optional[Callable] = None,
|
| 45 |
+
flatten_embedding: bool = True,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
image_HW = make_2tuple(img_size)
|
| 50 |
+
patch_HW = make_2tuple(patch_size)
|
| 51 |
+
patch_grid_size = (
|
| 52 |
+
image_HW[0] // patch_HW[0],
|
| 53 |
+
image_HW[1] // patch_HW[1],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.img_size = image_HW
|
| 57 |
+
self.patch_size = patch_HW
|
| 58 |
+
self.patches_resolution = patch_grid_size
|
| 59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 60 |
+
|
| 61 |
+
self.in_chans = in_chans
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
|
| 64 |
+
self.flatten_embedding = flatten_embedding
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 70 |
+
_, _, H, W = x.shape
|
| 71 |
+
patch_H, patch_W = self.patch_size
|
| 72 |
+
|
| 73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 75 |
+
|
| 76 |
+
x = self.proj(x) # B C H W
|
| 77 |
+
H, W = x.size(2), x.size(3)
|
| 78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 79 |
+
x = self.norm(x)
|
| 80 |
+
if not self.flatten_embedding:
|
| 81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def flops(self) -> float:
|
| 85 |
+
Ho, Wo = self.patches_resolution
|
| 86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 87 |
+
if self.norm is not None:
|
| 88 |
+
flops += Ho * Wo * self.embed_dim
|
| 89 |
+
return flops
|
depth_anything_v2/dinov2_layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SwiGLUFFN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
hidden_features: Optional[int] = None,
|
| 18 |
+
out_features: Optional[int] = None,
|
| 19 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
+
drop: float = 0.0,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x12 = self.w12(x)
|
| 31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
+
hidden = F.silu(x1) * x2
|
| 33 |
+
return self.w3(hidden)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from xformers.ops import SwiGLU
|
| 38 |
+
|
| 39 |
+
XFORMERS_AVAILABLE = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
SwiGLU = SwiGLUFFN
|
| 42 |
+
XFORMERS_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
in_features: int,
|
| 49 |
+
hidden_features: Optional[int] = None,
|
| 50 |
+
out_features: Optional[int] = None,
|
| 51 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
bias: bool = True,
|
| 54 |
+
) -> None:
|
| 55 |
+
out_features = out_features or in_features
|
| 56 |
+
hidden_features = hidden_features or in_features
|
| 57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 58 |
+
super().__init__(
|
| 59 |
+
in_features=in_features,
|
| 60 |
+
hidden_features=hidden_features,
|
| 61 |
+
out_features=out_features,
|
| 62 |
+
bias=bias,
|
| 63 |
+
)
|
depth_anything_v2/dpt.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.transforms import Compose
|
| 6 |
+
|
| 7 |
+
from .dinov2 import DINOv2
|
| 8 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
| 9 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _make_fusion_block(features, use_bn, size=None):
|
| 13 |
+
return FeatureFusionBlock(
|
| 14 |
+
features,
|
| 15 |
+
nn.ReLU(False),
|
| 16 |
+
deconv=False,
|
| 17 |
+
bn=use_bn,
|
| 18 |
+
expand=False,
|
| 19 |
+
align_corners=True,
|
| 20 |
+
size=size,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ConvBlock(nn.Module):
|
| 25 |
+
def __init__(self, in_feature, out_feature):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.conv_block = nn.Sequential(
|
| 29 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
| 30 |
+
nn.BatchNorm2d(out_feature),
|
| 31 |
+
nn.ReLU(True)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return self.conv_block(x)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DPTHead(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
in_channels,
|
| 42 |
+
features=256,
|
| 43 |
+
use_bn=False,
|
| 44 |
+
out_channels=[256, 512, 1024, 1024],
|
| 45 |
+
use_clstoken=False
|
| 46 |
+
):
|
| 47 |
+
super(DPTHead, self).__init__()
|
| 48 |
+
|
| 49 |
+
self.use_clstoken = use_clstoken
|
| 50 |
+
|
| 51 |
+
self.projects = nn.ModuleList([
|
| 52 |
+
nn.Conv2d(
|
| 53 |
+
in_channels=in_channels,
|
| 54 |
+
out_channels=out_channel,
|
| 55 |
+
kernel_size=1,
|
| 56 |
+
stride=1,
|
| 57 |
+
padding=0,
|
| 58 |
+
) for out_channel in out_channels
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
self.resize_layers = nn.ModuleList([
|
| 62 |
+
nn.ConvTranspose2d(
|
| 63 |
+
in_channels=out_channels[0],
|
| 64 |
+
out_channels=out_channels[0],
|
| 65 |
+
kernel_size=4,
|
| 66 |
+
stride=4,
|
| 67 |
+
padding=0),
|
| 68 |
+
nn.ConvTranspose2d(
|
| 69 |
+
in_channels=out_channels[1],
|
| 70 |
+
out_channels=out_channels[1],
|
| 71 |
+
kernel_size=2,
|
| 72 |
+
stride=2,
|
| 73 |
+
padding=0),
|
| 74 |
+
nn.Identity(),
|
| 75 |
+
nn.Conv2d(
|
| 76 |
+
in_channels=out_channels[3],
|
| 77 |
+
out_channels=out_channels[3],
|
| 78 |
+
kernel_size=3,
|
| 79 |
+
stride=2,
|
| 80 |
+
padding=1)
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
if use_clstoken:
|
| 84 |
+
self.readout_projects = nn.ModuleList()
|
| 85 |
+
for _ in range(len(self.projects)):
|
| 86 |
+
self.readout_projects.append(
|
| 87 |
+
nn.Sequential(
|
| 88 |
+
nn.Linear(2 * in_channels, in_channels),
|
| 89 |
+
nn.GELU()))
|
| 90 |
+
|
| 91 |
+
self.scratch = _make_scratch(
|
| 92 |
+
out_channels,
|
| 93 |
+
features,
|
| 94 |
+
groups=1,
|
| 95 |
+
expand=False,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.scratch.stem_transpose = None
|
| 99 |
+
|
| 100 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
| 101 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
| 102 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
| 103 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
| 104 |
+
|
| 105 |
+
head_features_1 = features
|
| 106 |
+
head_features_2 = 32
|
| 107 |
+
|
| 108 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
| 109 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 110 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 111 |
+
nn.ReLU(True),
|
| 112 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
| 113 |
+
nn.ReLU(True),
|
| 114 |
+
nn.Identity(),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, out_features, patch_h, patch_w):
|
| 118 |
+
out = []
|
| 119 |
+
for i, x in enumerate(out_features):
|
| 120 |
+
if self.use_clstoken:
|
| 121 |
+
x, cls_token = x[0], x[1]
|
| 122 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 123 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 124 |
+
else:
|
| 125 |
+
x = x[0]
|
| 126 |
+
|
| 127 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 128 |
+
|
| 129 |
+
x = self.projects[i](x)
|
| 130 |
+
x = self.resize_layers[i](x)
|
| 131 |
+
|
| 132 |
+
out.append(x)
|
| 133 |
+
|
| 134 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
| 135 |
+
|
| 136 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 137 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 138 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 139 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 140 |
+
|
| 141 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 142 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 143 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 144 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 145 |
+
|
| 146 |
+
out = self.scratch.output_conv1(path_1)
|
| 147 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
| 148 |
+
out = self.scratch.output_conv2(out)
|
| 149 |
+
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class DepthAnythingV2(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
encoder='vitl',
|
| 157 |
+
features=256,
|
| 158 |
+
out_channels=[256, 512, 1024, 1024],
|
| 159 |
+
use_bn=False,
|
| 160 |
+
use_clstoken=False
|
| 161 |
+
):
|
| 162 |
+
super(DepthAnythingV2, self).__init__()
|
| 163 |
+
|
| 164 |
+
self.intermediate_layer_idx = {
|
| 165 |
+
'vits': [2, 5, 8, 11],
|
| 166 |
+
'vitb': [2, 5, 8, 11],
|
| 167 |
+
'vitl': [4, 11, 17, 23],
|
| 168 |
+
'vitg': [9, 19, 29, 39]
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
self.encoder = encoder
|
| 172 |
+
self.pretrained = DINOv2(model_name=encoder)
|
| 173 |
+
|
| 174 |
+
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
|
| 178 |
+
|
| 179 |
+
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
| 180 |
+
|
| 181 |
+
depth = self.depth_head(features, patch_h, patch_w)
|
| 182 |
+
depth = F.relu(depth)
|
| 183 |
+
|
| 184 |
+
return depth.squeeze(1)
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def infer_image(self, raw_image, input_size=518):
|
| 188 |
+
image, (h, w) = self.image2tensor(raw_image, input_size)
|
| 189 |
+
|
| 190 |
+
depth = self.forward(image)
|
| 191 |
+
|
| 192 |
+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
|
| 193 |
+
|
| 194 |
+
return depth.cpu().numpy()
|
| 195 |
+
|
| 196 |
+
def image2tensor(self, raw_image, input_size=518):
|
| 197 |
+
transform = Compose([
|
| 198 |
+
Resize(
|
| 199 |
+
width=input_size,
|
| 200 |
+
height=input_size,
|
| 201 |
+
resize_target=False,
|
| 202 |
+
keep_aspect_ratio=True,
|
| 203 |
+
ensure_multiple_of=14,
|
| 204 |
+
resize_method='lower_bound',
|
| 205 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
| 206 |
+
),
|
| 207 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 208 |
+
PrepareForNet(),
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
h, w = raw_image.shape[:2]
|
| 212 |
+
|
| 213 |
+
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
|
| 214 |
+
|
| 215 |
+
image = transform({'image': image})['image']
|
| 216 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
| 217 |
+
|
| 218 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
| 219 |
+
image = image.to(DEVICE)
|
| 220 |
+
|
| 221 |
+
return image, (h, w)
|
depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc
ADDED
|
Binary file (3.29 kB). View file
|
|
|
depth_anything_v2/util/__pycache__/transform.cpython-310.pyc
ADDED
|
Binary file (4.73 kB). View file
|
|
|
depth_anything_v2/util/blocks.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
| 5 |
+
scratch = nn.Module()
|
| 6 |
+
|
| 7 |
+
out_shape1 = out_shape
|
| 8 |
+
out_shape2 = out_shape
|
| 9 |
+
out_shape3 = out_shape
|
| 10 |
+
if len(in_shape) >= 4:
|
| 11 |
+
out_shape4 = out_shape
|
| 12 |
+
|
| 13 |
+
if expand:
|
| 14 |
+
out_shape1 = out_shape
|
| 15 |
+
out_shape2 = out_shape * 2
|
| 16 |
+
out_shape3 = out_shape * 4
|
| 17 |
+
if len(in_shape) >= 4:
|
| 18 |
+
out_shape4 = out_shape * 8
|
| 19 |
+
|
| 20 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
| 21 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
| 22 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
| 23 |
+
if len(in_shape) >= 4:
|
| 24 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
| 25 |
+
|
| 26 |
+
return scratch
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ResidualConvUnit(nn.Module):
|
| 30 |
+
"""Residual convolution module.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, features, activation, bn):
|
| 34 |
+
"""Init.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
features (int): number of features
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.bn = bn
|
| 42 |
+
|
| 43 |
+
self.groups=1
|
| 44 |
+
|
| 45 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 46 |
+
|
| 47 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 48 |
+
|
| 49 |
+
if self.bn == True:
|
| 50 |
+
self.bn1 = nn.BatchNorm2d(features)
|
| 51 |
+
self.bn2 = nn.BatchNorm2d(features)
|
| 52 |
+
|
| 53 |
+
self.activation = activation
|
| 54 |
+
|
| 55 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
"""Forward pass.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
x (tensor): input
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
tensor: output
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
out = self.activation(x)
|
| 68 |
+
out = self.conv1(out)
|
| 69 |
+
if self.bn == True:
|
| 70 |
+
out = self.bn1(out)
|
| 71 |
+
|
| 72 |
+
out = self.activation(out)
|
| 73 |
+
out = self.conv2(out)
|
| 74 |
+
if self.bn == True:
|
| 75 |
+
out = self.bn2(out)
|
| 76 |
+
|
| 77 |
+
if self.groups > 1:
|
| 78 |
+
out = self.conv_merge(out)
|
| 79 |
+
|
| 80 |
+
return self.skip_add.add(out, x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FeatureFusionBlock(nn.Module):
|
| 84 |
+
"""Feature fusion block.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
features,
|
| 90 |
+
activation,
|
| 91 |
+
deconv=False,
|
| 92 |
+
bn=False,
|
| 93 |
+
expand=False,
|
| 94 |
+
align_corners=True,
|
| 95 |
+
size=None
|
| 96 |
+
):
|
| 97 |
+
"""Init.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
features (int): number of features
|
| 101 |
+
"""
|
| 102 |
+
super(FeatureFusionBlock, self).__init__()
|
| 103 |
+
|
| 104 |
+
self.deconv = deconv
|
| 105 |
+
self.align_corners = align_corners
|
| 106 |
+
|
| 107 |
+
self.groups=1
|
| 108 |
+
|
| 109 |
+
self.expand = expand
|
| 110 |
+
out_features = features
|
| 111 |
+
if self.expand == True:
|
| 112 |
+
out_features = features // 2
|
| 113 |
+
|
| 114 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
| 115 |
+
|
| 116 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
| 117 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
| 118 |
+
|
| 119 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 120 |
+
|
| 121 |
+
self.size=size
|
| 122 |
+
|
| 123 |
+
def forward(self, *xs, size=None):
|
| 124 |
+
"""Forward pass.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
tensor: output
|
| 128 |
+
"""
|
| 129 |
+
output = xs[0]
|
| 130 |
+
|
| 131 |
+
if len(xs) == 2:
|
| 132 |
+
res = self.resConfUnit1(xs[1])
|
| 133 |
+
output = self.skip_add.add(output, res)
|
| 134 |
+
|
| 135 |
+
output = self.resConfUnit2(output)
|
| 136 |
+
|
| 137 |
+
if (size is None) and (self.size is None):
|
| 138 |
+
modifier = {"scale_factor": 2}
|
| 139 |
+
elif size is None:
|
| 140 |
+
modifier = {"size": self.size}
|
| 141 |
+
else:
|
| 142 |
+
modifier = {"size": size}
|
| 143 |
+
|
| 144 |
+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 145 |
+
|
| 146 |
+
output = self.out_conv(output)
|
| 147 |
+
|
| 148 |
+
return output
|
depth_anything_v2/util/transform.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Resize(object):
|
| 6 |
+
"""Resize sample to given size (width, height).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
width,
|
| 12 |
+
height,
|
| 13 |
+
resize_target=True,
|
| 14 |
+
keep_aspect_ratio=False,
|
| 15 |
+
ensure_multiple_of=1,
|
| 16 |
+
resize_method="lower_bound",
|
| 17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
| 18 |
+
):
|
| 19 |
+
"""Init.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
width (int): desired output width
|
| 23 |
+
height (int): desired output height
|
| 24 |
+
resize_target (bool, optional):
|
| 25 |
+
True: Resize the full sample (image, mask, target).
|
| 26 |
+
False: Resize image only.
|
| 27 |
+
Defaults to True.
|
| 28 |
+
keep_aspect_ratio (bool, optional):
|
| 29 |
+
True: Keep the aspect ratio of the input sample.
|
| 30 |
+
Output sample might not have the given width and height, and
|
| 31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
| 32 |
+
Defaults to False.
|
| 33 |
+
ensure_multiple_of (int, optional):
|
| 34 |
+
Output width and height is constrained to be multiple of this parameter.
|
| 35 |
+
Defaults to 1.
|
| 36 |
+
resize_method (str, optional):
|
| 37 |
+
"lower_bound": Output will be at least as large as the given size.
|
| 38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
| 39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
| 40 |
+
Defaults to "lower_bound".
|
| 41 |
+
"""
|
| 42 |
+
self.__width = width
|
| 43 |
+
self.__height = height
|
| 44 |
+
|
| 45 |
+
self.__resize_target = resize_target
|
| 46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
| 47 |
+
self.__multiple_of = ensure_multiple_of
|
| 48 |
+
self.__resize_method = resize_method
|
| 49 |
+
self.__image_interpolation_method = image_interpolation_method
|
| 50 |
+
|
| 51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
| 52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 53 |
+
|
| 54 |
+
if max_val is not None and y > max_val:
|
| 55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 56 |
+
|
| 57 |
+
if y < min_val:
|
| 58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 59 |
+
|
| 60 |
+
return y
|
| 61 |
+
|
| 62 |
+
def get_size(self, width, height):
|
| 63 |
+
# determine new height and width
|
| 64 |
+
scale_height = self.__height / height
|
| 65 |
+
scale_width = self.__width / width
|
| 66 |
+
|
| 67 |
+
if self.__keep_aspect_ratio:
|
| 68 |
+
if self.__resize_method == "lower_bound":
|
| 69 |
+
# scale such that output size is lower bound
|
| 70 |
+
if scale_width > scale_height:
|
| 71 |
+
# fit width
|
| 72 |
+
scale_height = scale_width
|
| 73 |
+
else:
|
| 74 |
+
# fit height
|
| 75 |
+
scale_width = scale_height
|
| 76 |
+
elif self.__resize_method == "upper_bound":
|
| 77 |
+
# scale such that output size is upper bound
|
| 78 |
+
if scale_width < scale_height:
|
| 79 |
+
# fit width
|
| 80 |
+
scale_height = scale_width
|
| 81 |
+
else:
|
| 82 |
+
# fit height
|
| 83 |
+
scale_width = scale_height
|
| 84 |
+
elif self.__resize_method == "minimal":
|
| 85 |
+
# scale as least as possbile
|
| 86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
| 87 |
+
# fit width
|
| 88 |
+
scale_height = scale_width
|
| 89 |
+
else:
|
| 90 |
+
# fit height
|
| 91 |
+
scale_width = scale_height
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 94 |
+
|
| 95 |
+
if self.__resize_method == "lower_bound":
|
| 96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
| 97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
| 98 |
+
elif self.__resize_method == "upper_bound":
|
| 99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
| 100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
| 101 |
+
elif self.__resize_method == "minimal":
|
| 102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
| 103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 106 |
+
|
| 107 |
+
return (new_width, new_height)
|
| 108 |
+
|
| 109 |
+
def __call__(self, sample):
|
| 110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
| 111 |
+
|
| 112 |
+
# resize sample
|
| 113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
| 114 |
+
|
| 115 |
+
if self.__resize_target:
|
| 116 |
+
if "depth" in sample:
|
| 117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
| 118 |
+
|
| 119 |
+
if "mask" in sample:
|
| 120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
| 121 |
+
|
| 122 |
+
return sample
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class NormalizeImage(object):
|
| 126 |
+
"""Normlize image by given mean and std.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, mean, std):
|
| 130 |
+
self.__mean = mean
|
| 131 |
+
self.__std = std
|
| 132 |
+
|
| 133 |
+
def __call__(self, sample):
|
| 134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
| 135 |
+
|
| 136 |
+
return sample
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class PrepareForNet(object):
|
| 140 |
+
"""Prepare sample for usage as network input.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def __call__(self, sample):
|
| 147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
| 148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
| 149 |
+
|
| 150 |
+
if "depth" in sample:
|
| 151 |
+
depth = sample["depth"].astype(np.float32)
|
| 152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
| 153 |
+
|
| 154 |
+
if "mask" in sample:
|
| 155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
| 156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
| 157 |
+
|
| 158 |
+
return sample
|
models/FCN.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from keras.models import Model
|
| 3 |
+
from keras.layers import Input
|
| 4 |
+
from keras.layers import Conv2D, MaxPooling2D, Dropout, UpSampling2D
|
| 5 |
+
from utils.BilinearUpSampling import BilinearUpSampling2D
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def FCN_Vgg16_16s(input_shape=None, weight_decay=0., batch_momentum=0.9, batch_shape=None, classes=1):
|
| 9 |
+
if batch_shape:
|
| 10 |
+
img_input = Input(batch_shape=batch_shape)
|
| 11 |
+
image_size = batch_shape[1:3]
|
| 12 |
+
else:
|
| 13 |
+
img_input = Input(shape=input_shape)
|
| 14 |
+
image_size = input_shape[0:2]
|
| 15 |
+
# Block 1
|
| 16 |
+
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer='l2')(img_input)
|
| 17 |
+
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer='l2')(x)
|
| 18 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
|
| 19 |
+
|
| 20 |
+
# Block 2
|
| 21 |
+
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer='l2')(x)
|
| 22 |
+
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer='l2')(x)
|
| 23 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
|
| 24 |
+
|
| 25 |
+
# Block 3
|
| 26 |
+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer='l2')(x)
|
| 27 |
+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer='l2')(x)
|
| 28 |
+
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer='l2')(x)
|
| 29 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
|
| 30 |
+
|
| 31 |
+
# Block 4
|
| 32 |
+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer='l2')(x)
|
| 33 |
+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer='l2')(x)
|
| 34 |
+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer='l2')(x)
|
| 35 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
|
| 36 |
+
|
| 37 |
+
# Block 5
|
| 38 |
+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer='l2')(x)
|
| 39 |
+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer='l2')(x)
|
| 40 |
+
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer='l2')(x)
|
| 41 |
+
|
| 42 |
+
# Convolutional layers transfered from fully-connected layers
|
| 43 |
+
x = Conv2D(4096, (7, 7), activation='relu', padding='same', dilation_rate=(2, 2),
|
| 44 |
+
name='fc1', kernel_regularizer='l2')(x)
|
| 45 |
+
x = Dropout(0.5)(x)
|
| 46 |
+
x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer='l2')(x)
|
| 47 |
+
x = Dropout(0.5)(x)
|
| 48 |
+
#classifying layer
|
| 49 |
+
x = Conv2D(classes, (1, 1), kernel_initializer='he_normal', activation='linear', padding='valid', strides=(1, 1), kernel_regularizer='l2')(x)
|
| 50 |
+
|
| 51 |
+
x = BilinearUpSampling2D(size=(16, 16))(x)
|
| 52 |
+
|
| 53 |
+
model = Model(img_input, x)
|
| 54 |
+
model_name = 'FCN_Vgg16_16'
|
| 55 |
+
return model, model_name
|
models/SegNet.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.models import Model
|
| 2 |
+
from keras.layers import Input
|
| 3 |
+
from keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Concatenate, UpSampling2D
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SegNet:
|
| 7 |
+
def __init__(self, n_filters, input_dim_x, input_dim_y, num_channels):
|
| 8 |
+
self.input_dim_x = input_dim_x
|
| 9 |
+
self.input_dim_y = input_dim_y
|
| 10 |
+
self.n_filters = n_filters
|
| 11 |
+
self.num_channels = num_channels
|
| 12 |
+
|
| 13 |
+
def get_SegNet(self):
|
| 14 |
+
convnet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
|
| 15 |
+
|
| 16 |
+
encoder_conv1 = Conv2D(self.n_filters, kernel_size=9, activation='relu', padding='same')(convnet_input)
|
| 17 |
+
pool1 = MaxPooling2D(pool_size=(2, 2))(encoder_conv1)
|
| 18 |
+
encoder_conv2 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(pool1)
|
| 19 |
+
pool2 = MaxPooling2D(pool_size=(2, 2))(encoder_conv2)
|
| 20 |
+
encoder_conv3 = Conv2D(self.n_filters * 2, kernel_size=5, activation='relu', padding='same')(pool2)
|
| 21 |
+
pool3 = MaxPooling2D(pool_size=(2, 2))(encoder_conv3)
|
| 22 |
+
encoder_conv4 = Conv2D(self.n_filters * 2, kernel_size=5, activation='relu', padding='same')(pool3)
|
| 23 |
+
pool4 = MaxPooling2D(pool_size=(2, 2))(encoder_conv4)
|
| 24 |
+
|
| 25 |
+
conv5 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(pool4)
|
| 26 |
+
|
| 27 |
+
decoder_conv6 = Conv2D(self.n_filters, kernel_size=7, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5))
|
| 28 |
+
decoder_conv7 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv6))
|
| 29 |
+
decoder_conv8 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv7))
|
| 30 |
+
#decoder_conv9 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv8))
|
| 31 |
+
decoder_conv9 = Conv2D(1, kernel_size=1, activation='sigmoid', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv8))
|
| 32 |
+
|
| 33 |
+
return Model(outputs=decoder_conv9, inputs=convnet_input), 'SegNet'
|
models/__pycache__/FCN.cpython-37.pyc
ADDED
|
Binary file (1.91 kB). View file
|
|
|
models/__pycache__/FCN.cpython-39.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
models/__pycache__/SegNet.cpython-37.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
models/__pycache__/SegNet.cpython-39.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
models/__pycache__/deeplab.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
models/__pycache__/deeplab.cpython-313.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
models/__pycache__/deeplab.cpython-37.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
models/__pycache__/deeplab.cpython-39.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
models/__pycache__/unets.cpython-37.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
models/__pycache__/unets.cpython-39.pyc
ADDED
|
Binary file (4.96 kB). View file
|
|
|
models/deeplab.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
""" Deeplabv3+ model for Keras.
|
| 4 |
+
This model is based on this repo:
|
| 5 |
+
https://github.com/bonlime/keras-deeplab-v3-plus
|
| 6 |
+
|
| 7 |
+
MobileNetv2 backbone is based on this repo:
|
| 8 |
+
https://github.com/JonathanCMitchell/mobilenet_v2_keras
|
| 9 |
+
|
| 10 |
+
# Reference
|
| 11 |
+
- [Encoder-Decoder with Atrous Separable Convolution
|
| 12 |
+
for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
|
| 13 |
+
- [Xception: Deep Learning with Depthwise Separable Convolutions]
|
| 14 |
+
(https://arxiv.org/abs/1610.02357)
|
| 15 |
+
- [Inverted Residuals and Linear Bottlenecks: Mobile Networks for
|
| 16 |
+
Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import absolute_import
|
| 20 |
+
from __future__ import division
|
| 21 |
+
from __future__ import print_function
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
from keras.models import Model
|
| 27 |
+
from keras import layers
|
| 28 |
+
from keras.layers import Input
|
| 29 |
+
from keras.layers import Activation
|
| 30 |
+
from keras.layers import Concatenate
|
| 31 |
+
from keras.layers import Add
|
| 32 |
+
from keras.layers import Dropout
|
| 33 |
+
from keras.layers import BatchNormalization
|
| 34 |
+
from keras.layers import Conv2D
|
| 35 |
+
from keras.layers import DepthwiseConv2D
|
| 36 |
+
from keras.layers import ZeroPadding2D
|
| 37 |
+
from keras.layers import AveragePooling2D
|
| 38 |
+
from keras.layers import Layer
|
| 39 |
+
from tensorflow.keras.layers import InputSpec
|
| 40 |
+
from tensorflow.keras.utils import get_source_inputs
|
| 41 |
+
from keras import backend as K
|
| 42 |
+
from keras.applications import imagenet_utils
|
| 43 |
+
from keras.utils import conv_utils
|
| 44 |
+
from keras.utils.data_utils import get_file
|
| 45 |
+
|
| 46 |
+
WEIGHTS_PATH_X = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_xception_tf_dim_ordering_tf_kernels.h5"
|
| 47 |
+
WEIGHTS_PATH_MOBILE = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5"
|
| 48 |
+
WEIGHTS_PATH_X_CS = "https://github.com/rdiazgar/keras-deeplab-v3-plus/releases/download/1.2/deeplabv3_xception_tf_dim_ordering_tf_kernels_cityscapes.h5"
|
| 49 |
+
WEIGHTS_PATH_MOBILE_CS = "https://github.com/rdiazgar/keras-deeplab-v3-plus/releases/download/1.2/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels_cityscapes.h5"
|
| 50 |
+
|
| 51 |
+
class BilinearUpsampling(Layer):
|
| 52 |
+
"""Just a simple bilinear upsampling layer. Works only with TF.
|
| 53 |
+
Args:
|
| 54 |
+
upsampling: tuple of 2 numbers > 0. The upsampling ratio for h and w
|
| 55 |
+
output_size: used instead of upsampling arg if passed!
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, upsampling=(2, 2), output_size=None, data_format=None, **kwargs):
|
| 59 |
+
|
| 60 |
+
super(BilinearUpsampling, self).__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
self.data_format = K.image_data_format()
|
| 63 |
+
self.input_spec = InputSpec(ndim=4)
|
| 64 |
+
if output_size:
|
| 65 |
+
self.output_size = conv_utils.normalize_tuple(
|
| 66 |
+
output_size, 2, 'output_size')
|
| 67 |
+
self.upsampling = None
|
| 68 |
+
else:
|
| 69 |
+
self.output_size = None
|
| 70 |
+
self.upsampling = conv_utils.normalize_tuple(
|
| 71 |
+
upsampling, 2, 'upsampling')
|
| 72 |
+
|
| 73 |
+
def compute_output_shape(self, input_shape):
|
| 74 |
+
if self.upsampling:
|
| 75 |
+
height = self.upsampling[0] * \
|
| 76 |
+
input_shape[1] if input_shape[1] is not None else None
|
| 77 |
+
width = self.upsampling[1] * \
|
| 78 |
+
input_shape[2] if input_shape[2] is not None else None
|
| 79 |
+
else:
|
| 80 |
+
height = self.output_size[0]
|
| 81 |
+
width = self.output_size[1]
|
| 82 |
+
return (input_shape[0],
|
| 83 |
+
height,
|
| 84 |
+
width,
|
| 85 |
+
input_shape[3])
|
| 86 |
+
|
| 87 |
+
def call(self, inputs):
|
| 88 |
+
if self.upsampling:
|
| 89 |
+
return tf.compat.v1.image.resize_bilinear(inputs, (inputs.shape[1] * self.upsampling[0],
|
| 90 |
+
inputs.shape[2] * self.upsampling[1]),
|
| 91 |
+
align_corners=True)
|
| 92 |
+
else:
|
| 93 |
+
return tf.compat.v1.image.resize_bilinear(inputs, (self.output_size[0],
|
| 94 |
+
self.output_size[1]),
|
| 95 |
+
align_corners=True)
|
| 96 |
+
|
| 97 |
+
def get_config(self):
|
| 98 |
+
config = {'upsampling': self.upsampling,
|
| 99 |
+
'output_size': self.output_size,
|
| 100 |
+
'data_format': self.data_format}
|
| 101 |
+
base_config = super(BilinearUpsampling, self).get_config()
|
| 102 |
+
return dict(list(base_config.items()) + list(config.items()))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def SepConv_BN(x, filters, prefix, stride=1, kernel_size=3, rate=1, depth_activation=False, epsilon=1e-3):
|
| 106 |
+
""" SepConv with BN between depthwise & pointwise. Optionally add activation after BN
|
| 107 |
+
Implements right "same" padding for even kernel sizes
|
| 108 |
+
Args:
|
| 109 |
+
x: input tensor
|
| 110 |
+
filters: num of filters in pointwise convolution
|
| 111 |
+
prefix: prefix before name
|
| 112 |
+
stride: stride at depthwise conv
|
| 113 |
+
kernel_size: kernel size for depthwise convolution
|
| 114 |
+
rate: atrous rate for depthwise convolution
|
| 115 |
+
depth_activation: flag to use activation between depthwise & poinwise convs
|
| 116 |
+
epsilon: epsilon to use in BN layer
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if stride == 1:
|
| 120 |
+
depth_padding = 'same'
|
| 121 |
+
else:
|
| 122 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
|
| 123 |
+
pad_total = kernel_size_effective - 1
|
| 124 |
+
pad_beg = pad_total // 2
|
| 125 |
+
pad_end = pad_total - pad_beg
|
| 126 |
+
x = ZeroPadding2D((pad_beg, pad_end))(x)
|
| 127 |
+
depth_padding = 'valid'
|
| 128 |
+
|
| 129 |
+
if not depth_activation:
|
| 130 |
+
x = Activation('relu')(x)
|
| 131 |
+
x = DepthwiseConv2D((kernel_size, kernel_size), strides=(stride, stride), dilation_rate=(rate, rate),
|
| 132 |
+
padding=depth_padding, use_bias=False, name=prefix + '_depthwise')(x)
|
| 133 |
+
x = BatchNormalization(name=prefix + '_depthwise_BN', epsilon=epsilon)(x)
|
| 134 |
+
if depth_activation:
|
| 135 |
+
x = Activation('relu')(x)
|
| 136 |
+
x = Conv2D(filters, (1, 1), padding='same',
|
| 137 |
+
use_bias=False, name=prefix + '_pointwise')(x)
|
| 138 |
+
x = BatchNormalization(name=prefix + '_pointwise_BN', epsilon=epsilon)(x)
|
| 139 |
+
if depth_activation:
|
| 140 |
+
x = Activation('relu')(x)
|
| 141 |
+
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _conv2d_same(x, filters, prefix, stride=1, kernel_size=3, rate=1):
|
| 146 |
+
"""Implements right 'same' padding for even kernel sizes
|
| 147 |
+
Without this there is a 1 pixel drift when stride = 2
|
| 148 |
+
Args:
|
| 149 |
+
x: input tensor
|
| 150 |
+
filters: num of filters in pointwise convolution
|
| 151 |
+
prefix: prefix before name
|
| 152 |
+
stride: stride at depthwise conv
|
| 153 |
+
kernel_size: kernel size for depthwise convolution
|
| 154 |
+
rate: atrous rate for depthwise convolution
|
| 155 |
+
"""
|
| 156 |
+
if stride == 1:
|
| 157 |
+
return Conv2D(filters,
|
| 158 |
+
(kernel_size, kernel_size),
|
| 159 |
+
strides=(stride, stride),
|
| 160 |
+
padding='same', use_bias=False,
|
| 161 |
+
dilation_rate=(rate, rate),
|
| 162 |
+
name=prefix)(x)
|
| 163 |
+
else:
|
| 164 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
|
| 165 |
+
pad_total = kernel_size_effective - 1
|
| 166 |
+
pad_beg = pad_total // 2
|
| 167 |
+
pad_end = pad_total - pad_beg
|
| 168 |
+
x = ZeroPadding2D((pad_beg, pad_end))(x)
|
| 169 |
+
return Conv2D(filters,
|
| 170 |
+
(kernel_size, kernel_size),
|
| 171 |
+
strides=(stride, stride),
|
| 172 |
+
padding='valid', use_bias=False,
|
| 173 |
+
dilation_rate=(rate, rate),
|
| 174 |
+
name=prefix)(x)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _xception_block(inputs, depth_list, prefix, skip_connection_type, stride,
|
| 178 |
+
rate=1, depth_activation=False, return_skip=False):
|
| 179 |
+
""" Basic building block of modified Xception network
|
| 180 |
+
Args:
|
| 181 |
+
inputs: input tensor
|
| 182 |
+
depth_list: number of filters in each SepConv layer. len(depth_list) == 3
|
| 183 |
+
prefix: prefix before name
|
| 184 |
+
skip_connection_type: one of {'conv','sum','none'}
|
| 185 |
+
stride: stride at last depthwise conv
|
| 186 |
+
rate: atrous rate for depthwise convolution
|
| 187 |
+
depth_activation: flag to use activation between depthwise & pointwise convs
|
| 188 |
+
return_skip: flag to return additional tensor after 2 SepConvs for decoder
|
| 189 |
+
"""
|
| 190 |
+
residual = inputs
|
| 191 |
+
for i in range(3):
|
| 192 |
+
residual = SepConv_BN(residual,
|
| 193 |
+
depth_list[i],
|
| 194 |
+
prefix + '_separable_conv{}'.format(i + 1),
|
| 195 |
+
stride=stride if i == 2 else 1,
|
| 196 |
+
rate=rate,
|
| 197 |
+
depth_activation=depth_activation)
|
| 198 |
+
if i == 1:
|
| 199 |
+
skip = residual
|
| 200 |
+
if skip_connection_type == 'conv':
|
| 201 |
+
shortcut = _conv2d_same(inputs, depth_list[-1], prefix + '_shortcut',
|
| 202 |
+
kernel_size=1,
|
| 203 |
+
stride=stride)
|
| 204 |
+
shortcut = BatchNormalization(name=prefix + '_shortcut_BN')(shortcut)
|
| 205 |
+
outputs = layers.add([residual, shortcut])
|
| 206 |
+
elif skip_connection_type == 'sum':
|
| 207 |
+
outputs = layers.add([residual, inputs])
|
| 208 |
+
elif skip_connection_type == 'none':
|
| 209 |
+
outputs = residual
|
| 210 |
+
if return_skip:
|
| 211 |
+
return outputs, skip
|
| 212 |
+
else:
|
| 213 |
+
return outputs
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def relu6(x):
|
| 217 |
+
return K.relu(x, max_value=6)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _make_divisible(v, divisor, min_value=None):
|
| 221 |
+
if min_value is None:
|
| 222 |
+
min_value = divisor
|
| 223 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 224 |
+
# Make sure that round down does not go down by more than 10%.
|
| 225 |
+
if new_v < 0.9 * v:
|
| 226 |
+
new_v += divisor
|
| 227 |
+
return new_v
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, skip_connection, rate=1):
|
| 231 |
+
in_channels = inputs.shape[-1]
|
| 232 |
+
pointwise_conv_filters = int(filters * alpha)
|
| 233 |
+
pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
|
| 234 |
+
x = inputs
|
| 235 |
+
prefix = 'expanded_conv_{}_'.format(block_id)
|
| 236 |
+
if block_id:
|
| 237 |
+
# Expand
|
| 238 |
+
|
| 239 |
+
x = Conv2D(expansion * in_channels, kernel_size=1, padding='same',
|
| 240 |
+
use_bias=False, activation=None,
|
| 241 |
+
name=prefix + 'expand')(x)
|
| 242 |
+
x = BatchNormalization(epsilon=1e-3, momentum=0.999,
|
| 243 |
+
name=prefix + 'expand_BN')(x)
|
| 244 |
+
x = Activation(relu6, name=prefix + 'expand_relu')(x)
|
| 245 |
+
else:
|
| 246 |
+
prefix = 'expanded_conv_'
|
| 247 |
+
# Depthwise
|
| 248 |
+
x = DepthwiseConv2D(kernel_size=3, strides=stride, activation=None,
|
| 249 |
+
use_bias=False, padding='same', dilation_rate=(rate, rate),
|
| 250 |
+
name=prefix + 'depthwise')(x)
|
| 251 |
+
x = BatchNormalization(epsilon=1e-3, momentum=0.999,
|
| 252 |
+
name=prefix + 'depthwise_BN')(x)
|
| 253 |
+
|
| 254 |
+
x = Activation(relu6, name=prefix + 'depthwise_relu')(x)
|
| 255 |
+
|
| 256 |
+
# Project
|
| 257 |
+
x = Conv2D(pointwise_filters,
|
| 258 |
+
kernel_size=1, padding='same', use_bias=False, activation=None,
|
| 259 |
+
name=prefix + 'project')(x)
|
| 260 |
+
x = BatchNormalization(epsilon=1e-3, momentum=0.999,
|
| 261 |
+
name=prefix + 'project_BN')(x)
|
| 262 |
+
|
| 263 |
+
if skip_connection:
|
| 264 |
+
return Add(name=prefix + 'add')([inputs, x])
|
| 265 |
+
|
| 266 |
+
# if in_channels == pointwise_filters and stride == 1:
|
| 267 |
+
# return Add(name='res_connect_' + str(block_id))([inputs, x])
|
| 268 |
+
|
| 269 |
+
return x
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3), classes=21, backbone='mobilenetv2'
|
| 273 |
+
, OS=16, alpha=1.):
|
| 274 |
+
""" Instantiates the Deeplabv3+ architecture
|
| 275 |
+
|
| 276 |
+
Optionally loads weights pre-trained
|
| 277 |
+
on PASCAL VOC. This model is available for TensorFlow only,
|
| 278 |
+
and can only be used with inputs following the TensorFlow
|
| 279 |
+
data format `(width, height, channels)`.
|
| 280 |
+
# Arguments
|
| 281 |
+
weights: one of 'pascal_voc' (pre-trained on pascal voc)
|
| 282 |
+
or None (random initialization)
|
| 283 |
+
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
|
| 284 |
+
to use as image input for the model.
|
| 285 |
+
input_shape: shape of input image. format HxWxC
|
| 286 |
+
PASCAL VOC model was trained on (512,512,3) images
|
| 287 |
+
classes: number of desired classes. If classes != 21,
|
| 288 |
+
last layer is initialized randomly
|
| 289 |
+
backbone: backbone to use. one of {'xception','mobilenetv2'}
|
| 290 |
+
OS: determines input_shape/feature_extractor_output ratio. One of {8,16}.
|
| 291 |
+
Used only for xception backbone.
|
| 292 |
+
alpha: controls the width of the MobileNetV2 network. This is known as the
|
| 293 |
+
width multiplier in the MobileNetV2 paper.
|
| 294 |
+
- If `alpha` < 1.0, proportionally decreases the number
|
| 295 |
+
of filters in each layer.
|
| 296 |
+
- If `alpha` > 1.0, proportionally increases the number
|
| 297 |
+
of filters in each layer.
|
| 298 |
+
- If `alpha` = 1, default number of filters from the paper
|
| 299 |
+
are used at each layer.
|
| 300 |
+
Used only for mobilenetv2 backbone
|
| 301 |
+
|
| 302 |
+
# Returns
|
| 303 |
+
A Keras model instance.
|
| 304 |
+
|
| 305 |
+
# Raises
|
| 306 |
+
RuntimeError: If attempting to run this model with a
|
| 307 |
+
backend that does not support separable convolutions.
|
| 308 |
+
ValueError: in case of invalid argument for `weights` or `backbone`
|
| 309 |
+
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
if not (weights in {'pascal_voc', 'cityscapes', None}):
|
| 313 |
+
raise ValueError('The `weights` argument should be either '
|
| 314 |
+
'`None` (random initialization), `pascal_voc`, or `cityscapes` '
|
| 315 |
+
'(pre-trained on PASCAL VOC)')
|
| 316 |
+
|
| 317 |
+
if K.backend() != 'tensorflow':
|
| 318 |
+
raise RuntimeError('The Deeplabv3+ model is only available with '
|
| 319 |
+
'the TensorFlow backend.')
|
| 320 |
+
|
| 321 |
+
if not (backbone in {'xception', 'mobilenetv2'}):
|
| 322 |
+
raise ValueError('The `backbone` argument should be either '
|
| 323 |
+
'`xception` or `mobilenetv2` ')
|
| 324 |
+
|
| 325 |
+
if input_tensor is None:
|
| 326 |
+
img_input = Input(shape=input_shape)
|
| 327 |
+
else:
|
| 328 |
+
if not K.is_keras_tensor(input_tensor):
|
| 329 |
+
# Input layer
|
| 330 |
+
img_input = Input(tensor=input_tensor, shape=input_shape)
|
| 331 |
+
else:
|
| 332 |
+
img_input = input_tensor
|
| 333 |
+
|
| 334 |
+
if backbone == 'xception':
|
| 335 |
+
if OS == 8:
|
| 336 |
+
entry_block3_stride = 1
|
| 337 |
+
middle_block_rate = 2 # ! Not mentioned in paper, but required
|
| 338 |
+
exit_block_rates = (2, 4)
|
| 339 |
+
atrous_rates = (12, 24, 36)
|
| 340 |
+
else:
|
| 341 |
+
entry_block3_stride = 2
|
| 342 |
+
middle_block_rate = 1
|
| 343 |
+
exit_block_rates = (1, 2)
|
| 344 |
+
atrous_rates = (6, 12, 18)
|
| 345 |
+
|
| 346 |
+
x = Conv2D(32, (3, 3), strides=(2, 2),
|
| 347 |
+
name='entry_flow_conv1_1', use_bias=False, padding='same')(img_input)
|
| 348 |
+
x = BatchNormalization(name='entry_flow_conv1_1_BN')(x)
|
| 349 |
+
x = Activation('relu')(x)
|
| 350 |
+
|
| 351 |
+
x = _conv2d_same(x, 64, 'entry_flow_conv1_2', kernel_size=3, stride=1)
|
| 352 |
+
x = BatchNormalization(name='entry_flow_conv1_2_BN')(x)
|
| 353 |
+
x = Activation('relu')(x)
|
| 354 |
+
|
| 355 |
+
x = _xception_block(x, [128, 128, 128], 'entry_flow_block1',
|
| 356 |
+
skip_connection_type='conv', stride=2,
|
| 357 |
+
depth_activation=False)
|
| 358 |
+
x, skip1 = _xception_block(x, [256, 256, 256], 'entry_flow_block2',
|
| 359 |
+
skip_connection_type='conv', stride=2,
|
| 360 |
+
depth_activation=False, return_skip=True)
|
| 361 |
+
|
| 362 |
+
x = _xception_block(x, [728, 728, 728], 'entry_flow_block3',
|
| 363 |
+
skip_connection_type='conv', stride=entry_block3_stride,
|
| 364 |
+
depth_activation=False)
|
| 365 |
+
for i in range(16):
|
| 366 |
+
x = _xception_block(x, [728, 728, 728], 'middle_flow_unit_{}'.format(i + 1),
|
| 367 |
+
skip_connection_type='sum', stride=1, rate=middle_block_rate,
|
| 368 |
+
depth_activation=False)
|
| 369 |
+
|
| 370 |
+
x = _xception_block(x, [728, 1024, 1024], 'exit_flow_block1',
|
| 371 |
+
skip_connection_type='conv', stride=1, rate=exit_block_rates[0],
|
| 372 |
+
depth_activation=False)
|
| 373 |
+
x = _xception_block(x, [1536, 1536, 2048], 'exit_flow_block2',
|
| 374 |
+
skip_connection_type='none', stride=1, rate=exit_block_rates[1],
|
| 375 |
+
depth_activation=True)
|
| 376 |
+
|
| 377 |
+
else:
|
| 378 |
+
OS = 8
|
| 379 |
+
first_block_filters = _make_divisible(32 * alpha, 8)
|
| 380 |
+
x = Conv2D(first_block_filters,
|
| 381 |
+
kernel_size=3,
|
| 382 |
+
strides=(2, 2), padding='same',
|
| 383 |
+
use_bias=False, name='Conv')(img_input)
|
| 384 |
+
x = BatchNormalization(
|
| 385 |
+
epsilon=1e-3, momentum=0.999, name='Conv_BN')(x)
|
| 386 |
+
x = Activation(relu6, name='Conv_Relu6')(x)
|
| 387 |
+
|
| 388 |
+
x = _inverted_res_block(x, filters=16, alpha=alpha, stride=1,
|
| 389 |
+
expansion=1, block_id=0, skip_connection=False)
|
| 390 |
+
|
| 391 |
+
x = _inverted_res_block(x, filters=24, alpha=alpha, stride=2,
|
| 392 |
+
expansion=6, block_id=1, skip_connection=False)
|
| 393 |
+
x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1,
|
| 394 |
+
expansion=6, block_id=2, skip_connection=True)
|
| 395 |
+
|
| 396 |
+
x = _inverted_res_block(x, filters=32, alpha=alpha, stride=2,
|
| 397 |
+
expansion=6, block_id=3, skip_connection=False)
|
| 398 |
+
x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1,
|
| 399 |
+
expansion=6, block_id=4, skip_connection=True)
|
| 400 |
+
x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1,
|
| 401 |
+
expansion=6, block_id=5, skip_connection=True)
|
| 402 |
+
|
| 403 |
+
# stride in block 6 changed from 2 -> 1, so we need to use rate = 2
|
| 404 |
+
x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, # 1!
|
| 405 |
+
expansion=6, block_id=6, skip_connection=False)
|
| 406 |
+
x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2,
|
| 407 |
+
expansion=6, block_id=7, skip_connection=True)
|
| 408 |
+
x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2,
|
| 409 |
+
expansion=6, block_id=8, skip_connection=True)
|
| 410 |
+
x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2,
|
| 411 |
+
expansion=6, block_id=9, skip_connection=True)
|
| 412 |
+
|
| 413 |
+
x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2,
|
| 414 |
+
expansion=6, block_id=10, skip_connection=False)
|
| 415 |
+
x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2,
|
| 416 |
+
expansion=6, block_id=11, skip_connection=True)
|
| 417 |
+
x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2,
|
| 418 |
+
expansion=6, block_id=12, skip_connection=True)
|
| 419 |
+
|
| 420 |
+
x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=2, # 1!
|
| 421 |
+
expansion=6, block_id=13, skip_connection=False)
|
| 422 |
+
x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=4,
|
| 423 |
+
expansion=6, block_id=14, skip_connection=True)
|
| 424 |
+
x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=4,
|
| 425 |
+
expansion=6, block_id=15, skip_connection=True)
|
| 426 |
+
|
| 427 |
+
x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, rate=4,
|
| 428 |
+
expansion=6, block_id=16, skip_connection=False)
|
| 429 |
+
|
| 430 |
+
# end of feature extractor
|
| 431 |
+
|
| 432 |
+
# branching for Atrous Spatial Pyramid Pooling
|
| 433 |
+
|
| 434 |
+
# Image Feature branch
|
| 435 |
+
#out_shape = int(np.ceil(input_shape[0] / OS))
|
| 436 |
+
b4 = AveragePooling2D(pool_size=(int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))))(x)
|
| 437 |
+
b4 = Conv2D(256, (1, 1), padding='same',
|
| 438 |
+
use_bias=False, name='image_pooling')(b4)
|
| 439 |
+
b4 = BatchNormalization(name='image_pooling_BN', epsilon=1e-5)(b4)
|
| 440 |
+
b4 = Activation('relu')(b4)
|
| 441 |
+
b4 = BilinearUpsampling((int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))))(b4)
|
| 442 |
+
|
| 443 |
+
# simple 1x1
|
| 444 |
+
b0 = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp0')(x)
|
| 445 |
+
b0 = BatchNormalization(name='aspp0_BN', epsilon=1e-5)(b0)
|
| 446 |
+
b0 = Activation('relu', name='aspp0_activation')(b0)
|
| 447 |
+
|
| 448 |
+
# there are only 2 branches in mobilenetV2. not sure why
|
| 449 |
+
if backbone == 'xception':
|
| 450 |
+
# rate = 6 (12)
|
| 451 |
+
b1 = SepConv_BN(x, 256, 'aspp1',
|
| 452 |
+
rate=atrous_rates[0], depth_activation=True, epsilon=1e-5)
|
| 453 |
+
# rate = 12 (24)
|
| 454 |
+
b2 = SepConv_BN(x, 256, 'aspp2',
|
| 455 |
+
rate=atrous_rates[1], depth_activation=True, epsilon=1e-5)
|
| 456 |
+
# rate = 18 (36)
|
| 457 |
+
b3 = SepConv_BN(x, 256, 'aspp3',
|
| 458 |
+
rate=atrous_rates[2], depth_activation=True, epsilon=1e-5)
|
| 459 |
+
|
| 460 |
+
# concatenate ASPP branches & project
|
| 461 |
+
x = Concatenate()([b4, b0, b1, b2, b3])
|
| 462 |
+
else:
|
| 463 |
+
x = Concatenate()([b4, b0])
|
| 464 |
+
|
| 465 |
+
x = Conv2D(256, (1, 1), padding='same',
|
| 466 |
+
use_bias=False, name='concat_projection')(x)
|
| 467 |
+
x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
|
| 468 |
+
x = Activation('relu')(x)
|
| 469 |
+
x = Dropout(0.1)(x)
|
| 470 |
+
|
| 471 |
+
# DeepLab v.3+ decoder
|
| 472 |
+
|
| 473 |
+
if backbone == 'xception':
|
| 474 |
+
# Feature projection
|
| 475 |
+
# x4 (x2) block
|
| 476 |
+
x = BilinearUpsampling(output_size=(int(np.ceil(input_shape[0] / 4)),
|
| 477 |
+
int(np.ceil(input_shape[1] / 4))))(x)
|
| 478 |
+
dec_skip1 = Conv2D(48, (1, 1), padding='same',
|
| 479 |
+
use_bias=False, name='feature_projection0')(skip1)
|
| 480 |
+
dec_skip1 = BatchNormalization(
|
| 481 |
+
name='feature_projection0_BN', epsilon=1e-5)(dec_skip1)
|
| 482 |
+
dec_skip1 = Activation('relu')(dec_skip1)
|
| 483 |
+
x = Concatenate()([x, dec_skip1])
|
| 484 |
+
x = SepConv_BN(x, 256, 'decoder_conv0',
|
| 485 |
+
depth_activation=True, epsilon=1e-5)
|
| 486 |
+
x = SepConv_BN(x, 256, 'decoder_conv1',
|
| 487 |
+
depth_activation=True, epsilon=1e-5)
|
| 488 |
+
|
| 489 |
+
# you can use it with arbitary number of classes
|
| 490 |
+
if classes == 21:
|
| 491 |
+
last_layer_name = 'logits_semantic'
|
| 492 |
+
else:
|
| 493 |
+
last_layer_name = 'custom_logits_semantic'
|
| 494 |
+
|
| 495 |
+
x = Conv2D(classes, (1, 1), padding='same', name=last_layer_name)(x)
|
| 496 |
+
x = BilinearUpsampling(output_size=(input_shape[0], input_shape[1]))(x)
|
| 497 |
+
|
| 498 |
+
# Ensure that the model takes into account
|
| 499 |
+
# any potential predecessors of `input_tensor`.
|
| 500 |
+
if input_tensor is not None:
|
| 501 |
+
inputs = get_source_inputs(input_tensor)
|
| 502 |
+
else:
|
| 503 |
+
inputs = img_input
|
| 504 |
+
|
| 505 |
+
model = Model(inputs, x, name='deeplabv3plus')
|
| 506 |
+
|
| 507 |
+
# load weights
|
| 508 |
+
|
| 509 |
+
if weights == 'pascal_voc':
|
| 510 |
+
if backbone == 'xception':
|
| 511 |
+
weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels.h5',
|
| 512 |
+
WEIGHTS_PATH_X,
|
| 513 |
+
cache_subdir='models')
|
| 514 |
+
else:
|
| 515 |
+
weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5',
|
| 516 |
+
WEIGHTS_PATH_MOBILE,
|
| 517 |
+
cache_subdir='models')
|
| 518 |
+
model.load_weights(weights_path, by_name=True)
|
| 519 |
+
elif weights == 'cityscapes':
|
| 520 |
+
if backbone == 'xception':
|
| 521 |
+
weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels_cityscapes.h5',
|
| 522 |
+
WEIGHTS_PATH_X_CS,
|
| 523 |
+
cache_subdir='models')
|
| 524 |
+
else:
|
| 525 |
+
weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels_cityscapes.h5',
|
| 526 |
+
WEIGHTS_PATH_MOBILE_CS,
|
| 527 |
+
cache_subdir='models')
|
| 528 |
+
model.load_weights(weights_path, by_name=True)
|
| 529 |
+
return model
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def preprocess_input(x):
|
| 533 |
+
"""Preprocesses a numpy array encoding a batch of images.
|
| 534 |
+
# Arguments
|
| 535 |
+
x: a 4D numpy array consists of RGB values within [0, 255].
|
| 536 |
+
# Returns
|
| 537 |
+
Input array scaled to [-1.,1.]
|
| 538 |
+
"""
|
| 539 |
+
return imagenet_utils.preprocess_input(x, mode='tf')
|
models/unets.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.models import Model
|
| 2 |
+
from keras.layers import Input
|
| 3 |
+
from keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Concatenate, UpSampling2D
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Unet2D:
|
| 7 |
+
|
| 8 |
+
def __init__(self, n_filters, input_dim_x, input_dim_y, num_channels):
|
| 9 |
+
self.input_dim_x = input_dim_x
|
| 10 |
+
self.input_dim_y = input_dim_y
|
| 11 |
+
self.n_filters = n_filters
|
| 12 |
+
self.num_channels = num_channels
|
| 13 |
+
|
| 14 |
+
def get_unet_model_5_levels(self):
|
| 15 |
+
unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
|
| 16 |
+
|
| 17 |
+
conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(unet_input)
|
| 18 |
+
conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv1)
|
| 19 |
+
conv1 = BatchNormalization()(conv1)
|
| 20 |
+
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
|
| 21 |
+
|
| 22 |
+
conv2 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(pool1)
|
| 23 |
+
conv2 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv2)
|
| 24 |
+
conv2 = BatchNormalization()(conv2)
|
| 25 |
+
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
|
| 26 |
+
|
| 27 |
+
conv3 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(pool2)
|
| 28 |
+
conv3 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv3)
|
| 29 |
+
conv3 = BatchNormalization()(conv3)
|
| 30 |
+
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
|
| 31 |
+
|
| 32 |
+
conv4 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(pool3)
|
| 33 |
+
conv4 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv4)
|
| 34 |
+
conv4 = BatchNormalization()(conv4)
|
| 35 |
+
drop4 = Dropout(0.5)(conv4)
|
| 36 |
+
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
|
| 37 |
+
|
| 38 |
+
conv5 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(pool4)
|
| 39 |
+
conv5 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(conv5)
|
| 40 |
+
conv5 = BatchNormalization()(conv5)
|
| 41 |
+
drop5 = Dropout(0.5)(conv5)
|
| 42 |
+
|
| 43 |
+
up6 = Conv2D(self.n_filters*16, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))
|
| 44 |
+
concat6 = Concatenate()([drop4, up6])
|
| 45 |
+
conv6 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(concat6)
|
| 46 |
+
conv6 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv6)
|
| 47 |
+
conv6 = BatchNormalization()(conv6)
|
| 48 |
+
|
| 49 |
+
up7 = Conv2D(self.n_filters*8, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
|
| 50 |
+
concat7 = Concatenate()([conv3, up7])
|
| 51 |
+
conv7 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(concat7)
|
| 52 |
+
conv7 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv7)
|
| 53 |
+
conv7 = BatchNormalization()(conv7)
|
| 54 |
+
|
| 55 |
+
up8 = Conv2D(self.n_filters*4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7))
|
| 56 |
+
concat8 = Concatenate()([conv2, up8])
|
| 57 |
+
conv8 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(concat8)
|
| 58 |
+
conv8 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv8)
|
| 59 |
+
conv8 = BatchNormalization()(conv8)
|
| 60 |
+
|
| 61 |
+
up9 = Conv2D(self.n_filters*2, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8))
|
| 62 |
+
concat9 = Concatenate()([conv1, up9])
|
| 63 |
+
conv9 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(concat9)
|
| 64 |
+
conv9 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv9)
|
| 65 |
+
conv9 = BatchNormalization()(conv9)
|
| 66 |
+
|
| 67 |
+
conv10 = Conv2D(3, kernel_size=1, activation='sigmoid', padding='same')(conv9)
|
| 68 |
+
|
| 69 |
+
return Model(outputs=conv10, inputs=unet_input), 'unet_model_5_levels'
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_unet_model_4_levels(self):
|
| 73 |
+
unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
|
| 74 |
+
|
| 75 |
+
conv1 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(unet_input)
|
| 76 |
+
conv1 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv1)
|
| 77 |
+
conv1 = BatchNormalization()(conv1)
|
| 78 |
+
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
|
| 79 |
+
|
| 80 |
+
conv2 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(pool1)
|
| 81 |
+
conv2 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv2)
|
| 82 |
+
conv2 = BatchNormalization()(conv2)
|
| 83 |
+
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
|
| 84 |
+
|
| 85 |
+
conv3 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(pool2)
|
| 86 |
+
conv3 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv3)
|
| 87 |
+
conv3 = BatchNormalization()(conv3)
|
| 88 |
+
drop3 = Dropout(0.5)(conv3)
|
| 89 |
+
pool3 = MaxPooling2D(pool_size=(2, 2))(drop3)
|
| 90 |
+
|
| 91 |
+
conv4 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(pool3)
|
| 92 |
+
conv4 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(conv4)
|
| 93 |
+
conv4 = BatchNormalization()(conv4)
|
| 94 |
+
drop4 = Dropout(0.5)(conv4)
|
| 95 |
+
|
| 96 |
+
up5 = Conv2D(self.n_filters*16, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop4))
|
| 97 |
+
concat5 = Concatenate()([drop3, up5])
|
| 98 |
+
conv5 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(concat5)
|
| 99 |
+
conv5 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv5)
|
| 100 |
+
conv5 = BatchNormalization()(conv5)
|
| 101 |
+
|
| 102 |
+
up6 = Conv2D(self.n_filters*8, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5))
|
| 103 |
+
concat6 = Concatenate()([conv2, up6])
|
| 104 |
+
conv6 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(concat6)
|
| 105 |
+
conv6 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv6)
|
| 106 |
+
conv6 = BatchNormalization()(conv6)
|
| 107 |
+
|
| 108 |
+
up7 = Conv2D(self.n_filters*4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
|
| 109 |
+
concat7 = Concatenate()([conv1, up7])
|
| 110 |
+
conv7 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(concat7)
|
| 111 |
+
conv7 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv7)
|
| 112 |
+
conv7 = BatchNormalization()(conv7)
|
| 113 |
+
|
| 114 |
+
conv9 = Conv2D(3, kernel_size=1, activation='sigmoid', padding='same')(conv7)
|
| 115 |
+
|
| 116 |
+
return Model(outputs=conv9, inputs=unet_input), 'unet_model_4_levels'
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_unet_model_yuanqing(self):
|
| 120 |
+
# Model inspired by https://github.com/yuanqing811/ISIC2018
|
| 121 |
+
unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
|
| 122 |
+
|
| 123 |
+
conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(unet_input)
|
| 124 |
+
conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv1)
|
| 125 |
+
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
|
| 126 |
+
|
| 127 |
+
conv2 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(pool1)
|
| 128 |
+
conv2 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv2)
|
| 129 |
+
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
|
| 130 |
+
|
| 131 |
+
conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(pool2)
|
| 132 |
+
conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv3)
|
| 133 |
+
conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv3)
|
| 134 |
+
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
|
| 135 |
+
|
| 136 |
+
conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(pool3)
|
| 137 |
+
conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv4)
|
| 138 |
+
conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv4)
|
| 139 |
+
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
|
| 140 |
+
|
| 141 |
+
conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(pool4)
|
| 142 |
+
conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv5)
|
| 143 |
+
conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv5)
|
| 144 |
+
|
| 145 |
+
up6 = Conv2D(self.n_filters * 4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5))
|
| 146 |
+
feature4 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv4)
|
| 147 |
+
concat6 = Concatenate()([feature4, up6])
|
| 148 |
+
conv6 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(concat6)
|
| 149 |
+
conv6 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv6)
|
| 150 |
+
|
| 151 |
+
up7 = Conv2D(self.n_filters * 2, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
|
| 152 |
+
feature3 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv3)
|
| 153 |
+
concat7 = Concatenate()([feature3, up7])
|
| 154 |
+
conv7 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(concat7)
|
| 155 |
+
conv7 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv7)
|
| 156 |
+
|
| 157 |
+
up8 = Conv2D(self.n_filters * 1, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7))
|
| 158 |
+
feature2 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(conv2)
|
| 159 |
+
concat8 = Concatenate()([feature2, up8])
|
| 160 |
+
conv8 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(concat8)
|
| 161 |
+
conv8 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(conv8)
|
| 162 |
+
|
| 163 |
+
up9 = Conv2D(int(self.n_filters / 2), 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8))
|
| 164 |
+
feature1 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(conv1)
|
| 165 |
+
concat9 = Concatenate()([feature1, up9])
|
| 166 |
+
conv9 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(concat9)
|
| 167 |
+
conv9 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(conv9)
|
| 168 |
+
conv9 = Conv2D(3, kernel_size=3, activation='relu', padding='same')(conv9)
|
| 169 |
+
conv10 = Conv2D(1, kernel_size=1, activation='sigmoid')(conv9)
|
| 170 |
+
|
| 171 |
+
return Model(outputs=conv10, inputs=unet_input), 'unet_model_yuanqing'
|
requirements.txt
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
anyio==4.10.0
|
| 5 |
+
asttokens==3.0.0
|
| 6 |
+
astunparse==1.6.3
|
| 7 |
+
attrs==25.3.0
|
| 8 |
+
beautifulsoup4==4.13.4
|
| 9 |
+
blinker==1.9.0
|
| 10 |
+
Brotli==1.1.0
|
| 11 |
+
cachetools==5.5.2
|
| 12 |
+
certifi==2025.8.3
|
| 13 |
+
charset-normalizer==3.4.2
|
| 14 |
+
click==8.2.1
|
| 15 |
+
colorama==0.4.6
|
| 16 |
+
comm==0.2.3
|
| 17 |
+
ConfigArgParse==1.7.1
|
| 18 |
+
contourpy==1.3.2
|
| 19 |
+
cycler==0.12.1
|
| 20 |
+
dash==3.2.0
|
| 21 |
+
decorator==5.2.1
|
| 22 |
+
exceptiongroup==1.3.0
|
| 23 |
+
executing==2.2.0
|
| 24 |
+
fastapi==0.116.1
|
| 25 |
+
fastjsonschema==2.21.1
|
| 26 |
+
ffmpy==0.6.1
|
| 27 |
+
filelock==3.18.0
|
| 28 |
+
Flask==3.1.1
|
| 29 |
+
flatbuffers==25.2.10
|
| 30 |
+
fonttools==4.59.0
|
| 31 |
+
fsspec==2025.7.0
|
| 32 |
+
gast==0.4.0
|
| 33 |
+
google-generativeai
|
| 34 |
+
gdown==5.2.0
|
| 35 |
+
google-auth==2.40.3
|
| 36 |
+
google-auth-oauthlib==0.4.6
|
| 37 |
+
google-pasta==0.2.0
|
| 38 |
+
gradio==5.41.1
|
| 39 |
+
gradio_client==1.11.0
|
| 40 |
+
gradio_imageslider==0.0.20
|
| 41 |
+
groovy==0.1.2
|
| 42 |
+
grpcio==1.74.0
|
| 43 |
+
h11==0.16.0
|
| 44 |
+
h5py==3.14.0
|
| 45 |
+
httpcore==1.0.9
|
| 46 |
+
httpx==0.28.1
|
| 47 |
+
huggingface-hub==0.34.3
|
| 48 |
+
idna==3.10
|
| 49 |
+
imageio==2.37.0
|
| 50 |
+
importlib_metadata==8.7.0
|
| 51 |
+
ipython==8.37.0
|
| 52 |
+
ipywidgets==8.1.7
|
| 53 |
+
itsdangerous==2.2.0
|
| 54 |
+
jedi==0.19.2
|
| 55 |
+
Jinja2==3.1.6
|
| 56 |
+
jsonschema==4.25.0
|
| 57 |
+
jsonschema-specifications==2025.4.1
|
| 58 |
+
jupyter_core==5.8.1
|
| 59 |
+
jupyterlab_widgets==3.0.15
|
| 60 |
+
keras==2.10.0
|
| 61 |
+
Keras-Preprocessing==1.1.2
|
| 62 |
+
kiwisolver==1.4.8
|
| 63 |
+
lazy_loader==0.4
|
| 64 |
+
libclang==18.1.1
|
| 65 |
+
Markdown==3.8.2
|
| 66 |
+
markdown-it-py==3.0.0
|
| 67 |
+
MarkupSafe==3.0.2
|
| 68 |
+
matplotlib==3.10.5
|
| 69 |
+
matplotlib-inline==0.1.7
|
| 70 |
+
mdurl==0.1.2
|
| 71 |
+
mpmath==1.3.0
|
| 72 |
+
narwhals==2.0.1
|
| 73 |
+
nbformat==5.10.4
|
| 74 |
+
nest-asyncio==1.6.0
|
| 75 |
+
networkx==3.4.2
|
| 76 |
+
numpy==1.26.4
|
| 77 |
+
oauthlib==3.3.1
|
| 78 |
+
open3d==0.19.0
|
| 79 |
+
opencv-python==4.11.0.86
|
| 80 |
+
opt_einsum==3.4.0
|
| 81 |
+
orjson==3.11.1
|
| 82 |
+
packaging==25.0
|
| 83 |
+
pandas==2.3.1
|
| 84 |
+
parso==0.8.4
|
| 85 |
+
pillow==11.3.0
|
| 86 |
+
platformdirs==4.3.8
|
| 87 |
+
plotly==6.2.0
|
| 88 |
+
prompt_toolkit==3.0.51
|
| 89 |
+
protobuf==3.19.6
|
| 90 |
+
psutil==5.9.8
|
| 91 |
+
pure_eval==0.2.3
|
| 92 |
+
pyasn1==0.6.1
|
| 93 |
+
pyasn1_modules==0.4.2
|
| 94 |
+
pydantic==2.10.6
|
| 95 |
+
pydantic_core==2.27.2
|
| 96 |
+
pydub==0.25.1
|
| 97 |
+
Pygments==2.19.2
|
| 98 |
+
pyparsing==3.2.3
|
| 99 |
+
PySocks==1.7.1
|
| 100 |
+
python-dateutil==2.9.0.post0
|
| 101 |
+
python-multipart==0.0.20
|
| 102 |
+
pytz==2025.2
|
| 103 |
+
PyYAML==6.0.2
|
| 104 |
+
referencing==0.36.2
|
| 105 |
+
requests==2.32.4
|
| 106 |
+
requests-oauthlib==2.0.0
|
| 107 |
+
retrying==1.4.2
|
| 108 |
+
rich==14.1.0
|
| 109 |
+
rpds-py==0.27.0
|
| 110 |
+
rsa==4.9.1
|
| 111 |
+
ruff==0.12.7
|
| 112 |
+
safehttpx==0.1.6
|
| 113 |
+
scikit-image==0.25.2
|
| 114 |
+
scipy==1.15.3
|
| 115 |
+
semantic-version==2.10.0
|
| 116 |
+
shellingham==1.5.4
|
| 117 |
+
six==1.17.0
|
| 118 |
+
sniffio==1.3.1
|
| 119 |
+
soupsieve==2.7
|
| 120 |
+
spaces==0.39.0
|
| 121 |
+
stack-data==0.6.3
|
| 122 |
+
starlette==0.47.2
|
| 123 |
+
sympy==1.14.0
|
| 124 |
+
tensorboard==2.10.1
|
| 125 |
+
tensorboard-data-server==0.6.1
|
| 126 |
+
tensorboard-plugin-wit==1.8.1
|
| 127 |
+
tensorflow==2.10.1
|
| 128 |
+
tensorflow-estimator==2.10.0
|
| 129 |
+
tensorflow-hub==0.16.1
|
| 130 |
+
tensorflow-io-gcs-filesystem==0.31.0
|
| 131 |
+
termcolor==3.1.0
|
| 132 |
+
tf-keras==2.15.0
|
| 133 |
+
tifffile==2025.5.10
|
| 134 |
+
tomlkit==0.13.3
|
| 135 |
+
torch==2.8.0
|
| 136 |
+
torchvision==0.23.0
|
| 137 |
+
tqdm==4.67.1
|
| 138 |
+
traitlets==5.14.3
|
| 139 |
+
typer==0.16.0
|
| 140 |
+
typing-inspection==0.4.1
|
| 141 |
+
typing_extensions==4.14.1
|
| 142 |
+
tzdata==2025.2
|
| 143 |
+
urllib3==2.5.0
|
| 144 |
+
uvicorn==0.35.0
|
| 145 |
+
wcwidth==0.2.13
|
| 146 |
+
websockets==15.0.1
|
| 147 |
+
Werkzeug==3.1.3
|
| 148 |
+
widgetsnbextension==4.0.14
|
| 149 |
+
wrapt==1.17.2
|
| 150 |
+
zipp==3.23.0
|
| 151 |
+
transformers
|
temp_files/Final_workig_cpu.txt
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import tempfile
|
| 8 |
+
from gradio_imageslider import ImageSlider
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
| 13 |
+
import os
|
| 14 |
+
import tensorflow as tf
|
| 15 |
+
from tensorflow.keras.models import load_model
|
| 16 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
| 17 |
+
import base64
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import gdown
|
| 20 |
+
import spaces
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
# Import actual segmentation model components
|
| 24 |
+
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
|
| 25 |
+
from utils.learning.metrics import dice_coef, precision, recall
|
| 26 |
+
from utils.io.data import normalize
|
| 27 |
+
|
| 28 |
+
# Define path and file ID
|
| 29 |
+
checkpoint_dir = "checkpoints"
|
| 30 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
|
| 33 |
+
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
|
| 34 |
+
|
| 35 |
+
# Download if not already present
|
| 36 |
+
if not os.path.exists(model_file):
|
| 37 |
+
print("Downloading model from Google Drive...")
|
| 38 |
+
gdown.download(gdrive_url, model_file, quiet=False)
|
| 39 |
+
|
| 40 |
+
# --- TensorFlow: Check GPU Availability ---
|
| 41 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 42 |
+
if gpus:
|
| 43 |
+
print("TensorFlow is using GPU")
|
| 44 |
+
else:
|
| 45 |
+
print("TensorFlow is using CPU")
|
| 46 |
+
|
| 47 |
+
# --- Load Wound Classification Model and Class Labels ---
|
| 48 |
+
wound_model = load_model("keras_model.h5")
|
| 49 |
+
with open("labels.txt", "r") as f:
|
| 50 |
+
class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
|
| 51 |
+
|
| 52 |
+
# --- Load Actual Wound Segmentation Model ---
|
| 53 |
+
class WoundSegmentationModel:
|
| 54 |
+
def __init__(self):
|
| 55 |
+
self.input_dim_x = 224
|
| 56 |
+
self.input_dim_y = 224
|
| 57 |
+
self.model = None
|
| 58 |
+
self.load_model()
|
| 59 |
+
|
| 60 |
+
def load_model(self):
|
| 61 |
+
"""Load the trained wound segmentation model"""
|
| 62 |
+
try:
|
| 63 |
+
# Try to load the most recent model
|
| 64 |
+
weight_file_name = '2025-08-07_16-25-27.hdf5'
|
| 65 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 66 |
+
|
| 67 |
+
self.model = load_model(model_path,
|
| 68 |
+
custom_objects={
|
| 69 |
+
'recall': recall,
|
| 70 |
+
'precision': precision,
|
| 71 |
+
'dice_coef': dice_coef,
|
| 72 |
+
'relu6': relu6,
|
| 73 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 74 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 75 |
+
})
|
| 76 |
+
print(f"Segmentation model loaded successfully from {model_path}")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Error loading segmentation model: {e}")
|
| 79 |
+
# Fallback to the older model
|
| 80 |
+
try:
|
| 81 |
+
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
|
| 82 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 83 |
+
|
| 84 |
+
self.model = load_model(model_path,
|
| 85 |
+
custom_objects={
|
| 86 |
+
'recall': recall,
|
| 87 |
+
'precision': precision,
|
| 88 |
+
'dice_coef': dice_coef,
|
| 89 |
+
'relu6': relu6,
|
| 90 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 91 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 92 |
+
})
|
| 93 |
+
print(f"Segmentation model loaded successfully from {model_path}")
|
| 94 |
+
except Exception as e2:
|
| 95 |
+
print(f"Error loading fallback segmentation model: {e2}")
|
| 96 |
+
self.model = None
|
| 97 |
+
|
| 98 |
+
def preprocess_image(self, image):
|
| 99 |
+
"""Preprocess the uploaded image for model input"""
|
| 100 |
+
if image is None:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Convert to RGB if needed
|
| 104 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 105 |
+
# Convert BGR to RGB if needed
|
| 106 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 107 |
+
|
| 108 |
+
# Resize to model input size
|
| 109 |
+
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
|
| 110 |
+
|
| 111 |
+
# Normalize the image
|
| 112 |
+
image = image.astype(np.float32) / 255.0
|
| 113 |
+
|
| 114 |
+
# Add batch dimension
|
| 115 |
+
image = np.expand_dims(image, axis=0)
|
| 116 |
+
|
| 117 |
+
return image
|
| 118 |
+
|
| 119 |
+
def postprocess_prediction(self, prediction):
|
| 120 |
+
"""Postprocess the model prediction"""
|
| 121 |
+
# Remove batch dimension
|
| 122 |
+
prediction = prediction[0]
|
| 123 |
+
|
| 124 |
+
# Apply threshold to get binary mask
|
| 125 |
+
threshold = 0.5
|
| 126 |
+
binary_mask = (prediction > threshold).astype(np.uint8) * 255
|
| 127 |
+
|
| 128 |
+
return binary_mask
|
| 129 |
+
|
| 130 |
+
def segment_wound(self, input_image):
|
| 131 |
+
"""Main function to segment wound from uploaded image"""
|
| 132 |
+
if self.model is None:
|
| 133 |
+
return None, "Error: Segmentation model not loaded. Please check the model files."
|
| 134 |
+
|
| 135 |
+
if input_image is None:
|
| 136 |
+
return None, "Please upload an image."
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
# Preprocess the image
|
| 140 |
+
processed_image = self.preprocess_image(input_image)
|
| 141 |
+
|
| 142 |
+
if processed_image is None:
|
| 143 |
+
return None, "Error processing image."
|
| 144 |
+
|
| 145 |
+
# Make prediction
|
| 146 |
+
prediction = self.model.predict(processed_image, verbose=0)
|
| 147 |
+
|
| 148 |
+
# Postprocess the prediction
|
| 149 |
+
segmented_mask = self.postprocess_prediction(prediction)
|
| 150 |
+
|
| 151 |
+
return segmented_mask, "Segmentation completed successfully!"
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
return None, f"Error during segmentation: {str(e)}"
|
| 155 |
+
|
| 156 |
+
# Initialize the segmentation model
|
| 157 |
+
segmentation_model = WoundSegmentationModel()
|
| 158 |
+
|
| 159 |
+
# --- PyTorch: Set Device and Load Depth Model ---
|
| 160 |
+
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
|
| 161 |
+
print(f"Using PyTorch device: {map_device}")
|
| 162 |
+
|
| 163 |
+
model_configs = {
|
| 164 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 165 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 166 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 167 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
| 168 |
+
}
|
| 169 |
+
encoder = 'vitl'
|
| 170 |
+
depth_model = DepthAnythingV2(**model_configs[encoder])
|
| 171 |
+
state_dict = torch.load(
|
| 172 |
+
f'checkpoints/depth_anything_v2_{encoder}.pth',
|
| 173 |
+
map_location=map_device
|
| 174 |
+
)
|
| 175 |
+
depth_model.load_state_dict(state_dict)
|
| 176 |
+
depth_model = depth_model.to(map_device).eval()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# --- Custom CSS for unified dark theme ---
|
| 180 |
+
css = """
|
| 181 |
+
.gradio-container {
|
| 182 |
+
font-family: 'Segoe UI', sans-serif;
|
| 183 |
+
background-color: #121212;
|
| 184 |
+
color: #ffffff;
|
| 185 |
+
padding: 20px;
|
| 186 |
+
}
|
| 187 |
+
.gr-button {
|
| 188 |
+
background-color: #2c3e50;
|
| 189 |
+
color: white;
|
| 190 |
+
border-radius: 10px;
|
| 191 |
+
}
|
| 192 |
+
.gr-button:hover {
|
| 193 |
+
background-color: #34495e;
|
| 194 |
+
}
|
| 195 |
+
.gr-html, .gr-html div {
|
| 196 |
+
white-space: normal !important;
|
| 197 |
+
overflow: visible !important;
|
| 198 |
+
text-overflow: unset !important;
|
| 199 |
+
word-break: break-word !important;
|
| 200 |
+
}
|
| 201 |
+
#img-display-container {
|
| 202 |
+
max-height: 100vh;
|
| 203 |
+
}
|
| 204 |
+
#img-display-input {
|
| 205 |
+
max-height: 80vh;
|
| 206 |
+
}
|
| 207 |
+
#img-display-output {
|
| 208 |
+
max-height: 80vh;
|
| 209 |
+
}
|
| 210 |
+
#download {
|
| 211 |
+
height: 62px;
|
| 212 |
+
}
|
| 213 |
+
h1 {
|
| 214 |
+
text-align: center;
|
| 215 |
+
font-size: 3rem;
|
| 216 |
+
font-weight: bold;
|
| 217 |
+
margin: 2rem 0;
|
| 218 |
+
color: #ffffff;
|
| 219 |
+
}
|
| 220 |
+
h2 {
|
| 221 |
+
color: #ffffff;
|
| 222 |
+
text-align: center;
|
| 223 |
+
margin: 1rem 0;
|
| 224 |
+
}
|
| 225 |
+
.gr-tabs {
|
| 226 |
+
background-color: #1e1e1e;
|
| 227 |
+
border-radius: 10px;
|
| 228 |
+
padding: 10px;
|
| 229 |
+
}
|
| 230 |
+
.gr-tab-nav {
|
| 231 |
+
background-color: #2c3e50;
|
| 232 |
+
border-radius: 8px;
|
| 233 |
+
}
|
| 234 |
+
.gr-tab-nav button {
|
| 235 |
+
color: #ffffff !important;
|
| 236 |
+
}
|
| 237 |
+
.gr-tab-nav button.selected {
|
| 238 |
+
background-color: #34495e !important;
|
| 239 |
+
}
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
# --- Wound Classification Functions ---
|
| 243 |
+
def preprocess_input(img):
|
| 244 |
+
img = img.resize((224, 224))
|
| 245 |
+
arr = keras_image.img_to_array(img)
|
| 246 |
+
arr = arr / 255.0
|
| 247 |
+
return np.expand_dims(arr, axis=0)
|
| 248 |
+
|
| 249 |
+
def get_reasoning_from_gemini(img, prediction):
|
| 250 |
+
try:
|
| 251 |
+
# For now, return a simple explanation without Gemini API to avoid typing issues
|
| 252 |
+
# In production, you would implement the proper Gemini API call here
|
| 253 |
+
explanations = {
|
| 254 |
+
"Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
|
| 255 |
+
"Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
|
| 256 |
+
"Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
|
| 257 |
+
"Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
|
| 258 |
+
"Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
return f"(Reasoning unavailable: {str(e)})"
|
| 265 |
+
|
| 266 |
+
@spaces.GPU
|
| 267 |
+
def classify_wound_image(img):
|
| 268 |
+
if img is None:
|
| 269 |
+
return "<div style='color:#ff5252; font-size:18px;'>No image provided</div>", ""
|
| 270 |
+
|
| 271 |
+
img_array = preprocess_input(img)
|
| 272 |
+
predictions = wound_model.predict(img_array, verbose=0)[0]
|
| 273 |
+
pred_idx = int(np.argmax(predictions))
|
| 274 |
+
pred_class = class_labels[pred_idx]
|
| 275 |
+
|
| 276 |
+
# Get reasoning from Gemini
|
| 277 |
+
reasoning_text = get_reasoning_from_gemini(img, pred_class)
|
| 278 |
+
|
| 279 |
+
# Prediction Card
|
| 280 |
+
predicted_card = f"""
|
| 281 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 282 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 283 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 284 |
+
Predicted Wound Type
|
| 285 |
+
</div>
|
| 286 |
+
<div style='font-size: 26px; color: white;'>
|
| 287 |
+
{pred_class}
|
| 288 |
+
</div>
|
| 289 |
+
</div>
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
# Reasoning Card
|
| 293 |
+
reasoning_card = f"""
|
| 294 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 295 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 296 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 297 |
+
Reasoning
|
| 298 |
+
</div>
|
| 299 |
+
<div style='font-size: 16px; color: white; min-height: 80px;'>
|
| 300 |
+
{reasoning_text}
|
| 301 |
+
</div>
|
| 302 |
+
</div>
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
return predicted_card, reasoning_card
|
| 306 |
+
|
| 307 |
+
# --- Wound Severity Estimation Functions ---
|
| 308 |
+
@spaces.GPU
|
| 309 |
+
def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
|
| 310 |
+
"""Compute area statistics for different depth regions"""
|
| 311 |
+
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
|
| 312 |
+
|
| 313 |
+
# Extract only wound region
|
| 314 |
+
wound_mask = (mask > 127)
|
| 315 |
+
wound_depths = depth_map[wound_mask]
|
| 316 |
+
total_area = np.sum(wound_mask) * pixel_area_cm2
|
| 317 |
+
|
| 318 |
+
# Categorize depth regions
|
| 319 |
+
shallow = wound_depths < 3
|
| 320 |
+
moderate = (wound_depths >= 3) & (wound_depths < 6)
|
| 321 |
+
deep = wound_depths >= 6
|
| 322 |
+
|
| 323 |
+
shallow_area = np.sum(shallow) * pixel_area_cm2
|
| 324 |
+
moderate_area = np.sum(moderate) * pixel_area_cm2
|
| 325 |
+
deep_area = np.sum(deep) * pixel_area_cm2
|
| 326 |
+
|
| 327 |
+
deep_ratio = deep_area / total_area if total_area > 0 else 0
|
| 328 |
+
|
| 329 |
+
return {
|
| 330 |
+
'total_area_cm2': total_area,
|
| 331 |
+
'shallow_area_cm2': shallow_area,
|
| 332 |
+
'moderate_area_cm2': moderate_area,
|
| 333 |
+
'deep_area_cm2': deep_area,
|
| 334 |
+
'deep_ratio': deep_ratio,
|
| 335 |
+
'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def classify_wound_severity_by_area(depth_stats):
|
| 339 |
+
"""Classify wound severity based on area and depth distribution"""
|
| 340 |
+
total = depth_stats['total_area_cm2']
|
| 341 |
+
deep = depth_stats['deep_area_cm2']
|
| 342 |
+
moderate = depth_stats['moderate_area_cm2']
|
| 343 |
+
|
| 344 |
+
if total == 0:
|
| 345 |
+
return "Unknown"
|
| 346 |
+
|
| 347 |
+
# Severity classification rules
|
| 348 |
+
if deep > 2 or (deep / total) > 0.3:
|
| 349 |
+
return "Severe"
|
| 350 |
+
elif moderate > 1.5 or (moderate / total) > 0.4:
|
| 351 |
+
return "Moderate"
|
| 352 |
+
else:
|
| 353 |
+
return "Mild"
|
| 354 |
+
|
| 355 |
+
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
|
| 356 |
+
"""Analyze wound severity from depth map and wound mask"""
|
| 357 |
+
if image is None or depth_map is None or wound_mask is None:
|
| 358 |
+
return "β Please upload image, depth map, and wound mask."
|
| 359 |
+
|
| 360 |
+
# Convert wound mask to grayscale if needed
|
| 361 |
+
if len(wound_mask.shape) == 3:
|
| 362 |
+
wound_mask = np.mean(wound_mask, axis=2)
|
| 363 |
+
|
| 364 |
+
# Ensure depth map and mask have same dimensions
|
| 365 |
+
if depth_map.shape[:2] != wound_mask.shape[:2]:
|
| 366 |
+
# Resize mask to match depth map
|
| 367 |
+
from PIL import Image
|
| 368 |
+
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
|
| 369 |
+
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
|
| 370 |
+
wound_mask = np.array(mask_pil)
|
| 371 |
+
|
| 372 |
+
# Compute statistics
|
| 373 |
+
stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
|
| 374 |
+
severity = classify_wound_severity_by_area(stats)
|
| 375 |
+
|
| 376 |
+
# Create severity report with color coding
|
| 377 |
+
severity_color = {
|
| 378 |
+
"Mild": "#4CAF50", # Green
|
| 379 |
+
"Moderate": "#FF9800", # Orange
|
| 380 |
+
"Severe": "#F44336" # Red
|
| 381 |
+
}.get(severity, "#9E9E9E") # Gray for unknown
|
| 382 |
+
|
| 383 |
+
report = f"""
|
| 384 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 385 |
+
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
|
| 386 |
+
π©Ή Wound Severity Analysis
|
| 387 |
+
</div>
|
| 388 |
+
|
| 389 |
+
<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-bottom: 20px;'>
|
| 390 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 391 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 392 |
+
π Area Measurements
|
| 393 |
+
</div>
|
| 394 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 395 |
+
<div>π’ <b>Total Area:</b> {stats['total_area_cm2']:.2f} cmΒ²</div>
|
| 396 |
+
<div>π© <b>Shallow (0-3mm):</b> {stats['shallow_area_cm2']:.2f} cmΒ²</div>
|
| 397 |
+
<div>π¨ <b>Moderate (3-6mm):</b> {stats['moderate_area_cm2']:.2f} cmΒ²</div>
|
| 398 |
+
<div>π₯ <b>Deep (>6mm):</b> {stats['deep_area_cm2']:.2f} cmΒ²</div>
|
| 399 |
+
</div>
|
| 400 |
+
</div>
|
| 401 |
+
|
| 402 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 403 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 404 |
+
π Depth Analysis
|
| 405 |
+
</div>
|
| 406 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 407 |
+
<div>π₯ <b>Deep Coverage:</b> {stats['deep_ratio']*100:.1f}%</div>
|
| 408 |
+
<div>π <b>Max Depth:</b> {stats['max_depth']:.1f} mm</div>
|
| 409 |
+
<div>β‘ <b>Pixel Spacing:</b> {pixel_spacing_mm} mm</div>
|
| 410 |
+
</div>
|
| 411 |
+
</div>
|
| 412 |
+
</div>
|
| 413 |
+
|
| 414 |
+
<div style='text-align: center; padding: 15px; background-color: #2c2c2c; border-radius: 8px; border-left: 4px solid {severity_color};'>
|
| 415 |
+
<div style='font-size: 20px; font-weight: bold; color: {severity_color};'>
|
| 416 |
+
π― Predicted Severity: {severity}
|
| 417 |
+
</div>
|
| 418 |
+
<div style='font-size: 14px; color: #cccccc; margin-top: 5px;'>
|
| 419 |
+
{get_severity_description(severity)}
|
| 420 |
+
</div>
|
| 421 |
+
</div>
|
| 422 |
+
</div>
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
return report
|
| 426 |
+
|
| 427 |
+
def get_severity_description(severity):
|
| 428 |
+
"""Get description for severity level"""
|
| 429 |
+
descriptions = {
|
| 430 |
+
"Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
|
| 431 |
+
"Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
|
| 432 |
+
"Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
|
| 433 |
+
"Unknown": "Unable to determine severity due to insufficient data."
|
| 434 |
+
}
|
| 435 |
+
return descriptions.get(severity, "Severity assessment unavailable.")
|
| 436 |
+
|
| 437 |
+
def create_sample_wound_mask(image_shape, center=None, radius=50):
|
| 438 |
+
"""Create a sample circular wound mask for testing"""
|
| 439 |
+
if center is None:
|
| 440 |
+
center = (image_shape[1] // 2, image_shape[0] // 2)
|
| 441 |
+
|
| 442 |
+
mask = np.zeros(image_shape[:2], dtype=np.uint8)
|
| 443 |
+
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
|
| 444 |
+
|
| 445 |
+
# Create circular mask
|
| 446 |
+
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
|
| 447 |
+
mask[dist_from_center <= radius] = 255
|
| 448 |
+
|
| 449 |
+
return mask
|
| 450 |
+
|
| 451 |
+
def create_realistic_wound_mask(image_shape, method='elliptical'):
|
| 452 |
+
"""Create a more realistic wound mask with irregular shapes"""
|
| 453 |
+
h, w = image_shape[:2]
|
| 454 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 455 |
+
|
| 456 |
+
if method == 'elliptical':
|
| 457 |
+
# Create elliptical wound mask
|
| 458 |
+
center = (w // 2, h // 2)
|
| 459 |
+
radius_x = min(w, h) // 3
|
| 460 |
+
radius_y = min(w, h) // 4
|
| 461 |
+
|
| 462 |
+
y, x = np.ogrid[:h, :w]
|
| 463 |
+
# Add some irregularity to make it more realistic
|
| 464 |
+
ellipse = ((x - center[0])**2 / (radius_x**2) +
|
| 465 |
+
(y - center[1])**2 / (radius_y**2)) <= 1
|
| 466 |
+
|
| 467 |
+
# Add some noise and irregularity
|
| 468 |
+
noise = np.random.random((h, w)) > 0.8
|
| 469 |
+
mask = (ellipse | noise).astype(np.uint8) * 255
|
| 470 |
+
|
| 471 |
+
elif method == 'irregular':
|
| 472 |
+
# Create irregular wound mask
|
| 473 |
+
center = (w // 2, h // 2)
|
| 474 |
+
radius = min(w, h) // 4
|
| 475 |
+
|
| 476 |
+
y, x = np.ogrid[:h, :w]
|
| 477 |
+
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
|
| 478 |
+
|
| 479 |
+
# Add irregular extensions
|
| 480 |
+
extensions = np.zeros_like(base_circle)
|
| 481 |
+
for i in range(3):
|
| 482 |
+
angle = i * 2 * np.pi / 3
|
| 483 |
+
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
|
| 484 |
+
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
|
| 485 |
+
ext_radius = radius // 3
|
| 486 |
+
|
| 487 |
+
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
|
| 488 |
+
extensions = extensions | ext_circle
|
| 489 |
+
|
| 490 |
+
mask = (base_circle | extensions).astype(np.uint8) * 255
|
| 491 |
+
|
| 492 |
+
# Apply morphological operations to smooth the mask
|
| 493 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 494 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 495 |
+
|
| 496 |
+
return mask
|
| 497 |
+
|
| 498 |
+
# --- Depth Estimation Functions ---
|
| 499 |
+
@spaces.GPU
|
| 500 |
+
def predict_depth(image):
|
| 501 |
+
return depth_model.infer_image(image)
|
| 502 |
+
|
| 503 |
+
def calculate_max_points(image):
|
| 504 |
+
"""Calculate maximum points based on image dimensions (3x pixel count)"""
|
| 505 |
+
if image is None:
|
| 506 |
+
return 10000 # Default value
|
| 507 |
+
h, w = image.shape[:2]
|
| 508 |
+
max_points = h * w * 3
|
| 509 |
+
# Ensure minimum and reasonable maximum values
|
| 510 |
+
return max(1000, min(max_points, 300000))
|
| 511 |
+
|
| 512 |
+
def update_slider_on_image_upload(image):
|
| 513 |
+
"""Update the points slider when an image is uploaded"""
|
| 514 |
+
max_points = calculate_max_points(image)
|
| 515 |
+
default_value = min(10000, max_points // 10) # 10% of max points as default
|
| 516 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
| 517 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
| 518 |
+
|
| 519 |
+
@spaces.GPU
|
| 520 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
|
| 521 |
+
"""Create a point cloud from depth map using camera intrinsics with high detail"""
|
| 522 |
+
h, w = depth_map.shape
|
| 523 |
+
|
| 524 |
+
# Use smaller step for higher detail (reduced downsampling)
|
| 525 |
+
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
|
| 526 |
+
|
| 527 |
+
# Create mesh grid for camera coordinates
|
| 528 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 529 |
+
|
| 530 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 531 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
| 532 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
| 533 |
+
|
| 534 |
+
# Get depth values
|
| 535 |
+
depth_values = depth_map[::step, ::step]
|
| 536 |
+
|
| 537 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 538 |
+
x_3d = x_cam * depth_values
|
| 539 |
+
y_3d = y_cam * depth_values
|
| 540 |
+
z_3d = depth_values
|
| 541 |
+
|
| 542 |
+
# Flatten arrays
|
| 543 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
| 544 |
+
|
| 545 |
+
# Get corresponding image colors
|
| 546 |
+
image_colors = image[::step, ::step, :]
|
| 547 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
| 548 |
+
|
| 549 |
+
# Create Open3D point cloud
|
| 550 |
+
pcd = o3d.geometry.PointCloud()
|
| 551 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 552 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 553 |
+
|
| 554 |
+
return pcd
|
| 555 |
+
|
| 556 |
+
@spaces.GPU
|
| 557 |
+
def reconstruct_surface_mesh_from_point_cloud(pcd):
|
| 558 |
+
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
|
| 559 |
+
# Estimate and orient normals with high precision
|
| 560 |
+
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
|
| 561 |
+
pcd.orient_normals_consistent_tangent_plane(k=50)
|
| 562 |
+
|
| 563 |
+
# Create surface mesh with maximum detail (depth=12 for very high resolution)
|
| 564 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
|
| 565 |
+
|
| 566 |
+
# Return mesh without filtering low-density vertices
|
| 567 |
+
return mesh
|
| 568 |
+
|
| 569 |
+
@spaces.GPU
|
| 570 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
| 571 |
+
"""Create an enhanced 3D visualization using proper camera projection"""
|
| 572 |
+
h, w = depth_map.shape
|
| 573 |
+
|
| 574 |
+
# Downsample to avoid too many points for performance
|
| 575 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
| 576 |
+
|
| 577 |
+
# Create mesh grid for camera coordinates
|
| 578 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 579 |
+
|
| 580 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 581 |
+
focal_length = 470.4 # Default focal length
|
| 582 |
+
x_cam = (x_coords - w / 2) / focal_length
|
| 583 |
+
y_cam = (y_coords - h / 2) / focal_length
|
| 584 |
+
|
| 585 |
+
# Get depth values
|
| 586 |
+
depth_values = depth_map[::step, ::step]
|
| 587 |
+
|
| 588 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 589 |
+
x_3d = x_cam * depth_values
|
| 590 |
+
y_3d = y_cam * depth_values
|
| 591 |
+
z_3d = depth_values
|
| 592 |
+
|
| 593 |
+
# Flatten arrays
|
| 594 |
+
x_flat = x_3d.flatten()
|
| 595 |
+
y_flat = y_3d.flatten()
|
| 596 |
+
z_flat = z_3d.flatten()
|
| 597 |
+
|
| 598 |
+
# Get corresponding image colors
|
| 599 |
+
image_colors = image[::step, ::step, :]
|
| 600 |
+
colors_flat = image_colors.reshape(-1, 3)
|
| 601 |
+
|
| 602 |
+
# Create 3D scatter plot with proper camera projection
|
| 603 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 604 |
+
x=x_flat,
|
| 605 |
+
y=y_flat,
|
| 606 |
+
z=z_flat,
|
| 607 |
+
mode='markers',
|
| 608 |
+
marker=dict(
|
| 609 |
+
size=1.5,
|
| 610 |
+
color=colors_flat,
|
| 611 |
+
opacity=0.9
|
| 612 |
+
),
|
| 613 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
| 614 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
| 615 |
+
'<extra></extra>'
|
| 616 |
+
)])
|
| 617 |
+
|
| 618 |
+
fig.update_layout(
|
| 619 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
| 620 |
+
scene=dict(
|
| 621 |
+
xaxis_title="X (meters)",
|
| 622 |
+
yaxis_title="Y (meters)",
|
| 623 |
+
zaxis_title="Z (meters)",
|
| 624 |
+
camera=dict(
|
| 625 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
| 626 |
+
center=dict(x=0, y=0, z=0),
|
| 627 |
+
up=dict(x=0, y=0, z=1)
|
| 628 |
+
),
|
| 629 |
+
aspectmode='data'
|
| 630 |
+
),
|
| 631 |
+
width=700,
|
| 632 |
+
height=600
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
return fig
|
| 636 |
+
|
| 637 |
+
def on_depth_submit(image, num_points, focal_x, focal_y):
|
| 638 |
+
original_image = image.copy()
|
| 639 |
+
|
| 640 |
+
h, w = image.shape[:2]
|
| 641 |
+
|
| 642 |
+
# Predict depth using the model
|
| 643 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 644 |
+
|
| 645 |
+
# Save raw 16-bit depth
|
| 646 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
| 647 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 648 |
+
raw_depth.save(tmp_raw_depth.name)
|
| 649 |
+
|
| 650 |
+
# Normalize and convert to grayscale for display
|
| 651 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 652 |
+
norm_depth = norm_depth.astype(np.uint8)
|
| 653 |
+
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
|
| 654 |
+
|
| 655 |
+
gray_depth = Image.fromarray(norm_depth)
|
| 656 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 657 |
+
gray_depth.save(tmp_gray_depth.name)
|
| 658 |
+
|
| 659 |
+
# Create point cloud
|
| 660 |
+
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
|
| 661 |
+
|
| 662 |
+
# Reconstruct mesh from point cloud
|
| 663 |
+
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
|
| 664 |
+
|
| 665 |
+
# Save mesh with faces as .ply
|
| 666 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
| 667 |
+
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
|
| 668 |
+
|
| 669 |
+
# Create enhanced 3D scatter plot visualization
|
| 670 |
+
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
|
| 671 |
+
|
| 672 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
| 673 |
+
|
| 674 |
+
# --- Actual Wound Segmentation Functions ---
|
| 675 |
+
def create_automatic_wound_mask(image, method='deep_learning'):
|
| 676 |
+
"""
|
| 677 |
+
Automatically generate wound mask from image using the actual deep learning model
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
image: Input image (numpy array)
|
| 681 |
+
method: Segmentation method (currently only 'deep_learning' supported)
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
mask: Binary wound mask
|
| 685 |
+
"""
|
| 686 |
+
if image is None:
|
| 687 |
+
return None
|
| 688 |
+
|
| 689 |
+
# Use the actual deep learning model for segmentation
|
| 690 |
+
if method == 'deep_learning':
|
| 691 |
+
mask, _ = segmentation_model.segment_wound(image)
|
| 692 |
+
return mask
|
| 693 |
+
else:
|
| 694 |
+
# Fallback to deep learning if method not recognized
|
| 695 |
+
mask, _ = segmentation_model.segment_wound(image)
|
| 696 |
+
return mask
|
| 697 |
+
|
| 698 |
+
def post_process_wound_mask(mask, min_area=100):
|
| 699 |
+
"""Post-process the wound mask to remove noise and small objects"""
|
| 700 |
+
if mask is None:
|
| 701 |
+
return None
|
| 702 |
+
|
| 703 |
+
# Convert to binary if needed
|
| 704 |
+
if mask.dtype != np.uint8:
|
| 705 |
+
mask = mask.astype(np.uint8)
|
| 706 |
+
|
| 707 |
+
# Apply morphological operations to clean up
|
| 708 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 709 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 710 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 711 |
+
|
| 712 |
+
# Remove small objects using OpenCV
|
| 713 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 714 |
+
mask_clean = np.zeros_like(mask)
|
| 715 |
+
|
| 716 |
+
for contour in contours:
|
| 717 |
+
area = cv2.contourArea(contour)
|
| 718 |
+
if area >= min_area:
|
| 719 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 720 |
+
|
| 721 |
+
# Fill holes
|
| 722 |
+
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
|
| 723 |
+
|
| 724 |
+
return mask_clean
|
| 725 |
+
|
| 726 |
+
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
|
| 727 |
+
"""Analyze wound severity with automatic mask generation using actual segmentation model"""
|
| 728 |
+
if image is None or depth_map is None:
|
| 729 |
+
return "β Please provide both image and depth map."
|
| 730 |
+
|
| 731 |
+
# Generate automatic wound mask using the actual model
|
| 732 |
+
auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
|
| 733 |
+
|
| 734 |
+
if auto_mask is None:
|
| 735 |
+
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
|
| 736 |
+
|
| 737 |
+
# Post-process the mask
|
| 738 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 739 |
+
|
| 740 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 741 |
+
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
|
| 742 |
+
|
| 743 |
+
# Analyze severity using the automatic mask
|
| 744 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
|
| 745 |
+
|
| 746 |
+
# --- Main Gradio Interface ---
|
| 747 |
+
with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
|
| 748 |
+
gr.HTML("<h1>Wound Analysis & Depth Estimation System</h1>")
|
| 749 |
+
gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
|
| 750 |
+
|
| 751 |
+
# Shared image state
|
| 752 |
+
shared_image = gr.State()
|
| 753 |
+
|
| 754 |
+
with gr.Tabs():
|
| 755 |
+
# Tab 1: Wound Classification
|
| 756 |
+
with gr.Tab("1. Wound Classification"):
|
| 757 |
+
gr.Markdown("### Step 1: Upload and classify your wound image")
|
| 758 |
+
gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
|
| 759 |
+
|
| 760 |
+
with gr.Row():
|
| 761 |
+
with gr.Column(scale=1):
|
| 762 |
+
wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
|
| 763 |
+
|
| 764 |
+
with gr.Column(scale=1):
|
| 765 |
+
wound_prediction_box = gr.HTML()
|
| 766 |
+
wound_reasoning_box = gr.HTML()
|
| 767 |
+
|
| 768 |
+
# Button to pass image to depth estimation
|
| 769 |
+
with gr.Row():
|
| 770 |
+
pass_to_depth_btn = gr.Button("π Pass Image to Depth Analysis", variant="secondary", size="lg")
|
| 771 |
+
pass_status = gr.HTML("")
|
| 772 |
+
|
| 773 |
+
wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
|
| 774 |
+
outputs=[wound_prediction_box, wound_reasoning_box])
|
| 775 |
+
|
| 776 |
+
# Store image when uploaded for classification
|
| 777 |
+
wound_image_input.change(
|
| 778 |
+
fn=lambda img: img,
|
| 779 |
+
inputs=[wound_image_input],
|
| 780 |
+
outputs=[shared_image]
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Tab 2: Depth Estimation
|
| 784 |
+
with gr.Tab("2. Depth Estimation & 3D Visualization"):
|
| 785 |
+
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
|
| 786 |
+
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
|
| 787 |
+
|
| 788 |
+
with gr.Row():
|
| 789 |
+
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
| 790 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
| 791 |
+
|
| 792 |
+
with gr.Row():
|
| 793 |
+
depth_submit = gr.Button(value="Compute Depth", variant="primary")
|
| 794 |
+
load_shared_btn = gr.Button("π Load Image from Classification", variant="secondary")
|
| 795 |
+
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
| 796 |
+
label="Number of 3D points (upload image to update max)")
|
| 797 |
+
|
| 798 |
+
with gr.Row():
|
| 799 |
+
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 800 |
+
label="Focal Length X (pixels)")
|
| 801 |
+
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 802 |
+
label="Focal Length Y (pixels)")
|
| 803 |
+
|
| 804 |
+
with gr.Row():
|
| 805 |
+
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
| 806 |
+
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
| 807 |
+
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
| 808 |
+
|
| 809 |
+
# 3D Visualization
|
| 810 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
| 811 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
| 812 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
| 813 |
+
|
| 814 |
+
# Store depth map for severity analysis
|
| 815 |
+
depth_map_state = gr.State()
|
| 816 |
+
|
| 817 |
+
# Tab 3: Wound Severity Analysis
|
| 818 |
+
with gr.Tab("3. π©Ή Wound Severity Analysis"):
|
| 819 |
+
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
|
| 820 |
+
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
|
| 821 |
+
|
| 822 |
+
with gr.Row():
|
| 823 |
+
severity_input_image = gr.Image(label="Original Image", type='numpy')
|
| 824 |
+
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
|
| 825 |
+
|
| 826 |
+
with gr.Row():
|
| 827 |
+
wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
|
| 828 |
+
severity_output = gr.HTML(label="Severity Analysis Report")
|
| 829 |
+
|
| 830 |
+
gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
|
| 831 |
+
|
| 832 |
+
with gr.Row():
|
| 833 |
+
auto_severity_button = gr.Button("π€ Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
|
| 834 |
+
manual_severity_button = gr.Button("π Manual Mask Analysis", variant="secondary", size="lg")
|
| 835 |
+
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
|
| 836 |
+
label="Pixel Spacing (mm/pixel)")
|
| 837 |
+
|
| 838 |
+
gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
|
| 839 |
+
|
| 840 |
+
with gr.Row():
|
| 841 |
+
# Load depth map from previous tab
|
| 842 |
+
load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary")
|
| 843 |
+
|
| 844 |
+
gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
|
| 845 |
+
|
| 846 |
+
# Update slider when image is uploaded
|
| 847 |
+
depth_input_image.change(
|
| 848 |
+
fn=update_slider_on_image_upload,
|
| 849 |
+
inputs=[depth_input_image],
|
| 850 |
+
outputs=[points_slider]
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# Modified depth submit function to store depth map
|
| 854 |
+
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
|
| 855 |
+
results = on_depth_submit(image, num_points, focal_x, focal_y)
|
| 856 |
+
# Extract depth map from results for severity analysis
|
| 857 |
+
depth_map = None
|
| 858 |
+
if image is not None:
|
| 859 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 860 |
+
# Normalize depth for severity analysis
|
| 861 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 862 |
+
depth_map = norm_depth.astype(np.uint8)
|
| 863 |
+
return results + [depth_map]
|
| 864 |
+
|
| 865 |
+
depth_submit.click(on_depth_submit_with_state,
|
| 866 |
+
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
|
| 867 |
+
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
|
| 868 |
+
|
| 869 |
+
# Load depth map to severity tab and auto-generate mask
|
| 870 |
+
def load_depth_to_severity(depth_map, original_image):
|
| 871 |
+
if depth_map is None:
|
| 872 |
+
return None, None, None, "β No depth map available. Please compute depth in Tab 2 first."
|
| 873 |
+
|
| 874 |
+
# Auto-generate wound mask using segmentation model
|
| 875 |
+
if original_image is not None:
|
| 876 |
+
auto_mask, _ = segmentation_model.segment_wound(original_image)
|
| 877 |
+
if auto_mask is not None:
|
| 878 |
+
# Post-process the mask
|
| 879 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 880 |
+
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
|
| 881 |
+
return depth_map, original_image, processed_mask, "β
Depth map loaded and wound mask auto-generated!"
|
| 882 |
+
else:
|
| 883 |
+
return depth_map, original_image, None, "β
Depth map loaded but no wound detected. Try uploading a different image."
|
| 884 |
+
else:
|
| 885 |
+
return depth_map, original_image, None, "β
Depth map loaded but segmentation failed. Try uploading a different image."
|
| 886 |
+
else:
|
| 887 |
+
return depth_map, original_image, None, "β
Depth map loaded successfully!"
|
| 888 |
+
|
| 889 |
+
load_depth_btn.click(
|
| 890 |
+
fn=load_depth_to_severity,
|
| 891 |
+
inputs=[depth_map_state, depth_input_image],
|
| 892 |
+
outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
# Automatic severity analysis function
|
| 896 |
+
def run_auto_severity_analysis(image, depth_map, pixel_spacing):
|
| 897 |
+
if depth_map is None:
|
| 898 |
+
return "β Please load depth map from Tab 2 first."
|
| 899 |
+
|
| 900 |
+
# Generate automatic wound mask using the actual model
|
| 901 |
+
auto_mask = create_automatic_wound_mask(image, method='deep_learning')
|
| 902 |
+
|
| 903 |
+
if auto_mask is None:
|
| 904 |
+
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
|
| 905 |
+
|
| 906 |
+
# Post-process the mask with fixed minimum area
|
| 907 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 908 |
+
|
| 909 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 910 |
+
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
|
| 911 |
+
|
| 912 |
+
# Analyze severity using the automatic mask
|
| 913 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
|
| 914 |
+
|
| 915 |
+
# Manual severity analysis function
|
| 916 |
+
def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
|
| 917 |
+
if depth_map is None:
|
| 918 |
+
return "β Please load depth map from Tab 2 first."
|
| 919 |
+
if wound_mask is None:
|
| 920 |
+
return "β Please upload a wound mask (binary image where white pixels represent the wound area)."
|
| 921 |
+
|
| 922 |
+
return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
|
| 923 |
+
|
| 924 |
+
# Connect event handlers
|
| 925 |
+
auto_severity_button.click(
|
| 926 |
+
fn=run_auto_severity_analysis,
|
| 927 |
+
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider],
|
| 928 |
+
outputs=[severity_output]
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
manual_severity_button.click(
|
| 932 |
+
fn=run_manual_severity_analysis,
|
| 933 |
+
inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider],
|
| 934 |
+
outputs=[severity_output]
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
# Auto-generate mask when image is uploaded
|
| 940 |
+
def auto_generate_mask_on_image_upload(image):
|
| 941 |
+
if image is None:
|
| 942 |
+
return None, "β No image uploaded."
|
| 943 |
+
|
| 944 |
+
# Generate automatic wound mask using segmentation model
|
| 945 |
+
auto_mask, _ = segmentation_model.segment_wound(image)
|
| 946 |
+
if auto_mask is not None:
|
| 947 |
+
# Post-process the mask
|
| 948 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 949 |
+
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
|
| 950 |
+
return processed_mask, "β
Wound mask auto-generated using deep learning model!"
|
| 951 |
+
else:
|
| 952 |
+
return None, "β
Image uploaded but no wound detected. Try uploading a different image."
|
| 953 |
+
else:
|
| 954 |
+
return None, "β
Image uploaded but segmentation failed. Try uploading a different image."
|
| 955 |
+
|
| 956 |
+
# Load shared image from classification tab
|
| 957 |
+
def load_shared_image(shared_img):
|
| 958 |
+
if shared_img is None:
|
| 959 |
+
return gr.Image(), "οΏ½οΏ½ No image available from classification tab"
|
| 960 |
+
|
| 961 |
+
# Convert PIL image to numpy array for depth estimation
|
| 962 |
+
if hasattr(shared_img, 'convert'):
|
| 963 |
+
# It's a PIL image, convert to numpy
|
| 964 |
+
img_array = np.array(shared_img)
|
| 965 |
+
return img_array, "β
Image loaded from classification tab"
|
| 966 |
+
else:
|
| 967 |
+
# Already numpy array
|
| 968 |
+
return shared_img, "β
Image loaded from classification tab"
|
| 969 |
+
|
| 970 |
+
# Auto-generate mask when image is uploaded to severity tab
|
| 971 |
+
severity_input_image.change(
|
| 972 |
+
fn=auto_generate_mask_on_image_upload,
|
| 973 |
+
inputs=[severity_input_image],
|
| 974 |
+
outputs=[wound_mask_input, gr.HTML()]
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
load_shared_btn.click(
|
| 978 |
+
fn=load_shared_image,
|
| 979 |
+
inputs=[shared_image],
|
| 980 |
+
outputs=[depth_input_image, gr.HTML()]
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
# Pass image to depth tab function
|
| 984 |
+
def pass_image_to_depth(img):
|
| 985 |
+
if img is None:
|
| 986 |
+
return "β No image uploaded in classification tab"
|
| 987 |
+
return "β
Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
|
| 988 |
+
|
| 989 |
+
pass_to_depth_btn.click(
|
| 990 |
+
fn=pass_image_to_depth,
|
| 991 |
+
inputs=[shared_image],
|
| 992 |
+
outputs=[pass_status]
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
if __name__ == '__main__':
|
| 996 |
+
demo.queue().launch(
|
| 997 |
+
server_name="0.0.0.0",
|
| 998 |
+
server_port=7860,
|
| 999 |
+
share=True
|
| 1000 |
+
)
|
temp_files/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Wound Analysis V22
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.41.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
temp_files/fw2.txt
ADDED
|
@@ -0,0 +1,1175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import tempfile
|
| 8 |
+
from gradio_imageslider import ImageSlider
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
| 13 |
+
import os
|
| 14 |
+
import tensorflow as tf
|
| 15 |
+
from tensorflow.keras.models import load_model
|
| 16 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
| 17 |
+
import base64
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import gdown
|
| 20 |
+
import spaces
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
# Import actual segmentation model components
|
| 24 |
+
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
|
| 25 |
+
from utils.learning.metrics import dice_coef, precision, recall
|
| 26 |
+
from utils.io.data import normalize
|
| 27 |
+
|
| 28 |
+
# Define path and file ID
|
| 29 |
+
checkpoint_dir = "checkpoints"
|
| 30 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
|
| 33 |
+
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
|
| 34 |
+
|
| 35 |
+
# Download if not already present
|
| 36 |
+
if not os.path.exists(model_file):
|
| 37 |
+
print("Downloading model from Google Drive...")
|
| 38 |
+
gdown.download(gdrive_url, model_file, quiet=False)
|
| 39 |
+
|
| 40 |
+
# --- TensorFlow: Check GPU Availability ---
|
| 41 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 42 |
+
if gpus:
|
| 43 |
+
print("TensorFlow is using GPU")
|
| 44 |
+
else:
|
| 45 |
+
print("TensorFlow is using CPU")
|
| 46 |
+
|
| 47 |
+
# --- Load Wound Classification Model and Class Labels ---
|
| 48 |
+
wound_model = load_model("keras_model.h5")
|
| 49 |
+
with open("labels.txt", "r") as f:
|
| 50 |
+
class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
|
| 51 |
+
|
| 52 |
+
# --- Load Actual Wound Segmentation Model ---
|
| 53 |
+
class WoundSegmentationModel:
|
| 54 |
+
def __init__(self):
|
| 55 |
+
self.input_dim_x = 224
|
| 56 |
+
self.input_dim_y = 224
|
| 57 |
+
self.model = None
|
| 58 |
+
self.load_model()
|
| 59 |
+
|
| 60 |
+
def load_model(self):
|
| 61 |
+
"""Load the trained wound segmentation model"""
|
| 62 |
+
try:
|
| 63 |
+
# Try to load the most recent model
|
| 64 |
+
weight_file_name = '2025-08-07_16-25-27.hdf5'
|
| 65 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 66 |
+
|
| 67 |
+
self.model = load_model(model_path,
|
| 68 |
+
custom_objects={
|
| 69 |
+
'recall': recall,
|
| 70 |
+
'precision': precision,
|
| 71 |
+
'dice_coef': dice_coef,
|
| 72 |
+
'relu6': relu6,
|
| 73 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 74 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 75 |
+
})
|
| 76 |
+
print(f"Segmentation model loaded successfully from {model_path}")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Error loading segmentation model: {e}")
|
| 79 |
+
# Fallback to the older model
|
| 80 |
+
try:
|
| 81 |
+
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
|
| 82 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 83 |
+
|
| 84 |
+
self.model = load_model(model_path,
|
| 85 |
+
custom_objects={
|
| 86 |
+
'recall': recall,
|
| 87 |
+
'precision': precision,
|
| 88 |
+
'dice_coef': dice_coef,
|
| 89 |
+
'relu6': relu6,
|
| 90 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 91 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 92 |
+
})
|
| 93 |
+
print(f"Segmentation model loaded successfully from {model_path}")
|
| 94 |
+
except Exception as e2:
|
| 95 |
+
print(f"Error loading fallback segmentation model: {e2}")
|
| 96 |
+
self.model = None
|
| 97 |
+
|
| 98 |
+
def preprocess_image(self, image):
|
| 99 |
+
"""Preprocess the uploaded image for model input"""
|
| 100 |
+
if image is None:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Convert to RGB if needed
|
| 104 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 105 |
+
# Convert BGR to RGB if needed
|
| 106 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 107 |
+
|
| 108 |
+
# Resize to model input size
|
| 109 |
+
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
|
| 110 |
+
|
| 111 |
+
# Normalize the image
|
| 112 |
+
image = image.astype(np.float32) / 255.0
|
| 113 |
+
|
| 114 |
+
# Add batch dimension
|
| 115 |
+
image = np.expand_dims(image, axis=0)
|
| 116 |
+
|
| 117 |
+
return image
|
| 118 |
+
|
| 119 |
+
def postprocess_prediction(self, prediction):
|
| 120 |
+
"""Postprocess the model prediction"""
|
| 121 |
+
# Remove batch dimension
|
| 122 |
+
prediction = prediction[0]
|
| 123 |
+
|
| 124 |
+
# Apply threshold to get binary mask
|
| 125 |
+
threshold = 0.5
|
| 126 |
+
binary_mask = (prediction > threshold).astype(np.uint8) * 255
|
| 127 |
+
|
| 128 |
+
return binary_mask
|
| 129 |
+
|
| 130 |
+
def segment_wound(self, input_image):
|
| 131 |
+
"""Main function to segment wound from uploaded image"""
|
| 132 |
+
if self.model is None:
|
| 133 |
+
return None, "Error: Segmentation model not loaded. Please check the model files."
|
| 134 |
+
|
| 135 |
+
if input_image is None:
|
| 136 |
+
return None, "Please upload an image."
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
# Preprocess the image
|
| 140 |
+
processed_image = self.preprocess_image(input_image)
|
| 141 |
+
|
| 142 |
+
if processed_image is None:
|
| 143 |
+
return None, "Error processing image."
|
| 144 |
+
|
| 145 |
+
# Make prediction
|
| 146 |
+
prediction = self.model.predict(processed_image, verbose=0)
|
| 147 |
+
|
| 148 |
+
# Postprocess the prediction
|
| 149 |
+
segmented_mask = self.postprocess_prediction(prediction)
|
| 150 |
+
|
| 151 |
+
return segmented_mask, "Segmentation completed successfully!"
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
return None, f"Error during segmentation: {str(e)}"
|
| 155 |
+
|
| 156 |
+
# Initialize the segmentation model
|
| 157 |
+
segmentation_model = WoundSegmentationModel()
|
| 158 |
+
|
| 159 |
+
# --- PyTorch: Set Device and Load Depth Model ---
|
| 160 |
+
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
|
| 161 |
+
print(f"Using PyTorch device: {map_device}")
|
| 162 |
+
|
| 163 |
+
model_configs = {
|
| 164 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 165 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 166 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 167 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
| 168 |
+
}
|
| 169 |
+
encoder = 'vitl'
|
| 170 |
+
depth_model = DepthAnythingV2(**model_configs[encoder])
|
| 171 |
+
state_dict = torch.load(
|
| 172 |
+
f'checkpoints/depth_anything_v2_{encoder}.pth',
|
| 173 |
+
map_location=map_device
|
| 174 |
+
)
|
| 175 |
+
depth_model.load_state_dict(state_dict)
|
| 176 |
+
depth_model = depth_model.to(map_device).eval()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# --- Custom CSS for unified dark theme ---
|
| 180 |
+
css = """
|
| 181 |
+
.gradio-container {
|
| 182 |
+
font-family: 'Segoe UI', sans-serif;
|
| 183 |
+
background-color: #121212;
|
| 184 |
+
color: #ffffff;
|
| 185 |
+
padding: 20px;
|
| 186 |
+
}
|
| 187 |
+
.gr-button {
|
| 188 |
+
background-color: #2c3e50;
|
| 189 |
+
color: white;
|
| 190 |
+
border-radius: 10px;
|
| 191 |
+
}
|
| 192 |
+
.gr-button:hover {
|
| 193 |
+
background-color: #34495e;
|
| 194 |
+
}
|
| 195 |
+
.gr-html, .gr-html div {
|
| 196 |
+
white-space: normal !important;
|
| 197 |
+
overflow: visible !important;
|
| 198 |
+
text-overflow: unset !important;
|
| 199 |
+
word-break: break-word !important;
|
| 200 |
+
}
|
| 201 |
+
#img-display-container {
|
| 202 |
+
max-height: 100vh;
|
| 203 |
+
}
|
| 204 |
+
#img-display-input {
|
| 205 |
+
max-height: 80vh;
|
| 206 |
+
}
|
| 207 |
+
#img-display-output {
|
| 208 |
+
max-height: 80vh;
|
| 209 |
+
}
|
| 210 |
+
#download {
|
| 211 |
+
height: 62px;
|
| 212 |
+
}
|
| 213 |
+
h1 {
|
| 214 |
+
text-align: center;
|
| 215 |
+
font-size: 3rem;
|
| 216 |
+
font-weight: bold;
|
| 217 |
+
margin: 2rem 0;
|
| 218 |
+
color: #ffffff;
|
| 219 |
+
}
|
| 220 |
+
h2 {
|
| 221 |
+
color: #ffffff;
|
| 222 |
+
text-align: center;
|
| 223 |
+
margin: 1rem 0;
|
| 224 |
+
}
|
| 225 |
+
.gr-tabs {
|
| 226 |
+
background-color: #1e1e1e;
|
| 227 |
+
border-radius: 10px;
|
| 228 |
+
padding: 10px;
|
| 229 |
+
}
|
| 230 |
+
.gr-tab-nav {
|
| 231 |
+
background-color: #2c3e50;
|
| 232 |
+
border-radius: 8px;
|
| 233 |
+
}
|
| 234 |
+
.gr-tab-nav button {
|
| 235 |
+
color: #ffffff !important;
|
| 236 |
+
}
|
| 237 |
+
.gr-tab-nav button.selected {
|
| 238 |
+
background-color: #34495e !important;
|
| 239 |
+
}
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
# --- Wound Classification Functions ---
|
| 243 |
+
def preprocess_input(img):
|
| 244 |
+
img = img.resize((224, 224))
|
| 245 |
+
arr = keras_image.img_to_array(img)
|
| 246 |
+
arr = arr / 255.0
|
| 247 |
+
return np.expand_dims(arr, axis=0)
|
| 248 |
+
|
| 249 |
+
def get_reasoning_from_gemini(img, prediction):
|
| 250 |
+
try:
|
| 251 |
+
# For now, return a simple explanation without Gemini API to avoid typing issues
|
| 252 |
+
# In production, you would implement the proper Gemini API call here
|
| 253 |
+
explanations = {
|
| 254 |
+
"Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
|
| 255 |
+
"Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
|
| 256 |
+
"Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
|
| 257 |
+
"Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
|
| 258 |
+
"Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
return f"(Reasoning unavailable: {str(e)})"
|
| 265 |
+
|
| 266 |
+
@spaces.GPU
|
| 267 |
+
def classify_wound_image(img):
|
| 268 |
+
if img is None:
|
| 269 |
+
return "<div style='color:#ff5252; font-size:18px;'>No image provided</div>", ""
|
| 270 |
+
|
| 271 |
+
img_array = preprocess_input(img)
|
| 272 |
+
predictions = wound_model.predict(img_array, verbose=0)[0]
|
| 273 |
+
pred_idx = int(np.argmax(predictions))
|
| 274 |
+
pred_class = class_labels[pred_idx]
|
| 275 |
+
|
| 276 |
+
# Get reasoning from Gemini
|
| 277 |
+
reasoning_text = get_reasoning_from_gemini(img, pred_class)
|
| 278 |
+
|
| 279 |
+
# Prediction Card
|
| 280 |
+
predicted_card = f"""
|
| 281 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 282 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 283 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 284 |
+
Predicted Wound Type
|
| 285 |
+
</div>
|
| 286 |
+
<div style='font-size: 26px; color: white;'>
|
| 287 |
+
{pred_class}
|
| 288 |
+
</div>
|
| 289 |
+
</div>
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
# Reasoning Card
|
| 293 |
+
reasoning_card = f"""
|
| 294 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 295 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 296 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 297 |
+
Reasoning
|
| 298 |
+
</div>
|
| 299 |
+
<div style='font-size: 16px; color: white; min-height: 80px;'>
|
| 300 |
+
{reasoning_text}
|
| 301 |
+
</div>
|
| 302 |
+
</div>
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
return predicted_card, reasoning_card
|
| 306 |
+
|
| 307 |
+
# --- Enhanced Wound Severity Estimation Functions ---
|
| 308 |
+
@spaces.GPU
|
| 309 |
+
def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
|
| 310 |
+
"""
|
| 311 |
+
Enhanced depth analysis with proper calibration and medical standards
|
| 312 |
+
Based on wound depth classification standards:
|
| 313 |
+
- Superficial: 0-2mm (epidermis only)
|
| 314 |
+
- Partial thickness: 2-4mm (epidermis + partial dermis)
|
| 315 |
+
- Full thickness: 4-6mm (epidermis + full dermis)
|
| 316 |
+
- Deep: >6mm (involving subcutaneous tissue)
|
| 317 |
+
"""
|
| 318 |
+
# Convert pixel spacing to mm
|
| 319 |
+
pixel_spacing_mm = float(pixel_spacing_mm)
|
| 320 |
+
|
| 321 |
+
# Calculate pixel area in cmΒ²
|
| 322 |
+
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
|
| 323 |
+
|
| 324 |
+
# Extract wound region (binary mask)
|
| 325 |
+
wound_mask = (mask > 127).astype(np.uint8)
|
| 326 |
+
|
| 327 |
+
# Apply morphological operations to clean the mask
|
| 328 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 329 |
+
wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel)
|
| 330 |
+
|
| 331 |
+
# Get depth values only for wound region
|
| 332 |
+
wound_depths = depth_map[wound_mask > 0]
|
| 333 |
+
|
| 334 |
+
if len(wound_depths) == 0:
|
| 335 |
+
return {
|
| 336 |
+
'total_area_cm2': 0,
|
| 337 |
+
'superficial_area_cm2': 0,
|
| 338 |
+
'partial_thickness_area_cm2': 0,
|
| 339 |
+
'full_thickness_area_cm2': 0,
|
| 340 |
+
'deep_area_cm2': 0,
|
| 341 |
+
'mean_depth_mm': 0,
|
| 342 |
+
'max_depth_mm': 0,
|
| 343 |
+
'depth_std_mm': 0,
|
| 344 |
+
'deep_ratio': 0,
|
| 345 |
+
'wound_volume_cm3': 0,
|
| 346 |
+
'depth_percentiles': {'25': 0, '50': 0, '75': 0}
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
# Calibrate depth map for more accurate measurements
|
| 350 |
+
calibrated_depth_map = calibrate_depth_map(depth_map, reference_depth_mm=depth_calibration_mm)
|
| 351 |
+
|
| 352 |
+
# Get calibrated depth values for wound region
|
| 353 |
+
wound_depths_mm = calibrated_depth_map[wound_mask > 0]
|
| 354 |
+
|
| 355 |
+
# Medical depth classification
|
| 356 |
+
superficial_mask = wound_depths_mm < 2.0
|
| 357 |
+
partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0)
|
| 358 |
+
full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0)
|
| 359 |
+
deep_mask = wound_depths_mm >= 6.0
|
| 360 |
+
|
| 361 |
+
# Calculate areas
|
| 362 |
+
total_pixels = np.sum(wound_mask > 0)
|
| 363 |
+
total_area_cm2 = total_pixels * pixel_area_cm2
|
| 364 |
+
|
| 365 |
+
superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2
|
| 366 |
+
partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2
|
| 367 |
+
full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2
|
| 368 |
+
deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2
|
| 369 |
+
|
| 370 |
+
# Calculate depth statistics
|
| 371 |
+
mean_depth_mm = np.mean(wound_depths_mm)
|
| 372 |
+
max_depth_mm = np.max(wound_depths_mm)
|
| 373 |
+
depth_std_mm = np.std(wound_depths_mm)
|
| 374 |
+
|
| 375 |
+
# Calculate depth percentiles
|
| 376 |
+
depth_percentiles = {
|
| 377 |
+
'25': np.percentile(wound_depths_mm, 25),
|
| 378 |
+
'50': np.percentile(wound_depths_mm, 50),
|
| 379 |
+
'75': np.percentile(wound_depths_mm, 75)
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
# Calculate wound volume (approximate)
|
| 383 |
+
# Volume = area * average depth
|
| 384 |
+
wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0)
|
| 385 |
+
|
| 386 |
+
# Deep tissue ratio
|
| 387 |
+
deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0
|
| 388 |
+
|
| 389 |
+
# Calculate analysis quality metrics
|
| 390 |
+
wound_pixel_count = len(wound_depths_mm)
|
| 391 |
+
analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low"
|
| 392 |
+
|
| 393 |
+
# Calculate depth consistency (lower std dev = more consistent)
|
| 394 |
+
depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low"
|
| 395 |
+
|
| 396 |
+
return {
|
| 397 |
+
'total_area_cm2': total_area_cm2,
|
| 398 |
+
'superficial_area_cm2': superficial_area_cm2,
|
| 399 |
+
'partial_thickness_area_cm2': partial_thickness_area_cm2,
|
| 400 |
+
'full_thickness_area_cm2': full_thickness_area_cm2,
|
| 401 |
+
'deep_area_cm2': deep_area_cm2,
|
| 402 |
+
'mean_depth_mm': mean_depth_mm,
|
| 403 |
+
'max_depth_mm': max_depth_mm,
|
| 404 |
+
'depth_std_mm': depth_std_mm,
|
| 405 |
+
'deep_ratio': deep_ratio,
|
| 406 |
+
'wound_volume_cm3': wound_volume_cm3,
|
| 407 |
+
'depth_percentiles': depth_percentiles,
|
| 408 |
+
'analysis_quality': analysis_quality,
|
| 409 |
+
'depth_consistency': depth_consistency,
|
| 410 |
+
'wound_pixel_count': wound_pixel_count
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
def classify_wound_severity_by_enhanced_metrics(depth_stats):
|
| 414 |
+
"""
|
| 415 |
+
Enhanced wound severity classification based on medical standards
|
| 416 |
+
Uses multiple criteria: depth, area, volume, and tissue involvement
|
| 417 |
+
"""
|
| 418 |
+
if depth_stats['total_area_cm2'] == 0:
|
| 419 |
+
return "Unknown"
|
| 420 |
+
|
| 421 |
+
# Extract key metrics
|
| 422 |
+
total_area = depth_stats['total_area_cm2']
|
| 423 |
+
deep_area = depth_stats['deep_area_cm2']
|
| 424 |
+
full_thickness_area = depth_stats['full_thickness_area_cm2']
|
| 425 |
+
mean_depth = depth_stats['mean_depth_mm']
|
| 426 |
+
max_depth = depth_stats['max_depth_mm']
|
| 427 |
+
wound_volume = depth_stats['wound_volume_cm3']
|
| 428 |
+
deep_ratio = depth_stats['deep_ratio']
|
| 429 |
+
|
| 430 |
+
# Medical severity classification criteria
|
| 431 |
+
severity_score = 0
|
| 432 |
+
|
| 433 |
+
# Criterion 1: Maximum depth
|
| 434 |
+
if max_depth >= 10.0:
|
| 435 |
+
severity_score += 3 # Very severe
|
| 436 |
+
elif max_depth >= 6.0:
|
| 437 |
+
severity_score += 2 # Severe
|
| 438 |
+
elif max_depth >= 4.0:
|
| 439 |
+
severity_score += 1 # Moderate
|
| 440 |
+
|
| 441 |
+
# Criterion 2: Mean depth
|
| 442 |
+
if mean_depth >= 5.0:
|
| 443 |
+
severity_score += 2
|
| 444 |
+
elif mean_depth >= 3.0:
|
| 445 |
+
severity_score += 1
|
| 446 |
+
|
| 447 |
+
# Criterion 3: Deep tissue involvement ratio
|
| 448 |
+
if deep_ratio >= 0.5:
|
| 449 |
+
severity_score += 3 # More than 50% deep tissue
|
| 450 |
+
elif deep_ratio >= 0.25:
|
| 451 |
+
severity_score += 2 # 25-50% deep tissue
|
| 452 |
+
elif deep_ratio >= 0.1:
|
| 453 |
+
severity_score += 1 # 10-25% deep tissue
|
| 454 |
+
|
| 455 |
+
# Criterion 4: Total wound area
|
| 456 |
+
if total_area >= 10.0:
|
| 457 |
+
severity_score += 2 # Large wound (>10 cmΒ²)
|
| 458 |
+
elif total_area >= 5.0:
|
| 459 |
+
severity_score += 1 # Medium wound (5-10 cmΒ²)
|
| 460 |
+
|
| 461 |
+
# Criterion 5: Wound volume
|
| 462 |
+
if wound_volume >= 5.0:
|
| 463 |
+
severity_score += 2 # High volume
|
| 464 |
+
elif wound_volume >= 2.0:
|
| 465 |
+
severity_score += 1 # Medium volume
|
| 466 |
+
|
| 467 |
+
# Determine severity based on total score
|
| 468 |
+
if severity_score >= 8:
|
| 469 |
+
return "Very Severe"
|
| 470 |
+
elif severity_score >= 6:
|
| 471 |
+
return "Severe"
|
| 472 |
+
elif severity_score >= 4:
|
| 473 |
+
return "Moderate"
|
| 474 |
+
elif severity_score >= 2:
|
| 475 |
+
return "Mild"
|
| 476 |
+
else:
|
| 477 |
+
return "Superficial"
|
| 478 |
+
|
| 479 |
+
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
|
| 480 |
+
"""Enhanced wound severity analysis with medical-grade metrics"""
|
| 481 |
+
if image is None or depth_map is None or wound_mask is None:
|
| 482 |
+
return "β Please upload image, depth map, and wound mask."
|
| 483 |
+
|
| 484 |
+
# Convert wound mask to grayscale if needed
|
| 485 |
+
if len(wound_mask.shape) == 3:
|
| 486 |
+
wound_mask = np.mean(wound_mask, axis=2)
|
| 487 |
+
|
| 488 |
+
# Ensure depth map and mask have same dimensions
|
| 489 |
+
if depth_map.shape[:2] != wound_mask.shape[:2]:
|
| 490 |
+
# Resize mask to match depth map
|
| 491 |
+
from PIL import Image
|
| 492 |
+
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
|
| 493 |
+
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
|
| 494 |
+
wound_mask = np.array(mask_pil)
|
| 495 |
+
|
| 496 |
+
# Compute enhanced statistics
|
| 497 |
+
stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm)
|
| 498 |
+
severity = classify_wound_severity_by_enhanced_metrics(stats)
|
| 499 |
+
|
| 500 |
+
# Enhanced severity color coding
|
| 501 |
+
severity_color = {
|
| 502 |
+
"Superficial": "#4CAF50", # Green
|
| 503 |
+
"Mild": "#8BC34A", # Light Green
|
| 504 |
+
"Moderate": "#FF9800", # Orange
|
| 505 |
+
"Severe": "#F44336", # Red
|
| 506 |
+
"Very Severe": "#9C27B0" # Purple
|
| 507 |
+
}.get(severity, "#9E9E9E") # Gray for unknown
|
| 508 |
+
|
| 509 |
+
# Create comprehensive medical report
|
| 510 |
+
report = f"""
|
| 511 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 512 |
+
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
|
| 513 |
+
π©Ή Enhanced Wound Severity Analysis
|
| 514 |
+
</div>
|
| 515 |
+
|
| 516 |
+
<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-bottom: 20px;'>
|
| 517 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 518 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 519 |
+
π Tissue Involvement Analysis
|
| 520 |
+
</div>
|
| 521 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 522 |
+
<div>π’ <b>Superficial (0-2mm):</b> {stats['superficial_area_cm2']:.2f} cmΒ²</div>
|
| 523 |
+
<div>π‘ <b>Partial Thickness (2-4mm):</b> {stats['partial_thickness_area_cm2']:.2f} cmΒ²</div>
|
| 524 |
+
<div>π <b>Full Thickness (4-6mm):</b> {stats['full_thickness_area_cm2']:.2f} cmΒ²</div>
|
| 525 |
+
<div>π₯ <b>Deep (>6mm):</b> {stats['deep_area_cm2']:.2f} cmΒ²</div>
|
| 526 |
+
<div>π <b>Total Area:</b> {stats['total_area_cm2']:.2f} cmΒ²</div>
|
| 527 |
+
</div>
|
| 528 |
+
</div>
|
| 529 |
+
|
| 530 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 531 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 532 |
+
π Depth Statistics
|
| 533 |
+
</div>
|
| 534 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 535 |
+
<div>π <b>Mean Depth:</b> {stats['mean_depth_mm']:.1f} mm</div>
|
| 536 |
+
<div>π <b>Max Depth:</b> {stats['max_depth_mm']:.1f} mm</div>
|
| 537 |
+
<div>π <b>Depth Std Dev:</b> {stats['depth_std_mm']:.1f} mm</div>
|
| 538 |
+
<div>π¦ <b>Wound Volume:</b> {stats['wound_volume_cm3']:.2f} cmΒ³</div>
|
| 539 |
+
<div>π₯ <b>Deep Tissue Ratio:</b> {stats['deep_ratio']*100:.1f}%</div>
|
| 540 |
+
</div>
|
| 541 |
+
</div>
|
| 542 |
+
</div>
|
| 543 |
+
|
| 544 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px;'>
|
| 545 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 546 |
+
π Depth Percentiles & Quality Metrics
|
| 547 |
+
</div>
|
| 548 |
+
<div style='color: #cccccc; line-height: 1.6; display: grid; grid-template-columns: 1fr 1fr; gap: 15px;'>
|
| 549 |
+
<div>
|
| 550 |
+
<div>π <b>25th Percentile:</b> {stats['depth_percentiles']['25']:.1f} mm</div>
|
| 551 |
+
<div>π <b>Median (50th):</b> {stats['depth_percentiles']['50']:.1f} mm</div>
|
| 552 |
+
<div>π <b>75th Percentile:</b> {stats['depth_percentiles']['75']:.1f} mm</div>
|
| 553 |
+
</div>
|
| 554 |
+
<div>
|
| 555 |
+
<div>π <b>Analysis Quality:</b> {stats['analysis_quality']}</div>
|
| 556 |
+
<div>π <b>Depth Consistency:</b> {stats['depth_consistency']}</div>
|
| 557 |
+
<div>π <b>Data Points:</b> {stats['wound_pixel_count']:,}</div>
|
| 558 |
+
</div>
|
| 559 |
+
</div>
|
| 560 |
+
</div>
|
| 561 |
+
|
| 562 |
+
<div style='text-align: center; padding: 15px; background-color: #2c2c2c; border-radius: 8px; border-left: 4px solid {severity_color};'>
|
| 563 |
+
<div style='font-size: 20px; font-weight: bold; color: {severity_color};'>
|
| 564 |
+
π― Medical Severity Assessment: {severity}
|
| 565 |
+
</div>
|
| 566 |
+
<div style='font-size: 14px; color: #cccccc; margin-top: 5px;'>
|
| 567 |
+
{get_enhanced_severity_description(severity)}
|
| 568 |
+
</div>
|
| 569 |
+
</div>
|
| 570 |
+
</div>
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
return report
|
| 574 |
+
|
| 575 |
+
def calibrate_depth_map(depth_map, reference_depth_mm=10.0):
|
| 576 |
+
"""
|
| 577 |
+
Calibrate depth map to real-world measurements using reference depth
|
| 578 |
+
This helps convert normalized depth values to actual millimeters
|
| 579 |
+
"""
|
| 580 |
+
if depth_map is None:
|
| 581 |
+
return depth_map
|
| 582 |
+
|
| 583 |
+
# Find the maximum depth value in the depth map
|
| 584 |
+
max_depth_value = np.max(depth_map)
|
| 585 |
+
min_depth_value = np.min(depth_map)
|
| 586 |
+
|
| 587 |
+
if max_depth_value == min_depth_value:
|
| 588 |
+
return depth_map
|
| 589 |
+
|
| 590 |
+
# Apply calibration to convert to millimeters
|
| 591 |
+
# Assuming the maximum depth in the map corresponds to reference_depth_mm
|
| 592 |
+
calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm
|
| 593 |
+
|
| 594 |
+
return calibrated_depth
|
| 595 |
+
|
| 596 |
+
def get_enhanced_severity_description(severity):
|
| 597 |
+
"""Get comprehensive medical description for severity level"""
|
| 598 |
+
descriptions = {
|
| 599 |
+
"Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.",
|
| 600 |
+
"Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.",
|
| 601 |
+
"Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.",
|
| 602 |
+
"Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.",
|
| 603 |
+
"Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.",
|
| 604 |
+
"Unknown": "Unable to determine severity due to insufficient data or poor image quality."
|
| 605 |
+
}
|
| 606 |
+
return descriptions.get(severity, "Severity assessment unavailable.")
|
| 607 |
+
|
| 608 |
+
def create_sample_wound_mask(image_shape, center=None, radius=50):
|
| 609 |
+
"""Create a sample circular wound mask for testing"""
|
| 610 |
+
if center is None:
|
| 611 |
+
center = (image_shape[1] // 2, image_shape[0] // 2)
|
| 612 |
+
|
| 613 |
+
mask = np.zeros(image_shape[:2], dtype=np.uint8)
|
| 614 |
+
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
|
| 615 |
+
|
| 616 |
+
# Create circular mask
|
| 617 |
+
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
|
| 618 |
+
mask[dist_from_center <= radius] = 255
|
| 619 |
+
|
| 620 |
+
return mask
|
| 621 |
+
|
| 622 |
+
def create_realistic_wound_mask(image_shape, method='elliptical'):
|
| 623 |
+
"""Create a more realistic wound mask with irregular shapes"""
|
| 624 |
+
h, w = image_shape[:2]
|
| 625 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 626 |
+
|
| 627 |
+
if method == 'elliptical':
|
| 628 |
+
# Create elliptical wound mask
|
| 629 |
+
center = (w // 2, h // 2)
|
| 630 |
+
radius_x = min(w, h) // 3
|
| 631 |
+
radius_y = min(w, h) // 4
|
| 632 |
+
|
| 633 |
+
y, x = np.ogrid[:h, :w]
|
| 634 |
+
# Add some irregularity to make it more realistic
|
| 635 |
+
ellipse = ((x - center[0])**2 / (radius_x**2) +
|
| 636 |
+
(y - center[1])**2 / (radius_y**2)) <= 1
|
| 637 |
+
|
| 638 |
+
# Add some noise and irregularity
|
| 639 |
+
noise = np.random.random((h, w)) > 0.8
|
| 640 |
+
mask = (ellipse | noise).astype(np.uint8) * 255
|
| 641 |
+
|
| 642 |
+
elif method == 'irregular':
|
| 643 |
+
# Create irregular wound mask
|
| 644 |
+
center = (w // 2, h // 2)
|
| 645 |
+
radius = min(w, h) // 4
|
| 646 |
+
|
| 647 |
+
y, x = np.ogrid[:h, :w]
|
| 648 |
+
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
|
| 649 |
+
|
| 650 |
+
# Add irregular extensions
|
| 651 |
+
extensions = np.zeros_like(base_circle)
|
| 652 |
+
for i in range(3):
|
| 653 |
+
angle = i * 2 * np.pi / 3
|
| 654 |
+
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
|
| 655 |
+
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
|
| 656 |
+
ext_radius = radius // 3
|
| 657 |
+
|
| 658 |
+
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
|
| 659 |
+
extensions = extensions | ext_circle
|
| 660 |
+
|
| 661 |
+
mask = (base_circle | extensions).astype(np.uint8) * 255
|
| 662 |
+
|
| 663 |
+
# Apply morphological operations to smooth the mask
|
| 664 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 665 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 666 |
+
|
| 667 |
+
return mask
|
| 668 |
+
|
| 669 |
+
# --- Depth Estimation Functions ---
|
| 670 |
+
@spaces.GPU
|
| 671 |
+
def predict_depth(image):
|
| 672 |
+
return depth_model.infer_image(image)
|
| 673 |
+
|
| 674 |
+
def calculate_max_points(image):
|
| 675 |
+
"""Calculate maximum points based on image dimensions (3x pixel count)"""
|
| 676 |
+
if image is None:
|
| 677 |
+
return 10000 # Default value
|
| 678 |
+
h, w = image.shape[:2]
|
| 679 |
+
max_points = h * w * 3
|
| 680 |
+
# Ensure minimum and reasonable maximum values
|
| 681 |
+
return max(1000, min(max_points, 300000))
|
| 682 |
+
|
| 683 |
+
def update_slider_on_image_upload(image):
|
| 684 |
+
"""Update the points slider when an image is uploaded"""
|
| 685 |
+
max_points = calculate_max_points(image)
|
| 686 |
+
default_value = min(10000, max_points // 10) # 10% of max points as default
|
| 687 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
| 688 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
| 689 |
+
|
| 690 |
+
@spaces.GPU
|
| 691 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
|
| 692 |
+
"""Create a point cloud from depth map using camera intrinsics with high detail"""
|
| 693 |
+
h, w = depth_map.shape
|
| 694 |
+
|
| 695 |
+
# Use smaller step for higher detail (reduced downsampling)
|
| 696 |
+
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
|
| 697 |
+
|
| 698 |
+
# Create mesh grid for camera coordinates
|
| 699 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 700 |
+
|
| 701 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 702 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
| 703 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
| 704 |
+
|
| 705 |
+
# Get depth values
|
| 706 |
+
depth_values = depth_map[::step, ::step]
|
| 707 |
+
|
| 708 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 709 |
+
x_3d = x_cam * depth_values
|
| 710 |
+
y_3d = y_cam * depth_values
|
| 711 |
+
z_3d = depth_values
|
| 712 |
+
|
| 713 |
+
# Flatten arrays
|
| 714 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
| 715 |
+
|
| 716 |
+
# Get corresponding image colors
|
| 717 |
+
image_colors = image[::step, ::step, :]
|
| 718 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
| 719 |
+
|
| 720 |
+
# Create Open3D point cloud
|
| 721 |
+
pcd = o3d.geometry.PointCloud()
|
| 722 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 723 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 724 |
+
|
| 725 |
+
return pcd
|
| 726 |
+
|
| 727 |
+
@spaces.GPU
|
| 728 |
+
def reconstruct_surface_mesh_from_point_cloud(pcd):
|
| 729 |
+
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
|
| 730 |
+
# Estimate and orient normals with high precision
|
| 731 |
+
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
|
| 732 |
+
pcd.orient_normals_consistent_tangent_plane(k=50)
|
| 733 |
+
|
| 734 |
+
# Create surface mesh with maximum detail (depth=12 for very high resolution)
|
| 735 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
|
| 736 |
+
|
| 737 |
+
# Return mesh without filtering low-density vertices
|
| 738 |
+
return mesh
|
| 739 |
+
|
| 740 |
+
@spaces.GPU
|
| 741 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
| 742 |
+
"""Create an enhanced 3D visualization using proper camera projection"""
|
| 743 |
+
h, w = depth_map.shape
|
| 744 |
+
|
| 745 |
+
# Downsample to avoid too many points for performance
|
| 746 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
| 747 |
+
|
| 748 |
+
# Create mesh grid for camera coordinates
|
| 749 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 750 |
+
|
| 751 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 752 |
+
focal_length = 470.4 # Default focal length
|
| 753 |
+
x_cam = (x_coords - w / 2) / focal_length
|
| 754 |
+
y_cam = (y_coords - h / 2) / focal_length
|
| 755 |
+
|
| 756 |
+
# Get depth values
|
| 757 |
+
depth_values = depth_map[::step, ::step]
|
| 758 |
+
|
| 759 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 760 |
+
x_3d = x_cam * depth_values
|
| 761 |
+
y_3d = y_cam * depth_values
|
| 762 |
+
z_3d = depth_values
|
| 763 |
+
|
| 764 |
+
# Flatten arrays
|
| 765 |
+
x_flat = x_3d.flatten()
|
| 766 |
+
y_flat = y_3d.flatten()
|
| 767 |
+
z_flat = z_3d.flatten()
|
| 768 |
+
|
| 769 |
+
# Get corresponding image colors
|
| 770 |
+
image_colors = image[::step, ::step, :]
|
| 771 |
+
colors_flat = image_colors.reshape(-1, 3)
|
| 772 |
+
|
| 773 |
+
# Create 3D scatter plot with proper camera projection
|
| 774 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 775 |
+
x=x_flat,
|
| 776 |
+
y=y_flat,
|
| 777 |
+
z=z_flat,
|
| 778 |
+
mode='markers',
|
| 779 |
+
marker=dict(
|
| 780 |
+
size=1.5,
|
| 781 |
+
color=colors_flat,
|
| 782 |
+
opacity=0.9
|
| 783 |
+
),
|
| 784 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
| 785 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
| 786 |
+
'<extra></extra>'
|
| 787 |
+
)])
|
| 788 |
+
|
| 789 |
+
fig.update_layout(
|
| 790 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
| 791 |
+
scene=dict(
|
| 792 |
+
xaxis_title="X (meters)",
|
| 793 |
+
yaxis_title="Y (meters)",
|
| 794 |
+
zaxis_title="Z (meters)",
|
| 795 |
+
camera=dict(
|
| 796 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
| 797 |
+
center=dict(x=0, y=0, z=0),
|
| 798 |
+
up=dict(x=0, y=0, z=1)
|
| 799 |
+
),
|
| 800 |
+
aspectmode='data'
|
| 801 |
+
),
|
| 802 |
+
width=700,
|
| 803 |
+
height=600
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
return fig
|
| 807 |
+
|
| 808 |
+
def on_depth_submit(image, num_points, focal_x, focal_y):
|
| 809 |
+
original_image = image.copy()
|
| 810 |
+
|
| 811 |
+
h, w = image.shape[:2]
|
| 812 |
+
|
| 813 |
+
# Predict depth using the model
|
| 814 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 815 |
+
|
| 816 |
+
# Save raw 16-bit depth
|
| 817 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
| 818 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 819 |
+
raw_depth.save(tmp_raw_depth.name)
|
| 820 |
+
|
| 821 |
+
# Normalize and convert to grayscale for display
|
| 822 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 823 |
+
norm_depth = norm_depth.astype(np.uint8)
|
| 824 |
+
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
|
| 825 |
+
|
| 826 |
+
gray_depth = Image.fromarray(norm_depth)
|
| 827 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 828 |
+
gray_depth.save(tmp_gray_depth.name)
|
| 829 |
+
|
| 830 |
+
# Create point cloud
|
| 831 |
+
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
|
| 832 |
+
|
| 833 |
+
# Reconstruct mesh from point cloud
|
| 834 |
+
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
|
| 835 |
+
|
| 836 |
+
# Save mesh with faces as .ply
|
| 837 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
| 838 |
+
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
|
| 839 |
+
|
| 840 |
+
# Create enhanced 3D scatter plot visualization
|
| 841 |
+
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
|
| 842 |
+
|
| 843 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
| 844 |
+
|
| 845 |
+
# --- Actual Wound Segmentation Functions ---
|
| 846 |
+
def create_automatic_wound_mask(image, method='deep_learning'):
|
| 847 |
+
"""
|
| 848 |
+
Automatically generate wound mask from image using the actual deep learning model
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
image: Input image (numpy array)
|
| 852 |
+
method: Segmentation method (currently only 'deep_learning' supported)
|
| 853 |
+
|
| 854 |
+
Returns:
|
| 855 |
+
mask: Binary wound mask
|
| 856 |
+
"""
|
| 857 |
+
if image is None:
|
| 858 |
+
return None
|
| 859 |
+
|
| 860 |
+
# Use the actual deep learning model for segmentation
|
| 861 |
+
if method == 'deep_learning':
|
| 862 |
+
mask, _ = segmentation_model.segment_wound(image)
|
| 863 |
+
return mask
|
| 864 |
+
else:
|
| 865 |
+
# Fallback to deep learning if method not recognized
|
| 866 |
+
mask, _ = segmentation_model.segment_wound(image)
|
| 867 |
+
return mask
|
| 868 |
+
|
| 869 |
+
def post_process_wound_mask(mask, min_area=100):
|
| 870 |
+
"""Post-process the wound mask to remove noise and small objects"""
|
| 871 |
+
if mask is None:
|
| 872 |
+
return None
|
| 873 |
+
|
| 874 |
+
# Convert to binary if needed
|
| 875 |
+
if mask.dtype != np.uint8:
|
| 876 |
+
mask = mask.astype(np.uint8)
|
| 877 |
+
|
| 878 |
+
# Apply morphological operations to clean up
|
| 879 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 880 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 881 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 882 |
+
|
| 883 |
+
# Remove small objects using OpenCV
|
| 884 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 885 |
+
mask_clean = np.zeros_like(mask)
|
| 886 |
+
|
| 887 |
+
for contour in contours:
|
| 888 |
+
area = cv2.contourArea(contour)
|
| 889 |
+
if area >= min_area:
|
| 890 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 891 |
+
|
| 892 |
+
# Fill holes
|
| 893 |
+
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
|
| 894 |
+
|
| 895 |
+
return mask_clean
|
| 896 |
+
|
| 897 |
+
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
|
| 898 |
+
"""Analyze wound severity with automatic mask generation using actual segmentation model"""
|
| 899 |
+
if image is None or depth_map is None:
|
| 900 |
+
return "β Please provide both image and depth map."
|
| 901 |
+
|
| 902 |
+
# Generate automatic wound mask using the actual model
|
| 903 |
+
auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
|
| 904 |
+
|
| 905 |
+
if auto_mask is None:
|
| 906 |
+
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
|
| 907 |
+
|
| 908 |
+
# Post-process the mask
|
| 909 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 910 |
+
|
| 911 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 912 |
+
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
|
| 913 |
+
|
| 914 |
+
# Analyze severity using the automatic mask
|
| 915 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
|
| 916 |
+
|
| 917 |
+
# --- Main Gradio Interface ---
|
| 918 |
+
with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
|
| 919 |
+
gr.HTML("<h1>Wound Analysis & Depth Estimation System</h1>")
|
| 920 |
+
gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
|
| 921 |
+
|
| 922 |
+
# Shared image state
|
| 923 |
+
shared_image = gr.State()
|
| 924 |
+
|
| 925 |
+
with gr.Tabs():
|
| 926 |
+
# Tab 1: Wound Classification
|
| 927 |
+
with gr.Tab("1. Wound Classification"):
|
| 928 |
+
gr.Markdown("### Step 1: Upload and classify your wound image")
|
| 929 |
+
gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
|
| 930 |
+
|
| 931 |
+
with gr.Row():
|
| 932 |
+
with gr.Column(scale=1):
|
| 933 |
+
wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
|
| 934 |
+
|
| 935 |
+
with gr.Column(scale=1):
|
| 936 |
+
wound_prediction_box = gr.HTML()
|
| 937 |
+
wound_reasoning_box = gr.HTML()
|
| 938 |
+
|
| 939 |
+
# Button to pass image to depth estimation
|
| 940 |
+
with gr.Row():
|
| 941 |
+
pass_to_depth_btn = gr.Button("π Pass Image to Depth Analysis", variant="secondary", size="lg")
|
| 942 |
+
pass_status = gr.HTML("")
|
| 943 |
+
|
| 944 |
+
wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
|
| 945 |
+
outputs=[wound_prediction_box, wound_reasoning_box])
|
| 946 |
+
|
| 947 |
+
# Store image when uploaded for classification
|
| 948 |
+
wound_image_input.change(
|
| 949 |
+
fn=lambda img: img,
|
| 950 |
+
inputs=[wound_image_input],
|
| 951 |
+
outputs=[shared_image]
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
# Tab 2: Depth Estimation
|
| 955 |
+
with gr.Tab("2. Depth Estimation & 3D Visualization"):
|
| 956 |
+
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
|
| 957 |
+
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
|
| 958 |
+
|
| 959 |
+
with gr.Row():
|
| 960 |
+
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
| 961 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
| 962 |
+
|
| 963 |
+
with gr.Row():
|
| 964 |
+
depth_submit = gr.Button(value="Compute Depth", variant="primary")
|
| 965 |
+
load_shared_btn = gr.Button("π Load Image from Classification", variant="secondary")
|
| 966 |
+
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
| 967 |
+
label="Number of 3D points (upload image to update max)")
|
| 968 |
+
|
| 969 |
+
with gr.Row():
|
| 970 |
+
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 971 |
+
label="Focal Length X (pixels)")
|
| 972 |
+
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 973 |
+
label="Focal Length Y (pixels)")
|
| 974 |
+
|
| 975 |
+
with gr.Row():
|
| 976 |
+
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
| 977 |
+
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
| 978 |
+
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
| 979 |
+
|
| 980 |
+
# 3D Visualization
|
| 981 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
| 982 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
| 983 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
| 984 |
+
|
| 985 |
+
# Store depth map for severity analysis
|
| 986 |
+
depth_map_state = gr.State()
|
| 987 |
+
|
| 988 |
+
# Tab 3: Wound Severity Analysis
|
| 989 |
+
with gr.Tab("3. π©Ή Wound Severity Analysis"):
|
| 990 |
+
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
|
| 991 |
+
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
|
| 992 |
+
|
| 993 |
+
with gr.Row():
|
| 994 |
+
severity_input_image = gr.Image(label="Original Image", type='numpy')
|
| 995 |
+
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
|
| 996 |
+
|
| 997 |
+
with gr.Row():
|
| 998 |
+
wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
|
| 999 |
+
severity_output = gr.HTML(label="Severity Analysis Report")
|
| 1000 |
+
|
| 1001 |
+
gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
|
| 1002 |
+
|
| 1003 |
+
with gr.Row():
|
| 1004 |
+
auto_severity_button = gr.Button("π€ Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
|
| 1005 |
+
manual_severity_button = gr.Button("π Manual Mask Analysis", variant="secondary", size="lg")
|
| 1006 |
+
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
|
| 1007 |
+
label="Pixel Spacing (mm/pixel)")
|
| 1008 |
+
depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0,
|
| 1009 |
+
label="Depth Calibration (mm)",
|
| 1010 |
+
info="Adjust based on expected maximum wound depth")
|
| 1011 |
+
|
| 1012 |
+
gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
|
| 1013 |
+
gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.")
|
| 1014 |
+
|
| 1015 |
+
with gr.Row():
|
| 1016 |
+
# Load depth map from previous tab
|
| 1017 |
+
load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary")
|
| 1018 |
+
|
| 1019 |
+
gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
|
| 1020 |
+
|
| 1021 |
+
# Update slider when image is uploaded
|
| 1022 |
+
depth_input_image.change(
|
| 1023 |
+
fn=update_slider_on_image_upload,
|
| 1024 |
+
inputs=[depth_input_image],
|
| 1025 |
+
outputs=[points_slider]
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
# Modified depth submit function to store depth map
|
| 1029 |
+
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
|
| 1030 |
+
results = on_depth_submit(image, num_points, focal_x, focal_y)
|
| 1031 |
+
# Extract depth map from results for severity analysis
|
| 1032 |
+
depth_map = None
|
| 1033 |
+
if image is not None:
|
| 1034 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 1035 |
+
# Normalize depth for severity analysis
|
| 1036 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 1037 |
+
depth_map = norm_depth.astype(np.uint8)
|
| 1038 |
+
return results + [depth_map]
|
| 1039 |
+
|
| 1040 |
+
depth_submit.click(on_depth_submit_with_state,
|
| 1041 |
+
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
|
| 1042 |
+
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
|
| 1043 |
+
|
| 1044 |
+
# Load depth map to severity tab and auto-generate mask
|
| 1045 |
+
def load_depth_to_severity(depth_map, original_image):
|
| 1046 |
+
if depth_map is None:
|
| 1047 |
+
return None, None, None, "β No depth map available. Please compute depth in Tab 2 first."
|
| 1048 |
+
|
| 1049 |
+
# Auto-generate wound mask using segmentation model
|
| 1050 |
+
if original_image is not None:
|
| 1051 |
+
auto_mask, _ = segmentation_model.segment_wound(original_image)
|
| 1052 |
+
if auto_mask is not None:
|
| 1053 |
+
# Post-process the mask
|
| 1054 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1055 |
+
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
|
| 1056 |
+
return depth_map, original_image, processed_mask, "β
Depth map loaded and wound mask auto-generated!"
|
| 1057 |
+
else:
|
| 1058 |
+
return depth_map, original_image, None, "β
Depth map loaded but no wound detected. Try uploading a different image."
|
| 1059 |
+
else:
|
| 1060 |
+
return depth_map, original_image, None, "β
Depth map loaded but segmentation failed. Try uploading a different image."
|
| 1061 |
+
else:
|
| 1062 |
+
return depth_map, original_image, None, "β
Depth map loaded successfully!"
|
| 1063 |
+
|
| 1064 |
+
load_depth_btn.click(
|
| 1065 |
+
fn=load_depth_to_severity,
|
| 1066 |
+
inputs=[depth_map_state, depth_input_image],
|
| 1067 |
+
outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# Automatic severity analysis function
|
| 1071 |
+
def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration):
|
| 1072 |
+
if depth_map is None:
|
| 1073 |
+
return "β Please load depth map from Tab 2 first."
|
| 1074 |
+
|
| 1075 |
+
# Generate automatic wound mask using the actual model
|
| 1076 |
+
auto_mask = create_automatic_wound_mask(image, method='deep_learning')
|
| 1077 |
+
|
| 1078 |
+
if auto_mask is None:
|
| 1079 |
+
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
|
| 1080 |
+
|
| 1081 |
+
# Post-process the mask with fixed minimum area
|
| 1082 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1083 |
+
|
| 1084 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 1085 |
+
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
|
| 1086 |
+
|
| 1087 |
+
# Analyze severity using the automatic mask
|
| 1088 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration)
|
| 1089 |
+
|
| 1090 |
+
# Manual severity analysis function
|
| 1091 |
+
def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing, depth_calibration):
|
| 1092 |
+
if depth_map is None:
|
| 1093 |
+
return "β Please load depth map from Tab 2 first."
|
| 1094 |
+
if wound_mask is None:
|
| 1095 |
+
return "β Please upload a wound mask (binary image where white pixels represent the wound area)."
|
| 1096 |
+
|
| 1097 |
+
return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing, depth_calibration)
|
| 1098 |
+
|
| 1099 |
+
# Connect event handlers
|
| 1100 |
+
auto_severity_button.click(
|
| 1101 |
+
fn=run_auto_severity_analysis,
|
| 1102 |
+
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider],
|
| 1103 |
+
outputs=[severity_output]
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
manual_severity_button.click(
|
| 1107 |
+
fn=run_manual_severity_analysis,
|
| 1108 |
+
inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider, depth_calibration_slider],
|
| 1109 |
+
outputs=[severity_output]
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
# Auto-generate mask when image is uploaded
|
| 1115 |
+
def auto_generate_mask_on_image_upload(image):
|
| 1116 |
+
if image is None:
|
| 1117 |
+
return None, "β No image uploaded."
|
| 1118 |
+
|
| 1119 |
+
# Generate automatic wound mask using segmentation model
|
| 1120 |
+
auto_mask, _ = segmentation_model.segment_wound(image)
|
| 1121 |
+
if auto_mask is not None:
|
| 1122 |
+
# Post-process the mask
|
| 1123 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 1124 |
+
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
|
| 1125 |
+
return processed_mask, "β
Wound mask auto-generated using deep learning model!"
|
| 1126 |
+
else:
|
| 1127 |
+
return None, "β
Image uploaded but no wound detected. Try uploading a different image."
|
| 1128 |
+
else:
|
| 1129 |
+
return None, "β
Image uploaded but segmentation failed. Try uploading a different image."
|
| 1130 |
+
|
| 1131 |
+
# Load shared image from classification tab
|
| 1132 |
+
def load_shared_image(shared_img):
|
| 1133 |
+
if shared_img is None:
|
| 1134 |
+
return gr.Image(), "β No image available from classification tab"
|
| 1135 |
+
|
| 1136 |
+
# Convert PIL image to numpy array for depth estimation
|
| 1137 |
+
if hasattr(shared_img, 'convert'):
|
| 1138 |
+
# It's a PIL image, convert to numpy
|
| 1139 |
+
img_array = np.array(shared_img)
|
| 1140 |
+
return img_array, "β
Image loaded from classification tab"
|
| 1141 |
+
else:
|
| 1142 |
+
# Already numpy array
|
| 1143 |
+
return shared_img, "β
Image loaded from classification tab"
|
| 1144 |
+
|
| 1145 |
+
# Auto-generate mask when image is uploaded to severity tab
|
| 1146 |
+
severity_input_image.change(
|
| 1147 |
+
fn=auto_generate_mask_on_image_upload,
|
| 1148 |
+
inputs=[severity_input_image],
|
| 1149 |
+
outputs=[wound_mask_input, gr.HTML()]
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
load_shared_btn.click(
|
| 1153 |
+
fn=load_shared_image,
|
| 1154 |
+
inputs=[shared_image],
|
| 1155 |
+
outputs=[depth_input_image, gr.HTML()]
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
# Pass image to depth tab function
|
| 1159 |
+
def pass_image_to_depth(img):
|
| 1160 |
+
if img is None:
|
| 1161 |
+
return "β No image uploaded in classification tab"
|
| 1162 |
+
return "β
Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
|
| 1163 |
+
|
| 1164 |
+
pass_to_depth_btn.click(
|
| 1165 |
+
fn=pass_image_to_depth,
|
| 1166 |
+
inputs=[shared_image],
|
| 1167 |
+
outputs=[pass_status]
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
if __name__ == '__main__':
|
| 1171 |
+
demo.queue().launch(
|
| 1172 |
+
server_name="0.0.0.0",
|
| 1173 |
+
server_port=7860,
|
| 1174 |
+
share=True
|
| 1175 |
+
)
|
temp_files/predict.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
from keras.models import load_model
|
| 3 |
+
from keras.utils.generic_utils import CustomObjectScope
|
| 4 |
+
|
| 5 |
+
from models.unets import Unet2D
|
| 6 |
+
from models.deeplab import Deeplabv3, relu6, BilinearUpsampling, DepthwiseConv2D
|
| 7 |
+
from models.FCN import FCN_Vgg16_16s
|
| 8 |
+
|
| 9 |
+
from utils.learning.metrics import dice_coef, precision, recall
|
| 10 |
+
from utils.BilinearUpSampling import BilinearUpSampling2D
|
| 11 |
+
from utils.io.data import load_data, save_results, save_rgb_results, save_history, load_test_images, DataGen
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# settings
|
| 15 |
+
input_dim_x = 224
|
| 16 |
+
input_dim_y = 224
|
| 17 |
+
color_space = 'rgb'
|
| 18 |
+
path = './data/Medetec_foot_ulcer_224/'
|
| 19 |
+
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
|
| 20 |
+
pred_save_path = '2019-12-19 01%3A53%3A15.480800/'
|
| 21 |
+
|
| 22 |
+
data_gen = DataGen(path, split_ratio=0.0, x=input_dim_x, y=input_dim_y, color_space=color_space)
|
| 23 |
+
x_test, test_label_filenames_list = load_test_images(path)
|
| 24 |
+
|
| 25 |
+
# ### get unet model
|
| 26 |
+
# unet2d = Unet2D(n_filters=64, input_dim_x=input_dim_x, input_dim_y=input_dim_y, num_channels=3)
|
| 27 |
+
# model = unet2d.get_unet_model_yuanqing()
|
| 28 |
+
# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name
|
| 29 |
+
# , custom_objects={'recall':recall,
|
| 30 |
+
# 'precision':precision,
|
| 31 |
+
# 'dice_coef': dice_coef,
|
| 32 |
+
# 'relu6':relu6,
|
| 33 |
+
# 'DepthwiseConv2D':DepthwiseConv2D,
|
| 34 |
+
# 'BilinearUpsampling':BilinearUpsampling})
|
| 35 |
+
|
| 36 |
+
# ### get separable unet model
|
| 37 |
+
# sep_unet = Separable_Unet2D(n_filters=64, input_dim_x=input_dim_x, input_dim_y=input_dim_y, num_channels=3)
|
| 38 |
+
# model, model_name = sep_unet.get_sep_unet_v2()
|
| 39 |
+
# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name
|
| 40 |
+
# , custom_objects={'dice_coef': dice_coef,
|
| 41 |
+
# 'relu6':relu6,
|
| 42 |
+
# 'DepthwiseConv2D':DepthwiseConv2D,
|
| 43 |
+
# 'BilinearUpsampling':BilinearUpsampling})
|
| 44 |
+
|
| 45 |
+
# ### get VGG16 model
|
| 46 |
+
# model, model_name = FCN_Vgg16_16s(input_shape=(input_dim_x, input_dim_y, 3))
|
| 47 |
+
# with CustomObjectScope({'BilinearUpSampling2D':BilinearUpSampling2D}):
|
| 48 |
+
# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name
|
| 49 |
+
# , custom_objects={'dice_coef': dice_coef})
|
| 50 |
+
|
| 51 |
+
# ### get mobilenetv2 model
|
| 52 |
+
model = Deeplabv3(input_shape=(input_dim_x, input_dim_y, 3), classes=1)
|
| 53 |
+
model = load_model('./training_history/' + weight_file_name
|
| 54 |
+
, custom_objects={'recall':recall,
|
| 55 |
+
'precision':precision,
|
| 56 |
+
'dice_coef': dice_coef,
|
| 57 |
+
'relu6':relu6,
|
| 58 |
+
'DepthwiseConv2D':DepthwiseConv2D,
|
| 59 |
+
'BilinearUpsampling':BilinearUpsampling})
|
| 60 |
+
|
| 61 |
+
for image_batch, label_batch in data_gen.generate_data(batch_size=len(x_test), test=True):
|
| 62 |
+
prediction = model.predict(image_batch, verbose=1)
|
| 63 |
+
save_results(prediction, 'rgb', path + 'test/predictions/' + pred_save_path, test_label_filenames_list)
|
| 64 |
+
break
|
temp_files/requirements.txt
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles
|
| 2 |
+
annotated-types
|
| 3 |
+
anyio
|
| 4 |
+
asttokens
|
| 5 |
+
attrs
|
| 6 |
+
blinker
|
| 7 |
+
certifi
|
| 8 |
+
charset-normalizer
|
| 9 |
+
click
|
| 10 |
+
colorama
|
| 11 |
+
comm
|
| 12 |
+
ConfigArgParse
|
| 13 |
+
contourpy
|
| 14 |
+
cycler
|
| 15 |
+
dash
|
| 16 |
+
decorator
|
| 17 |
+
executing
|
| 18 |
+
fastapi
|
| 19 |
+
fastjsonschema
|
| 20 |
+
ffmpy
|
| 21 |
+
filelock
|
| 22 |
+
Flask
|
| 23 |
+
fonttools
|
| 24 |
+
fsspec
|
| 25 |
+
gdown
|
| 26 |
+
gradio
|
| 27 |
+
gradio_client
|
| 28 |
+
gradio_imageslider
|
| 29 |
+
groovy
|
| 30 |
+
h11
|
| 31 |
+
httpcore
|
| 32 |
+
httpx
|
| 33 |
+
huggingface-hub
|
| 34 |
+
idna
|
| 35 |
+
importlib_metadata
|
| 36 |
+
itsdangerous
|
| 37 |
+
jedi
|
| 38 |
+
Jinja2
|
| 39 |
+
jsonschema
|
| 40 |
+
jsonschema-specifications
|
| 41 |
+
jupyter_core
|
| 42 |
+
jupyterlab_widgets
|
| 43 |
+
kiwisolver
|
| 44 |
+
markdown-it-py
|
| 45 |
+
MarkupSafe
|
| 46 |
+
matplotlib
|
| 47 |
+
matplotlib-inline
|
| 48 |
+
mdurl
|
| 49 |
+
mpmath
|
| 50 |
+
narwhals
|
| 51 |
+
nbformat
|
| 52 |
+
nest-asyncio
|
| 53 |
+
networkx
|
| 54 |
+
numpy<2
|
| 55 |
+
open3d
|
| 56 |
+
opencv-python
|
| 57 |
+
orjson
|
| 58 |
+
packaging
|
| 59 |
+
pandas
|
| 60 |
+
parso
|
| 61 |
+
pillow
|
| 62 |
+
platformdirs
|
| 63 |
+
plotly
|
| 64 |
+
prompt_toolkit
|
| 65 |
+
pure_eval
|
| 66 |
+
pydantic_core
|
| 67 |
+
pydub
|
| 68 |
+
Pygments
|
| 69 |
+
pyparsing
|
| 70 |
+
python-dateutil
|
| 71 |
+
python-multipart
|
| 72 |
+
pytz
|
| 73 |
+
PyYAML
|
| 74 |
+
referencing
|
| 75 |
+
requests
|
| 76 |
+
retrying
|
| 77 |
+
rich
|
| 78 |
+
rpds-py
|
| 79 |
+
ruff
|
| 80 |
+
safehttpx
|
| 81 |
+
scikit-image
|
| 82 |
+
semantic-version
|
| 83 |
+
setuptools
|
| 84 |
+
shellingham
|
| 85 |
+
six
|
| 86 |
+
sniffio
|
| 87 |
+
stack-data
|
| 88 |
+
starlette
|
| 89 |
+
sympy
|
| 90 |
+
tensorflow<2.11
|
| 91 |
+
tensorflow_hub
|
| 92 |
+
tomlkit
|
| 93 |
+
torch
|
| 94 |
+
torchvision
|
| 95 |
+
tqdm
|
| 96 |
+
traitlets
|
| 97 |
+
typer
|
| 98 |
+
typing-inspection
|
| 99 |
+
typing_extensions
|
| 100 |
+
tzdata
|
| 101 |
+
urllib3
|
| 102 |
+
uvicorn
|
| 103 |
+
wcwidth
|
| 104 |
+
websockets
|
| 105 |
+
Werkzeug
|
| 106 |
+
wheel
|
| 107 |
+
widgetsnbextension
|
| 108 |
+
zipp
|
| 109 |
+
pydantic==2.10.6
|
temp_files/run_gradio_app.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple launcher for the Wound Segmentation Gradio App
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def check_dependencies():
|
| 10 |
+
"""Check if required dependencies are installed"""
|
| 11 |
+
required_packages = ['gradio', 'tensorflow', 'cv2', 'numpy']
|
| 12 |
+
missing_packages = []
|
| 13 |
+
|
| 14 |
+
for package in required_packages:
|
| 15 |
+
try:
|
| 16 |
+
if package == 'cv2':
|
| 17 |
+
import cv2
|
| 18 |
+
else:
|
| 19 |
+
__import__(package)
|
| 20 |
+
except ImportError:
|
| 21 |
+
missing_packages.append(package)
|
| 22 |
+
|
| 23 |
+
if missing_packages:
|
| 24 |
+
print("β Missing required packages:")
|
| 25 |
+
for package in missing_packages:
|
| 26 |
+
print(f" - {package}")
|
| 27 |
+
print("\nπ¦ Install missing packages with:")
|
| 28 |
+
print(" pip install -r requirements.txt")
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
print("β
All required packages are installed!")
|
| 32 |
+
return True
|
| 33 |
+
|
| 34 |
+
def check_model_files():
|
| 35 |
+
"""Check if model files exist"""
|
| 36 |
+
model_files = [
|
| 37 |
+
'training_history/2025-08-07_12-30-43.hdf5',
|
| 38 |
+
'training_history/2019-12-19 01%3A53%3A15.480800.hdf5'
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
existing_models = []
|
| 42 |
+
for model_file in model_files:
|
| 43 |
+
if os.path.exists(model_file):
|
| 44 |
+
existing_models.append(model_file)
|
| 45 |
+
|
| 46 |
+
if not existing_models:
|
| 47 |
+
print("β No model files found!")
|
| 48 |
+
print(" Please ensure you have trained models in the training_history/ directory")
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
print(f"β
Found {len(existing_models)} model file(s):")
|
| 52 |
+
for model in existing_models:
|
| 53 |
+
print(f" - {model}")
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
+
"""Main function to launch the Gradio app"""
|
| 58 |
+
print("π Starting Wound Segmentation Gradio App...")
|
| 59 |
+
print("=" * 50)
|
| 60 |
+
|
| 61 |
+
# Check dependencies
|
| 62 |
+
if not check_dependencies():
|
| 63 |
+
sys.exit(1)
|
| 64 |
+
|
| 65 |
+
# Check model files
|
| 66 |
+
if not check_model_files():
|
| 67 |
+
sys.exit(1)
|
| 68 |
+
|
| 69 |
+
print("\nπ― Launching Gradio interface...")
|
| 70 |
+
print(" The app will be available at: http://localhost:7860")
|
| 71 |
+
print(" Press Ctrl+C to stop the server")
|
| 72 |
+
print("=" * 50)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Import and run the Gradio app
|
| 76 |
+
from gradio_app import create_gradio_interface
|
| 77 |
+
|
| 78 |
+
interface = create_gradio_interface()
|
| 79 |
+
interface.launch(
|
| 80 |
+
server_name="0.0.0.0",
|
| 81 |
+
server_port=7860,
|
| 82 |
+
share=True,
|
| 83 |
+
show_error=True
|
| 84 |
+
)
|
| 85 |
+
except KeyboardInterrupt:
|
| 86 |
+
print("\nπ Gradio app stopped by user")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"\nβ Error launching Gradio app: {e}")
|
| 89 |
+
sys.exit(1)
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
temp_files/segmentation_app.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from tensorflow import keras
|
| 6 |
+
from keras.models import load_model
|
| 7 |
+
from keras.utils.generic_utils import CustomObjectScope
|
| 8 |
+
|
| 9 |
+
# Import custom modules
|
| 10 |
+
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
|
| 11 |
+
from utils.learning.metrics import dice_coef, precision, recall
|
| 12 |
+
from utils.io.data import normalize
|
| 13 |
+
|
| 14 |
+
class WoundSegmentationApp:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.input_dim_x = 224
|
| 17 |
+
self.input_dim_y = 224
|
| 18 |
+
self.model = None
|
| 19 |
+
self.load_model()
|
| 20 |
+
|
| 21 |
+
def load_model(self):
|
| 22 |
+
"""Load the trained wound segmentation model"""
|
| 23 |
+
try:
|
| 24 |
+
# Load the model with custom objects
|
| 25 |
+
weight_file_name = '2025-08-07_12-30-43.hdf5' # Use the most recent model
|
| 26 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 27 |
+
|
| 28 |
+
self.model = load_model(model_path,
|
| 29 |
+
custom_objects={
|
| 30 |
+
'recall': recall,
|
| 31 |
+
'precision': precision,
|
| 32 |
+
'dice_coef': dice_coef,
|
| 33 |
+
'relu6': relu6,
|
| 34 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 35 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 36 |
+
})
|
| 37 |
+
print(f"Model loaded successfully from {model_path}")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error loading model: {e}")
|
| 40 |
+
# Fallback to the older model if the newer one fails
|
| 41 |
+
try:
|
| 42 |
+
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
|
| 43 |
+
model_path = f'./training_history/{weight_file_name}'
|
| 44 |
+
|
| 45 |
+
self.model = load_model(model_path,
|
| 46 |
+
custom_objects={
|
| 47 |
+
'recall': recall,
|
| 48 |
+
'precision': precision,
|
| 49 |
+
'dice_coef': dice_coef,
|
| 50 |
+
'relu6': relu6,
|
| 51 |
+
'DepthwiseConv2D': DepthwiseConv2D,
|
| 52 |
+
'BilinearUpsampling': BilinearUpsampling
|
| 53 |
+
})
|
| 54 |
+
print(f"Model loaded successfully from {model_path}")
|
| 55 |
+
except Exception as e2:
|
| 56 |
+
print(f"Error loading fallback model: {e2}")
|
| 57 |
+
self.model = None
|
| 58 |
+
|
| 59 |
+
def preprocess_image(self, image):
|
| 60 |
+
"""Preprocess the uploaded image for model input"""
|
| 61 |
+
if image is None:
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
# Convert to RGB if needed
|
| 65 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 66 |
+
# Convert BGR to RGB if needed
|
| 67 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 68 |
+
|
| 69 |
+
# Resize to model input size
|
| 70 |
+
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
|
| 71 |
+
|
| 72 |
+
# Normalize the image
|
| 73 |
+
image = image.astype(np.float32) / 255.0
|
| 74 |
+
|
| 75 |
+
# Add batch dimension
|
| 76 |
+
image = np.expand_dims(image, axis=0)
|
| 77 |
+
|
| 78 |
+
return image
|
| 79 |
+
|
| 80 |
+
def postprocess_prediction(self, prediction):
|
| 81 |
+
"""Postprocess the model prediction"""
|
| 82 |
+
# Remove batch dimension
|
| 83 |
+
prediction = prediction[0]
|
| 84 |
+
|
| 85 |
+
# Apply threshold to get binary mask
|
| 86 |
+
threshold = 0.5
|
| 87 |
+
binary_mask = (prediction > threshold).astype(np.uint8) * 255
|
| 88 |
+
|
| 89 |
+
# Convert to 3-channel image for visualization
|
| 90 |
+
mask_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2RGB)
|
| 91 |
+
|
| 92 |
+
return mask_rgb
|
| 93 |
+
|
| 94 |
+
def segment_wound(self, input_image):
|
| 95 |
+
"""Main function to segment wound from uploaded image"""
|
| 96 |
+
if self.model is None:
|
| 97 |
+
return None, "Error: Model not loaded. Please check the model files."
|
| 98 |
+
|
| 99 |
+
if input_image is None:
|
| 100 |
+
return None, "Please upload an image."
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# Preprocess the image
|
| 104 |
+
processed_image = self.preprocess_image(input_image)
|
| 105 |
+
|
| 106 |
+
if processed_image is None:
|
| 107 |
+
return None, "Error processing image."
|
| 108 |
+
|
| 109 |
+
# Make prediction
|
| 110 |
+
prediction = self.model.predict(processed_image, verbose=0)
|
| 111 |
+
|
| 112 |
+
# Postprocess the prediction
|
| 113 |
+
segmented_mask = self.postprocess_prediction(prediction)
|
| 114 |
+
|
| 115 |
+
# Create overlay image (original image with segmentation overlay)
|
| 116 |
+
original_resized = cv2.resize(input_image, (self.input_dim_x, self.input_dim_y))
|
| 117 |
+
if len(original_resized.shape) == 3:
|
| 118 |
+
original_resized = cv2.cvtColor(original_resized, cv2.COLOR_RGB2BGR)
|
| 119 |
+
|
| 120 |
+
# Create overlay with red segmentation
|
| 121 |
+
overlay = original_resized.copy()
|
| 122 |
+
mask_red = np.zeros_like(original_resized)
|
| 123 |
+
mask_red[:, :, 2] = segmented_mask[:, :, 0] # Red channel
|
| 124 |
+
|
| 125 |
+
# Blend overlay with original image
|
| 126 |
+
alpha = 0.6
|
| 127 |
+
overlay = cv2.addWeighted(overlay, 1-alpha, mask_red, alpha, 0)
|
| 128 |
+
|
| 129 |
+
return segmented_mask, overlay
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return None, f"Error during segmentation: {str(e)}"
|
| 133 |
+
|
| 134 |
+
def create_gradio_interface():
|
| 135 |
+
"""Create and return the Gradio interface"""
|
| 136 |
+
|
| 137 |
+
# Initialize the app
|
| 138 |
+
app = WoundSegmentationApp()
|
| 139 |
+
|
| 140 |
+
# Define the interface
|
| 141 |
+
with gr.Blocks(title="Wound Segmentation Tool", theme=gr.themes.Soft()) as interface:
|
| 142 |
+
gr.Markdown(
|
| 143 |
+
"""
|
| 144 |
+
# π©Ή Wound Segmentation Tool
|
| 145 |
+
|
| 146 |
+
Upload an image of a wound to get an automated segmentation mask.
|
| 147 |
+
The model will identify and highlight the wound area in the image.
|
| 148 |
+
|
| 149 |
+
**Instructions:**
|
| 150 |
+
1. Upload an image of a wound
|
| 151 |
+
2. Click "Segment Wound" to process the image
|
| 152 |
+
3. View the segmentation mask and overlay results
|
| 153 |
+
"""
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
with gr.Row():
|
| 157 |
+
with gr.Column():
|
| 158 |
+
input_image = gr.Image(
|
| 159 |
+
label="Upload Wound Image",
|
| 160 |
+
type="numpy",
|
| 161 |
+
height=400
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
segment_btn = gr.Button(
|
| 165 |
+
"π Segment Wound",
|
| 166 |
+
variant="primary",
|
| 167 |
+
size="lg"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
with gr.Column():
|
| 171 |
+
mask_output = gr.Image(
|
| 172 |
+
label="Segmentation Mask",
|
| 173 |
+
height=400
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
overlay_output = gr.Image(
|
| 177 |
+
label="Overlay Result",
|
| 178 |
+
height=400
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Status message
|
| 182 |
+
status_msg = gr.Textbox(
|
| 183 |
+
label="Status",
|
| 184 |
+
interactive=False,
|
| 185 |
+
placeholder="Ready to process images..."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Example images
|
| 189 |
+
gr.Markdown("### πΈ Example Images")
|
| 190 |
+
gr.Markdown("You can test the tool with wound images from the dataset.")
|
| 191 |
+
|
| 192 |
+
# Connect the button to the segmentation function
|
| 193 |
+
def process_image(image):
|
| 194 |
+
mask, overlay = app.segment_wound(image)
|
| 195 |
+
if mask is None:
|
| 196 |
+
return None, None, overlay # overlay contains error message
|
| 197 |
+
return mask, overlay, "Segmentation completed successfully!"
|
| 198 |
+
|
| 199 |
+
segment_btn.click(
|
| 200 |
+
fn=process_image,
|
| 201 |
+
inputs=[input_image],
|
| 202 |
+
outputs=[mask_output, overlay_output, status_msg]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Auto-process when image is uploaded
|
| 206 |
+
input_image.change(
|
| 207 |
+
fn=process_image,
|
| 208 |
+
inputs=[input_image],
|
| 209 |
+
outputs=[mask_output, overlay_output, status_msg]
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return interface
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
# Create and launch the interface
|
| 216 |
+
interface = create_gradio_interface()
|
| 217 |
+
interface.launch(
|
| 218 |
+
server_name="0.0.0.0",
|
| 219 |
+
server_port=7860,
|
| 220 |
+
share=True,
|
| 221 |
+
show_error=True
|
| 222 |
+
)
|
temp_files/test1.txt
ADDED
|
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import tempfile
|
| 8 |
+
from gradio_imageslider import ImageSlider
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
| 13 |
+
import os
|
| 14 |
+
import tensorflow as tf
|
| 15 |
+
from tensorflow.keras.models import load_model
|
| 16 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
| 17 |
+
import base64
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import gdown
|
| 20 |
+
import spaces
|
| 21 |
+
import cv2
|
| 22 |
+
from skimage import filters, morphology, measure
|
| 23 |
+
from skimage.segmentation import clear_border
|
| 24 |
+
|
| 25 |
+
# --- LINEAR INITIALIZATION - NO MODULAR FUNCTIONS ---
|
| 26 |
+
print("Starting linear initialization for ZeroGPU compatibility...")
|
| 27 |
+
|
| 28 |
+
# Define path and file ID
|
| 29 |
+
checkpoint_dir = "checkpoints"
|
| 30 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
|
| 33 |
+
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
|
| 34 |
+
|
| 35 |
+
# Download if not already present
|
| 36 |
+
if not os.path.exists(model_file):
|
| 37 |
+
print("Downloading model from Google Drive...")
|
| 38 |
+
gdown.download(gdrive_url, model_file, quiet=False)
|
| 39 |
+
|
| 40 |
+
# --- TensorFlow: Check GPU Availability ---
|
| 41 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 42 |
+
if gpus:
|
| 43 |
+
print("TensorFlow is using GPU")
|
| 44 |
+
else:
|
| 45 |
+
print("TensorFlow is using CPU")
|
| 46 |
+
|
| 47 |
+
# --- Load Wound Classification Model and Class Labels ---
|
| 48 |
+
wound_model = load_model("/home/user/app/keras_model.h5")
|
| 49 |
+
with open("/home/user/app/labels.txt", "r") as f:
|
| 50 |
+
class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
|
| 51 |
+
|
| 52 |
+
# --- PyTorch: Set Device and Load Depth Model ---
|
| 53 |
+
print("Initializing PyTorch device...")
|
| 54 |
+
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
|
| 55 |
+
print(f"Using PyTorch device: {map_device}")
|
| 56 |
+
|
| 57 |
+
model_configs = {
|
| 58 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 59 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 60 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 61 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
| 62 |
+
}
|
| 63 |
+
encoder = 'vitl'
|
| 64 |
+
depth_model = DepthAnythingV2(**model_configs[encoder])
|
| 65 |
+
state_dict = torch.load(
|
| 66 |
+
f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth',
|
| 67 |
+
map_location=map_device
|
| 68 |
+
)
|
| 69 |
+
depth_model.load_state_dict(state_dict)
|
| 70 |
+
depth_model = depth_model.to(map_device).eval()
|
| 71 |
+
|
| 72 |
+
# --- Custom CSS for unified dark theme ---
|
| 73 |
+
css = """
|
| 74 |
+
.gradio-container {
|
| 75 |
+
font-family: 'Segoe UI', sans-serif;
|
| 76 |
+
background-color: #121212;
|
| 77 |
+
color: #ffffff;
|
| 78 |
+
padding: 20px;
|
| 79 |
+
}
|
| 80 |
+
.gr-button {
|
| 81 |
+
background-color: #2c3e50;
|
| 82 |
+
color: white;
|
| 83 |
+
border-radius: 10px;
|
| 84 |
+
}
|
| 85 |
+
.gr-button:hover {
|
| 86 |
+
background-color: #34495e;
|
| 87 |
+
}
|
| 88 |
+
.gr-html, .gr-html div {
|
| 89 |
+
white-space: normal !important;
|
| 90 |
+
overflow: visible !important;
|
| 91 |
+
text-overflow: unset !important;
|
| 92 |
+
word-break: break-word !important;
|
| 93 |
+
}
|
| 94 |
+
#img-display-container {
|
| 95 |
+
max-height: 100vh;
|
| 96 |
+
}
|
| 97 |
+
#img-display-input {
|
| 98 |
+
max-height: 80vh;
|
| 99 |
+
}
|
| 100 |
+
#img-display-output {
|
| 101 |
+
max-height: 80vh;
|
| 102 |
+
}
|
| 103 |
+
#download {
|
| 104 |
+
height: 62px;
|
| 105 |
+
}
|
| 106 |
+
h1 {
|
| 107 |
+
text-align: center;
|
| 108 |
+
font-size: 3rem;
|
| 109 |
+
font-weight: bold;
|
| 110 |
+
margin: 2rem 0;
|
| 111 |
+
color: #ffffff;
|
| 112 |
+
}
|
| 113 |
+
h2 {
|
| 114 |
+
color: #ffffff;
|
| 115 |
+
text-align: center;
|
| 116 |
+
margin: 1rem 0;
|
| 117 |
+
}
|
| 118 |
+
.gr-tabs {
|
| 119 |
+
background-color: #1e1e1e;
|
| 120 |
+
border-radius: 10px;
|
| 121 |
+
padding: 10px;
|
| 122 |
+
}
|
| 123 |
+
.gr-tab-nav {
|
| 124 |
+
background-color: #2c3e50;
|
| 125 |
+
border-radius: 8px;
|
| 126 |
+
}
|
| 127 |
+
.gr-tab-nav button {
|
| 128 |
+
color: #ffffff !important;
|
| 129 |
+
}
|
| 130 |
+
.gr-tab-nav button.selected {
|
| 131 |
+
background-color: #34495e !important;
|
| 132 |
+
}
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# --- LINEAR FUNCTION DEFINITIONS (NO MODULAR CALLS) ---
|
| 136 |
+
|
| 137 |
+
# Wound Classification Functions
|
| 138 |
+
def preprocess_input(img):
|
| 139 |
+
img = img.resize((224, 224))
|
| 140 |
+
arr = keras_image.img_to_array(img)
|
| 141 |
+
arr = arr / 255.0
|
| 142 |
+
return np.expand_dims(arr, axis=0)
|
| 143 |
+
|
| 144 |
+
def get_reasoning_from_gemini(img, prediction):
|
| 145 |
+
try:
|
| 146 |
+
explanations = {
|
| 147 |
+
"Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
|
| 148 |
+
"Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
|
| 149 |
+
"Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
|
| 150 |
+
"Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
|
| 151 |
+
"Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
|
| 152 |
+
}
|
| 153 |
+
return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
|
| 154 |
+
except Exception as e:
|
| 155 |
+
return f"(Reasoning unavailable: {str(e)})"
|
| 156 |
+
|
| 157 |
+
@spaces.GPU
|
| 158 |
+
def classify_wound_image(img):
|
| 159 |
+
if img is None:
|
| 160 |
+
return "<div style='color:#ff5252; font-size:18px;'>No image provided</div>", ""
|
| 161 |
+
|
| 162 |
+
img_array = preprocess_input(img)
|
| 163 |
+
predictions = wound_model.predict(img_array, verbose=0)[0]
|
| 164 |
+
pred_idx = int(np.argmax(predictions))
|
| 165 |
+
pred_class = class_labels[pred_idx]
|
| 166 |
+
|
| 167 |
+
reasoning_text = get_reasoning_from_gemini(img, pred_class)
|
| 168 |
+
|
| 169 |
+
predicted_card = f"""
|
| 170 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 171 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 172 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 173 |
+
Predicted Wound Type
|
| 174 |
+
</div>
|
| 175 |
+
<div style='font-size: 26px; color: white;'>
|
| 176 |
+
{pred_class}
|
| 177 |
+
</div>
|
| 178 |
+
</div>
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
reasoning_card = f"""
|
| 182 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 183 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 184 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 185 |
+
Reasoning
|
| 186 |
+
</div>
|
| 187 |
+
<div style='font-size: 16px; color: white; min-height: 80px;'>
|
| 188 |
+
{reasoning_text}
|
| 189 |
+
</div>
|
| 190 |
+
</div>
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
return predicted_card, reasoning_card
|
| 194 |
+
|
| 195 |
+
# Depth Estimation Functions
|
| 196 |
+
@spaces.GPU
|
| 197 |
+
def predict_depth(image):
|
| 198 |
+
return depth_model.infer_image(image)
|
| 199 |
+
|
| 200 |
+
def calculate_max_points(image):
|
| 201 |
+
if image is None:
|
| 202 |
+
return 10000
|
| 203 |
+
h, w = image.shape[:2]
|
| 204 |
+
max_points = h * w * 3
|
| 205 |
+
return max(1000, min(max_points, 300000))
|
| 206 |
+
|
| 207 |
+
def update_slider_on_image_upload(image):
|
| 208 |
+
max_points = calculate_max_points(image)
|
| 209 |
+
default_value = min(10000, max_points // 10)
|
| 210 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
| 211 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
| 212 |
+
|
| 213 |
+
@spaces.GPU
|
| 214 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
|
| 215 |
+
h, w = depth_map.shape
|
| 216 |
+
step = max(1, int(np.sqrt(h * w / max_points) * 0.5))
|
| 217 |
+
|
| 218 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 219 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
| 220 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
| 221 |
+
depth_values = depth_map[::step, ::step]
|
| 222 |
+
|
| 223 |
+
x_3d = x_cam * depth_values
|
| 224 |
+
y_3d = y_cam * depth_values
|
| 225 |
+
z_3d = depth_values
|
| 226 |
+
|
| 227 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
| 228 |
+
image_colors = image[::step, ::step, :]
|
| 229 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
| 230 |
+
|
| 231 |
+
pcd = o3d.geometry.PointCloud()
|
| 232 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 233 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 234 |
+
|
| 235 |
+
return pcd
|
| 236 |
+
|
| 237 |
+
@spaces.GPU
|
| 238 |
+
def reconstruct_surface_mesh_from_point_cloud(pcd):
|
| 239 |
+
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
|
| 240 |
+
pcd.orient_normals_consistent_tangent_plane(k=50)
|
| 241 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
|
| 242 |
+
return mesh
|
| 243 |
+
|
| 244 |
+
@spaces.GPU
|
| 245 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
| 246 |
+
h, w = depth_map.shape
|
| 247 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
| 248 |
+
|
| 249 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 250 |
+
focal_length = 470.4
|
| 251 |
+
x_cam = (x_coords - w / 2) / focal_length
|
| 252 |
+
y_cam = (y_coords - h / 2) / focal_length
|
| 253 |
+
depth_values = depth_map[::step, ::step]
|
| 254 |
+
|
| 255 |
+
x_3d = x_cam * depth_values
|
| 256 |
+
y_3d = y_cam * depth_values
|
| 257 |
+
z_3d = depth_values
|
| 258 |
+
|
| 259 |
+
x_flat = x_3d.flatten()
|
| 260 |
+
y_flat = y_3d.flatten()
|
| 261 |
+
z_flat = z_3d.flatten()
|
| 262 |
+
|
| 263 |
+
image_colors = image[::step, ::step, :]
|
| 264 |
+
colors_flat = image_colors.reshape(-1, 3)
|
| 265 |
+
|
| 266 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 267 |
+
x=x_flat,
|
| 268 |
+
y=y_flat,
|
| 269 |
+
z=z_flat,
|
| 270 |
+
mode='markers',
|
| 271 |
+
marker=dict(
|
| 272 |
+
size=1.5,
|
| 273 |
+
color=colors_flat,
|
| 274 |
+
opacity=0.9
|
| 275 |
+
),
|
| 276 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
| 277 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
| 278 |
+
'<extra></extra>'
|
| 279 |
+
)])
|
| 280 |
+
|
| 281 |
+
fig.update_layout(
|
| 282 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
| 283 |
+
scene=dict(
|
| 284 |
+
xaxis_title="X (meters)",
|
| 285 |
+
yaxis_title="Y (meters)",
|
| 286 |
+
zaxis_title="Z (meters)",
|
| 287 |
+
camera=dict(
|
| 288 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
| 289 |
+
center=dict(x=0, y=0, z=0),
|
| 290 |
+
up=dict(x=0, y=0, z=1)
|
| 291 |
+
),
|
| 292 |
+
aspectmode='data'
|
| 293 |
+
),
|
| 294 |
+
width=700,
|
| 295 |
+
height=600
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
return fig
|
| 299 |
+
|
| 300 |
+
def on_depth_submit(image, num_points, focal_x, focal_y):
|
| 301 |
+
original_image = image.copy()
|
| 302 |
+
h, w = image.shape[:2]
|
| 303 |
+
|
| 304 |
+
depth = predict_depth(image[:, :, ::-1])
|
| 305 |
+
|
| 306 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
| 307 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 308 |
+
raw_depth.save(tmp_raw_depth.name)
|
| 309 |
+
|
| 310 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 311 |
+
norm_depth = norm_depth.astype(np.uint8)
|
| 312 |
+
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
|
| 313 |
+
|
| 314 |
+
gray_depth = Image.fromarray(norm_depth)
|
| 315 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 316 |
+
gray_depth.save(tmp_gray_depth.name)
|
| 317 |
+
|
| 318 |
+
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
|
| 319 |
+
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
|
| 320 |
+
|
| 321 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
| 322 |
+
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
|
| 323 |
+
|
| 324 |
+
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
|
| 325 |
+
|
| 326 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
| 327 |
+
|
| 328 |
+
# Wound Severity Analysis Functions
|
| 329 |
+
@spaces.GPU
|
| 330 |
+
def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
|
| 331 |
+
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
|
| 332 |
+
wound_mask = (mask > 127)
|
| 333 |
+
wound_depths = depth_map[wound_mask]
|
| 334 |
+
total_area = np.sum(wound_mask) * pixel_area_cm2
|
| 335 |
+
|
| 336 |
+
shallow = wound_depths < 3
|
| 337 |
+
moderate = (wound_depths >= 3) & (wound_depths < 6)
|
| 338 |
+
deep = wound_depths >= 6
|
| 339 |
+
|
| 340 |
+
shallow_area = np.sum(shallow) * pixel_area_cm2
|
| 341 |
+
moderate_area = np.sum(moderate) * pixel_area_cm2
|
| 342 |
+
deep_area = np.sum(deep) * pixel_area_cm2
|
| 343 |
+
deep_ratio = deep_area / total_area if total_area > 0 else 0
|
| 344 |
+
|
| 345 |
+
return {
|
| 346 |
+
'total_area_cm2': total_area,
|
| 347 |
+
'shallow_area_cm2': shallow_area,
|
| 348 |
+
'moderate_area_cm2': moderate_area,
|
| 349 |
+
'deep_area_cm2': deep_area,
|
| 350 |
+
'deep_ratio': deep_ratio,
|
| 351 |
+
'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
def classify_wound_severity_by_area(depth_stats):
|
| 355 |
+
total = depth_stats['total_area_cm2']
|
| 356 |
+
deep = depth_stats['deep_area_cm2']
|
| 357 |
+
moderate = depth_stats['moderate_area_cm2']
|
| 358 |
+
|
| 359 |
+
if total == 0:
|
| 360 |
+
return "Unknown"
|
| 361 |
+
|
| 362 |
+
if deep > 2 or (deep / total) > 0.3:
|
| 363 |
+
return "Severe"
|
| 364 |
+
elif moderate > 1.5 or (moderate / total) > 0.4:
|
| 365 |
+
return "Moderate"
|
| 366 |
+
else:
|
| 367 |
+
return "Mild"
|
| 368 |
+
|
| 369 |
+
def get_severity_description(severity):
|
| 370 |
+
descriptions = {
|
| 371 |
+
"Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
|
| 372 |
+
"Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
|
| 373 |
+
"Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
|
| 374 |
+
"Unknown": "Unable to determine severity due to insufficient data."
|
| 375 |
+
}
|
| 376 |
+
return descriptions.get(severity, "Severity assessment unavailable.")
|
| 377 |
+
|
| 378 |
+
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
|
| 379 |
+
if image is None or depth_map is None or wound_mask is None:
|
| 380 |
+
return "β Please upload image, depth map, and wound mask."
|
| 381 |
+
|
| 382 |
+
if len(wound_mask.shape) == 3:
|
| 383 |
+
wound_mask = np.mean(wound_mask, axis=2)
|
| 384 |
+
|
| 385 |
+
if depth_map.shape[:2] != wound_mask.shape[:2]:
|
| 386 |
+
from PIL import Image
|
| 387 |
+
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
|
| 388 |
+
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
|
| 389 |
+
wound_mask = np.array(mask_pil)
|
| 390 |
+
|
| 391 |
+
stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
|
| 392 |
+
severity = classify_wound_severity_by_area(stats)
|
| 393 |
+
|
| 394 |
+
severity_color = {
|
| 395 |
+
"Mild": "#4CAF50",
|
| 396 |
+
"Moderate": "#FF9800",
|
| 397 |
+
"Severe": "#F44336"
|
| 398 |
+
}.get(severity, "#9E9E9E")
|
| 399 |
+
|
| 400 |
+
report = f"""
|
| 401 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 402 |
+
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
|
| 403 |
+
π©Ή Wound Severity Analysis
|
| 404 |
+
</div>
|
| 405 |
+
|
| 406 |
+
<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-bottom: 20px;'>
|
| 407 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 408 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 409 |
+
π Area Measurements
|
| 410 |
+
</div>
|
| 411 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 412 |
+
<div>π’ <b>Total Area:</b> {stats['total_area_cm2']:.2f} cmΒ²</div>
|
| 413 |
+
<div>π© <b>Shallow (0-3mm):</b> {stats['shallow_area_cm2']:.2f} cmΒ²</div>
|
| 414 |
+
<div>π¨ <b>Moderate (3-6mm):</b> {stats['moderate_area_cm2']:.2f} cmΒ²</div>
|
| 415 |
+
<div>π₯ <b>Deep (>6mm):</b> {stats['deep_area_cm2']:.2f} cmΒ²</div>
|
| 416 |
+
</div>
|
| 417 |
+
</div>
|
| 418 |
+
|
| 419 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 420 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 421 |
+
π Depth Analysis
|
| 422 |
+
</div>
|
| 423 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 424 |
+
<div>π₯ <b>Deep Coverage:</b> {stats['deep_ratio']*100:.1f}%</div>
|
| 425 |
+
<div>π <b>Max Depth:</b> {stats['max_depth']:.1f} mm</div>
|
| 426 |
+
<div>β‘ <b>Pixel Spacing:</b> {pixel_spacing_mm} mm</div>
|
| 427 |
+
</div>
|
| 428 |
+
</div>
|
| 429 |
+
</div>
|
| 430 |
+
|
| 431 |
+
<div style='text-align: center; padding: 15px; background-color: #2c2c2c; border-radius: 8px; border-left: 4px solid {severity_color};'>
|
| 432 |
+
<div style='font-size: 20px; font-weight: bold; color: {severity_color};'>
|
| 433 |
+
π― Predicted Severity: {severity}
|
| 434 |
+
</div>
|
| 435 |
+
<div style='font-size: 14px; color: #cccccc; margin-top: 5px;'>
|
| 436 |
+
{get_severity_description(severity)}
|
| 437 |
+
</div>
|
| 438 |
+
</div>
|
| 439 |
+
</div>
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
return report
|
| 443 |
+
|
| 444 |
+
# Automatic Wound Mask Generation Functions
|
| 445 |
+
def create_automatic_wound_mask(image, method='adaptive'):
|
| 446 |
+
if image is None:
|
| 447 |
+
return None
|
| 448 |
+
|
| 449 |
+
if len(image.shape) == 3:
|
| 450 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 451 |
+
else:
|
| 452 |
+
gray = image.copy()
|
| 453 |
+
|
| 454 |
+
if method == 'adaptive':
|
| 455 |
+
mask = adaptive_threshold_segmentation(gray)
|
| 456 |
+
elif method == 'otsu':
|
| 457 |
+
mask = otsu_threshold_segmentation(gray)
|
| 458 |
+
elif method == 'color':
|
| 459 |
+
mask = color_based_segmentation(image)
|
| 460 |
+
elif method == 'combined':
|
| 461 |
+
mask = combined_segmentation(image, gray)
|
| 462 |
+
else:
|
| 463 |
+
mask = adaptive_threshold_segmentation(gray)
|
| 464 |
+
|
| 465 |
+
return mask
|
| 466 |
+
|
| 467 |
+
def adaptive_threshold_segmentation(gray):
|
| 468 |
+
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
| 469 |
+
thresh = cv2.adaptiveThreshold(
|
| 470 |
+
blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5
|
| 471 |
+
)
|
| 472 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 473 |
+
mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
| 474 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 475 |
+
|
| 476 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 477 |
+
mask_clean = np.zeros_like(mask)
|
| 478 |
+
for contour in contours:
|
| 479 |
+
area = cv2.contourArea(contour)
|
| 480 |
+
if area > 1000:
|
| 481 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 482 |
+
|
| 483 |
+
return mask_clean
|
| 484 |
+
|
| 485 |
+
def otsu_threshold_segmentation(gray):
|
| 486 |
+
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
| 487 |
+
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 488 |
+
|
| 489 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 490 |
+
mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
| 491 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 492 |
+
|
| 493 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 494 |
+
mask_clean = np.zeros_like(mask)
|
| 495 |
+
for contour in contours:
|
| 496 |
+
area = cv2.contourArea(contour)
|
| 497 |
+
if area > 800:
|
| 498 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 499 |
+
|
| 500 |
+
return mask_clean
|
| 501 |
+
|
| 502 |
+
def color_based_segmentation(image):
|
| 503 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
| 504 |
+
|
| 505 |
+
lower_red1 = np.array([0, 30, 30])
|
| 506 |
+
upper_red1 = np.array([15, 255, 255])
|
| 507 |
+
lower_red2 = np.array([160, 30, 30])
|
| 508 |
+
upper_red2 = np.array([180, 255, 255])
|
| 509 |
+
|
| 510 |
+
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
| 511 |
+
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
| 512 |
+
red_mask = mask1 + mask2
|
| 513 |
+
|
| 514 |
+
lower_yellow = np.array([15, 30, 30])
|
| 515 |
+
upper_yellow = np.array([35, 255, 255])
|
| 516 |
+
yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
| 517 |
+
|
| 518 |
+
lower_brown = np.array([10, 50, 20])
|
| 519 |
+
upper_brown = np.array([20, 255, 200])
|
| 520 |
+
brown_mask = cv2.inRange(hsv, lower_brown, upper_brown)
|
| 521 |
+
|
| 522 |
+
color_mask = red_mask + yellow_mask + brown_mask
|
| 523 |
+
|
| 524 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 525 |
+
color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel)
|
| 526 |
+
color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel)
|
| 527 |
+
|
| 528 |
+
contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 529 |
+
mask_clean = np.zeros_like(color_mask)
|
| 530 |
+
for contour in contours:
|
| 531 |
+
area = cv2.contourArea(contour)
|
| 532 |
+
if area > 600:
|
| 533 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 534 |
+
|
| 535 |
+
return mask_clean
|
| 536 |
+
|
| 537 |
+
def combined_segmentation(image, gray):
|
| 538 |
+
adaptive_mask = adaptive_threshold_segmentation(gray)
|
| 539 |
+
otsu_mask = otsu_threshold_segmentation(gray)
|
| 540 |
+
color_mask = color_based_segmentation(image)
|
| 541 |
+
|
| 542 |
+
combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask)
|
| 543 |
+
combined_mask = cv2.bitwise_or(combined_mask, color_mask)
|
| 544 |
+
|
| 545 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
|
| 546 |
+
combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
|
| 547 |
+
|
| 548 |
+
contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 549 |
+
mask_clean = np.zeros_like(combined_mask)
|
| 550 |
+
for contour in contours:
|
| 551 |
+
area = cv2.contourArea(contour)
|
| 552 |
+
if area > 500:
|
| 553 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 554 |
+
|
| 555 |
+
if np.sum(mask_clean) == 0:
|
| 556 |
+
mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical')
|
| 557 |
+
|
| 558 |
+
return mask_clean
|
| 559 |
+
|
| 560 |
+
def create_realistic_wound_mask(image_shape, method='elliptical'):
|
| 561 |
+
h, w = image_shape[:2]
|
| 562 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 563 |
+
|
| 564 |
+
if method == 'elliptical':
|
| 565 |
+
center = (w // 2, h // 2)
|
| 566 |
+
radius_x = min(w, h) // 3
|
| 567 |
+
radius_y = min(w, h) // 4
|
| 568 |
+
|
| 569 |
+
y, x = np.ogrid[:h, :w]
|
| 570 |
+
ellipse = ((x - center[0])**2 / (radius_x**2) +
|
| 571 |
+
(y - center[1])**2 / (radius_y**2)) <= 1
|
| 572 |
+
|
| 573 |
+
noise = np.random.random((h, w)) > 0.8
|
| 574 |
+
mask = (ellipse | noise).astype(np.uint8) * 255
|
| 575 |
+
|
| 576 |
+
elif method == 'irregular':
|
| 577 |
+
center = (w // 2, h // 2)
|
| 578 |
+
radius = min(w, h) // 4
|
| 579 |
+
|
| 580 |
+
y, x = np.ogrid[:h, :w]
|
| 581 |
+
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
|
| 582 |
+
|
| 583 |
+
extensions = np.zeros_like(base_circle)
|
| 584 |
+
for i in range(3):
|
| 585 |
+
angle = i * 2 * np.pi / 3
|
| 586 |
+
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
|
| 587 |
+
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
|
| 588 |
+
ext_radius = radius // 3
|
| 589 |
+
|
| 590 |
+
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
|
| 591 |
+
extensions = extensions | ext_circle
|
| 592 |
+
|
| 593 |
+
mask = (base_circle | extensions).astype(np.uint8) * 255
|
| 594 |
+
|
| 595 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 596 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 597 |
+
|
| 598 |
+
return mask
|
| 599 |
+
|
| 600 |
+
def post_process_wound_mask(mask, min_area=100):
|
| 601 |
+
if mask is None:
|
| 602 |
+
return None
|
| 603 |
+
|
| 604 |
+
if mask.dtype != np.uint8:
|
| 605 |
+
mask = mask.astype(np.uint8)
|
| 606 |
+
|
| 607 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 608 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 609 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 610 |
+
|
| 611 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 612 |
+
mask_clean = np.zeros_like(mask)
|
| 613 |
+
|
| 614 |
+
for contour in contours:
|
| 615 |
+
area = cv2.contourArea(contour)
|
| 616 |
+
if area >= min_area:
|
| 617 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 618 |
+
|
| 619 |
+
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
|
| 620 |
+
|
| 621 |
+
return mask_clean
|
| 622 |
+
|
| 623 |
+
def create_sample_wound_mask(image_shape, center=None, radius=50):
|
| 624 |
+
if center is None:
|
| 625 |
+
center = (image_shape[1] // 2, image_shape[0] // 2)
|
| 626 |
+
|
| 627 |
+
mask = np.zeros(image_shape[:2], dtype=np.uint8)
|
| 628 |
+
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
|
| 629 |
+
|
| 630 |
+
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
|
| 631 |
+
mask[dist_from_center <= radius] = 255
|
| 632 |
+
|
| 633 |
+
return mask
|
| 634 |
+
|
| 635 |
+
# --- MAIN GRADIO INTERFACE (LINEAR EXECUTION) ---
|
| 636 |
+
print("Creating Gradio interface...")
|
| 637 |
+
|
| 638 |
+
with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
|
| 639 |
+
gr.HTML("<h1>Wound Analysis & Depth Estimation System</h1>")
|
| 640 |
+
gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
|
| 641 |
+
|
| 642 |
+
shared_image = gr.State()
|
| 643 |
+
|
| 644 |
+
with gr.Tabs():
|
| 645 |
+
# Tab 1: Wound Classification
|
| 646 |
+
with gr.Tab("1. Wound Classification"):
|
| 647 |
+
gr.Markdown("### Step 1: Upload and classify your wound image")
|
| 648 |
+
gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
|
| 649 |
+
|
| 650 |
+
with gr.Row():
|
| 651 |
+
with gr.Column(scale=1):
|
| 652 |
+
wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
|
| 653 |
+
|
| 654 |
+
with gr.Column(scale=1):
|
| 655 |
+
wound_prediction_box = gr.HTML()
|
| 656 |
+
wound_reasoning_box = gr.HTML()
|
| 657 |
+
|
| 658 |
+
with gr.Row():
|
| 659 |
+
pass_to_depth_btn = gr.Button("π Pass Image to Depth Analysis", variant="secondary", size="lg")
|
| 660 |
+
pass_status = gr.HTML("")
|
| 661 |
+
|
| 662 |
+
wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
|
| 663 |
+
outputs=[wound_prediction_box, wound_reasoning_box])
|
| 664 |
+
|
| 665 |
+
wound_image_input.change(
|
| 666 |
+
fn=lambda img: img,
|
| 667 |
+
inputs=[wound_image_input],
|
| 668 |
+
outputs=[shared_image]
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Tab 2: Depth Estimation
|
| 672 |
+
with gr.Tab("2. Depth Estimation & 3D Visualization"):
|
| 673 |
+
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
|
| 674 |
+
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
|
| 675 |
+
|
| 676 |
+
with gr.Row():
|
| 677 |
+
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
| 678 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
| 679 |
+
|
| 680 |
+
with gr.Row():
|
| 681 |
+
depth_submit = gr.Button(value="Compute Depth", variant="primary")
|
| 682 |
+
load_shared_btn = gr.Button("π Load Image from Classification", variant="secondary")
|
| 683 |
+
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
| 684 |
+
label="Number of 3D points (upload image to update max)")
|
| 685 |
+
|
| 686 |
+
with gr.Row():
|
| 687 |
+
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 688 |
+
label="Focal Length X (pixels)")
|
| 689 |
+
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 690 |
+
label="Focal Length Y (pixels)")
|
| 691 |
+
|
| 692 |
+
with gr.Row():
|
| 693 |
+
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
| 694 |
+
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
| 695 |
+
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
| 696 |
+
|
| 697 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
| 698 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
| 699 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
| 700 |
+
|
| 701 |
+
depth_map_state = gr.State()
|
| 702 |
+
|
| 703 |
+
# Tab 3: Wound Severity Analysis
|
| 704 |
+
with gr.Tab("3. π©Ή Wound Severity Analysis"):
|
| 705 |
+
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
|
| 706 |
+
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
|
| 707 |
+
|
| 708 |
+
with gr.Row():
|
| 709 |
+
severity_input_image = gr.Image(label="Original Image", type='numpy')
|
| 710 |
+
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
|
| 711 |
+
|
| 712 |
+
with gr.Row():
|
| 713 |
+
wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy')
|
| 714 |
+
severity_output = gr.HTML(label="Severity Analysis Report")
|
| 715 |
+
|
| 716 |
+
gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.")
|
| 717 |
+
|
| 718 |
+
with gr.Row():
|
| 719 |
+
auto_severity_button = gr.Button("π€ Auto-Analyze Severity", variant="primary", size="lg")
|
| 720 |
+
manual_severity_button = gr.Button("π Manual Mask Analysis", variant="secondary", size="lg")
|
| 721 |
+
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
|
| 722 |
+
label="Pixel Spacing (mm/pixel)")
|
| 723 |
+
|
| 724 |
+
gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
|
| 725 |
+
|
| 726 |
+
with gr.Row():
|
| 727 |
+
segmentation_method = gr.Dropdown(
|
| 728 |
+
choices=["combined", "adaptive", "otsu", "color"],
|
| 729 |
+
value="combined",
|
| 730 |
+
label="Segmentation Method",
|
| 731 |
+
info="Choose automatic segmentation method"
|
| 732 |
+
)
|
| 733 |
+
min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100,
|
| 734 |
+
label="Minimum Area (pixels)",
|
| 735 |
+
info="Minimum wound area to detect")
|
| 736 |
+
|
| 737 |
+
with gr.Row():
|
| 738 |
+
load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary")
|
| 739 |
+
sample_mask_btn = gr.Button("π― Generate Sample Mask", variant="secondary")
|
| 740 |
+
realistic_mask_btn = gr.Button("π₯ Generate Realistic Mask", variant="secondary")
|
| 741 |
+
preview_mask_btn = gr.Button("ποΈ Preview Auto Mask", variant="secondary")
|
| 742 |
+
|
| 743 |
+
gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.")
|
| 744 |
+
|
| 745 |
+
# Event handlers
|
| 746 |
+
def generate_sample_mask(image):
|
| 747 |
+
if image is None:
|
| 748 |
+
return None, "β Please load an image first."
|
| 749 |
+
sample_mask = create_sample_wound_mask(image.shape)
|
| 750 |
+
return sample_mask, "β
Sample circular wound mask generated!"
|
| 751 |
+
|
| 752 |
+
def generate_realistic_mask(image):
|
| 753 |
+
if image is None:
|
| 754 |
+
return None, "β Please load an image first."
|
| 755 |
+
realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical')
|
| 756 |
+
return realistic_mask, "β
Realistic elliptical wound mask generated!"
|
| 757 |
+
|
| 758 |
+
def load_depth_to_severity(depth_map, original_image):
|
| 759 |
+
if depth_map is None:
|
| 760 |
+
return None, None, "β No depth map available. Please compute depth in Tab 2 first."
|
| 761 |
+
return depth_map, original_image, "β
Depth map loaded successfully!"
|
| 762 |
+
|
| 763 |
+
def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area):
|
| 764 |
+
if depth_map is None:
|
| 765 |
+
return "β Please load depth map from Tab 2 first."
|
| 766 |
+
|
| 767 |
+
def post_process_with_area(mask):
|
| 768 |
+
return post_process_wound_mask(mask, min_area=min_area)
|
| 769 |
+
|
| 770 |
+
auto_mask = create_automatic_wound_mask(image, method=seg_method)
|
| 771 |
+
|
| 772 |
+
if auto_mask is None:
|
| 773 |
+
return "β Failed to generate automatic wound mask."
|
| 774 |
+
|
| 775 |
+
processed_mask = post_process_with_area(auto_mask)
|
| 776 |
+
|
| 777 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 778 |
+
return "β No wound region detected. Try adjusting segmentation parameters or use manual mask."
|
| 779 |
+
|
| 780 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
|
| 781 |
+
|
| 782 |
+
def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
|
| 783 |
+
if depth_map is None:
|
| 784 |
+
return "β Please load depth map from Tab 2 first."
|
| 785 |
+
if wound_mask is None:
|
| 786 |
+
return "β Please upload a wound mask (binary image where white pixels represent the wound area)."
|
| 787 |
+
return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
|
| 788 |
+
|
| 789 |
+
def preview_auto_mask(image, seg_method, min_area):
|
| 790 |
+
if image is None:
|
| 791 |
+
return None, "β Please load an image first."
|
| 792 |
+
auto_mask = create_automatic_wound_mask(image, method=seg_method)
|
| 793 |
+
if auto_mask is None:
|
| 794 |
+
return None, "β Failed to generate automatic wound mask."
|
| 795 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=min_area)
|
| 796 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 797 |
+
return None, "β No wound region detected. Try adjusting parameters."
|
| 798 |
+
return processed_mask, f"β
Auto mask generated using {seg_method} method!"
|
| 799 |
+
|
| 800 |
+
def load_shared_image(shared_img):
|
| 801 |
+
if shared_img is None:
|
| 802 |
+
return gr.Image(), "β No image available from classification tab"
|
| 803 |
+
if hasattr(shared_img, 'convert'):
|
| 804 |
+
img_array = np.array(shared_img)
|
| 805 |
+
return img_array, "β
Image loaded from classification tab"
|
| 806 |
+
else:
|
| 807 |
+
return shared_img, "β
Image loaded from classification tab"
|
| 808 |
+
|
| 809 |
+
def pass_image_to_depth(img):
|
| 810 |
+
if img is None:
|
| 811 |
+
return "β No image uploaded in classification tab"
|
| 812 |
+
return "β
Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
|
| 813 |
+
|
| 814 |
+
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
|
| 815 |
+
results = on_depth_submit(image, num_points, focal_x, focal_y)
|
| 816 |
+
depth_map = None
|
| 817 |
+
if image is not None:
|
| 818 |
+
depth = predict_depth(image[:, :, ::-1])
|
| 819 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 820 |
+
depth_map = norm_depth.astype(np.uint8)
|
| 821 |
+
return results + [depth_map]
|
| 822 |
+
|
| 823 |
+
# Connect all event handlers
|
| 824 |
+
sample_mask_btn.click(fn=generate_sample_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()])
|
| 825 |
+
realistic_mask_btn.click(fn=generate_realistic_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()])
|
| 826 |
+
depth_input_image.change(fn=update_slider_on_image_upload, inputs=[depth_input_image], outputs=[points_slider])
|
| 827 |
+
depth_submit.click(on_depth_submit_with_state, inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
|
| 828 |
+
load_depth_btn.click(fn=load_depth_to_severity, inputs=[depth_map_state, depth_input_image], outputs=[severity_depth_map, severity_input_image, gr.HTML()])
|
| 829 |
+
auto_severity_button.click(fn=run_auto_severity_analysis, inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, segmentation_method, min_area_slider], outputs=[severity_output])
|
| 830 |
+
manual_severity_button.click(fn=run_manual_severity_analysis, inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider], outputs=[severity_output])
|
| 831 |
+
preview_mask_btn.click(fn=preview_auto_mask, inputs=[severity_input_image, segmentation_method, min_area_slider], outputs=[wound_mask_input, gr.HTML()])
|
| 832 |
+
load_shared_btn.click(fn=load_shared_image, inputs=[shared_image], outputs=[depth_input_image, gr.HTML()])
|
| 833 |
+
pass_to_depth_btn.click(fn=pass_image_to_depth, inputs=[shared_image], outputs=[pass_status])
|
| 834 |
+
|
| 835 |
+
print("Gradio interface created successfully!")
|
| 836 |
+
|
| 837 |
+
if __name__ == '__main__':
|
| 838 |
+
print("Launching app...")
|
| 839 |
+
demo.queue().launch(
|
| 840 |
+
server_name="0.0.0.0",
|
| 841 |
+
server_port=7860,
|
| 842 |
+
share=True
|
| 843 |
+
)
|
temp_files/test2.txt
ADDED
|
@@ -0,0 +1,1063 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import tempfile
|
| 8 |
+
from gradio_imageslider import ImageSlider
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
| 13 |
+
import os
|
| 14 |
+
import tensorflow as tf
|
| 15 |
+
from tensorflow.keras.models import load_model
|
| 16 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
| 17 |
+
import base64
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import gdown
|
| 20 |
+
import spaces
|
| 21 |
+
|
| 22 |
+
# Define path and file ID
|
| 23 |
+
checkpoint_dir = "checkpoints"
|
| 24 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
|
| 27 |
+
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
|
| 28 |
+
|
| 29 |
+
# Download if not already present
|
| 30 |
+
if not os.path.exists(model_file):
|
| 31 |
+
print("Downloading model from Google Drive...")
|
| 32 |
+
gdown.download(gdrive_url, model_file, quiet=False)
|
| 33 |
+
|
| 34 |
+
# --- TensorFlow: Check GPU Availability ---
|
| 35 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 36 |
+
if gpus:
|
| 37 |
+
print("TensorFlow is using GPU")
|
| 38 |
+
else:
|
| 39 |
+
print("TensorFlow is using CPU")
|
| 40 |
+
|
| 41 |
+
# --- Load Wound Classification Model and Class Labels ---
|
| 42 |
+
wound_model = load_model("/home/user/app/keras_model.h5")
|
| 43 |
+
with open("/home/user/app/labels.txt", "r") as f:
|
| 44 |
+
class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
|
| 45 |
+
|
| 46 |
+
# --- PyTorch: Set Device and Load Depth Model ---
|
| 47 |
+
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
|
| 48 |
+
print(f"Using PyTorch device: {map_device}")
|
| 49 |
+
|
| 50 |
+
model_configs = {
|
| 51 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 52 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 53 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 54 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
| 55 |
+
}
|
| 56 |
+
encoder = 'vitl'
|
| 57 |
+
depth_model = DepthAnythingV2(**model_configs[encoder])
|
| 58 |
+
state_dict = torch.load(
|
| 59 |
+
f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth',
|
| 60 |
+
map_location=map_device
|
| 61 |
+
)
|
| 62 |
+
depth_model.load_state_dict(state_dict)
|
| 63 |
+
depth_model = depth_model.to(map_device).eval()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# --- Custom CSS for unified dark theme ---
|
| 67 |
+
css = """
|
| 68 |
+
.gradio-container {
|
| 69 |
+
font-family: 'Segoe UI', sans-serif;
|
| 70 |
+
background-color: #121212;
|
| 71 |
+
color: #ffffff;
|
| 72 |
+
padding: 20px;
|
| 73 |
+
}
|
| 74 |
+
.gr-button {
|
| 75 |
+
background-color: #2c3e50;
|
| 76 |
+
color: white;
|
| 77 |
+
border-radius: 10px;
|
| 78 |
+
}
|
| 79 |
+
.gr-button:hover {
|
| 80 |
+
background-color: #34495e;
|
| 81 |
+
}
|
| 82 |
+
.gr-html, .gr-html div {
|
| 83 |
+
white-space: normal !important;
|
| 84 |
+
overflow: visible !important;
|
| 85 |
+
text-overflow: unset !important;
|
| 86 |
+
word-break: break-word !important;
|
| 87 |
+
}
|
| 88 |
+
#img-display-container {
|
| 89 |
+
max-height: 100vh;
|
| 90 |
+
}
|
| 91 |
+
#img-display-input {
|
| 92 |
+
max-height: 80vh;
|
| 93 |
+
}
|
| 94 |
+
#img-display-output {
|
| 95 |
+
max-height: 80vh;
|
| 96 |
+
}
|
| 97 |
+
#download {
|
| 98 |
+
height: 62px;
|
| 99 |
+
}
|
| 100 |
+
h1 {
|
| 101 |
+
text-align: center;
|
| 102 |
+
font-size: 3rem;
|
| 103 |
+
font-weight: bold;
|
| 104 |
+
margin: 2rem 0;
|
| 105 |
+
color: #ffffff;
|
| 106 |
+
}
|
| 107 |
+
h2 {
|
| 108 |
+
color: #ffffff;
|
| 109 |
+
text-align: center;
|
| 110 |
+
margin: 1rem 0;
|
| 111 |
+
}
|
| 112 |
+
.gr-tabs {
|
| 113 |
+
background-color: #1e1e1e;
|
| 114 |
+
border-radius: 10px;
|
| 115 |
+
padding: 10px;
|
| 116 |
+
}
|
| 117 |
+
.gr-tab-nav {
|
| 118 |
+
background-color: #2c3e50;
|
| 119 |
+
border-radius: 8px;
|
| 120 |
+
}
|
| 121 |
+
.gr-tab-nav button {
|
| 122 |
+
color: #ffffff !important;
|
| 123 |
+
}
|
| 124 |
+
.gr-tab-nav button.selected {
|
| 125 |
+
background-color: #34495e !important;
|
| 126 |
+
}
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# --- Wound Classification Functions ---
|
| 130 |
+
def preprocess_input(img):
|
| 131 |
+
img = img.resize((224, 224))
|
| 132 |
+
arr = keras_image.img_to_array(img)
|
| 133 |
+
arr = arr / 255.0
|
| 134 |
+
return np.expand_dims(arr, axis=0)
|
| 135 |
+
|
| 136 |
+
def get_reasoning_from_gemini(img, prediction):
|
| 137 |
+
try:
|
| 138 |
+
# For now, return a simple explanation without Gemini API to avoid typing issues
|
| 139 |
+
# In production, you would implement the proper Gemini API call here
|
| 140 |
+
explanations = {
|
| 141 |
+
"Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
|
| 142 |
+
"Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
|
| 143 |
+
"Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
|
| 144 |
+
"Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
|
| 145 |
+
"Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
return f"(Reasoning unavailable: {str(e)})"
|
| 152 |
+
|
| 153 |
+
@spaces.GPU
|
| 154 |
+
def classify_wound_image(img):
|
| 155 |
+
if img is None:
|
| 156 |
+
return "<div style='color:#ff5252; font-size:18px;'>No image provided</div>", ""
|
| 157 |
+
|
| 158 |
+
img_array = preprocess_input(img)
|
| 159 |
+
predictions = wound_model.predict(img_array, verbose=0)[0]
|
| 160 |
+
pred_idx = int(np.argmax(predictions))
|
| 161 |
+
pred_class = class_labels[pred_idx]
|
| 162 |
+
|
| 163 |
+
# Get reasoning from Gemini
|
| 164 |
+
reasoning_text = get_reasoning_from_gemini(img, pred_class)
|
| 165 |
+
|
| 166 |
+
# Prediction Card
|
| 167 |
+
predicted_card = f"""
|
| 168 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 169 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 170 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 171 |
+
Predicted Wound Type
|
| 172 |
+
</div>
|
| 173 |
+
<div style='font-size: 26px; color: white;'>
|
| 174 |
+
{pred_class}
|
| 175 |
+
</div>
|
| 176 |
+
</div>
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
# Reasoning Card
|
| 180 |
+
reasoning_card = f"""
|
| 181 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
|
| 182 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 183 |
+
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
|
| 184 |
+
Reasoning
|
| 185 |
+
</div>
|
| 186 |
+
<div style='font-size: 16px; color: white; min-height: 80px;'>
|
| 187 |
+
{reasoning_text}
|
| 188 |
+
</div>
|
| 189 |
+
</div>
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
return predicted_card, reasoning_card
|
| 193 |
+
|
| 194 |
+
# --- Wound Severity Estimation Functions ---
|
| 195 |
+
@spaces.GPU
|
| 196 |
+
def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
|
| 197 |
+
"""Compute area statistics for different depth regions"""
|
| 198 |
+
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
|
| 199 |
+
|
| 200 |
+
# Extract only wound region
|
| 201 |
+
wound_mask = (mask > 127)
|
| 202 |
+
wound_depths = depth_map[wound_mask]
|
| 203 |
+
total_area = np.sum(wound_mask) * pixel_area_cm2
|
| 204 |
+
|
| 205 |
+
# Categorize depth regions
|
| 206 |
+
shallow = wound_depths < 3
|
| 207 |
+
moderate = (wound_depths >= 3) & (wound_depths < 6)
|
| 208 |
+
deep = wound_depths >= 6
|
| 209 |
+
|
| 210 |
+
shallow_area = np.sum(shallow) * pixel_area_cm2
|
| 211 |
+
moderate_area = np.sum(moderate) * pixel_area_cm2
|
| 212 |
+
deep_area = np.sum(deep) * pixel_area_cm2
|
| 213 |
+
|
| 214 |
+
deep_ratio = deep_area / total_area if total_area > 0 else 0
|
| 215 |
+
|
| 216 |
+
return {
|
| 217 |
+
'total_area_cm2': total_area,
|
| 218 |
+
'shallow_area_cm2': shallow_area,
|
| 219 |
+
'moderate_area_cm2': moderate_area,
|
| 220 |
+
'deep_area_cm2': deep_area,
|
| 221 |
+
'deep_ratio': deep_ratio,
|
| 222 |
+
'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
def classify_wound_severity_by_area(depth_stats):
|
| 226 |
+
"""Classify wound severity based on area and depth distribution"""
|
| 227 |
+
total = depth_stats['total_area_cm2']
|
| 228 |
+
deep = depth_stats['deep_area_cm2']
|
| 229 |
+
moderate = depth_stats['moderate_area_cm2']
|
| 230 |
+
|
| 231 |
+
if total == 0:
|
| 232 |
+
return "Unknown"
|
| 233 |
+
|
| 234 |
+
# Severity classification rules
|
| 235 |
+
if deep > 2 or (deep / total) > 0.3:
|
| 236 |
+
return "Severe"
|
| 237 |
+
elif moderate > 1.5 or (moderate / total) > 0.4:
|
| 238 |
+
return "Moderate"
|
| 239 |
+
else:
|
| 240 |
+
return "Mild"
|
| 241 |
+
|
| 242 |
+
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
|
| 243 |
+
"""Analyze wound severity from depth map and wound mask"""
|
| 244 |
+
if image is None or depth_map is None or wound_mask is None:
|
| 245 |
+
return "β Please upload image, depth map, and wound mask."
|
| 246 |
+
|
| 247 |
+
# Convert wound mask to grayscale if needed
|
| 248 |
+
if len(wound_mask.shape) == 3:
|
| 249 |
+
wound_mask = np.mean(wound_mask, axis=2)
|
| 250 |
+
|
| 251 |
+
# Ensure depth map and mask have same dimensions
|
| 252 |
+
if depth_map.shape[:2] != wound_mask.shape[:2]:
|
| 253 |
+
# Resize mask to match depth map
|
| 254 |
+
from PIL import Image
|
| 255 |
+
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
|
| 256 |
+
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
|
| 257 |
+
wound_mask = np.array(mask_pil)
|
| 258 |
+
|
| 259 |
+
# Compute statistics
|
| 260 |
+
stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
|
| 261 |
+
severity = classify_wound_severity_by_area(stats)
|
| 262 |
+
|
| 263 |
+
# Create severity report with color coding
|
| 264 |
+
severity_color = {
|
| 265 |
+
"Mild": "#4CAF50", # Green
|
| 266 |
+
"Moderate": "#FF9800", # Orange
|
| 267 |
+
"Severe": "#F44336" # Red
|
| 268 |
+
}.get(severity, "#9E9E9E") # Gray for unknown
|
| 269 |
+
|
| 270 |
+
report = f"""
|
| 271 |
+
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
|
| 272 |
+
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
|
| 273 |
+
π©Ή Wound Severity Analysis
|
| 274 |
+
</div>
|
| 275 |
+
|
| 276 |
+
<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-bottom: 20px;'>
|
| 277 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 278 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 279 |
+
π Area Measurements
|
| 280 |
+
</div>
|
| 281 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 282 |
+
<div>π’ <b>Total Area:</b> {stats['total_area_cm2']:.2f} cmΒ²</div>
|
| 283 |
+
<div>π© <b>Shallow (0-3mm):</b> {stats['shallow_area_cm2']:.2f} cmΒ²</div>
|
| 284 |
+
<div>π¨ <b>Moderate (3-6mm):</b> {stats['moderate_area_cm2']:.2f} cmΒ²</div>
|
| 285 |
+
<div>π₯ <b>Deep (>6mm):</b> {stats['deep_area_cm2']:.2f} cmΒ²</div>
|
| 286 |
+
</div>
|
| 287 |
+
</div>
|
| 288 |
+
|
| 289 |
+
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
|
| 290 |
+
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
|
| 291 |
+
π Depth Analysis
|
| 292 |
+
</div>
|
| 293 |
+
<div style='color: #cccccc; line-height: 1.6;'>
|
| 294 |
+
<div>π₯ <b>Deep Coverage:</b> {stats['deep_ratio']*100:.1f}%</div>
|
| 295 |
+
<div>π <b>Max Depth:</b> {stats['max_depth']:.1f} mm</div>
|
| 296 |
+
<div>β‘ <b>Pixel Spacing:</b> {pixel_spacing_mm} mm</div>
|
| 297 |
+
</div>
|
| 298 |
+
</div>
|
| 299 |
+
</div>
|
| 300 |
+
|
| 301 |
+
<div style='text-align: center; padding: 15px; background-color: #2c2c2c; border-radius: 8px; border-left: 4px solid {severity_color};'>
|
| 302 |
+
<div style='font-size: 20px; font-weight: bold; color: {severity_color};'>
|
| 303 |
+
π― Predicted Severity: {severity}
|
| 304 |
+
</div>
|
| 305 |
+
<div style='font-size: 14px; color: #cccccc; margin-top: 5px;'>
|
| 306 |
+
{get_severity_description(severity)}
|
| 307 |
+
</div>
|
| 308 |
+
</div>
|
| 309 |
+
</div>
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
return report
|
| 313 |
+
|
| 314 |
+
def get_severity_description(severity):
|
| 315 |
+
"""Get description for severity level"""
|
| 316 |
+
descriptions = {
|
| 317 |
+
"Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
|
| 318 |
+
"Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
|
| 319 |
+
"Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
|
| 320 |
+
"Unknown": "Unable to determine severity due to insufficient data."
|
| 321 |
+
}
|
| 322 |
+
return descriptions.get(severity, "Severity assessment unavailable.")
|
| 323 |
+
|
| 324 |
+
def create_sample_wound_mask(image_shape, center=None, radius=50):
|
| 325 |
+
"""Create a sample circular wound mask for testing"""
|
| 326 |
+
if center is None:
|
| 327 |
+
center = (image_shape[1] // 2, image_shape[0] // 2)
|
| 328 |
+
|
| 329 |
+
mask = np.zeros(image_shape[:2], dtype=np.uint8)
|
| 330 |
+
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
|
| 331 |
+
|
| 332 |
+
# Create circular mask
|
| 333 |
+
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
|
| 334 |
+
mask[dist_from_center <= radius] = 255
|
| 335 |
+
|
| 336 |
+
return mask
|
| 337 |
+
|
| 338 |
+
def create_realistic_wound_mask(image_shape, method='elliptical'):
|
| 339 |
+
"""Create a more realistic wound mask with irregular shapes"""
|
| 340 |
+
h, w = image_shape[:2]
|
| 341 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 342 |
+
|
| 343 |
+
if method == 'elliptical':
|
| 344 |
+
# Create elliptical wound mask
|
| 345 |
+
center = (w // 2, h // 2)
|
| 346 |
+
radius_x = min(w, h) // 3
|
| 347 |
+
radius_y = min(w, h) // 4
|
| 348 |
+
|
| 349 |
+
y, x = np.ogrid[:h, :w]
|
| 350 |
+
# Add some irregularity to make it more realistic
|
| 351 |
+
ellipse = ((x - center[0])**2 / (radius_x**2) +
|
| 352 |
+
(y - center[1])**2 / (radius_y**2)) <= 1
|
| 353 |
+
|
| 354 |
+
# Add some noise and irregularity
|
| 355 |
+
noise = np.random.random((h, w)) > 0.8
|
| 356 |
+
mask = (ellipse | noise).astype(np.uint8) * 255
|
| 357 |
+
|
| 358 |
+
elif method == 'irregular':
|
| 359 |
+
# Create irregular wound mask
|
| 360 |
+
center = (w // 2, h // 2)
|
| 361 |
+
radius = min(w, h) // 4
|
| 362 |
+
|
| 363 |
+
y, x = np.ogrid[:h, :w]
|
| 364 |
+
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
|
| 365 |
+
|
| 366 |
+
# Add irregular extensions
|
| 367 |
+
extensions = np.zeros_like(base_circle)
|
| 368 |
+
for i in range(3):
|
| 369 |
+
angle = i * 2 * np.pi / 3
|
| 370 |
+
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
|
| 371 |
+
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
|
| 372 |
+
ext_radius = radius // 3
|
| 373 |
+
|
| 374 |
+
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
|
| 375 |
+
extensions = extensions | ext_circle
|
| 376 |
+
|
| 377 |
+
mask = (base_circle | extensions).astype(np.uint8) * 255
|
| 378 |
+
|
| 379 |
+
# Apply morphological operations to smooth the mask
|
| 380 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 381 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 382 |
+
|
| 383 |
+
return mask
|
| 384 |
+
|
| 385 |
+
# --- Depth Estimation Functions ---
|
| 386 |
+
@spaces.GPU
|
| 387 |
+
def predict_depth(image):
|
| 388 |
+
return depth_model.infer_image(image)
|
| 389 |
+
|
| 390 |
+
def calculate_max_points(image):
|
| 391 |
+
"""Calculate maximum points based on image dimensions (3x pixel count)"""
|
| 392 |
+
if image is None:
|
| 393 |
+
return 10000 # Default value
|
| 394 |
+
h, w = image.shape[:2]
|
| 395 |
+
max_points = h * w * 3
|
| 396 |
+
# Ensure minimum and reasonable maximum values
|
| 397 |
+
return max(1000, min(max_points, 300000))
|
| 398 |
+
|
| 399 |
+
def update_slider_on_image_upload(image):
|
| 400 |
+
"""Update the points slider when an image is uploaded"""
|
| 401 |
+
max_points = calculate_max_points(image)
|
| 402 |
+
default_value = min(10000, max_points // 10) # 10% of max points as default
|
| 403 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
| 404 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
| 405 |
+
|
| 406 |
+
@spaces.GPU
|
| 407 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
|
| 408 |
+
"""Create a point cloud from depth map using camera intrinsics with high detail"""
|
| 409 |
+
h, w = depth_map.shape
|
| 410 |
+
|
| 411 |
+
# Use smaller step for higher detail (reduced downsampling)
|
| 412 |
+
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
|
| 413 |
+
|
| 414 |
+
# Create mesh grid for camera coordinates
|
| 415 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 416 |
+
|
| 417 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 418 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
| 419 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
| 420 |
+
|
| 421 |
+
# Get depth values
|
| 422 |
+
depth_values = depth_map[::step, ::step]
|
| 423 |
+
|
| 424 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 425 |
+
x_3d = x_cam * depth_values
|
| 426 |
+
y_3d = y_cam * depth_values
|
| 427 |
+
z_3d = depth_values
|
| 428 |
+
|
| 429 |
+
# Flatten arrays
|
| 430 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
| 431 |
+
|
| 432 |
+
# Get corresponding image colors
|
| 433 |
+
image_colors = image[::step, ::step, :]
|
| 434 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
| 435 |
+
|
| 436 |
+
# Create Open3D point cloud
|
| 437 |
+
pcd = o3d.geometry.PointCloud()
|
| 438 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 439 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 440 |
+
|
| 441 |
+
return pcd
|
| 442 |
+
|
| 443 |
+
@spaces.GPU
|
| 444 |
+
def reconstruct_surface_mesh_from_point_cloud(pcd):
|
| 445 |
+
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
|
| 446 |
+
# Estimate and orient normals with high precision
|
| 447 |
+
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
|
| 448 |
+
pcd.orient_normals_consistent_tangent_plane(k=50)
|
| 449 |
+
|
| 450 |
+
# Create surface mesh with maximum detail (depth=12 for very high resolution)
|
| 451 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
|
| 452 |
+
|
| 453 |
+
# Return mesh without filtering low-density vertices
|
| 454 |
+
return mesh
|
| 455 |
+
|
| 456 |
+
@spaces.GPU
|
| 457 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
| 458 |
+
"""Create an enhanced 3D visualization using proper camera projection"""
|
| 459 |
+
h, w = depth_map.shape
|
| 460 |
+
|
| 461 |
+
# Downsample to avoid too many points for performance
|
| 462 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
| 463 |
+
|
| 464 |
+
# Create mesh grid for camera coordinates
|
| 465 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
| 466 |
+
|
| 467 |
+
# Convert to camera coordinates (normalized by focal length)
|
| 468 |
+
focal_length = 470.4 # Default focal length
|
| 469 |
+
x_cam = (x_coords - w / 2) / focal_length
|
| 470 |
+
y_cam = (y_coords - h / 2) / focal_length
|
| 471 |
+
|
| 472 |
+
# Get depth values
|
| 473 |
+
depth_values = depth_map[::step, ::step]
|
| 474 |
+
|
| 475 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
| 476 |
+
x_3d = x_cam * depth_values
|
| 477 |
+
y_3d = y_cam * depth_values
|
| 478 |
+
z_3d = depth_values
|
| 479 |
+
|
| 480 |
+
# Flatten arrays
|
| 481 |
+
x_flat = x_3d.flatten()
|
| 482 |
+
y_flat = y_3d.flatten()
|
| 483 |
+
z_flat = z_3d.flatten()
|
| 484 |
+
|
| 485 |
+
# Get corresponding image colors
|
| 486 |
+
image_colors = image[::step, ::step, :]
|
| 487 |
+
colors_flat = image_colors.reshape(-1, 3)
|
| 488 |
+
|
| 489 |
+
# Create 3D scatter plot with proper camera projection
|
| 490 |
+
fig = go.Figure(data=[go.Scatter3d(
|
| 491 |
+
x=x_flat,
|
| 492 |
+
y=y_flat,
|
| 493 |
+
z=z_flat,
|
| 494 |
+
mode='markers',
|
| 495 |
+
marker=dict(
|
| 496 |
+
size=1.5,
|
| 497 |
+
color=colors_flat,
|
| 498 |
+
opacity=0.9
|
| 499 |
+
),
|
| 500 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
| 501 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
| 502 |
+
'<extra></extra>'
|
| 503 |
+
)])
|
| 504 |
+
|
| 505 |
+
fig.update_layout(
|
| 506 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
| 507 |
+
scene=dict(
|
| 508 |
+
xaxis_title="X (meters)",
|
| 509 |
+
yaxis_title="Y (meters)",
|
| 510 |
+
zaxis_title="Z (meters)",
|
| 511 |
+
camera=dict(
|
| 512 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
| 513 |
+
center=dict(x=0, y=0, z=0),
|
| 514 |
+
up=dict(x=0, y=0, z=1)
|
| 515 |
+
),
|
| 516 |
+
aspectmode='data'
|
| 517 |
+
),
|
| 518 |
+
width=700,
|
| 519 |
+
height=600
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
return fig
|
| 523 |
+
|
| 524 |
+
def on_depth_submit(image, num_points, focal_x, focal_y):
|
| 525 |
+
original_image = image.copy()
|
| 526 |
+
|
| 527 |
+
h, w = image.shape[:2]
|
| 528 |
+
|
| 529 |
+
# Predict depth using the model
|
| 530 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 531 |
+
|
| 532 |
+
# Save raw 16-bit depth
|
| 533 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
| 534 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 535 |
+
raw_depth.save(tmp_raw_depth.name)
|
| 536 |
+
|
| 537 |
+
# Normalize and convert to grayscale for display
|
| 538 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 539 |
+
norm_depth = norm_depth.astype(np.uint8)
|
| 540 |
+
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
|
| 541 |
+
|
| 542 |
+
gray_depth = Image.fromarray(norm_depth)
|
| 543 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 544 |
+
gray_depth.save(tmp_gray_depth.name)
|
| 545 |
+
|
| 546 |
+
# Create point cloud
|
| 547 |
+
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
|
| 548 |
+
|
| 549 |
+
# Reconstruct mesh from point cloud
|
| 550 |
+
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
|
| 551 |
+
|
| 552 |
+
# Save mesh with faces as .ply
|
| 553 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
| 554 |
+
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
|
| 555 |
+
|
| 556 |
+
# Create enhanced 3D scatter plot visualization
|
| 557 |
+
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
|
| 558 |
+
|
| 559 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
| 560 |
+
|
| 561 |
+
# --- Automatic Wound Mask Generation Functions ---
|
| 562 |
+
import cv2
|
| 563 |
+
from skimage import filters, morphology, measure
|
| 564 |
+
from skimage.segmentation import clear_border
|
| 565 |
+
|
| 566 |
+
def create_automatic_wound_mask(image, method='adaptive'):
|
| 567 |
+
"""
|
| 568 |
+
Automatically generate wound mask from image using various segmentation methods
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
image: Input image (numpy array)
|
| 572 |
+
method: Segmentation method ('adaptive', 'otsu', 'color', 'combined')
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
mask: Binary wound mask
|
| 576 |
+
"""
|
| 577 |
+
if image is None:
|
| 578 |
+
return None
|
| 579 |
+
|
| 580 |
+
# Convert to grayscale if needed
|
| 581 |
+
if len(image.shape) == 3:
|
| 582 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 583 |
+
else:
|
| 584 |
+
gray = image.copy()
|
| 585 |
+
|
| 586 |
+
# Apply different segmentation methods
|
| 587 |
+
if method == 'adaptive':
|
| 588 |
+
mask = adaptive_threshold_segmentation(gray)
|
| 589 |
+
elif method == 'otsu':
|
| 590 |
+
mask = otsu_threshold_segmentation(gray)
|
| 591 |
+
elif method == 'color':
|
| 592 |
+
mask = color_based_segmentation(image)
|
| 593 |
+
elif method == 'combined':
|
| 594 |
+
mask = combined_segmentation(image, gray)
|
| 595 |
+
else:
|
| 596 |
+
mask = adaptive_threshold_segmentation(gray)
|
| 597 |
+
|
| 598 |
+
return mask
|
| 599 |
+
|
| 600 |
+
def adaptive_threshold_segmentation(gray):
|
| 601 |
+
"""Use adaptive thresholding for wound segmentation"""
|
| 602 |
+
# Apply Gaussian blur to reduce noise
|
| 603 |
+
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
| 604 |
+
|
| 605 |
+
# Adaptive thresholding with larger block size
|
| 606 |
+
thresh = cv2.adaptiveThreshold(
|
| 607 |
+
blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Morphological operations to clean up the mask
|
| 611 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 612 |
+
mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
| 613 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 614 |
+
|
| 615 |
+
# Find contours and keep only the largest ones
|
| 616 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 617 |
+
|
| 618 |
+
# Create a new mask with only large contours
|
| 619 |
+
mask_clean = np.zeros_like(mask)
|
| 620 |
+
for contour in contours:
|
| 621 |
+
area = cv2.contourArea(contour)
|
| 622 |
+
if area > 1000: # Minimum area threshold
|
| 623 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 624 |
+
|
| 625 |
+
return mask_clean
|
| 626 |
+
|
| 627 |
+
def otsu_threshold_segmentation(gray):
|
| 628 |
+
"""Use Otsu's thresholding for wound segmentation"""
|
| 629 |
+
# Apply Gaussian blur
|
| 630 |
+
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
| 631 |
+
|
| 632 |
+
# Otsu's thresholding
|
| 633 |
+
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 634 |
+
|
| 635 |
+
# Morphological operations
|
| 636 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 637 |
+
mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
| 638 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 639 |
+
|
| 640 |
+
# Find contours and keep only the largest ones
|
| 641 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 642 |
+
|
| 643 |
+
# Create a new mask with only large contours
|
| 644 |
+
mask_clean = np.zeros_like(mask)
|
| 645 |
+
for contour in contours:
|
| 646 |
+
area = cv2.contourArea(contour)
|
| 647 |
+
if area > 800: # Minimum area threshold
|
| 648 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 649 |
+
|
| 650 |
+
return mask_clean
|
| 651 |
+
|
| 652 |
+
def color_based_segmentation(image):
|
| 653 |
+
"""Use color-based segmentation for wound detection"""
|
| 654 |
+
# Convert to different color spaces
|
| 655 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
| 656 |
+
|
| 657 |
+
# Create masks for different color ranges (wound-like colors)
|
| 658 |
+
# Reddish/brownish wound colors in HSV - broader ranges
|
| 659 |
+
lower_red1 = np.array([0, 30, 30])
|
| 660 |
+
upper_red1 = np.array([15, 255, 255])
|
| 661 |
+
lower_red2 = np.array([160, 30, 30])
|
| 662 |
+
upper_red2 = np.array([180, 255, 255])
|
| 663 |
+
|
| 664 |
+
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
| 665 |
+
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
| 666 |
+
red_mask = mask1 + mask2
|
| 667 |
+
|
| 668 |
+
# Yellowish wound colors - broader range
|
| 669 |
+
lower_yellow = np.array([15, 30, 30])
|
| 670 |
+
upper_yellow = np.array([35, 255, 255])
|
| 671 |
+
yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
| 672 |
+
|
| 673 |
+
# Brownish wound colors
|
| 674 |
+
lower_brown = np.array([10, 50, 20])
|
| 675 |
+
upper_brown = np.array([20, 255, 200])
|
| 676 |
+
brown_mask = cv2.inRange(hsv, lower_brown, upper_brown)
|
| 677 |
+
|
| 678 |
+
# Combine color masks
|
| 679 |
+
color_mask = red_mask + yellow_mask + brown_mask
|
| 680 |
+
|
| 681 |
+
# Clean up the mask with larger kernels
|
| 682 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 683 |
+
color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel)
|
| 684 |
+
color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel)
|
| 685 |
+
|
| 686 |
+
# Find contours and keep only the largest ones
|
| 687 |
+
contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 688 |
+
|
| 689 |
+
# Create a new mask with only large contours
|
| 690 |
+
mask_clean = np.zeros_like(color_mask)
|
| 691 |
+
for contour in contours:
|
| 692 |
+
area = cv2.contourArea(contour)
|
| 693 |
+
if area > 600: # Minimum area threshold
|
| 694 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 695 |
+
|
| 696 |
+
return mask_clean
|
| 697 |
+
|
| 698 |
+
def combined_segmentation(image, gray):
|
| 699 |
+
"""Combine multiple segmentation methods for better results"""
|
| 700 |
+
# Get masks from different methods
|
| 701 |
+
adaptive_mask = adaptive_threshold_segmentation(gray)
|
| 702 |
+
otsu_mask = otsu_threshold_segmentation(gray)
|
| 703 |
+
color_mask = color_based_segmentation(image)
|
| 704 |
+
|
| 705 |
+
# Combine masks (union)
|
| 706 |
+
combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask)
|
| 707 |
+
combined_mask = cv2.bitwise_or(combined_mask, color_mask)
|
| 708 |
+
|
| 709 |
+
# Apply additional morphological operations to clean up
|
| 710 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
|
| 711 |
+
combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
|
| 712 |
+
|
| 713 |
+
# Find contours and keep only the largest ones
|
| 714 |
+
contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 715 |
+
|
| 716 |
+
# Create a new mask with only large contours
|
| 717 |
+
mask_clean = np.zeros_like(combined_mask)
|
| 718 |
+
for contour in contours:
|
| 719 |
+
area = cv2.contourArea(contour)
|
| 720 |
+
if area > 500: # Minimum area threshold
|
| 721 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 722 |
+
|
| 723 |
+
# If no large contours found, create a realistic wound mask
|
| 724 |
+
if np.sum(mask_clean) == 0:
|
| 725 |
+
mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical')
|
| 726 |
+
|
| 727 |
+
return mask_clean
|
| 728 |
+
|
| 729 |
+
def post_process_wound_mask(mask, min_area=100):
|
| 730 |
+
"""Post-process the wound mask to remove noise and small objects"""
|
| 731 |
+
if mask is None:
|
| 732 |
+
return None
|
| 733 |
+
|
| 734 |
+
# Convert to binary if needed
|
| 735 |
+
if mask.dtype != np.uint8:
|
| 736 |
+
mask = mask.astype(np.uint8)
|
| 737 |
+
|
| 738 |
+
# Apply morphological operations to clean up
|
| 739 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 740 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 741 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 742 |
+
|
| 743 |
+
# Remove small objects using OpenCV
|
| 744 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 745 |
+
mask_clean = np.zeros_like(mask)
|
| 746 |
+
|
| 747 |
+
for contour in contours:
|
| 748 |
+
area = cv2.contourArea(contour)
|
| 749 |
+
if area >= min_area:
|
| 750 |
+
cv2.fillPoly(mask_clean, [contour], 255)
|
| 751 |
+
|
| 752 |
+
# Fill holes
|
| 753 |
+
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
|
| 754 |
+
|
| 755 |
+
return mask_clean
|
| 756 |
+
|
| 757 |
+
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='combined'):
|
| 758 |
+
"""Analyze wound severity with automatic mask generation"""
|
| 759 |
+
if image is None or depth_map is None:
|
| 760 |
+
return "β Please provide both image and depth map."
|
| 761 |
+
|
| 762 |
+
# Generate automatic wound mask
|
| 763 |
+
auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
|
| 764 |
+
|
| 765 |
+
if auto_mask is None:
|
| 766 |
+
return "β Failed to generate automatic wound mask."
|
| 767 |
+
|
| 768 |
+
# Post-process the mask
|
| 769 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
|
| 770 |
+
|
| 771 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 772 |
+
return "β No wound region detected. Try adjusting segmentation parameters or upload a manual mask."
|
| 773 |
+
|
| 774 |
+
# Analyze severity using the automatic mask
|
| 775 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
|
| 776 |
+
|
| 777 |
+
# --- Main Gradio Interface ---
|
| 778 |
+
with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
|
| 779 |
+
gr.HTML("<h1>Wound Analysis & Depth Estimation System</h1>")
|
| 780 |
+
gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
|
| 781 |
+
|
| 782 |
+
# Shared image state
|
| 783 |
+
shared_image = gr.State()
|
| 784 |
+
|
| 785 |
+
with gr.Tabs():
|
| 786 |
+
# Tab 1: Wound Classification
|
| 787 |
+
with gr.Tab("1. Wound Classification"):
|
| 788 |
+
gr.Markdown("### Step 1: Upload and classify your wound image")
|
| 789 |
+
gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
|
| 790 |
+
|
| 791 |
+
with gr.Row():
|
| 792 |
+
with gr.Column(scale=1):
|
| 793 |
+
wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
|
| 794 |
+
|
| 795 |
+
with gr.Column(scale=1):
|
| 796 |
+
wound_prediction_box = gr.HTML()
|
| 797 |
+
wound_reasoning_box = gr.HTML()
|
| 798 |
+
|
| 799 |
+
# Button to pass image to depth estimation
|
| 800 |
+
with gr.Row():
|
| 801 |
+
pass_to_depth_btn = gr.Button("π Pass Image to Depth Analysis", variant="secondary", size="lg")
|
| 802 |
+
pass_status = gr.HTML("")
|
| 803 |
+
|
| 804 |
+
wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
|
| 805 |
+
outputs=[wound_prediction_box, wound_reasoning_box])
|
| 806 |
+
|
| 807 |
+
# Store image when uploaded for classification
|
| 808 |
+
wound_image_input.change(
|
| 809 |
+
fn=lambda img: img,
|
| 810 |
+
inputs=[wound_image_input],
|
| 811 |
+
outputs=[shared_image]
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# Tab 2: Depth Estimation
|
| 815 |
+
with gr.Tab("2. Depth Estimation & 3D Visualization"):
|
| 816 |
+
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
|
| 817 |
+
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
|
| 818 |
+
|
| 819 |
+
with gr.Row():
|
| 820 |
+
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
| 821 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
| 822 |
+
|
| 823 |
+
with gr.Row():
|
| 824 |
+
depth_submit = gr.Button(value="Compute Depth", variant="primary")
|
| 825 |
+
load_shared_btn = gr.Button("π Load Image from Classification", variant="secondary")
|
| 826 |
+
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
| 827 |
+
label="Number of 3D points (upload image to update max)")
|
| 828 |
+
|
| 829 |
+
with gr.Row():
|
| 830 |
+
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 831 |
+
label="Focal Length X (pixels)")
|
| 832 |
+
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
| 833 |
+
label="Focal Length Y (pixels)")
|
| 834 |
+
|
| 835 |
+
with gr.Row():
|
| 836 |
+
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
| 837 |
+
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
| 838 |
+
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
| 839 |
+
|
| 840 |
+
# 3D Visualization
|
| 841 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
| 842 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
| 843 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
| 844 |
+
|
| 845 |
+
# Store depth map for severity analysis
|
| 846 |
+
depth_map_state = gr.State()
|
| 847 |
+
|
| 848 |
+
# Tab 3: Wound Severity Analysis
|
| 849 |
+
with gr.Tab("3. π©Ή Wound Severity Analysis"):
|
| 850 |
+
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
|
| 851 |
+
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
|
| 852 |
+
|
| 853 |
+
with gr.Row():
|
| 854 |
+
severity_input_image = gr.Image(label="Original Image", type='numpy')
|
| 855 |
+
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
|
| 856 |
+
|
| 857 |
+
with gr.Row():
|
| 858 |
+
wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy')
|
| 859 |
+
severity_output = gr.HTML(label="Severity Analysis Report")
|
| 860 |
+
|
| 861 |
+
gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.")
|
| 862 |
+
|
| 863 |
+
with gr.Row():
|
| 864 |
+
auto_severity_button = gr.Button("π€ Auto-Analyze Severity", variant="primary", size="lg")
|
| 865 |
+
manual_severity_button = gr.Button("π Manual Mask Analysis", variant="secondary", size="lg")
|
| 866 |
+
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
|
| 867 |
+
label="Pixel Spacing (mm/pixel)")
|
| 868 |
+
|
| 869 |
+
gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
|
| 870 |
+
|
| 871 |
+
with gr.Row():
|
| 872 |
+
segmentation_method = gr.Dropdown(
|
| 873 |
+
choices=["combined", "adaptive", "otsu", "color"],
|
| 874 |
+
value="combined",
|
| 875 |
+
label="Segmentation Method",
|
| 876 |
+
info="Choose automatic segmentation method"
|
| 877 |
+
)
|
| 878 |
+
min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100,
|
| 879 |
+
label="Minimum Area (pixels)",
|
| 880 |
+
info="Minimum wound area to detect")
|
| 881 |
+
|
| 882 |
+
with gr.Row():
|
| 883 |
+
# Load depth map from previous tab
|
| 884 |
+
load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary")
|
| 885 |
+
sample_mask_btn = gr.Button("π― Generate Sample Mask", variant="secondary")
|
| 886 |
+
realistic_mask_btn = gr.Button("π₯ Generate Realistic Mask", variant="secondary")
|
| 887 |
+
preview_mask_btn = gr.Button("ποΈ Preview Auto Mask", variant="secondary")
|
| 888 |
+
|
| 889 |
+
gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.")
|
| 890 |
+
|
| 891 |
+
# Generate sample mask function
|
| 892 |
+
def generate_sample_mask(image):
|
| 893 |
+
if image is None:
|
| 894 |
+
return None, "β Please load an image first."
|
| 895 |
+
|
| 896 |
+
sample_mask = create_sample_wound_mask(image.shape)
|
| 897 |
+
return sample_mask, "β
Sample circular wound mask generated!"
|
| 898 |
+
|
| 899 |
+
# Generate realistic mask function
|
| 900 |
+
def generate_realistic_mask(image):
|
| 901 |
+
if image is None:
|
| 902 |
+
return None, "β Please load an image first."
|
| 903 |
+
|
| 904 |
+
realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical')
|
| 905 |
+
return realistic_mask, "β
Realistic elliptical wound mask generated!"
|
| 906 |
+
|
| 907 |
+
sample_mask_btn.click(
|
| 908 |
+
fn=generate_sample_mask,
|
| 909 |
+
inputs=[severity_input_image],
|
| 910 |
+
outputs=[wound_mask_input, gr.HTML()]
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
realistic_mask_btn.click(
|
| 914 |
+
fn=generate_realistic_mask,
|
| 915 |
+
inputs=[severity_input_image],
|
| 916 |
+
outputs=[wound_mask_input, gr.HTML()]
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
# Update slider when image is uploaded
|
| 920 |
+
depth_input_image.change(
|
| 921 |
+
fn=update_slider_on_image_upload,
|
| 922 |
+
inputs=[depth_input_image],
|
| 923 |
+
outputs=[points_slider]
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
# Modified depth submit function to store depth map
|
| 927 |
+
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
|
| 928 |
+
results = on_depth_submit(image, num_points, focal_x, focal_y)
|
| 929 |
+
# Extract depth map from results for severity analysis
|
| 930 |
+
depth_map = None
|
| 931 |
+
if image is not None:
|
| 932 |
+
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
|
| 933 |
+
# Normalize depth for severity analysis
|
| 934 |
+
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 935 |
+
depth_map = norm_depth.astype(np.uint8)
|
| 936 |
+
return results + [depth_map]
|
| 937 |
+
|
| 938 |
+
depth_submit.click(on_depth_submit_with_state,
|
| 939 |
+
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
|
| 940 |
+
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
|
| 941 |
+
|
| 942 |
+
# Load depth map to severity tab
|
| 943 |
+
def load_depth_to_severity(depth_map, original_image):
|
| 944 |
+
if depth_map is None:
|
| 945 |
+
return None, None, "β No depth map available. Please compute depth in Tab 2 first."
|
| 946 |
+
return depth_map, original_image, "β
Depth map loaded successfully!"
|
| 947 |
+
|
| 948 |
+
load_depth_btn.click(
|
| 949 |
+
fn=load_depth_to_severity,
|
| 950 |
+
inputs=[depth_map_state, depth_input_image],
|
| 951 |
+
outputs=[severity_depth_map, severity_input_image, gr.HTML()]
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
# Automatic severity analysis function
|
| 955 |
+
def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area):
|
| 956 |
+
if depth_map is None:
|
| 957 |
+
return "β Please load depth map from Tab 2 first."
|
| 958 |
+
|
| 959 |
+
# Update post-processing with user-defined minimum area
|
| 960 |
+
def post_process_with_area(mask):
|
| 961 |
+
return post_process_wound_mask(mask, min_area=min_area)
|
| 962 |
+
|
| 963 |
+
# Generate automatic wound mask
|
| 964 |
+
auto_mask = create_automatic_wound_mask(image, method=seg_method)
|
| 965 |
+
|
| 966 |
+
if auto_mask is None:
|
| 967 |
+
return "β Failed to generate automatic wound mask."
|
| 968 |
+
|
| 969 |
+
# Post-process the mask
|
| 970 |
+
processed_mask = post_process_with_area(auto_mask)
|
| 971 |
+
|
| 972 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 973 |
+
return "β No wound region detected. Try adjusting segmentation parameters or use manual mask."
|
| 974 |
+
|
| 975 |
+
# Analyze severity using the automatic mask
|
| 976 |
+
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
|
| 977 |
+
|
| 978 |
+
# Manual severity analysis function
|
| 979 |
+
def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
|
| 980 |
+
if depth_map is None:
|
| 981 |
+
return "β Please load depth map from Tab 2 first."
|
| 982 |
+
if wound_mask is None:
|
| 983 |
+
return "β Please upload a wound mask (binary image where white pixels represent the wound area)."
|
| 984 |
+
|
| 985 |
+
return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
|
| 986 |
+
|
| 987 |
+
# Preview automatic mask function
|
| 988 |
+
def preview_auto_mask(image, seg_method, min_area):
|
| 989 |
+
if image is None:
|
| 990 |
+
return None, "β Please load an image first."
|
| 991 |
+
|
| 992 |
+
# Generate automatic wound mask
|
| 993 |
+
auto_mask = create_automatic_wound_mask(image, method=seg_method)
|
| 994 |
+
|
| 995 |
+
if auto_mask is None:
|
| 996 |
+
return None, "β Failed to generate automatic wound mask."
|
| 997 |
+
|
| 998 |
+
# Post-process the mask
|
| 999 |
+
processed_mask = post_process_wound_mask(auto_mask, min_area=min_area)
|
| 1000 |
+
|
| 1001 |
+
if processed_mask is None or np.sum(processed_mask > 0) == 0:
|
| 1002 |
+
return None, "β No wound region detected. Try adjusting parameters."
|
| 1003 |
+
|
| 1004 |
+
return processed_mask, f"β
Auto mask generated using {seg_method} method!"
|
| 1005 |
+
|
| 1006 |
+
# Connect event handlers
|
| 1007 |
+
auto_severity_button.click(
|
| 1008 |
+
fn=run_auto_severity_analysis,
|
| 1009 |
+
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider,
|
| 1010 |
+
segmentation_method, min_area_slider],
|
| 1011 |
+
outputs=[severity_output]
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
manual_severity_button.click(
|
| 1015 |
+
fn=run_manual_severity_analysis,
|
| 1016 |
+
inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider],
|
| 1017 |
+
outputs=[severity_output]
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
preview_mask_btn.click(
|
| 1021 |
+
fn=preview_auto_mask,
|
| 1022 |
+
inputs=[severity_input_image, segmentation_method, min_area_slider],
|
| 1023 |
+
outputs=[wound_mask_input, gr.HTML()]
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
# Load shared image from classification tab
|
| 1027 |
+
def load_shared_image(shared_img):
|
| 1028 |
+
if shared_img is None:
|
| 1029 |
+
return gr.Image(), "β No image available from classification tab"
|
| 1030 |
+
|
| 1031 |
+
# Convert PIL image to numpy array for depth estimation
|
| 1032 |
+
if hasattr(shared_img, 'convert'):
|
| 1033 |
+
# It's a PIL image, convert to numpy
|
| 1034 |
+
img_array = np.array(shared_img)
|
| 1035 |
+
return img_array, "β
Image loaded from classification tab"
|
| 1036 |
+
else:
|
| 1037 |
+
# Already numpy array
|
| 1038 |
+
return shared_img, "β
Image loaded from classification tab"
|
| 1039 |
+
|
| 1040 |
+
load_shared_btn.click(
|
| 1041 |
+
fn=load_shared_image,
|
| 1042 |
+
inputs=[shared_image],
|
| 1043 |
+
outputs=[depth_input_image, gr.HTML()]
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
# Pass image to depth tab function
|
| 1047 |
+
def pass_image_to_depth(img):
|
| 1048 |
+
if img is None:
|
| 1049 |
+
return "β No image uploaded in classification tab"
|
| 1050 |
+
return "β
Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
|
| 1051 |
+
|
| 1052 |
+
pass_to_depth_btn.click(
|
| 1053 |
+
fn=pass_image_to_depth,
|
| 1054 |
+
inputs=[shared_image],
|
| 1055 |
+
outputs=[pass_status]
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
if __name__ == '__main__':
|
| 1059 |
+
demo.queue().launch(
|
| 1060 |
+
server_name="0.0.0.0",
|
| 1061 |
+
server_port=7860,
|
| 1062 |
+
share=True
|
| 1063 |
+
)
|