Commit ·
4bee0a6
0
Parent(s):
upload
Browse files- .gitattributes +35 -0
- README.md +76 -0
- ckpt_best.pt +3 -0
- config.snapshot.yaml +116 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-nd-4.0
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- medical-imaging
|
| 6 |
+
- 3d-cnn
|
| 7 |
+
- ultrasound
|
| 8 |
+
- focused-ultrasound
|
| 9 |
+
- transcranial-ultrasound
|
| 10 |
+
- reproduction
|
| 11 |
+
datasets:
|
| 12 |
+
- vinkle-srivastav/TFUScapes
|
| 13 |
+
language:
|
| 14 |
+
- en
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# DeepTFUS: base (run-1 reproduction)
|
| 18 |
+
|
| 19 |
+
*A reproduction attempt of DeepTFUS, proposed by [Srivastav et al. (arXiv:2505.12998)](https://arxiv.org/abs/2505.12998).*
|
| 20 |
+
|
| 21 |
+
This is the from-scratch baseline: 50 epochs on the paper recipe
|
| 22 |
+
(weighted-MSE + λ·gradient-L1, no focal-position aux), `base_width=16`
|
| 23 |
+
(3.4 M params), `pure-bf16`, `batch=4` at 256³ resolution. Given a 3D
|
| 24 |
+
head CT and a transducer placement, predicts the resulting in-skull
|
| 25 |
+
pressure field in <1 s on an H100 (≈ 50× faster than the k-Wave
|
| 26 |
+
physics simulator the dataset was generated from).
|
| 27 |
+
|
| 28 |
+
⭐ Partial reproduction: matched paper on `relative_l2`, did not match
|
| 29 |
+
on `focal_position_error_mm` (~2× worse) or `max_pressure_error`. This
|
| 30 |
+
gap motivated the 5 fine-tune variants in this model collection.
|
| 31 |
+
|
| 32 |
+
## Test results (n = 597 held-out CT × placement combinations)
|
| 33 |
+
|
| 34 |
+
| metric | paper | base (this model) | reproduced? |
|
| 35 |
+
|---|---:|---:|---|
|
| 36 |
+
| `relative_l2` mean ± std | 0.414 ± 0.086 | **0.384 ± 0.078** | ✅ Yes (slightly beats paper) |
|
| 37 |
+
| `relative_l2` median | 0.394 | **0.369** | ✅ |
|
| 38 |
+
| `focal_position_error_mm` mean ± std | 2.89 ± 2.14 | 6.49 ± 4.58 | ❌ No (~2.25× worse mean) |
|
| 39 |
+
| `focal_position_error_mm` median | 2.45 | 5.15 | ❌ |
|
| 40 |
+
| `max_pressure_error` mean ± std | 0.199 ± 0.158 | 0.225 ± 0.116 | ✅ Yes (within paper's std) |
|
| 41 |
+
| `max_pressure_error` median | 0.166 | 0.217 | (slightly above paper) |
|
| 42 |
+
| `focal_pressure_error` median | : | 0.528 | : |
|
| 43 |
+
| `focal_iou_fwhm` median | : | 0.143 | : |
|
| 44 |
+
| `inference_latency_s` (b=1, H100) | 11.4 (RTX 4090) | 0.233 | 49× faster (different HW) |
|
| 45 |
+
|
| 46 |
+
## Other variants and discussion
|
| 47 |
+
|
| 48 |
+
See the [Collection](https://huggingface.co/collections/masonwang025/deeptfus-reproduction-6a03e39286a09470b960511f)
|
| 49 |
+
for the 5 fine-tune variants built from this base ckpt, and the
|
| 50 |
+
[project page](https://masonjwang.com/projects/reproducing-deeptfus)
|
| 51 |
+
for the full reproduction story, interactive viewer, and discussion of
|
| 52 |
+
trade-offs.
|
| 53 |
+
|
| 54 |
+
## Usage
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
from huggingface_hub import hf_hub_download
|
| 58 |
+
import torch
|
| 59 |
+
|
| 60 |
+
ckpt = torch.load(
|
| 61 |
+
hf_hub_download("masonwang025/deeptfus-base", "ckpt_best.pt"),
|
| 62 |
+
map_location="cpu", weights_only=False,
|
| 63 |
+
)
|
| 64 |
+
# ckpt['model'] : state_dict for the model defined in masonwang025/deeptfus repo
|
| 65 |
+
# ckpt['config'] : training config (architecture knobs + train hyperparams)
|
| 66 |
+
# ckpt['epoch'] : 43 (best by val_rel_l2)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Model code: [github.com/masonwang025/deeptfus](https://github.com/masonwang025/deeptfus).
|
| 70 |
+
|
| 71 |
+
## Citation & License
|
| 72 |
+
|
| 73 |
+
Paper: Srivastav et al., [arXiv:2505.12998](https://arxiv.org/abs/2505.12998), 2025.
|
| 74 |
+
|
| 75 |
+
License: [CC-BY-NC-ND-4.0](https://creativecommons.org/licenses/by-nc-nd/4.0/),
|
| 76 |
+
matching the TFUScapes dataset license.
|
ckpt_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b362a03019d2f614913e9e5b375d5a501b41217ad43bfe68f20cdba40803059
|
| 3 |
+
size 20743755
|
config.snapshot.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepTFUS reproduction config (GPU defaults).
|
| 2 |
+
#
|
| 3 |
+
# Paper: Srivastav et al., "A Skull-Adaptive Framework for AI-Based 3D
|
| 4 |
+
# Transcranial Focused Ultrasound Simulation", arXiv:2505.12998. PDF at
|
| 5 |
+
# repo root: original-paper.pdf.
|
| 6 |
+
#
|
| 7 |
+
# For local Mac smoke testing, use `python scripts/local_verify.py`; that
|
| 8 |
+
# script overrides the heavy fields in-process. Do not edit this file for
|
| 9 |
+
# verification.
|
| 10 |
+
#
|
| 11 |
+
# Architectural specifics the paper does not pin down (base_width, depth,
|
| 12 |
+
# dynamic-conv kernel size, cross-attention head count, which encoder
|
| 13 |
+
# levels carry cross-attention) are flagged TENTATIVE below. An email
|
| 14 |
+
# is out to the authors; update those values when they reply.
|
| 15 |
+
|
| 16 |
+
data:
|
| 17 |
+
resolution: 256 # paper: 256^3 cropped subvolumes
|
| 18 |
+
n_transducer_points: 512 # uniform random subsample per step
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
base_width: 16 # TENTATIVE; paper does not specify capacity. We picked 16
|
| 22 |
+
# because the paper's own ablation (Table 1) shows
|
| 23 |
+
# DeepTFUStiny within 1σ of DeepTFUS on every metric, and
|
| 24 |
+
# neither variant has a published param count, so there is
|
| 25 |
+
# no paper-anchored "right size". 16 → 2.6M params, 7h
|
| 26 |
+
# at 50 epochs, fits without grad-checkpointing.
|
| 27 |
+
cond_dim: 128 # transducer embedding dim (z_T)
|
| 28 |
+
n_transducer_freqs: 8 # TENTATIVE; paper says "Fourier PE" no count
|
| 29 |
+
dynamic_conv_kernel: 3 # TENTATIVE; paper does not specify
|
| 30 |
+
cross_attention_heads: 4 # TENTATIVE
|
| 31 |
+
cross_attention_levels: [level1, level2, level3, bottleneck]
|
| 32 |
+
# Paper says "each encoder level"; level0 at
|
| 33 |
+
# 256^3 is OOM on a single 80GB H100 even with
|
| 34 |
+
# direction-1 disabled (the concat+1x1x1 fusion
|
| 35 |
+
# at that resolution adds ~15 GB on top of the
|
| 36 |
+
# ~55 GB base). level1..bottleneck fit and
|
| 37 |
+
# match the paper at each of those levels.
|
| 38 |
+
|
| 39 |
+
cross_attention_bidirectional: false
|
| 40 |
+
# Paper says "bi-directional ... two multi-head
|
| 41 |
+
# attention blocks". We default to false
|
| 42 |
+
# because direction 1 (CT-queries-z_T) is
|
| 43 |
+
# degenerate with z_T as a single token:
|
| 44 |
+
# softmax over one key is identically 1, so it
|
| 45 |
+
# collapses to a learned broadcast of a
|
| 46 |
+
# projection of z_T — same function the
|
| 47 |
+
# encoder's DynamicConv layers already serve.
|
| 48 |
+
# Flip to true to recover the paper-faithful
|
| 49 |
+
# bi-directional design.
|
| 50 |
+
|
| 51 |
+
use_film_decoder: false # Paper §3.2 puts FiLM in the decoding path.
|
| 52 |
+
# We default to false because the paper's own
|
| 53 |
+
# Table 1 "No FiLM" row shows lower
|
| 54 |
+
# max_pressure_error than full DeepTFUS and is
|
| 55 |
+
# within 1 sigma on every other metric. With
|
| 56 |
+
# FiLM off, decoder is a plain U-Net decoder.
|
| 57 |
+
# Flip to true to recover paper-faithful FiLM.
|
| 58 |
+
|
| 59 |
+
loss:
|
| 60 |
+
alpha: 5.0 # paper Eq 5: exponent in w(v)=exp(α(P-maxP))/E[...]
|
| 61 |
+
grad_weight: 0.1 # paper: lambda for the gradient-L1 term
|
| 62 |
+
|
| 63 |
+
train:
|
| 64 |
+
epochs: 50 # paper
|
| 65 |
+
batch_size: 4 # paper
|
| 66 |
+
lr: 0.001 # paper
|
| 67 |
+
weight_decay: 0.0001
|
| 68 |
+
grad_clip: 1.0
|
| 69 |
+
seed: 0
|
| 70 |
+
num_workers: 4 # for CUDA; local_verify forces 0 on MPS
|
| 71 |
+
val_every: 1 # epochs
|
| 72 |
+
|
| 73 |
+
# Precision / memory. The H100 path needs both of these to fit batch=4 at
|
| 74 |
+
# 256^3 (fp32 OOMs at >78 GiB on the first encoder level; autocast bf16 also
|
| 75 |
+
# OOMs because GroupNorm/FiLM upcast paths leak fp32 into downstream
|
| 76 |
+
# activations). See docs/1-reproduction-setup/synthetic_bench.md and
|
| 77 |
+
# investigation.md for the receipts. local_verify.py overrides these to
|
| 78 |
+
# fp32/false for the CPU/MPS smoke test.
|
| 79 |
+
precision: pure-bf16 # fp32 | pure-bf16. autocast bf16 is a trap on this
|
| 80 |
+
# model (GroupNorm/FiLM/DynConv leak fp32 promotions).
|
| 81 |
+
grad_checkpoint_encoder: false # at base_width=16 we fit without checkpointing.
|
| 82 |
+
# Flip true if you bump base_width back to 24+.
|
| 83 |
+
|
| 84 |
+
# Speedups (benched 2026-05-11; see docs/2-paper-audit/bench_speedups.txt).
|
| 85 |
+
# Both DEFAULT OFF after empirical findings on this specific model:
|
| 86 |
+
channels_last: false # channels_last_3d was 17% SLOWER on this model
|
| 87 |
+
# (1574 ms vs 1344 ms baseline) because cuDNN's
|
| 88 |
+
# 3D NHWC kernels do not have a good path for
|
| 89 |
+
# our depthwise grouped DynamicConv3d. Flip
|
| 90 |
+
# true only if you remove DynamicConv3d.
|
| 91 |
+
compile: false # torch.compile OOMs the up1 decoder stage
|
| 92 |
+
# because Inductor materializes the (B, 48, 256^3)
|
| 93 |
+
# = 6 GiB concat intermediate that eager mode
|
| 94 |
+
# streams through. Could be re-attempted with
|
| 95 |
+
# mode="reduce-overhead" or by compiling only
|
| 96 |
+
# sub-modules; left as future work.
|
| 97 |
+
|
| 98 |
+
# Run observability. WandB init no-ops if WANDB_API_KEY is unset OR
|
| 99 |
+
# wandb_project is null.
|
| 100 |
+
wandb_project: deeptfus-reproduction
|
| 101 |
+
wandb_entity: mason-wang # team entity; personal entities disabled on this account
|
| 102 |
+
|
| 103 |
+
# Per-epoch checkpoint frequency. 0 disables (only ckpt_best and ckpt_last
|
| 104 |
+
# are written). Set to e.g. 5 to keep ckpt_epoch_004.pt, ckpt_epoch_009.pt,
|
| 105 |
+
# ... in addition. Each ckpt is ~20 MB at bw=16; 50 epochs = ~1 GB total.
|
| 106 |
+
save_every_epochs: 0
|
| 107 |
+
|
| 108 |
+
eval:
|
| 109 |
+
voxel_size_mm: 0.5 # paper Section 3.3 (k-Wave grid)
|
| 110 |
+
focal_threshold_db: -6.0 # iso-surface threshold for focal-volume Dice
|
| 111 |
+
off_target_min_dist_mm: 10.0 # secondary-lobe radius exclusion
|
| 112 |
+
n_warmup_inferences: 3 # for latency measurement
|
| 113 |
+
save_predictions: true # write per-sample pred npz for figures
|
| 114 |
+
|
| 115 |
+
output:
|
| 116 |
+
run_dir: runs/deeptfus
|