stmdit-anon commited on
Commit
3bf033a
·
verified ·
1 Parent(s): a5d4a6a

Add xattn-pma-p05 (XAttn-PMA-p05)

Browse files
xattn-pma-p05/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - histopathology
5
+ - diffusion
6
+ - spatial-transcriptomics
7
+ - icml-2026-sd4h-workshop
8
+ ---
9
+
10
+ # XAttn-PMA-p05
11
+
12
+ EMA-only inference weights for the **XAttn-PMA-p05** row reported in the
13
+ ICML 2026 SD4H workshop submission *Transcriptomics-Conditioned Virtual Tissue
14
+ Synthesis via Diffusion Transformers*.
15
+
16
+ - **Source checkpoint**: `step_2703000_ema.pt`
17
+ - **Architecture**: see `training_config.yaml` in this folder.
18
+ - **License**: Apache-2.0.
19
+
20
+ See the umbrella repo README at `stmdit-anon/stmdit-checkpoints` for usage.
xattn-pma-p05/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c95724274673f2265e9d591737621bb026c0542744a62113279f8e53934f20e7
3
+ size 630725612
xattn-pma-p05/training_config.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Configuration - PixArtGEXAttn-B with PMA Token Source
2
+ # Dropout p=0.5 — equal weight across all 4 conditioning regimes.
3
+ # bf16-mixed + torch.compile for faster training.
4
+
5
+ output_dir: "/cluster/work/grlab/projects/projects2025-virtual-tissue-gen/scratch/10x_TuPro/PixCell-GE/training/pixart-ge-cf-B-xattn-pma-p05"
6
+ device: "cuda"
7
+
8
+ # ============================================================================
9
+ # MODEL
10
+ # ============================================================================
11
+
12
+ model:
13
+ type: "pixart_ge_xattn"
14
+ variant: "B" # 130M params: depth=12, hidden=768, heads=12
15
+ ge_encoder_type: "cancerfoundation"
16
+ ge_hidden_dim: 512
17
+ cf_model_dir: "/cluster/home/pvlachas/leomed-home/pretrained_model_weights/cancer-foundation"
18
+ cf_freeze_backbone: true
19
+ ge_token_source: "pma"
20
+ ge_num_tokens: 32
21
+ ge_xattn_fusion: "xattn_only"
22
+
23
+ # ============================================================================
24
+ # DATA
25
+ # ============================================================================
26
+
27
+ data:
28
+ features_dir: "/cluster/work/grlab/projects/projects2025-virtual-tissue-gen/scratch/10x_TuPro/feat-extraction/features_train"
29
+ load_gene_expression: true
30
+ load_gene_expression_binned: true
31
+ num_workers: 8
32
+ pin_memory: true
33
+ val_split: 0.1
34
+
35
+ # ============================================================================
36
+ # DIFFUSION
37
+ # ============================================================================
38
+
39
+ diffusion:
40
+ timesteps: 1000
41
+ beta_schedule: "linear"
42
+ image_size: 256
43
+ latent_size: 32
44
+
45
+ # ============================================================================
46
+ # TRAINING
47
+ # ============================================================================
48
+
49
+ training:
50
+ batch_size: 32
51
+ batch_size_val: 32
52
+ gradient_accumulation_steps: 4 # effective batch = 128
53
+ num_epochs: 1000
54
+ seed: 42
55
+ gradient_clip: 0.01
56
+ ema_rate: 0.9999
57
+
58
+ optimizer:
59
+ lr: 2e-5
60
+ weight_decay: 0.01
61
+ betas: [0.9, 0.999]
62
+
63
+ scheduler:
64
+ warmup_steps: 1000
65
+ min_lr_ratio: 0.1
66
+
67
+ classifier_free_guidance:
68
+ conditioning_schedule:
69
+ - mask: [uni, ge] # full conditioning (UNI + GE active)
70
+ weight: 25
71
+ - mask: [ge] # GE only (UNI dropped)
72
+ weight: 25
73
+ - mask: [uni] # UNI only (GE dropped)
74
+ weight: 25
75
+ - mask: [] # unconditional (both dropped)
76
+ weight: 25
77
+
78
+ modality_monitor:
79
+ enabled: true
80
+ diagnostic_freq: 10
81
+ diagnostic_batch_size: 64
82
+
83
+ convergence:
84
+ monitor_timestep_range: [900, 1000]
85
+ patience: 50
86
+ min_epochs: 300
87
+
88
+ # ============================================================================
89
+ # DISTRIBUTED
90
+ # ============================================================================
91
+
92
+ distributed:
93
+ precision: "bf16-mixed"
94
+ compile_model: true
95
+
96
+ # ============================================================================
97
+ # CHECKPOINT
98
+ # ============================================================================
99
+
100
+ checkpoint:
101
+ save_every: 1000
102
+ resume: null
103
+
104
+ # ============================================================================
105
+ # LOGGING
106
+ # ============================================================================
107
+
108
+ logging:
109
+ log_every: 100
110
+ validate_every: 0
111
+ gpu_monitor: true
112
+ gpu_monitor_interval: 60.0
113
+ sample_every_epochs: 10
114
+ sample_every_steps: 0
115
+ num_samples: 16
116
+ sample_guidance_scale: 3.0
117
+ sample_num_steps: 20
118
+ sample_vae_path: "/cluster/home/pvlachas/leomed-home/pretrained_model_weights/stability-ai-stable-diffusion-3-5-large/models--stabilityai--stable-diffusion-3.5-large/snapshots/ceddf0a7fdf2064ea28e2213e3b84e4afa170a0f/vae"