stmdit-anon commited on
Commit
6044093
·
verified ·
1 Parent(s): 2cde047

Add ptpl-xattn-pma-p05 (PTPL-XAttn-PMA-B (p=0.5))

Browse files
ptpl-xattn-pma-p05/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - histopathology
5
+ - diffusion
6
+ - spatial-transcriptomics
7
+ - icml-2026-sd4h-workshop
8
+ ---
9
+
10
+ # PTPL-XAttn-PMA-B (p=0.5)
11
+
12
+ EMA-only inference weights for the **PTPL-XAttn-PMA-B (p=0.5)** row reported in the
13
+ ICML 2026 SD4H workshop submission *Transcriptomics-Conditioned Virtual Tissue
14
+ Synthesis via Diffusion Transformers*.
15
+
16
+ - **Source checkpoint**: `step_2617000_ema.pt`
17
+ - **Architecture**: see `training_config.yaml` in this folder.
18
+ - **License**: Apache-2.0.
19
+
20
+ See the umbrella repo README at `stmdit-anon/stmdit-checkpoints` for usage.
ptpl-xattn-pma-p05/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0505ed6ab7ffaa4c5dded254187e17b54bec2a98f5ef9f8434450d7d35356df
3
+ size 630725612
ptpl-xattn-pma-p05/training_config.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Configuration - PTPL-XAttn-PMA-B (Pooling by Multihead Attention)
2
+ # ===========================================================================
3
+ # Pooling by Multihead Attention: 32 learned seeds attend over CF tokens.
4
+ # Trained on PTPL features (DeepSpot predicted GE, corrected normalization).
5
+ # Uses best dropout from ablation study: p=0.5 (weights 25/25/25/25).
6
+ #
7
+ # Usage:
8
+ # run-training configs/training.yaml --lightning
9
+
10
+ output_dir: "/cluster/work/grlab/projects/projects2025-virtual-tissue-gen/scratch/10x_TuPro/PixCell-PTPL-XAttn/training/ptpl-xattn-pma-B"
11
+ device: "cuda"
12
+
13
+ # ============================================================================
14
+ # MODEL
15
+ # ============================================================================
16
+
17
+ model:
18
+ type: "pixart_ge_xattn"
19
+ variant: "B" # 130M params: depth=12, hidden=768, heads=12
20
+ ge_encoder_type: "cancerfoundation"
21
+ ge_hidden_dim: 512
22
+ cf_model_dir: "/cluster/home/pvlachas/leomed-home/pretrained_model_weights/cancer-foundation"
23
+ cf_freeze_backbone: true
24
+ ge_token_source: "pma"
25
+ ge_num_tokens: 32
26
+ ge_xattn_fusion: "xattn_only"
27
+
28
+ # ============================================================================
29
+ # DATA — PTPL features (DeepSpot predicted GE, same VAE/UNI as Visium)
30
+ # ============================================================================
31
+
32
+ data:
33
+ features_dir: "/cluster/work/grlab/projects/projects2025-virtual-tissue-gen/scratch/10x_TuPro-PTPL/feat-extraction/features_train"
34
+ load_gene_expression: true
35
+ load_gene_expression_binned: true
36
+ num_workers: 8
37
+ pin_memory: true
38
+ val_split: 0.1
39
+
40
+ # ============================================================================
41
+ # DIFFUSION
42
+ # ============================================================================
43
+
44
+ diffusion:
45
+ timesteps: 1000
46
+ beta_schedule: "linear"
47
+ image_size: 256
48
+ latent_size: 32
49
+
50
+ # ============================================================================
51
+ # TRAINING
52
+ # ============================================================================
53
+
54
+ training:
55
+ batch_size: 32
56
+ batch_size_val: 32
57
+ gradient_accumulation_steps: 4 # effective batch = 128
58
+ num_epochs: 1000
59
+ seed: 42
60
+ gradient_clip: 0.01
61
+ ema_rate: 0.9999
62
+
63
+ optimizer:
64
+ lr: 2e-5
65
+ weight_decay: 0.01
66
+ betas: [0.9, 0.999]
67
+
68
+ scheduler:
69
+ warmup_steps: 1000
70
+ min_lr_ratio: 0.1
71
+
72
+ # Best dropout from ablation: p=0.5 (25/25/25/25)
73
+ classifier_free_guidance:
74
+ conditioning_schedule:
75
+ - mask: [uni, ge] # full conditioning (UNI + GE active)
76
+ weight: 25
77
+ - mask: [ge] # GE only (UNI dropped)
78
+ weight: 25
79
+ - mask: [uni] # UNI only (GE dropped)
80
+ weight: 25
81
+ - mask: [] # unconditional (both dropped)
82
+ weight: 25
83
+
84
+ convergence:
85
+ monitor_timestep_range: [900, 1000]
86
+ patience: 10
87
+ min_epochs: 50
88
+
89
+ # ============================================================================
90
+ # DISTRIBUTED
91
+ # ============================================================================
92
+
93
+ distributed:
94
+ precision: "bf16-mixed"
95
+ compile_model: true
96
+
97
+ # ============================================================================
98
+ # CHECKPOINT
99
+ # ============================================================================
100
+
101
+ checkpoint:
102
+ save_every: 1000
103
+ resume: null
104
+
105
+ # ============================================================================
106
+ # LOGGING
107
+ # ============================================================================
108
+
109
+ logging:
110
+ log_every: 100
111
+ validate_every: 0
112
+ gpu_monitor: true
113
+ gpu_monitor_interval: 60.0
114
+ sample_every_epochs: 10
115
+ sample_every_steps: 0
116
+ num_samples: 16
117
+ sample_guidance_scale: 3.0
118
+ sample_num_steps: 20
119
+ sample_vae_path: "/cluster/home/pvlachas/leomed-home/pretrained_model_weights/stability-ai-stable-diffusion-3-5-large/models--stabilityai--stable-diffusion-3.5-large/snapshots/ceddf0a7fdf2064ea28e2213e3b84e4afa170a0f/vae"