masonwang025 commited on
Commit
4bee0a6
·
0 Parent(s):
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +76 -0
  3. ckpt_best.pt +3 -0
  4. config.snapshot.yaml +116 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-nd-4.0
3
+ library_name: pytorch
4
+ tags:
5
+ - medical-imaging
6
+ - 3d-cnn
7
+ - ultrasound
8
+ - focused-ultrasound
9
+ - transcranial-ultrasound
10
+ - reproduction
11
+ datasets:
12
+ - vinkle-srivastav/TFUScapes
13
+ language:
14
+ - en
15
+ ---
16
+
17
+ # DeepTFUS: base (run-1 reproduction)
18
+
19
+ *A reproduction attempt of DeepTFUS, proposed by [Srivastav et al. (arXiv:2505.12998)](https://arxiv.org/abs/2505.12998).*
20
+
21
+ This is the from-scratch baseline: 50 epochs on the paper recipe
22
+ (weighted-MSE + λ·gradient-L1, no focal-position aux), `base_width=16`
23
+ (3.4 M params), `pure-bf16`, `batch=4` at 256³ resolution. Given a 3D
24
+ head CT and a transducer placement, predicts the resulting in-skull
25
+ pressure field in <1 s on an H100 (≈ 50× faster than the k-Wave
26
+ physics simulator the dataset was generated from).
27
+
28
+ ⭐ Partial reproduction: matched paper on `relative_l2`, did not match
29
+ on `focal_position_error_mm` (~2× worse) or `max_pressure_error`. This
30
+ gap motivated the 5 fine-tune variants in this model collection.
31
+
32
+ ## Test results (n = 597 held-out CT × placement combinations)
33
+
34
+ | metric | paper | base (this model) | reproduced? |
35
+ |---|---:|---:|---|
36
+ | `relative_l2` mean ± std | 0.414 ± 0.086 | **0.384 ± 0.078** | ✅ Yes (slightly beats paper) |
37
+ | `relative_l2` median | 0.394 | **0.369** | ✅ |
38
+ | `focal_position_error_mm` mean ± std | 2.89 ± 2.14 | 6.49 ± 4.58 | ❌ No (~2.25× worse mean) |
39
+ | `focal_position_error_mm` median | 2.45 | 5.15 | ❌ |
40
+ | `max_pressure_error` mean ± std | 0.199 ± 0.158 | 0.225 ± 0.116 | ✅ Yes (within paper's std) |
41
+ | `max_pressure_error` median | 0.166 | 0.217 | (slightly above paper) |
42
+ | `focal_pressure_error` median | : | 0.528 | : |
43
+ | `focal_iou_fwhm` median | : | 0.143 | : |
44
+ | `inference_latency_s` (b=1, H100) | 11.4 (RTX 4090) | 0.233 | 49× faster (different HW) |
45
+
46
+ ## Other variants and discussion
47
+
48
+ See the [Collection](https://huggingface.co/collections/masonwang025/deeptfus-reproduction-6a03e39286a09470b960511f)
49
+ for the 5 fine-tune variants built from this base ckpt, and the
50
+ [project page](https://masonjwang.com/projects/reproducing-deeptfus)
51
+ for the full reproduction story, interactive viewer, and discussion of
52
+ trade-offs.
53
+
54
+ ## Usage
55
+
56
+ ```python
57
+ from huggingface_hub import hf_hub_download
58
+ import torch
59
+
60
+ ckpt = torch.load(
61
+ hf_hub_download("masonwang025/deeptfus-base", "ckpt_best.pt"),
62
+ map_location="cpu", weights_only=False,
63
+ )
64
+ # ckpt['model'] : state_dict for the model defined in masonwang025/deeptfus repo
65
+ # ckpt['config'] : training config (architecture knobs + train hyperparams)
66
+ # ckpt['epoch'] : 43 (best by val_rel_l2)
67
+ ```
68
+
69
+ Model code: [github.com/masonwang025/deeptfus](https://github.com/masonwang025/deeptfus).
70
+
71
+ ## Citation & License
72
+
73
+ Paper: Srivastav et al., [arXiv:2505.12998](https://arxiv.org/abs/2505.12998), 2025.
74
+
75
+ License: [CC-BY-NC-ND-4.0](https://creativecommons.org/licenses/by-nc-nd/4.0/),
76
+ matching the TFUScapes dataset license.
ckpt_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b362a03019d2f614913e9e5b375d5a501b41217ad43bfe68f20cdba40803059
3
+ size 20743755
config.snapshot.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepTFUS reproduction config (GPU defaults).
2
+ #
3
+ # Paper: Srivastav et al., "A Skull-Adaptive Framework for AI-Based 3D
4
+ # Transcranial Focused Ultrasound Simulation", arXiv:2505.12998. PDF at
5
+ # repo root: original-paper.pdf.
6
+ #
7
+ # For local Mac smoke testing, use `python scripts/local_verify.py`; that
8
+ # script overrides the heavy fields in-process. Do not edit this file for
9
+ # verification.
10
+ #
11
+ # Architectural specifics the paper does not pin down (base_width, depth,
12
+ # dynamic-conv kernel size, cross-attention head count, which encoder
13
+ # levels carry cross-attention) are flagged TENTATIVE below. An email
14
+ # is out to the authors; update those values when they reply.
15
+
16
+ data:
17
+ resolution: 256 # paper: 256^3 cropped subvolumes
18
+ n_transducer_points: 512 # uniform random subsample per step
19
+
20
+ model:
21
+ base_width: 16 # TENTATIVE; paper does not specify capacity. We picked 16
22
+ # because the paper's own ablation (Table 1) shows
23
+ # DeepTFUStiny within 1σ of DeepTFUS on every metric, and
24
+ # neither variant has a published param count, so there is
25
+ # no paper-anchored "right size". 16 → 2.6M params, 7h
26
+ # at 50 epochs, fits without grad-checkpointing.
27
+ cond_dim: 128 # transducer embedding dim (z_T)
28
+ n_transducer_freqs: 8 # TENTATIVE; paper says "Fourier PE" no count
29
+ dynamic_conv_kernel: 3 # TENTATIVE; paper does not specify
30
+ cross_attention_heads: 4 # TENTATIVE
31
+ cross_attention_levels: [level1, level2, level3, bottleneck]
32
+ # Paper says "each encoder level"; level0 at
33
+ # 256^3 is OOM on a single 80GB H100 even with
34
+ # direction-1 disabled (the concat+1x1x1 fusion
35
+ # at that resolution adds ~15 GB on top of the
36
+ # ~55 GB base). level1..bottleneck fit and
37
+ # match the paper at each of those levels.
38
+
39
+ cross_attention_bidirectional: false
40
+ # Paper says "bi-directional ... two multi-head
41
+ # attention blocks". We default to false
42
+ # because direction 1 (CT-queries-z_T) is
43
+ # degenerate with z_T as a single token:
44
+ # softmax over one key is identically 1, so it
45
+ # collapses to a learned broadcast of a
46
+ # projection of z_T — same function the
47
+ # encoder's DynamicConv layers already serve.
48
+ # Flip to true to recover the paper-faithful
49
+ # bi-directional design.
50
+
51
+ use_film_decoder: false # Paper §3.2 puts FiLM in the decoding path.
52
+ # We default to false because the paper's own
53
+ # Table 1 "No FiLM" row shows lower
54
+ # max_pressure_error than full DeepTFUS and is
55
+ # within 1 sigma on every other metric. With
56
+ # FiLM off, decoder is a plain U-Net decoder.
57
+ # Flip to true to recover paper-faithful FiLM.
58
+
59
+ loss:
60
+ alpha: 5.0 # paper Eq 5: exponent in w(v)=exp(α(P-maxP))/E[...]
61
+ grad_weight: 0.1 # paper: lambda for the gradient-L1 term
62
+
63
+ train:
64
+ epochs: 50 # paper
65
+ batch_size: 4 # paper
66
+ lr: 0.001 # paper
67
+ weight_decay: 0.0001
68
+ grad_clip: 1.0
69
+ seed: 0
70
+ num_workers: 4 # for CUDA; local_verify forces 0 on MPS
71
+ val_every: 1 # epochs
72
+
73
+ # Precision / memory. The H100 path needs both of these to fit batch=4 at
74
+ # 256^3 (fp32 OOMs at >78 GiB on the first encoder level; autocast bf16 also
75
+ # OOMs because GroupNorm/FiLM upcast paths leak fp32 into downstream
76
+ # activations). See docs/1-reproduction-setup/synthetic_bench.md and
77
+ # investigation.md for the receipts. local_verify.py overrides these to
78
+ # fp32/false for the CPU/MPS smoke test.
79
+ precision: pure-bf16 # fp32 | pure-bf16. autocast bf16 is a trap on this
80
+ # model (GroupNorm/FiLM/DynConv leak fp32 promotions).
81
+ grad_checkpoint_encoder: false # at base_width=16 we fit without checkpointing.
82
+ # Flip true if you bump base_width back to 24+.
83
+
84
+ # Speedups (benched 2026-05-11; see docs/2-paper-audit/bench_speedups.txt).
85
+ # Both DEFAULT OFF after empirical findings on this specific model:
86
+ channels_last: false # channels_last_3d was 17% SLOWER on this model
87
+ # (1574 ms vs 1344 ms baseline) because cuDNN's
88
+ # 3D NHWC kernels do not have a good path for
89
+ # our depthwise grouped DynamicConv3d. Flip
90
+ # true only if you remove DynamicConv3d.
91
+ compile: false # torch.compile OOMs the up1 decoder stage
92
+ # because Inductor materializes the (B, 48, 256^3)
93
+ # = 6 GiB concat intermediate that eager mode
94
+ # streams through. Could be re-attempted with
95
+ # mode="reduce-overhead" or by compiling only
96
+ # sub-modules; left as future work.
97
+
98
+ # Run observability. WandB init no-ops if WANDB_API_KEY is unset OR
99
+ # wandb_project is null.
100
+ wandb_project: deeptfus-reproduction
101
+ wandb_entity: mason-wang # team entity; personal entities disabled on this account
102
+
103
+ # Per-epoch checkpoint frequency. 0 disables (only ckpt_best and ckpt_last
104
+ # are written). Set to e.g. 5 to keep ckpt_epoch_004.pt, ckpt_epoch_009.pt,
105
+ # ... in addition. Each ckpt is ~20 MB at bw=16; 50 epochs = ~1 GB total.
106
+ save_every_epochs: 0
107
+
108
+ eval:
109
+ voxel_size_mm: 0.5 # paper Section 3.3 (k-Wave grid)
110
+ focal_threshold_db: -6.0 # iso-surface threshold for focal-volume Dice
111
+ off_target_min_dist_mm: 10.0 # secondary-lobe radius exclusion
112
+ n_warmup_inferences: 3 # for latency measurement
113
+ save_predictions: true # write per-sample pred npz for figures
114
+
115
+ output:
116
+ run_dir: runs/deeptfus