LAXMAYDAY commited on
Commit
b34f97f
·
verified ·
1 Parent(s): 2d4351a

Official-consistency ablation: code, configs, paper-style writeup

Browse files

Mirror of GitHub commit 555d1c0. See https://github.com/Yidhar/sensenova-u1-lora-trainer/commit/555d1c0 for full message.

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +33 -0
  2. README.md +86 -29
  3. configs/default.yaml +31 -14
  4. configs/official_alignment.yaml +57 -0
  5. docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg +3 -0
  6. docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg +3 -0
  7. docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg +3 -0
  8. docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg +3 -0
  9. docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg +3 -0
  10. docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg +3 -0
  11. docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg +3 -0
  12. docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg +3 -0
  13. docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg +3 -0
  14. docs/assets/small_data_style_ablation/detail00_v18.jpg +3 -0
  15. docs/assets/small_data_style_ablation/detail00_v19.jpg +3 -0
  16. docs/assets/small_data_style_ablation/detail00_v19a.jpg +3 -0
  17. docs/assets/small_data_style_ablation/detail00_v19b.jpg +3 -0
  18. docs/assets/small_data_style_ablation/detail00_v19c.jpg +3 -0
  19. docs/assets/small_data_style_ablation/detail01_v18.jpg +3 -0
  20. docs/assets/small_data_style_ablation/detail01_v19.jpg +3 -0
  21. docs/assets/small_data_style_ablation/detail01_v19a.jpg +3 -0
  22. docs/assets/small_data_style_ablation/detail01_v19b.jpg +3 -0
  23. docs/assets/small_data_style_ablation/detail02_v18.jpg +3 -0
  24. docs/assets/small_data_style_ablation/detail02_v19.jpg +3 -0
  25. docs/assets/small_data_style_ablation/detail02_v19a.jpg +3 -0
  26. docs/assets/small_data_style_ablation/detail02_v19b.jpg +3 -0
  27. docs/assets/small_data_style_ablation/detail02_v19c.jpg +3 -0
  28. docs/assets/small_data_style_ablation/detail07_v18.jpg +3 -0
  29. docs/assets/small_data_style_ablation/detail07_v19.jpg +3 -0
  30. docs/assets/small_data_style_ablation/detail07_v19a.jpg +3 -0
  31. docs/assets/small_data_style_ablation/detail07_v19b.jpg +3 -0
  32. docs/assets/small_data_style_ablation/detail07_v19c.jpg +3 -0
  33. docs/assets/small_data_style_ablation/tb_active_loss_curves.png +0 -0
  34. docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png +0 -0
  35. docs/assets/small_data_style_ablation/tb_t_distribution.png +0 -0
  36. docs/assets/small_data_style_ablation/tb_v_mse_curves.png +0 -0
  37. docs/assets/small_data_style_ablation/tb_x0_mse_curves.png +0 -0
  38. docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg +3 -0
  39. docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg +3 -0
  40. docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg +3 -0
  41. docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg +3 -0
  42. docs/small_data_style_ablation.html +941 -0
  43. docs/small_data_style_ablation.pdf +3 -0
  44. train.sh +4 -4
  45. train_u1/README.md +4 -1
  46. train_u1/config.py +41 -5
  47. train_u1/data/collators.py +121 -34
  48. train_u1/data/datasets.py +9 -1
  49. train_u1/model/lora.py +108 -11
  50. train_u1/model/losses.py +67 -5
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg filter=lfs diff=lfs merge=lfs -text
2
+ docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg filter=lfs diff=lfs merge=lfs -text
3
+ docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg filter=lfs diff=lfs merge=lfs -text
4
+ docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg filter=lfs diff=lfs merge=lfs -text
5
+ docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg filter=lfs diff=lfs merge=lfs -text
6
+ docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg filter=lfs diff=lfs merge=lfs -text
7
+ docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg filter=lfs diff=lfs merge=lfs -text
8
+ docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg filter=lfs diff=lfs merge=lfs -text
9
+ docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg filter=lfs diff=lfs merge=lfs -text
10
+ docs/assets/small_data_style_ablation/detail00_v18.jpg filter=lfs diff=lfs merge=lfs -text
11
+ docs/assets/small_data_style_ablation/detail00_v19.jpg filter=lfs diff=lfs merge=lfs -text
12
+ docs/assets/small_data_style_ablation/detail00_v19a.jpg filter=lfs diff=lfs merge=lfs -text
13
+ docs/assets/small_data_style_ablation/detail00_v19b.jpg filter=lfs diff=lfs merge=lfs -text
14
+ docs/assets/small_data_style_ablation/detail00_v19c.jpg filter=lfs diff=lfs merge=lfs -text
15
+ docs/assets/small_data_style_ablation/detail01_v18.jpg filter=lfs diff=lfs merge=lfs -text
16
+ docs/assets/small_data_style_ablation/detail01_v19.jpg filter=lfs diff=lfs merge=lfs -text
17
+ docs/assets/small_data_style_ablation/detail01_v19a.jpg filter=lfs diff=lfs merge=lfs -text
18
+ docs/assets/small_data_style_ablation/detail01_v19b.jpg filter=lfs diff=lfs merge=lfs -text
19
+ docs/assets/small_data_style_ablation/detail02_v18.jpg filter=lfs diff=lfs merge=lfs -text
20
+ docs/assets/small_data_style_ablation/detail02_v19.jpg filter=lfs diff=lfs merge=lfs -text
21
+ docs/assets/small_data_style_ablation/detail02_v19a.jpg filter=lfs diff=lfs merge=lfs -text
22
+ docs/assets/small_data_style_ablation/detail02_v19b.jpg filter=lfs diff=lfs merge=lfs -text
23
+ docs/assets/small_data_style_ablation/detail02_v19c.jpg filter=lfs diff=lfs merge=lfs -text
24
+ docs/assets/small_data_style_ablation/detail07_v18.jpg filter=lfs diff=lfs merge=lfs -text
25
+ docs/assets/small_data_style_ablation/detail07_v19.jpg filter=lfs diff=lfs merge=lfs -text
26
+ docs/assets/small_data_style_ablation/detail07_v19a.jpg filter=lfs diff=lfs merge=lfs -text
27
+ docs/assets/small_data_style_ablation/detail07_v19b.jpg filter=lfs diff=lfs merge=lfs -text
28
+ docs/assets/small_data_style_ablation/detail07_v19c.jpg filter=lfs diff=lfs merge=lfs -text
29
+ docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg filter=lfs diff=lfs merge=lfs -text
30
+ docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg filter=lfs diff=lfs merge=lfs -text
31
+ docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg filter=lfs diff=lfs merge=lfs -text
32
+ docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg filter=lfs diff=lfs merge=lfs -text
33
+ docs/small_data_style_ablation.pdf filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -16,7 +16,7 @@ checkpoints drop straight into the official inference scripts.
16
  ```
17
 
18
  Fits on a 32 GB GPU (RTX 5090 / A100-40 / RTX 6000 Ada). Peak VRAM ~20 GB
19
- on the maintainer's 56-image hayateluc dataset at 2048².
20
 
21
  ---
22
 
@@ -25,16 +25,26 @@ on the maintainer's 56-image hayateluc dataset at 2048².
25
  - **Config-first**: every run is one YAML file (`configs/default.yaml`).
26
  - **Per-module rank + enable**: each LoRA target (`q_proj_mot_gen`, `mlp_mot_gen.down_proj`,
27
  `fm_modules.fm_head.0`, …) takes its own rank / alpha / on-off independently.
28
- - **Default = official coverage at rank 64**: the same 296 module wraps as
29
- upstream's 8-step distill LoRA (168 attn + 126 mlp + 2 fm_head), but at
30
- rank 64 instead of 128 — half the trainable params, half the on-disk size,
31
- retains full module surface.
 
 
 
 
 
 
32
  - **Upstream-format save**: load straight into `examples/t2i/inference.py`
33
  via `--lora_path`, or stack with the official 8-step LoRA.
34
  - **bf16 training, not 4/8-bit**. Earlier 4-bit nf4 LoRA training produced
35
  grid artefacts and limb collapse on the gen tower; switching the base to
36
  bf16 (with offload + static prefix-KV cache) eliminated both.
37
 
 
 
 
 
38
  ---
39
 
40
  ## Hardware
@@ -48,6 +58,29 @@ on the maintainer's 56-image hayateluc dataset at 2048².
48
  `bitsandbytes>=0.45` and `torch>=2.9` must be linked against your CUDA
49
  runtime. On RTX 5090 (sm_120) you'll likely need the cu128 torch wheel.
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ---
52
 
53
  ## Install
@@ -85,10 +118,10 @@ HF_HOME=$PWD/hf_cache python -m train_u1.scripts.install_modeling_into_snapshot
85
  └── … └── …
86
  ```
87
 
88
- Each `.txt` is a single-paragraph natural-language caption. Embed the
89
- artist credit / style anchor inside the description naturally don't
90
- rely on a hard-coded trigger prepend (`style.trigger` in the YAML is
91
- for backward compat only; the v18 recipe uses an empty trigger).
92
 
93
  **Optional**: append a `<think>...</think>` reasoning label inside the
94
  same `.txt` after a `---think---` delimiter line:
@@ -101,12 +134,11 @@ HF_HOME=$PWD/hf_cache python -m train_u1.scripts.install_modeling_into_snapshot
101
  6. **Explicit Prompt:** ...
102
  ```
103
 
104
- When present, the trainer renders this into the prompt template's
105
- `<think>` window so train-time distribution matches inference
106
- `--think-mode` (avoids prefix-distribution shift on long autoregressive
107
- think). For batch generation of think labels see Agent B's prompt in
108
- the v18 commit history; or write them yourself in the upstream
109
- 6-section format.
110
 
111
  **Parquet/arrow shards** (recommended for ≥ ~10k images, e.g. 1M
112
  scaling):
@@ -192,6 +224,7 @@ data:
192
  data_dir: dataset/my_style
193
  cap_max_pixels: 4194304 # 2048² hard cap per image
194
  snap_bucket: true # snap to upstream bucket grid
 
195
  # n_samples: 56 # cap dataset size; default = use everything
196
 
197
  style:
@@ -199,13 +232,15 @@ style:
199
  prompt_template: official # 'official' (recommended) | 'plain'
200
 
201
  lora:
202
- preset: default # = attn+mlp+fm_head, all r=64 a=64
203
- # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off;fm_head=r128a128"
204
  dropout: 0.0
205
 
206
  unfreeze: # full-FT (non-LoRA) regex patterns
207
  - '^fm_modules\.timestep_embedder\.'
208
  - '^fm_modules\.noise_scale_embedder\.'
 
 
209
 
210
  train:
211
  steps: 6000
@@ -214,6 +249,15 @@ train:
214
  shuffle: true
215
  grad_accum: 1
216
  checkpoint_every: 600
 
 
 
 
 
 
 
 
 
217
 
218
  runtime:
219
  keep_kvs_on_gpu: true
@@ -259,30 +303,41 @@ q_proj_mot_gen=r=128,a=64;k_proj_mot_gen=r=64,a=64 # asymmetric ranks
259
 
260
  | Preset | Coverage | Trainable LoRA params | Use when |
261
  |---|---|---|---|
262
- | `default` | 168 attn + 126 mlp + 2 fm_head, all r=64 | ~75 M | first try / production |
263
  | `attn_only` | 168 attn, r=64 | ~50 M | ablation |
264
  | `attn_mlp` | attn + mlp (no fm_head), r=64 | ~75 M | when fm_head is full-FT'd separately |
 
 
265
  | `official_r128` | exact upstream shape (r=128 across all 296 wraps) | ~298 M | parameter-matching upstream's 8-step LoRA |
266
 
 
 
 
 
 
 
 
267
  ---
268
 
269
  ## Stack with the official 8-step distill LoRA
270
 
271
  Upstream released a step-distillation LoRA that brings inference down to 8
272
- NFE at `cfg_scale=1.0`. You can train your own style LoRA **on top** of it.
 
 
 
273
 
274
  ```yaml
275
- # configs/stack_8step.yaml (already in this repo)
276
  runtime:
277
  upstream_lora_path: hf_cache/.../SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors
278
  upstream_lora_skip: ['fm_modules.fm_head'] # don't clobber our fm_head LoRA
279
  ```
280
 
281
- At sample time, also pass the same upstream LoRA:
282
 
283
  ```bash
284
- ./sample.sh configs/stack_8step.yaml \
285
- artifacts/my_style_8step/trainable_state.safetensors \
286
  --prompt "…" \
287
  --upstream-lora-path SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors \
288
  --upstream-lora-skip fm_modules.fm_head \
@@ -302,9 +357,8 @@ At sample time, also pass the same upstream LoRA:
302
  ├── pyproject.toml # package metadata
303
  ├── LICENSE # Apache-2.0
304
  ├── configs/
305
- │ ├── default.yaml # opinionated starting point
306
- ── v16c.yaml # production recipe (LoRA + ts/ns/vision/fm_head full-FT)
307
- │ └── stack_8step.yaml # train on top of 8-step distill LoRA
308
  ├── train_u1/ # importable package
309
  │ ├── config.py # YAML config schema
310
  │ ├── constants.py # pinned MODEL_SHA / CODE_COMMIT / arch constants
@@ -314,7 +368,7 @@ At sample time, also pass the same upstream LoRA:
314
  │ │ ├── lora_io.py # save/load + upstream merge
315
  │ │ ├── loader.py # bf16 base load + tower offload
316
  │ │ ├── wrapper.py # forward_t2i_step
317
- │ │ ├── losses.py # fm_loss_x0
318
  │ │ ├── patching.py # patchify/unpatchify
319
  │ │ └── …
320
  │ ├��─ scripts/
@@ -325,7 +379,10 @@ At sample time, also pass the same upstream LoRA:
325
  │ │ └── install_modeling_into_snapshot.py
326
  │ └── tests/
327
  ├── docs/
328
- ── SETUP.md # data layout, design rationale, pinned-upstream details
 
 
 
329
  ├── artifacts/ # local-only: checkpoints + sweeps (gitignored)
330
  ├── dataset/ # local-only: image+caption pairs (gitignored)
331
  ├── hf_cache/ # local-only: HF snapshot (gitignored)
@@ -346,4 +403,4 @@ At sample time, also pass the same upstream LoRA:
346
  release; consumed via the `upstream_lora_path` mechanism.
347
  - **This trainer** is licensed under Apache-2.0 (see `LICENSE`).
348
 
349
- **Thanks to comfy.org for the GPU power support. The open-source community will not forget.**
 
16
  ```
17
 
18
  Fits on a 32 GB GPU (RTX 5090 / A100-40 / RTX 6000 Ada). Peak VRAM ~20 GB
19
+ on the train dataset at 2048².
20
 
21
  ---
22
 
 
25
  - **Config-first**: every run is one YAML file (`configs/default.yaml`).
26
  - **Per-module rank + enable**: each LoRA target (`q_proj_mot_gen`, `mlp_mot_gen.down_proj`,
27
  `fm_modules.fm_head.0`, …) takes its own rank / alpha / on-off independently.
28
+ - **Experimental MoE target grammar**: A3B-style generation experts can be
29
+ addressed explicitly (`gen_moe_mlp`, `gen_moe_router`,
30
+ `mlp_mot_gen.experts.*.gate_proj`) without changing the stable 8B main path.
31
+ - **Default = small-data style baseline**: `configs/default.yaml` uses
32
+ `x0 + uniform t + no condition dropout`, short captions, LoRA on attn+mlp,
33
+ and full fine-tuning of the timestep/noise embedders, gen vision bridge, and
34
+ fm_head.
35
+ - **Official-alignment recipe is optional**: `configs/official_alignment.yaml`
36
+ keeps the public report knobs together for research ablations, but it is not
37
+ the safest first run for small style datasets.
38
  - **Upstream-format save**: load straight into `examples/t2i/inference.py`
39
  via `--lora_path`, or stack with the official 8-step LoRA.
40
  - **bf16 training, not 4/8-bit**. Earlier 4-bit nf4 LoRA training produced
41
  grid artefacts and limb collapse on the gen tower; switching the base to
42
  bf16 (with offload + static prefix-KV cache) eliminated both.
43
 
44
+ See the ablation write-up with training curves and sample grids:
45
+ [`docs/small_data_style_ablation.html`](docs/small_data_style_ablation.html)
46
+ or [`docs/small_data_style_ablation.pdf`](docs/small_data_style_ablation.pdf).
47
+
48
  ---
49
 
50
  ## Hardware
 
58
  `bitsandbytes>=0.45` and `torch>=2.9` must be linked against your CUDA
59
  runtime. On RTX 5090 (sm_120) you'll likely need the cu128 torch wheel.
60
 
61
+ ### A3B / MoE Status
62
+
63
+ The trainer now has experimental target grammar for future
64
+ `SenseNova-U1-A3B-MoT` generation-side MoE LoRA work:
65
+
66
+ ```yaml
67
+ lora:
68
+ spec: "attn=r8a8;gen_moe_mlp=r8a8;gen_moe_router=r8a8"
69
+ ```
70
+
71
+ This is a compatibility layer, not the main training path and not an end-to-end
72
+ A3B training claim. The stable release target remains `SenseNova-U1-8B-MoT`;
73
+ A3B training depends on public MoE runtime support that can instantiate the
74
+ `mlp_mot_gen.experts.*` modules.
75
+
76
+ Before training, estimate MoE LoRA size from metadata only:
77
+
78
+ ```bash
79
+ python -m train_u1.scripts.inspect_lora_targets \
80
+ --model path/to/A3B/config.json \
81
+ --spec "attn=r8a8;gen_moe_mlp=r8a8;fm_head=r8a8"
82
+ ```
83
+
84
  ---
85
 
86
  ## Install
 
118
  └── … └── …
119
  ```
120
 
121
+ Each `.txt` is a single-paragraph natural-language caption. Put the style
122
+ or artist anchor in a stable way and keep `style.trigger` aligned with how
123
+ you will sample later. The default config prepends that trigger to every
124
+ caption.
125
 
126
  **Optional**: append a `<think>...</think>` reasoning label inside the
127
  same `.txt` after a `---think---` delimiter line:
 
134
  6. **Explicit Prompt:** ...
135
  ```
136
 
137
+ Think labels are **ignored by default** because low-quality or highly
138
+ templated think text can dominate the prefix and hurt style binding. To use
139
+ them, set `data.use_think_labels: true` and evaluate with the same think
140
+ distribution at sample time. Do this only when your think labels are
141
+ curated and repeatable.
 
142
 
143
  **Parquet/arrow shards** (recommended for ≥ ~10k images, e.g. 1M
144
  scaling):
 
224
  data_dir: dataset/my_style
225
  cap_max_pixels: 4194304 # 2048² hard cap per image
226
  snap_bucket: true # snap to upstream bucket grid
227
+ use_think_labels: false # keep prefixes short by default
228
  # n_samples: 56 # cap dataset size; default = use everything
229
 
230
  style:
 
232
  prompt_template: official # 'official' (recommended) | 'plain'
233
 
234
  lora:
235
+ preset: attn_mlp_no_head # attn+mlp LoRA; fm_head is full-FT below
236
+ # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off"
237
  dropout: 0.0
238
 
239
  unfreeze: # full-FT (non-LoRA) regex patterns
240
  - '^fm_modules\.timestep_embedder\.'
241
  - '^fm_modules\.noise_scale_embedder\.'
242
+ - '^fm_modules\.vision_model_mot_gen\.'
243
+ - '^fm_modules\.fm_head\.'
244
 
245
  train:
246
  steps: 6000
 
249
  shuffle: true
250
  grad_accum: 1
251
  checkpoint_every: 600
252
+ # Small-data style baseline. See docs/small_data_style_ablation.html before
253
+ # switching to the official-alignment recipe.
254
+ loss_type: x0
255
+ t_dist: uniform
256
+ t_logit_mean: -0.8
257
+ t_logit_std: 0.8
258
+ # huber_delta: 1.0 # only used for *_huber
259
+ cond_dropout_text: 0.0
260
+ cond_dropout_both: 0.0
261
 
262
  runtime:
263
  keep_kvs_on_gpu: true
 
303
 
304
  | Preset | Coverage | Trainable LoRA params | Use when |
305
  |---|---|---|---|
306
+ | `default` | 168 attn + 126 mlp + 2 fm_head, all r=64 | ~75 M | match upstream 8-step LoRA coverage |
307
  | `attn_only` | 168 attn, r=64 | ~50 M | ablation |
308
  | `attn_mlp` | attn + mlp (no fm_head), r=64 | ~75 M | when fm_head is full-FT'd separately |
309
+ | `attn_only_no_head` | alias for `attn_only`; explicit no-fm_head intent | ~50 M | conservative small-data style training |
310
+ | `attn_mlp_no_head` | alias for `attn_mlp`; explicit no-fm_head intent | ~75 M | conservative small-data style training |
311
  | `official_r128` | exact upstream shape (r=128 across all 296 wraps) | ~298 M | parameter-matching upstream's 8-step LoRA |
312
 
313
+ The shipped `configs/default.yaml` uses `attn_mlp_no_head` and full-FTs
314
+ `fm_head` separately because that was the most stable small-data baseline in
315
+ our ablations. The `default` preset name inside the LoRA parser still means
316
+ "match upstream 8-step LoRA coverage"; use it only when that exact module
317
+ coverage is what you want. For report-alignment research, start from
318
+ `configs/official_alignment.yaml`.
319
+
320
  ---
321
 
322
  ## Stack with the official 8-step distill LoRA
323
 
324
  Upstream released a step-distillation LoRA that brings inference down to 8
325
+ NFE at `cfg_scale=1.0`. You can train your own style LoRA **on top** of it
326
+ by setting `runtime.upstream_lora_path` in your YAML — at training time we
327
+ bake-in the official 8-step delta into the bf16 base (skipping `fm_head` so
328
+ we don't clobber our own fm_head LoRA), then wrap our LoRA on top.
329
 
330
  ```yaml
 
331
  runtime:
332
  upstream_lora_path: hf_cache/.../SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors
333
  upstream_lora_skip: ['fm_modules.fm_head'] # don't clobber our fm_head LoRA
334
  ```
335
 
336
+ At sample time, pass the same upstream LoRA and use 8 steps at cfg=1.0:
337
 
338
  ```bash
339
+ ./sample.sh configs/my_style.yaml \
340
+ artifacts/my_style/trainable_state.safetensors \
341
  --prompt "…" \
342
  --upstream-lora-path SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors \
343
  --upstream-lora-skip fm_modules.fm_head \
 
357
  ├── pyproject.toml # package metadata
358
  ├── LICENSE # Apache-2.0
359
  ├── configs/
360
+ │ ├── default.yaml # recommended small-data style baseline
361
+ ── official_alignment.yaml # optional report-alignment research config
 
362
  ├── train_u1/ # importable package
363
  │ ├── config.py # YAML config schema
364
  │ ├── constants.py # pinned MODEL_SHA / CODE_COMMIT / arch constants
 
368
  │ │ ├── lora_io.py # save/load + upstream merge
369
  │ │ ├── loader.py # bf16 base load + tower offload
370
  │ │ ├── wrapper.py # forward_t2i_step
371
+ │ │ ├── losses.py # fm_loss_x0 / fm_loss_v / fm_loss dispatcher
372
  │ │ ├── patching.py # patchify/unpatchify
373
  │ │ └── …
374
  │ ├��─ scripts/
 
379
  │ │ └── install_modeling_into_snapshot.py
380
  │ └── tests/
381
  ├── docs/
382
+ ── SETUP.md # data layout, design rationale, pinned-upstream details
383
+ │ ├── small_data_style_ablation.html
384
+ │ ├── small_data_style_ablation.pdf
385
+ │ └── assets/ # figures used by the ablation document
386
  ├── artifacts/ # local-only: checkpoints + sweeps (gitignored)
387
  ├── dataset/ # local-only: image+caption pairs (gitignored)
388
  ├── hf_cache/ # local-only: HF snapshot (gitignored)
 
403
  release; consumed via the `upstream_lora_path` mechanism.
404
  - **This trainer** is licensed under Apache-2.0 (see `LICENSE`).
405
 
406
+ **Thanks to comfy.org for the GPU power support. The open-source community will not forget.**
configs/default.yaml CHANGED
@@ -1,9 +1,11 @@
1
  # SenseNova-U1 LoRA trainer — default config
2
  #
3
- # This is the recommended starting point. Trains a LoRA whose module
4
- # coverage matches the official 8-step distill LoRA (168 attn + 126 mlp +
5
- # 2 fm_head = 296 wraps), but at rank 64 instead of upstream's rank 128.
6
- # Roughly 75 M trainable LoRA params, ~750 MB on disk in fp32.
 
 
7
  #
8
  # Edit `data.data_dir`, `style.trigger`, and `run_name` and you should be
9
  # good to go for a 32 GB single-GPU run.
@@ -17,6 +19,9 @@ data:
17
  cap_max_pixels: 4194304
18
  # Snap each image to nearest official bucket so train shape == infer shape.
19
  snap_bucket: true
 
 
 
20
  # n_samples: cap on dataset size (omit / null = use entire data_dir)
21
 
22
  style:
@@ -24,20 +29,21 @@ style:
24
  prompt_template: official # 'official' (recommended) | 'plain'
25
 
26
  lora:
27
- # Named preset: 'default' = official coverage at rank 64.
28
- # Other presets: 'attn_only', 'attn_mlp', 'official_r128'.
29
- preset: default
 
 
30
  # Or override per-target:
31
- # spec: "attn=r64a64;mlp=r64a64;fm_head=r64a64;mlp_mot_gen.down_proj=off"
32
  dropout: 0.0
33
 
34
  unfreeze:
35
- # Full-finetune (non-LoRA) regex patterns. Empty = LoRA-only training.
36
- # Examples for our v16c-style recipe:
37
- # - '^fm_modules\.timestep_embedder\.'
38
- # - '^fm_modules\.noise_scale_embedder\.'
39
- # - '^fm_modules\.vision_model_mot_gen\.'
40
- []
41
 
42
  train:
43
  steps: 6000
@@ -47,6 +53,17 @@ train:
47
  grad_accum: 1
48
  checkpoint_every: 600 # 0 = only save final state
49
  # checkpoint_dir: null # default: artifacts/{run_name}/checkpoints
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  runtime:
52
  keep_kvs_on_gpu: true # 56 samples × ~59 MB = ~3.3 GB; keep on GPU for speed
 
1
  # SenseNova-U1 LoRA trainer — default config
2
  #
3
+ # This is the recommended starting point for small/medium style training.
4
+ # It follows the local small-data baseline: x0 loss + uniform t + no train-time
5
+ # condition dropout, explicit style trigger, LoRA on attn+mlp, and full-FT
6
+ # of ts/ns/vision/fm_head. Public official-consistency knobs are available
7
+ # in configs/official_alignment.yaml for research ablations, but they are
8
+ # not the default because the ablation study showed worse small-data sampling.
9
  #
10
  # Edit `data.data_dir`, `style.trigger`, and `run_name` and you should be
11
  # good to go for a 32 GB single-GPU run.
 
19
  cap_max_pixels: 4194304
20
  # Snap each image to nearest official bucket so train shape == infer shape.
21
  snap_bucket: true
22
+ # Keep the prefix short by default. Set true only when your think labels are
23
+ # curated and your sampling path uses the same think distribution.
24
+ use_think_labels: false
25
  # n_samples: cap on dataset size (omit / null = use entire data_dir)
26
 
27
  style:
 
29
  prompt_template: official # 'official' (recommended) | 'plain'
30
 
31
  lora:
32
+ # Small-data baseline: LoRA on attn+mlp only. fm_head is full-FT'd below.
33
+ # 'default' remains available when you explicitly want upstream 8-step
34
+ # coverage including fm_head LoRA.
35
+ # A3B/MoE presets are experimental compatibility helpers, not the main path.
36
+ preset: attn_mlp_no_head
37
  # Or override per-target:
38
+ # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off"
39
  dropout: 0.0
40
 
41
  unfreeze:
42
+ # Full-finetune (non-LoRA) regex patterns for the small-data baseline.
43
+ - '^fm_modules\.timestep_embedder\.'
44
+ - '^fm_modules\.noise_scale_embedder\.'
45
+ - '^fm_modules\.vision_model_mot_gen\.'
46
+ - '^fm_modules\.fm_head\.'
 
47
 
48
  train:
49
  steps: 6000
 
53
  grad_accum: 1
54
  checkpoint_every: 600 # 0 = only save final state
55
  # checkpoint_dir: null # default: artifacts/{run_name}/checkpoints
56
+ # Local small-data baseline. See docs/small_data_style_ablation.html before changing these:
57
+ # official-style v-loss/logit-normal/dropout is an optional ablation path,
58
+ # not the small-data default.
59
+ loss_type: x0 # 'x0' | 'v' | 'x0_huber' | 'v_huber'
60
+ t_dist: uniform # 'uniform' | 'logit_normal'
61
+ t_logit_mean: -0.8 # only used for logit_normal
62
+ t_logit_std: 0.8 # only used for logit_normal
63
+ # huber_delta: 1.0 # only used for *_huber
64
+ # CFG / condition dropout. Keep off by default for small style datasets.
65
+ cond_dropout_text: 0.0
66
+ cond_dropout_both: 0.0
67
 
68
  runtime:
69
  keep_kvs_on_gpu: true # 56 samples × ~59 MB = ~3.3 GB; keep on GPU for speed
configs/official_alignment.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optional official-consistency research config.
2
+ #
3
+ # This keeps the public SenseNova-U1 report knobs together:
4
+ # - v-loss
5
+ # - logit-normal t sampler, mean=-0.8 std=0.8
6
+ # - train-time condition dropout, text=0.10 and both=0.10
7
+ # - think labels enabled when the dataset provides them
8
+ #
9
+ # Do not treat this as the small-data style-training default. The ablation
10
+ # study showed weaker sampling than the default baseline on the 56-image
11
+ # Hayateluc task. Use this config when you explicitly want to study report
12
+ # alignment, not when you want the safest first run.
13
+
14
+ run_name: official_alignment
15
+
16
+ data:
17
+ data_dir: dataset/my_style
18
+ cap_max_pixels: 4194304
19
+ snap_bucket: true
20
+ use_think_labels: true
21
+
22
+ style:
23
+ trigger: "" # prefer captions that carry their own style anchor
24
+ prompt_template: official
25
+
26
+ lora:
27
+ # Matches upstream 8-step LoRA coverage at r=64.
28
+ preset: default
29
+ dropout: 0.0
30
+
31
+ unfreeze:
32
+ - '^fm_modules\.timestep_embedder\.'
33
+ - '^fm_modules\.noise_scale_embedder\.'
34
+ - '^fm_modules\.vision_model_mot_gen\.'
35
+ - '^fm_modules\.fm_head\.'
36
+
37
+ train:
38
+ steps: 6000
39
+ lr: 5.0e-5
40
+ seed: 0
41
+ shuffle: true
42
+ grad_accum: 1
43
+ checkpoint_every: 600
44
+ loss_type: v
45
+ t_dist: logit_normal
46
+ t_logit_mean: -0.8
47
+ t_logit_std: 0.8
48
+ cond_dropout_text: 0.10
49
+ cond_dropout_both: 0.10
50
+
51
+ runtime:
52
+ # Long think prefixes and unconditional-prefix precompute can push VRAM up.
53
+ # Offload KVs and checkpoint all layers unless you have confirmed headroom.
54
+ keep_kvs_on_gpu: false
55
+ gc_skip_last: 0
56
+ device: cuda
57
+ cpu_device: cpu
docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg ADDED

Git LFS Details

  • SHA256: ad42ff152cf0b9b60c5fad6c1248477268b9b0fcf982e4d2c6fa8c4e4ad84e24
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB
docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg ADDED

Git LFS Details

  • SHA256: d8865dc230203718b62f2738db1dee0526146be58b2e0c37cec7d8ec1490011c
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg ADDED

Git LFS Details

  • SHA256: 43e8cd49b524ca6d63b8bd5ab3200bce8def0f5ac0bee8cb4062b061cc7efce4
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg ADDED

Git LFS Details

  • SHA256: f0560e2bed21e4f87c36642c284226987756d9982690df4dd30d941856a6555a
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg ADDED

Git LFS Details

  • SHA256: 7bac586857be3f55de3474ede157710b41a601b2cc92957a82f46748fde99437
  • Pointer size: 131 Bytes
  • Size of remote file: 444 kB
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg ADDED

Git LFS Details

  • SHA256: 5c57298303e7bee189eb88ca23700cabd19309f0bf1b34f5004503dfe551958c
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg ADDED

Git LFS Details

  • SHA256: 579a2dd4e839b515beaf3c8519a537ebe18234a193719445959e2e09953bde87
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg ADDED

Git LFS Details

  • SHA256: 9744f109e7b7ec6686954e0c79e82c18a0ac3007148f38760bcdeb57b058e4c3
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg ADDED

Git LFS Details

  • SHA256: 7535ce9e3a7a11d91098b8ab1a33a48a404711f9e8826341cfc009836a7e808a
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
docs/assets/small_data_style_ablation/detail00_v18.jpg ADDED

Git LFS Details

  • SHA256: d170e9029386c69037819b3d7f1612ec653b2f645cee4ee6ac4219f4abed073a
  • Pointer size: 131 Bytes
  • Size of remote file: 554 kB
docs/assets/small_data_style_ablation/detail00_v19.jpg ADDED

Git LFS Details

  • SHA256: 7d9421b5a44907af560906ca5cb28d54536308229eb21307e73a917145304b48
  • Pointer size: 131 Bytes
  • Size of remote file: 446 kB
docs/assets/small_data_style_ablation/detail00_v19a.jpg ADDED

Git LFS Details

  • SHA256: c5202b4f7774bde04f8e4ef589c8c925f44a9875f8b8bc5964a4d49c047f0562
  • Pointer size: 131 Bytes
  • Size of remote file: 460 kB
docs/assets/small_data_style_ablation/detail00_v19b.jpg ADDED

Git LFS Details

  • SHA256: 8d5406077a9074310ff30bd98242b9f3af4e15ba5f2f992961583129bb73c167
  • Pointer size: 131 Bytes
  • Size of remote file: 552 kB
docs/assets/small_data_style_ablation/detail00_v19c.jpg ADDED

Git LFS Details

  • SHA256: f2e38bfc550d83bfd0f0ca34ade521630b68499561997cfa144d2dc874f6a650
  • Pointer size: 131 Bytes
  • Size of remote file: 461 kB
docs/assets/small_data_style_ablation/detail01_v18.jpg ADDED

Git LFS Details

  • SHA256: 09e26f8cfd7a15b7983c45e8121de43bd2018921206f9b149024900224eeebee
  • Pointer size: 131 Bytes
  • Size of remote file: 228 kB
docs/assets/small_data_style_ablation/detail01_v19.jpg ADDED

Git LFS Details

  • SHA256: fa0147b231dd67646980751cdf3b17c25f74b32b975b757cae535be79dbbb7c2
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
docs/assets/small_data_style_ablation/detail01_v19a.jpg ADDED

Git LFS Details

  • SHA256: da197305288ed7b139e1fbe2d16dbd91f6a44db1d8c83fa16dc25e8fc0f1d1d5
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB
docs/assets/small_data_style_ablation/detail01_v19b.jpg ADDED

Git LFS Details

  • SHA256: a25ee998774d952c30f500488eb7c502d0ec1c86580d8b4c28e84ab84f32427c
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
docs/assets/small_data_style_ablation/detail02_v18.jpg ADDED

Git LFS Details

  • SHA256: e6a91aff5fbe8485c8aee5551574e96bfdc0135821406191dda4bc67507e00a2
  • Pointer size: 131 Bytes
  • Size of remote file: 359 kB
docs/assets/small_data_style_ablation/detail02_v19.jpg ADDED

Git LFS Details

  • SHA256: 31a20bc3c7cd598240aceb6e5ef86e76931da4f8e30ccc5aa3398fe8d0c05715
  • Pointer size: 131 Bytes
  • Size of remote file: 288 kB
docs/assets/small_data_style_ablation/detail02_v19a.jpg ADDED

Git LFS Details

  • SHA256: 98ff82f6d3e5f03a8a05f74048ecbdf3f9f86d2bd748b7f3891871a39b61e8bd
  • Pointer size: 131 Bytes
  • Size of remote file: 289 kB
docs/assets/small_data_style_ablation/detail02_v19b.jpg ADDED

Git LFS Details

  • SHA256: 8de9b562cb11b3fd5ab77e127bc0ebb0f9b157d4a9e9f1461cdadbdc5bd1855a
  • Pointer size: 131 Bytes
  • Size of remote file: 344 kB
docs/assets/small_data_style_ablation/detail02_v19c.jpg ADDED

Git LFS Details

  • SHA256: 4ea2874cf2939cf68327fcde2571c460c45facf9dbfd41b6996fd54c16c474a8
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB
docs/assets/small_data_style_ablation/detail07_v18.jpg ADDED

Git LFS Details

  • SHA256: d32a35bc3a20486cbc9e88434889773a1263d5e0098b49a4bf23ab81ba8eee0f
  • Pointer size: 131 Bytes
  • Size of remote file: 466 kB
docs/assets/small_data_style_ablation/detail07_v19.jpg ADDED

Git LFS Details

  • SHA256: 3d5f3b96288f70bc5eb7d601cf58771a5416956dafe62969470a31e170839bd8
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB
docs/assets/small_data_style_ablation/detail07_v19a.jpg ADDED

Git LFS Details

  • SHA256: fcca9a543c59209fd792465a93a2f6374fce778770d4b004526ffb244bb1f075
  • Pointer size: 131 Bytes
  • Size of remote file: 586 kB
docs/assets/small_data_style_ablation/detail07_v19b.jpg ADDED

Git LFS Details

  • SHA256: 6c5c4dcd0b646817443f71479c8ffcd0430e8d596c844b56a7f49c9aff7d23dd
  • Pointer size: 131 Bytes
  • Size of remote file: 446 kB
docs/assets/small_data_style_ablation/detail07_v19c.jpg ADDED

Git LFS Details

  • SHA256: 3080f6442f76e1a522c6e776a738899562d918c004f5d45a676990441b08b57c
  • Pointer size: 131 Bytes
  • Size of remote file: 464 kB
docs/assets/small_data_style_ablation/tb_active_loss_curves.png ADDED
docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png ADDED
docs/assets/small_data_style_ablation/tb_t_distribution.png ADDED
docs/assets/small_data_style_ablation/tb_v_mse_curves.png ADDED
docs/assets/small_data_style_ablation/tb_x0_mse_curves.png ADDED
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg ADDED

Git LFS Details

  • SHA256: dd484d8b36543be0ce1f3b348f28fb7fe4a1d21f750c412377f1333296442fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 855 kB
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg ADDED

Git LFS Details

  • SHA256: a5441764becf0a4154547831c3e6bf81fee014e0fccb84ae09c40d2992e3ecb8
  • Pointer size: 131 Bytes
  • Size of remote file: 295 kB
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg ADDED

Git LFS Details

  • SHA256: 47876dcf4cf7ec68604f7f0e229d72d350ac781f2fc5e19e3453688e523b13ee
  • Pointer size: 131 Bytes
  • Size of remote file: 308 kB
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg ADDED

Git LFS Details

  • SHA256: 8476298d8dd7c75a4ee97a916e38af85b69d489630554f12de94178c3bb93891
  • Pointer size: 131 Bytes
  • Size of remote file: 315 kB
docs/small_data_style_ablation.html ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="zh-CN">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <title>对 SenseNova-U1 官方训练配置在小数据风格 LoRA 场景的消融研究</title>
6
+ <style>
7
+ @page {
8
+ size: A4;
9
+ margin: 22mm 18mm 22mm 18mm;
10
+ }
11
+ :root {
12
+ --ink: #1a1a1a;
13
+ --muted: #555;
14
+ --line: #c8c8c8;
15
+ --rule: #2a2a2a;
16
+ --soft: #f6f6f4;
17
+ --accent: #6b3410;
18
+ }
19
+ * { box-sizing: border-box; }
20
+ body {
21
+ margin: 0;
22
+ color: var(--ink);
23
+ font-family: "Source Han Serif SC", "Noto Serif CJK SC", "STSong",
24
+ "Latin Modern Roman", "Linux Libertine O", "Times New Roman", Georgia, serif;
25
+ font-size: 10.5pt;
26
+ line-height: 1.55;
27
+ background: white;
28
+ text-align: justify;
29
+ hyphens: auto;
30
+ }
31
+ h1, h2, h3, h4 {
32
+ font-family: "Source Han Sans SC", "Noto Sans CJK SC",
33
+ "Latin Modern Sans", "Helvetica Neue", Arial, sans-serif;
34
+ font-weight: 700;
35
+ line-height: 1.25;
36
+ page-break-after: avoid;
37
+ color: var(--ink);
38
+ }
39
+ h1 {
40
+ font-size: 19pt;
41
+ margin: 0 0 6pt;
42
+ text-align: center;
43
+ letter-spacing: 0.5pt;
44
+ }
45
+ h2 {
46
+ font-size: 13pt;
47
+ margin: 24pt 0 8pt;
48
+ padding-bottom: 4pt;
49
+ border-bottom: 0.7pt solid var(--rule);
50
+ }
51
+ h3 {
52
+ font-size: 11pt;
53
+ margin: 14pt 0 4pt;
54
+ }
55
+ h4 {
56
+ font-size: 10pt;
57
+ margin: 10pt 0 3pt;
58
+ font-style: italic;
59
+ font-weight: 600;
60
+ }
61
+ p { margin: 5pt 0; text-indent: 2em; }
62
+ p.noindent, h1 + p, h2 + p, h3 + p, h4 + p,
63
+ .abstract p, .figure p, .table-wrap p, .equation,
64
+ li > p, .meta { text-indent: 0; }
65
+ .meta {
66
+ text-align: center;
67
+ color: var(--muted);
68
+ font-size: 10pt;
69
+ margin: 0 0 14pt;
70
+ }
71
+ .affil {
72
+ text-align: center;
73
+ color: var(--muted);
74
+ font-size: 9.5pt;
75
+ margin: 0 0 18pt;
76
+ }
77
+ .abstract {
78
+ margin: 4pt 8% 18pt;
79
+ padding: 10pt 14pt;
80
+ border: 0.5pt solid var(--line);
81
+ background: var(--soft);
82
+ }
83
+ .abstract h3 {
84
+ margin: 0 0 4pt;
85
+ text-align: center;
86
+ font-size: 10.5pt;
87
+ letter-spacing: 1pt;
88
+ text-transform: uppercase;
89
+ }
90
+ .abstract p { text-indent: 0; margin: 4pt 0; font-size: 10pt; }
91
+ .keywords {
92
+ font-size: 9.5pt;
93
+ margin-top: 8pt;
94
+ }
95
+ .keywords strong { letter-spacing: 0.5pt; }
96
+ code, pre, .mono {
97
+ font-family: "Latin Modern Mono", "DejaVu Sans Mono", "Consolas", monospace;
98
+ }
99
+ code { font-size: 9.5pt; background: #f0f0ec; padding: 0 2pt; border-radius: 2pt; }
100
+ pre {
101
+ background: #f4f4f0;
102
+ color: #1a1a1a;
103
+ padding: 8pt 10pt;
104
+ font-size: 9pt;
105
+ border: 0.5pt solid var(--line);
106
+ border-radius: 3pt;
107
+ overflow: hidden;
108
+ white-space: pre-wrap;
109
+ page-break-inside: avoid;
110
+ }
111
+ table {
112
+ width: 100%;
113
+ border-collapse: collapse;
114
+ margin: 4pt 0 6pt;
115
+ font-size: 9.5pt;
116
+ page-break-inside: avoid;
117
+ }
118
+ th, td {
119
+ border-top: 0.5pt solid var(--rule);
120
+ border-bottom: 0.5pt solid var(--rule);
121
+ padding: 4pt 6pt;
122
+ vertical-align: top;
123
+ text-align: left;
124
+ }
125
+ th {
126
+ background: transparent;
127
+ font-weight: 600;
128
+ border-top: 0.8pt solid var(--rule);
129
+ border-bottom: 0.8pt solid var(--rule);
130
+ }
131
+ tbody tr td { border-top: none; }
132
+ tbody tr:last-child td { border-bottom: 0.8pt solid var(--rule); }
133
+ table.bordered td, table.bordered th { border: 0.5pt solid var(--line); }
134
+
135
+ .figure {
136
+ page-break-inside: avoid;
137
+ margin: 10pt 0 12pt;
138
+ }
139
+ .figure img {
140
+ display: block;
141
+ width: 100%;
142
+ max-width: 100%;
143
+ margin: 0 auto;
144
+ border: 0.5pt solid var(--line);
145
+ background: white;
146
+ }
147
+ .figure.center img { margin: 0 auto; }
148
+ .figure.narrow img { width: 78%; }
149
+ .figure.full img { width: 100%; }
150
+ /* Multi-slice figures: a tall composite split into A4-fitting slices.
151
+ Each slice is its own atomic page-break-inside-avoid; the caption
152
+ sits on the last slice. */
153
+ .figure.sliced {
154
+ page-break-inside: auto;
155
+ }
156
+ .figure.sliced .slice {
157
+ page-break-inside: avoid;
158
+ margin: 0 0 4pt;
159
+ }
160
+ .figure.sliced .slice img {
161
+ width: 100%;
162
+ max-height: 245mm;
163
+ object-fit: contain;
164
+ }
165
+ .figure.sliced.narrow .slice img {
166
+ width: 65%;
167
+ margin: 0 auto;
168
+ display: block;
169
+ }
170
+ /* Side-by-side / stacked closeup with row labels */
171
+ .closeup-row {
172
+ display: grid;
173
+ grid-template-columns: 70pt 1fr;
174
+ align-items: center;
175
+ gap: 8pt;
176
+ margin: 0 0 4pt;
177
+ page-break-inside: avoid;
178
+ }
179
+ .closeup-row .label {
180
+ font-size: 9.5pt;
181
+ font-family: "Source Han Sans SC", "Noto Sans CJK SC", sans-serif;
182
+ color: var(--ink);
183
+ text-align: right;
184
+ padding-right: 4pt;
185
+ }
186
+ .closeup-row .label .tag {
187
+ font-weight: 700;
188
+ font-size: 11pt;
189
+ display: block;
190
+ margin-bottom: 1pt;
191
+ }
192
+ .closeup-row .label .sub {
193
+ color: var(--muted);
194
+ font-size: 8.5pt;
195
+ }
196
+ .closeup-row img {
197
+ width: 100%;
198
+ max-height: 70mm;
199
+ object-fit: cover;
200
+ object-position: center;
201
+ border: 0.5pt solid var(--line);
202
+ display: block;
203
+ }
204
+ .closeup-figure {
205
+ page-break-inside: avoid;
206
+ margin: 10pt 0 14pt;
207
+ }
208
+ .caption {
209
+ color: var(--ink);
210
+ font-size: 9pt;
211
+ margin: 4pt 0 0;
212
+ line-height: 1.45;
213
+ text-align: left;
214
+ text-indent: 0;
215
+ }
216
+ .caption strong { font-size: 9pt; letter-spacing: 0.3pt; }
217
+
218
+ .grid-2 {
219
+ display: grid;
220
+ grid-template-columns: 1fr 1fr;
221
+ gap: 8pt;
222
+ margin: 8pt 0;
223
+ }
224
+ .grid-2 .figure { margin: 0; }
225
+ .table-wrap { page-break-inside: avoid; margin: 6pt 0 14pt; }
226
+ .table-title { font-size: 9.5pt; margin: 0 0 4pt; text-indent: 0; }
227
+ .table-title strong { letter-spacing: 0.3pt; }
228
+
229
+ .equation {
230
+ margin: 8pt 0;
231
+ text-align: center;
232
+ font-size: 10.5pt;
233
+ page-break-inside: avoid;
234
+ }
235
+ .eq-row {
236
+ display: flex;
237
+ align-items: center;
238
+ justify-content: center;
239
+ gap: 14pt;
240
+ }
241
+ .eq-label { font-style: italic; color: var(--muted); font-size: 9.5pt; }
242
+ /* Static math rendering — no MathJax, works in weasyprint PDF */
243
+ .m, em.m { font-style: italic; font-family: "Latin Modern Math", "STIX Two Math", "Cambria Math", "Times New Roman", serif; }
244
+ .m-up { font-style: normal; }
245
+ .m-cal { font-style: italic; font-family: "Latin Modern Math", "STIX Two Math", "Cambria Math", cursive, serif; }
246
+ .frac {
247
+ display: inline-block;
248
+ vertical-align: -0.35em;
249
+ text-align: center;
250
+ font-size: 90%;
251
+ margin: 0 3pt;
252
+ line-height: 1.1;
253
+ white-space: nowrap;
254
+ }
255
+ .frac .num,
256
+ .frac .den {
257
+ display: block;
258
+ white-space: nowrap;
259
+ padding: 0 4pt;
260
+ }
261
+ .frac .num { border-bottom: 0.6pt solid var(--ink); padding-bottom: 1pt; }
262
+ .frac .den { padding-top: 1pt; }
263
+ sub, sup { line-height: 0; font-size: 75%; }
264
+ .eq-bracket {
265
+ font-size: 150%;
266
+ vertical-align: -0.25em;
267
+ font-weight: 400;
268
+ }
269
+
270
+ ol.refs { padding-left: 1.2em; font-size: 9.5pt; line-height: 1.5; }
271
+ ol.refs li { margin: 3pt 0; }
272
+
273
+ .page-break { page-break-before: always; }
274
+ .small { font-size: 9pt; color: var(--muted); }
275
+ section.body { padding: 0 4pt; }
276
+
277
+ /* Column layout for compact paragraphs (not used full-width to keep figures full) */
278
+ .two-col {
279
+ column-count: 2;
280
+ column-gap: 14pt;
281
+ column-rule: 0.3pt solid var(--line);
282
+ }
283
+ .two-col p:first-of-type { margin-top: 0; }
284
+ </style>
285
+ </head>
286
+ <body>
287
+
288
+ <header>
289
+ <h1>对 SenseNova-U1 官方训练配置在小数据风格 LoRA 场景的消融研究</h1>
290
+ <p class="meta">
291
+ 内部技术报告 · 2026 年 5 月 · train_u1 工具链
292
+ </p>
293
+ <p class="affil">
294
+ 针对 SenseNova-U1-8B-MoT 基模的 LoRA / Partial Fine-tuning 训练器维护团队
295
+ </p>
296
+
297
+ <div class="abstract">
298
+ <h3>Abstract</h3>
299
+ <p>
300
+ SenseNova-U1 公开技术报告披露的 text-to-image 流匹配训练在三个关键算法层
301
+ 上做出特定选择:以速度 (velocity) 为目标的损失函数、logit-normal 时间步采样、
302
+ 以及 condition dropout 形式的无分类器引导校准。在大规模预训练或后训阶段,
303
+ 这些选择共同支撑了模型最终行为。但对小数据风格 LoRA 微调而言,是否应当继承
304
+ 这一组合并未由现有公开文献回答。本文以 56 张 Hayateluc 风格图像与 8B-MoT 基模
305
+ 为受控环境,对三个官方算法层逐一做单变量消融,并与 v18 本地基线 (x<sub>0</sub>-MSE
306
+ 损失、uniform 时间步、零 dropout) 进行对比,同时纳入 prompt 前缀格式作为额外
307
+ 受控变量。所有训练运行 6000 步,并在统一 prompt 集上做多步采样以做视觉评估。
308
+ 结果表明,逐项替换为官方设置后,<strong>训练 loss 不发散,但完整采样图像呈
309
+ 现可观察的图像重建质量劣化</strong>:(i) velocity 损失导致天空与大面积渐变上
310
+ 规则横向条纹 (banding) 与 photoreal 化;(ii) logit-normal 时间步导致大面积平
311
+ 滑色块上高频 speckle 噪点与色阶离散化;(iii) 两者组合时进一步出现远景元素缺
312
+ 失、构图被压缩的结构扭曲 (structural distortion),已超出 "style drift" 范畴而
313
+ 属重建质量缺陷。condition dropout 是三项中唯一未引入可见劣化的项,反而轻
314
+ 微改善图像结构质量。基于该结论,我们将仓库的发布默认值保持为本地基线,并
315
+ 把官方对齐组合作为可选的 ablation 配置发布,供研究复现使用。
316
+ </p>
317
+ <p class="keywords">
318
+ <strong>Keywords —</strong>
319
+ 流匹配 (flow matching) ·
320
+ LoRA 微调 ·
321
+ 消融研究 ·
322
+ classifier-free guidance ·
323
+ 风格迁移 ·
324
+ SenseNova-U1
325
+ </p>
326
+ </div>
327
+ </header>
328
+
329
+ <section class="body">
330
+
331
+ <h2>1. 引言</h2>
332
+ <p>
333
+ SenseNova-U1 公开报告 [1] 在 text-to-image 流匹配训练中采用三个具体的算法层:
334
+ 速度损失 (velocity loss)、logit-normal 时间步采样、以及在条件 prefix 上的 dropout
335
+ 作为 classifier-free guidance (CFG) 的校准信号。这一组合在该报告所对应的大规模
336
+ 数据与多阶段训练管线下是合理且互相支撑的设计选择。
337
+ </p>
338
+ <p>
339
+ 与之相对,在小数据风格 LoRA 微调场景下,训练目标的属性发生了实质变化:训练数据
340
+ 规模通常在 10<sup>1</sup>–10<sup>2</sup> 张量级,目的是从基模注入一个相对集中的
341
+ 视觉风格 fingerprint,而非更新一个完整的世界模型。这种情况下,将官方训练配置整
342
+ 体迁移过来是否仍最优,并不显然。
343
+ </p>
344
+ <p>
345
+ 本文以 SenseNova-U1-8B-MoT 为基模,在 56 张 Hayateluc 风格图像数据上做受控消
346
+ 融。具体贡献为三点:
347
+ </p>
348
+ <p>
349
+ <strong>(i)</strong> 提出一组单变量消融配置,分别替换官方三个算法层中的一个,并
350
+ 与 v18 本地基线对照,使得每对差异可被归因到单一 lever。
351
+ <strong>(ii)</strong> 指出 prompt 前缀格式 (训练时 caption 中风格 anchor 的位置)
352
+ 是一个独立而强的混淆变量;在没有控制该变量前,单纯的 lever 替换会得到误导性的视
353
+ 觉差异。
354
+ <strong>(iii)</strong> 基于训练标量曲线与多步采样视觉评估,识别每个官方 lever
355
+ 引入的具体图像重建质量缺陷:velocity 损失→横纹 + photoreal 化、logit-normal
356
+ 时间步→噪点 + 色阶离散化、两者叠加→结构扭曲。condition dropout 是唯一对小数据
357
+ 风格 LoRA 无损甚至有益的官方算法层。
358
+ </p>
359
+
360
+ <h2>2. 背景与符号</h2>
361
+
362
+ <h3>2.1 流匹配训练目标</h3>
363
+ <p>
364
+ 设 <em class="m">x</em><sub>0</sub> 为目标图像 patch,<em class="m">ε</em> 为各向同
365
+ 性高斯噪声,时间步 <em class="m">t</em> ∈ [<em class="m">t</em><sub class="m">ε</sub>,
366
+ 1 − <em class="m">t</em><sub class="m">ε</sub>]。本仓库与上游推理共用如下线性插
367
+ 值约定 (linear-<em class="m">z</em> schedule,<em class="m">t</em> 越接近 1 越接
368
+ 近干净图):
369
+ </p>
370
+ <div class="equation">
371
+ <span style="white-space: nowrap;">
372
+ <em class="m">z</em><sub class="m">t</sub> =
373
+ <em class="m">t</em>&thinsp;<em class="m">x</em><sub>0</sub> +
374
+ (1&minus;<em class="m">t</em>)&thinsp;<em class="m">ε</em>,
375
+ </span>
376
+ &emsp;&emsp;
377
+ <span style="white-space: nowrap;">
378
+ <em class="m">v</em><sup>★</sup> =
379
+ <span class="frac">
380
+ <span class="num"><em class="m">x</em><sub>0</sub> &minus; <em class="m">z</em><sub class="m">t</sub></span>
381
+ <span class="den">1 &minus; <em class="m">t</em></span>
382
+ </span>.
383
+ </span>
384
+ &emsp;<span class="eq-label">(1)</span>
385
+ </div>
386
+ <p>
387
+ 在 x<sub>0</sub>-MSE 损失下,训练目标为
388
+ <em class="m-cal">L</em><sub class="m">x</sub><sub>0</sub> =
389
+ 𝔼 ‖<em class="m">x</em><sub class="m">θ</sub>(<em class="m">z</em><sub class="m">t</sub>,<em class="m">t</em>)
390
+ − <em class="m">x</em><sub>0</sub>‖<sup>2</sup>;
391
+ 在 velocity 损失下,<em class="m-cal">L</em><sub class="m">v</sub> = 𝔼
392
+ ‖<em class="m">v</em><sub class="m">θ</sub>(<em class="m">z</em><sub class="m">t</sub>,<em class="m">t</em>)
393
+ − <em class="m">v</em><sup>★</sup>‖<sup>2</sup>。代入
394
+ <em class="m">v</em><sup>★</sup> = (<em class="m">x</em><sub>0</sub> − <em class="m">z</em><sub class="m">t</sub>)/(1 − <em class="m">t</em>)
395
+ 与 <em class="m">v</em><sub class="m">θ</sub> = (<em class="m">x</em><sub class="m">θ</sub> − <em class="m">z</em><sub class="m">t</sub>)/(1 − <em class="m">t</em>),可得二者关系为
396
+ </p>
397
+ <div class="equation">
398
+ <span style="white-space: nowrap;">
399
+ <em class="m-cal">L</em><sub class="m">v</sub> =
400
+ 𝔼<span class="eq-bracket">[</span>&hairsp;
401
+ <span class="frac">
402
+ <span class="num">‖<em class="m">x</em><sub class="m">θ</sub> &minus; <em class="m">x</em><sub>0</sub>‖<sup>2</sup></span>
403
+ <span class="den">(1 &minus; <em class="m">t</em>)<sup>2</sup></span>
404
+ </span>
405
+ &hairsp;<span class="eq-bracket">]</span>.
406
+ </span>
407
+ &emsp;<span class="eq-label">(2)</span>
408
+ </div>
409
+ <p>
410
+ 即 velocity 损失等价于以 (1 − <em class="m">t</em>)<sup>−2</sup> 重新加权的
411
+ x<sub>0</sub>-MSE。该权重在 <em class="m">t</em> → 1 (近 clean 端) 发散,因此训
412
+ 练梯度的有效分布严重偏向高 <em class="m">t</em> 区间。
413
+ </p>
414
+
415
+ <h3>2.2 时间步采样分布</h3>
416
+ <p>
417
+ uniform 采样取 <em class="m">t</em> ∼ <em class="m-cal">U</em>(<em class="m">t</em><sub class="m">ε</sub>,
418
+ 1 − <em class="m">t</em><sub class="m">ε</sub>)。logit-normal 采样定义为
419
+ <em class="m">u</em> ∼ <em class="m-cal">N</em>(<em class="m">μ</em>, <em class="m">σ</em><sup>2</sup>),
420
+ <em class="m">t</em> = <em class="m">σ</em>(<em class="m">u</em>),其中
421
+ <em class="m">σ</em> 为 sigmoid。本文 logit-normal 取
422
+ <em class="m">μ</em> = −0.8, <em class="m">σ</em> = 0.8,对应
423
+ 𝔼[<em class="m">t</em>] ≈ 0.34,将概率质量偏向较低 <em class="m">t</em>
424
+ (较 noisy 一侧)。
425
+ </p>
426
+
427
+ <h3>2.3 Condition dropout</h3>
428
+ <p>
429
+ 设每步训练以独立概率 <em class="m">p</em><sub class="m-up">text</sub> 将文本条件
430
+ 替换为空 prompt 对应的 prefix KV;以独立概率 <em class="m">p</em><sub class="m-up">both</sub>
431
+ 走 “text + image” 全部 drop 的分支。在纯 T2I 场景下,这两个 drop 模式都会回退到
432
+ 统一的 unconditional prefix。本文取
433
+ <em class="m">p</em><sub class="m-up">text</sub> = <em class="m">p</em><sub class="m-up">both</sub> = 0.10,
434
+ 即约 20% 步使用无条件 prefix。
435
+ </p>
436
+
437
+ <h2>3. 实验配置</h2>
438
+
439
+ <h3>3.1 数据集与基模</h3>
440
+ <p>
441
+ 训练数据为 56 张 Hayateluc 风格自然语言 caption 配对图像,分布在 7 个 aspect-ratio
442
+ bucket 上 (最大像素数 ≤ 2048<sup>2</sup>)。基模为 SenseNova-U1-8B-MoT [1],加载方式为
443
+ bf16 CPU 驻留 + 静态前缀 KV cache 的低显存 LoRA 训练,单卡 32 GB 峰值约 21 GB。
444
+ </p>
445
+
446
+ <h3>3.2 可训练面</h3>
447
+ <p>
448
+ 所有实验共享如下可训练面,保证不同运行间差异仅来自表 1 列出的三个 lever。
449
+ 表 2 给出训练面分层;表 3 给出训练超参。
450
+ </p>
451
+
452
+ <div class="table-wrap">
453
+ <p class="table-title"><strong>表 2.</strong> 训练面分层。LoRA / partial FT / frozen
454
+ 三类合计 286M 可训练参数。所有运行共享此结构。</p>
455
+ <table class="bordered">
456
+ <thead>
457
+ <tr>
458
+ <th style="width: 14%;">类别</th>
459
+ <th>覆盖模块</th>
460
+ <th style="width: 14%;">参数量</th>
461
+ <th style="width: 18%;">备注</th>
462
+ </tr>
463
+ </thead>
464
+ <tbody>
465
+ <tr>
466
+ <td><strong>LoRA wrap</strong></td>
467
+ <td>
468
+ 注意力:<code>q_proj_mot_gen</code>, <code>k_proj_mot_gen</code>,
469
+ <code>v_proj_mot_gen</code>, <code>o_proj_mot_gen</code>;<br>
470
+ MLP:<code>mlp_mot_gen.{gate,up,down}_proj</code>
471
+ </td>
472
+ <td>~204M</td>
473
+ <td>共 294 wrap;<em class="m">r</em> = 64, <em class="m">α</em> = 64</td>
474
+ </tr>
475
+ <tr>
476
+ <td><strong>Partial fine-tune</strong></td>
477
+ <td>
478
+ <code>fm_modules.timestep_embedder</code>,<br>
479
+ <code>fm_modules.noise_scale_embedder</code>,<br>
480
+ <code>fm_modules.vision_model_mot_gen</code>,<br>
481
+ <code>fm_modules.fm_head</code>
482
+ </td>
483
+ <td>~82M</td>
484
+ <td>仅 <code>fm_modules</code> 子树</td>
485
+ </tr>
486
+ <tr>
487
+ <td><strong>Frozen</strong></td>
488
+ <td>Understand path 全部模块及其它未列出的所有权重</td>
489
+ <td>—</td>
490
+ <td>不更新</td>
491
+ </tr>
492
+ </tbody>
493
+ </table>
494
+ </div>
495
+
496
+ <div class="table-wrap">
497
+ <p class="table-title"><strong>表 3.</strong> 训练超参。所有运行共享。</p>
498
+ <table class="bordered">
499
+ <tbody>
500
+ <tr>
501
+ <td style="width: 22%;"><strong>训练步数</strong></td>
502
+ <td>6000</td>
503
+ <td style="width: 22%;"><strong>学习率</strong></td>
504
+ <td>5 × 10<sup>−5</sup></td>
505
+ </tr>
506
+ <tr>
507
+ <td><strong>优化器</strong></td>
508
+ <td>PagedAdamW8bit</td>
509
+ <td><strong>Batch size</strong></td>
510
+ <td>1 (native resolution)</td>
511
+ </tr>
512
+ <tr>
513
+ <td><strong>Gradient accumulation</strong></td>
514
+ <td>1</td>
515
+ <td><strong>Seed</strong></td>
516
+ <td>固定 (cross-run)</td>
517
+ </tr>
518
+ </tbody>
519
+ </table>
520
+ </div>
521
+
522
+ <h3>3.3 受控变量</h3>
523
+ <p>
524
+ 本文设五组运行做单变量消融,其中四组训练运行共享数据与可训练面 (表 1)。
525
+ </p>
526
+
527
+ <div class="table-wrap">
528
+ <p class="table-title"><strong>表 1.</strong> 五组消融运行的算法配置。<em>baseline</em>
529
+ Baseline 为本仓库默认;运行 (a)–(d) 对应官方算法层逐项与全部替换。</p>
530
+ <table>
531
+ <thead>
532
+ <tr>
533
+ <th>运行</th>
534
+ <th><em class="m-cal">L</em></th>
535
+ <th><em class="m">t</em> 分布</th>
536
+ <th><em class="m">p</em><sub class="m-up">text</sub>, <em class="m">p</em><sub class="m-up">both</sub></th>
537
+ <th>训练 prefix 注释</th>
538
+ </tr>
539
+ </thead>
540
+ <tbody>
541
+ <tr><td><strong>v18 baseline</strong></td><td>x<sub>0</sub></td><td>uniform</td><td>0, 0</td><td>含 think sidecar 长前缀</td></tr>
542
+ <tr><td>(a) +velocity loss</td><td><strong>v</strong></td><td>uniform</td><td>0, 0</td><td>含 think sidecar 长前缀</td></tr>
543
+ <tr><td>(b) +logit-normal <em class="m">t</em></td><td>x<sub>0</sub></td><td><strong>logit-normal</strong></td><td>0, 0</td><td>含 think sidecar 长前缀</td></tr>
544
+ <tr><td>(c) +cond. dropout</td><td>x<sub>0</sub></td><td>uniform</td><td><strong>0.10, 0.10</strong></td><td>含 think sidecar 长前缀</td></tr>
545
+ <tr><td>(d) full official (a + b + c)</td><td><strong>v</strong></td><td><strong>logit-normal</strong></td><td><strong>0.10, 0.10</strong></td><td>含 think sidecar 长前缀</td></tr>
546
+ </tbody>
547
+ </table>
548
+ </div>
549
+
550
+ <h3>3.4 评估协议</h3>
551
+ <p>
552
+ 训练时记录每步 active loss、x<sub>0</sub>-MSE、v-MSE、<em class="m">t</em> 统计量、
553
+ 以及 dropout 路由计数;121-step 滑动平均后绘制。采样评估使用一组固定的 12 条
554
+ 自然风景 prompts,在统一的 7 个 bucket 分辨率下做 50 步 Euler、
555
+ <em class="m-up">cfg_scale</em> = 4.0、<em class="m-up">timestep_shift</em> = 3.0。
556
+ 我们额外引入 prompt 前缀格式变量:v1 prompts 保留与 baseline 训练分布一致的
557
+ 简短前缀;v2 prompts 在每条句首嵌入 artist anchor,与含 think sidecar 的训练
558
+ caption 分布对齐。
559
+ </p>
560
+
561
+ <h2>4. 训练动力学结果</h2>
562
+
563
+ <p>
564
+ 图 1–3 给出 x0-MSE、active loss 与 v-MSE 三条标量曲线。x0-MSE 是唯一可跨损
565
+ 失类型公平比较的指标:它是所有运行共同执行的诊断量。可见运行 (b) 与 (d) 在
566
+ x<sub>0</sub>-MSE 上系统性高于 Baseline,差距贯穿整个训练,并非过
567
+ 渡期偏差。
568
+ </p>
569
+
570
+ <div class="figure">
571
+ <img src="assets/small_data_style_ablation/tb_x0_mse_curves.png" alt="x0 MSE curves">
572
+ <p class="caption">
573
+ <strong>图 1.</strong>&nbsp; x<sub>0</sub>-MSE 在五组运行中的演化 (121-step 滑动平均)。
574
+ (b) 与 (d) 系统性偏高,表明把 <em class="m">t</em> 密度搬向较 noisy 区间会降低
575
+ 模型在更清晰图像状态下学习颜色与构图的机会。
576
+ </p>
577
+ </div>
578
+
579
+ <div class="figure">
580
+ <img src="assets/small_data_style_ablation/tb_active_loss_curves.png" alt="active loss curves">
581
+ <p class="caption">
582
+ <strong>图 2.</strong>&nbsp; active loss 的绝对值不可跨目标比较 (velocity 损失与
583
+ x<sub>0</sub>-MSE 单位不同)。该图仅用于检查每组运行内部是否存在有效下降趋势,
584
+ 所有运行均下降稳定,未出现训练发散。
585
+ </p>
586
+ </div>
587
+
588
+ <div class="figure">
589
+ <img src="assets/small_data_style_ablation/tb_v_mse_curves.png" alt="v MSE curves">
590
+ <p class="caption">
591
+ <strong>图 3.</strong>&nbsp; v-MSE 作为诊断量在所有运行中都被计算。可观察到 v-MSE
592
+ 在 <em class="m">t</em> → 1 端出现极端尖峰,与 §2.1 中给出的
593
+ (1 − <em class="m">t</em>)<sup>−2</sup> 加权一致。在 velocity 训练的运行 (a) 与 (d)
594
+ 中这些尖峰直接进入梯度,被高权重区间支配。
595
+ </p>
596
+ </div>
597
+
598
+ <div class="grid-2">
599
+ <div class="figure">
600
+ <img src="assets/small_data_style_ablation/tb_t_distribution.png" alt="t distribution">
601
+ <p class="caption">
602
+ <strong>图 4.</strong>&nbsp; 实测 <em class="m">t</em> 分布。logit-normal 运行 (b) 与
603
+ (d) 的经验均值约 0.34,与理论值 <em class="m">σ</em>(−0.8) ≈ 0.31 一致;其余运行
604
+ 约为 0.50。
605
+ </p>
606
+ </div>
607
+ <div class="figure">
608
+ <img src="assets/small_data_style_ablation/tb_condition_dropout_counts.png" alt="dropout counts">
609
+ <p class="caption">
610
+ <strong>图 5.</strong>&nbsp; 含 condition dropout 的运行 (c) 与 (d) 中每步走
611
+ cond / uncond 分支的累计次数。最终约 21% 步走 unconditional prefix,与设定的
612
+ <em class="m">p</em><sub class="m-up">text</sub> + <em class="m">p</em><sub class="m-up">both</sub>
613
+ = 0.20 在大样本统计上吻合。
614
+ </p>
615
+ </div>
616
+ </div>
617
+
618
+ <h2 class="page-break">5. 视觉评估结果</h2>
619
+
620
+ <p>
621
+ 训练标量只反映单步去噪行为,无法替代从纯噪声开始的多步采样。在固定 seed 与
622
+ identical sampling 超参下,我们对每组运行生成 12 张 1024–2048 像素级风景图。
623
+ </p>
624
+
625
+ <h3>5.1 同 prompt 多配方对比</h3>
626
+
627
+ <p>
628
+ 图 6 为 forest / wildflower meadow / dandelion field 三个最能体现风格 fingerprint
629
+ 的 prompt 在五组运行 (含 Baseline) 下的 contact sheet。关键观察:
630
+ </p>
631
+ <p>
632
+ <strong>(i)</strong> Baseline 与运行 (c) 在所有三个 prompt 上都保持暖橙地平线、
633
+ 青蓝高空、清晰前景轮廓的视觉指纹。
634
+ <strong>(ii)</strong> 运行 (a) 在 dandelion field 上系统性向 photoreal 夜景偏移:
635
+ cyan 天空消失、puffball 密度降低、暖金核被压暗。
636
+ <strong>(iii)</strong> 运行 (b) 在所有 prompt 上整体 palette 偏冷,紫蓝山失去层
637
+ 次。
638
+ <strong>(iv)</strong> 运行 (d) 同时承袭 (a) 与 (b) 的劣化,并叠加 composition
639
+ cropping:dandelion 图中失去远景 forest mass 与天空。
640
+ </p>
641
+
642
+ <div class="figure sliced">
643
+ <div class="slice"><img src="assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg" alt="visual same-prompt contact sheet (1/3)"></div>
644
+ <div class="slice"><img src="assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg" alt="visual same-prompt contact sheet (2/3)"></div>
645
+ <div class="slice"><img src="assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg" alt="visual same-prompt contact sheet (3/3)">
646
+ <p class="caption">
647
+ <strong>图 6.</strong>&nbsp; 同 prompt × 五组运行的 contact sheet (上→中→下三页连读)。
648
+ 每列为一组运行,每行为一条 prompt。列名直接对应表 1 的算法配置差异。
649
+ 风格 fingerprint 由暖金地平线、青蓝高空、painterly 厚笔触三要素共同定义;
650
+ 运行 (a) 与 (d) 出现 atmospheric drift,运行 (b) 出现 palette 偏冷。
651
+ </p>
652
+ </div>
653
+ </div>
654
+
655
+ <h3>5.2 局部细节:图像重建质量劣化</h3>
656
+
657
+ <p>
658
+ §5.1 给出的是整图尺度上 palette / brushwork / 构图层面的偏移。但在原始
659
+ 分辨率下查看局部,三个 v19 系列方案相比 v18 baseline 还呈现三类<strong>图像
660
+ 重建质量</strong>层面的劣化,且每一类都与具体 lever 的训练分布偏置直接挂钩:
661
+ </p>
662
+ <ul>
663
+ <li><strong>横纹 (banding):</strong>云形或色温过渡上出现规则横向条纹,主要见于
664
+ 运行 (a) 与 (d),与 velocity 损失的高 <em class="m">t</em> 加权 (§6.1) 一致。</li>
665
+ <li><strong>噪点 (speckle noise):</strong>大面积渐变上出现非自然色阶离散化、
666
+ 细粒度高频纹理,主要见于运行 (b) 与 (d),与 logit-normal 时间步的低
667
+ <em class="m">t</em> 偏置 (§6.2) 一致。</li>
668
+ <li><strong>结构扭曲 (structural distortion):</strong>远景元素 (forest mass、
669
+ cloud column) 缺失或被压平、构图整体被裁切,集中见于运行 (d),是
670
+ 上述两个分布偏置叠加效应。</li>
671
+ </ul>
672
+ <p>
673
+ 以下分别从三个 prompt 截取细节区域 (图 7、图 8、图 9)。每图以 v18 baseline 在最
674
+ 上,按 (a) (b) (d) 顺序向下排列。
675
+ </p>
676
+
677
+ <h4>5.2.1 Sample 00 森林暗部:竖向 striation 格纹 (运行 (d) 主)</h4>
678
+ <p>
679
+ Sample 00 (dense old-growth forest interior at dawn) 的画面左下、左侧粗树干、
680
+ 右侧暗部都是低光强、低 SNR 区域,是 v-loss + logit-normal 组合下竖向格纹伪
681
+ 影最先暴露的位置。运行 (d) 的左侧粗树干表面与右下苔藓层均出现可见的等间距
682
+ 竖线。
683
+ </p>
684
+
685
+ <div class="closeup-figure">
686
+ <div class="closeup-row">
687
+ <div class="label"><span class="tag">v18</span><span class="sub">baseline</span></div>
688
+ <img src="assets/small_data_style_ablation/detail00_v18.jpg" alt="v18 forest bottom">
689
+ </div>
690
+ <div class="closeup-row">
691
+ <div class="label"><span class="tag">(a)</span><span class="sub">+v-loss</span></div>
692
+ <img src="assets/small_data_style_ablation/detail00_v19a.jpg" alt="v19a forest bottom">
693
+ </div>
694
+ <div class="closeup-row">
695
+ <div class="label"><span class="tag">(b)</span><span class="sub">+logit-normal <em class="m">t</em></span></div>
696
+ <img src="assets/small_data_style_ablation/detail00_v19b.jpg" alt="v19b forest bottom">
697
+ </div>
698
+ <div class="closeup-row">
699
+ <div class="label"><span class="tag">(c)</span><span class="sub">+cond. dropout</span></div>
700
+ <img src="assets/small_data_style_ablation/detail00_v19c.jpg" alt="v19c forest bottom">
701
+ </div>
702
+ <div class="closeup-row">
703
+ <div class="label"><span class="tag">(d)</span><span class="sub">+full official</span></div>
704
+ <img src="assets/small_data_style_ablation/detail00_v19.jpg" alt="v19 forest bottom with vertical striations">
705
+ </div>
706
+ <p class="caption">
707
+ <strong>图 7.</strong>&nbsp; Sample 00 画面下半 (左粗树干 + 中间地面 + 右下苔藓 +
708
+ 右侧上层树叶) 五向对比。
709
+ <strong>v18:</strong>painterly chunky 厚笔触树干 + 暖色秋叶碎片 + 多层 silhouette
710
+ 清晰可分;
711
+ <strong>(a):</strong><span style="color: #c53030; font-weight: 600;">god-rays 与 mist 大幅放大主导画面</span>,
712
+ painterly chunky 厚笔触被 atmospheric haze 稀释成更软的笔触;树干与
713
+ 地面 silhouette 基本保留,主要问题是风格 softening 与雾化过度;
714
+ <strong>(b):</strong>painterly 厚笔触与树形保留较好,但
715
+ <span style="color: #c53030; font-weight: 600;">左下暗部 + 右下暗部仍有未完全恢复的轻度竖向条纹</span>
716
+ (强度远低于 (d));高光 god-rays 中段也出现轻度结构化痕迹;
717
+ <strong>(c):</strong>五行中重建质量最接近 v18 — painterly chunky 树干 + 完整地面
718
+ 碎叶都保留,god-rays 强度略偏 (a) 但 painterly 信息完整,无格纹或条纹伪影;
719
+ <strong>(d):</strong><span style="color: #c53030; font-weight: 600;">左侧粗树干表面 + 中景树干 + 右下苔藓暗部出现明显竖向 striation 格纹</span>,
720
+ 暗部 banding 最严重,painterly 信息几乎完全缺失。
721
+ </p>
722
+ </div>
723
+
724
+ <h4>5.2.2 Sample 02 湖岸:树 silhouette 与水面倒影竖纹</h4>
725
+ <p>
726
+ Sample 02 (deep mountain lake at dawn) 的两侧针叶林 silhouette 与中央镜面倒影
727
+ 是考察树形 painterly 是否退化、倒影是否出现竖纹的核心区域。运行 (d) 在两侧
728
+ 树林区 + 水面倒影区均呈现 screen-door 竖向条纹。
729
+ </p>
730
+
731
+ <div class="closeup-figure">
732
+ <div class="closeup-row">
733
+ <div class="label"><span class="tag">v18</span><span class="sub">baseline</span></div>
734
+ <img src="assets/small_data_style_ablation/detail02_v18.jpg" alt="v18 lakeside trees + reflection">
735
+ </div>
736
+ <div class="closeup-row">
737
+ <div class="label"><span class="tag">(a)</span><span class="sub">+v-loss</span></div>
738
+ <img src="assets/small_data_style_ablation/detail02_v19a.jpg" alt="v19a lakeside trees + reflection">
739
+ </div>
740
+ <div class="closeup-row">
741
+ <div class="label"><span class="tag">(b)</span><span class="sub">+logit-normal <em class="m">t</em></span></div>
742
+ <img src="assets/small_data_style_ablation/detail02_v19b.jpg" alt="v19b lakeside trees + reflection">
743
+ </div>
744
+ <div class="closeup-row">
745
+ <div class="label"><span class="tag">(c)</span><span class="sub">+cond. dropout</span></div>
746
+ <img src="assets/small_data_style_ablation/detail02_v19c.jpg" alt="v19c lakeside trees + reflection">
747
+ </div>
748
+ <div class="closeup-row">
749
+ <div class="label"><span class="tag">(d)</span><span class="sub">+full official</span></div>
750
+ <img src="assets/small_data_style_ablation/detail02_v19.jpg" alt="v19 lakeside trees + reflection with vertical striations">
751
+ </div>
752
+ <p class="caption">
753
+ <strong>图 8.</strong>&nbsp; Sample 02 中段 (左右两侧 painted tree silhouette + 中央
754
+ 水面倒影 + 远雾) 五向对比。
755
+ <strong>v18:</strong>painterly chunky 树形 + 厚笔触倒影 + 远山三层 silhouette
756
+ 清晰可分;
757
+ <strong>(a):</strong>整图 photoreal 化,painterly 笔触退化为 stock-photo 树形,
758
+ 倒影保留 photoreal 镜面但失去厚笔触;
759
+ <strong>(b):</strong>painted 风格部分保留,但远雾 + 倒影出现微弱竖向条纹;
760
+ <strong>(c):</strong>painted 风格保留较好,树形+倒影皆有 painterly 触感,未观察
761
+ 到明显伪影;
762
+ <strong>(d):</strong><span style="color: #c53030; font-weight: 600;">左侧山体 + 右侧针叶林 silhouette + 整个倒影区出现 screen-door 竖向 striation 格纹</span>,
763
+ 覆盖范围广。
764
+ </p>
765
+ </div>
766
+
767
+ <h4>5.2.3 Sample 07 蒲公英主体:grass 区竖向条纹 + bokeh 退化</h4>
768
+ <p>
769
+ Sample 07 (backlit dandelion field) 主体为前景蒲公英 + 草丛 + 背光,是考察
770
+ painterly 笔触是否被替换为 photographic bokeh、草丛暗区是否出现条纹的核心
771
+ 区域。
772
+ </p>
773
+
774
+ <div class="closeup-figure">
775
+ <div class="closeup-row">
776
+ <div class="label"><span class="tag">v18</span><span class="sub">baseline</span></div>
777
+ <img src="assets/small_data_style_ablation/detail07_v18.jpg" alt="v18 dandelion main subject">
778
+ </div>
779
+ <div class="closeup-row">
780
+ <div class="label"><span class="tag">(a)</span><span class="sub">+v-loss</span></div>
781
+ <img src="assets/small_data_style_ablation/detail07_v19a.jpg" alt="v19a dandelion main subject">
782
+ </div>
783
+ <div class="closeup-row">
784
+ <div class="label"><span class="tag">(b)</span><span class="sub">+logit-normal <em class="m">t</em></span></div>
785
+ <img src="assets/small_data_style_ablation/detail07_v19b.jpg" alt="v19b dandelion main subject">
786
+ </div>
787
+ <div class="closeup-row">
788
+ <div class="label"><span class="tag">(c)</span><span class="sub">+cond. dropout</span></div>
789
+ <img src="assets/small_data_style_ablation/detail07_v19c.jpg" alt="v19c dandelion main subject">
790
+ </div>
791
+ <div class="closeup-row">
792
+ <div class="label"><span class="tag">(d)</span><span class="sub">+full official</span></div>
793
+ <img src="assets/small_data_style_ablation/detail07_v19.jpg" alt="v19 dandelion main subject with cross-hatch grid">
794
+ </div>
795
+ <p class="caption">
796
+ <strong>图 9.</strong>&nbsp; Sample 07 主体区 (前景蒲公英 + 草丛暗部 + 暖背光) 五向
797
+ 对比。
798
+ <strong>v18:</strong>painterly chunky 笔触草叶 + painted 蒲公英 + 远景树线
799
+ 可见;
800
+ <strong>(a):</strong><span style="color: #c53030; font-weight: 600;">painterly 完全退化为 photographic bokeh</span>,
801
+ 前景蒲公英变疏散小,背景暗化呈相机散景;
802
+ <strong>(b):</strong>painterly 较为完整 — chunky 草叶 + painted 蒲公英都还在,
803
+ 主要变化是视野相对收紧 (前景密度减少);草丛暗部存在轻度的竖向 streak,密度
804
+ 低于 (d);
805
+ <strong>(c):</strong>painterly fingerprint 接近 v18 — 蒲公英密度高、chunky
806
+ 草叶 + 浮散种子全保留,无观察到格纹伪影;
807
+ <strong>(d):</strong><span style="color: #c53030; font-weight: 600;">草丛暗部出现密集 cross-hatch 格纹</span>
808
+ (竖向 striation 为主,叠加微横向 banding),painterly 信息几乎完全缺失,整图
809
+ 带 halftone 质感。
810
+ </p>
811
+ </div>
812
+
813
+ <h3>5.3 Prompt 前缀格式作为混淆变量</h3>
814
+
815
+ <p>
816
+ 我们在受控分析中发现,若 inference prompt 的 artist anchor 位置与训练 caption
817
+ 不一致,会引入与 lever 无关但视觉量级相当的差异。因此对 Baseline 与运行 (c) 额外
818
+ 做两组 prompt 集对比,结果见图 10。该结果说明把 prompt 写法与训练分布对齐是
819
+ 正确归因 lever 效果的必要前置条件,并解释了我们对 README 默认配置中保留
820
+ 显式 <code>style.trigger</code> 与忽略 think sidecar 的选择。
821
+ </p>
822
+
823
+ <div class="figure sliced narrow">
824
+ <div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg" alt="baseline prompt-set contact sheet (1/4)"></div>
825
+ <div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg" alt="baseline prompt-set contact sheet (2/4)"></div>
826
+ <div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg" alt="baseline prompt-set contact sheet (3/4)"></div>
827
+ <div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg" alt="baseline prompt-set contact sheet (4/4)">
828
+ <p class="caption">
829
+ <strong>图 10.</strong>&nbsp; Baseline 在两组 prompt 集下的采样 (上→下四页连读)。
830
+ 左:与训练同分布的短前缀 (含显式 style trigger);右:分布外的长 prompt。
831
+ 完整采样在分布内稳定;这表明 Baseline 的强 fingerprint 不依赖额外 prefix 文本。
832
+ </p>
833
+ </div>
834
+ </div>
835
+
836
+ <h2>6. 讨论</h2>
837
+
838
+ <h3>6.1 Velocity 损失的隐式重加权</h3>
839
+ <p>
840
+ 由式 (2),<em class="m-cal">L</em><sub class="m">v</sub> =
841
+ 𝔼 [‖<em class="m">x</em><sub class="m">θ</sub> − <em class="m">x</em><sub>0</sub>‖<sup>2</sup>
842
+ · (1 − <em class="m">t</em>)<sup>−2</sup>],velocity 损失等价于在 x<sub>0</sub>-MSE
843
+ 上施加一个 (1 − <em class="m">t</em>)<sup>−2</sup> 的权重函数。在 uniform
844
+ <em class="m">t</em> 下,该权重在 <em class="m">t</em> ∈ [0.9, 1.0] 区间承担约 64%
845
+ 的总积分质量。意即在 velocity 训练中,模型有大半的梯度信号被分配到 "接近 clean
846
+ 图但仍残留少量噪声" 的窗口。对于风格 LoRA 而言,颜色调性、构图、笔触特征均在
847
+ <em class="m">t</em> ∈ [0.3, 0.7] 这一中频窗口形成;velocity 损失显著削减了该窗口的
848
+ 有效更新量,因而出现了 §5.1 (ii) 报告的 photoreal 夜景偏移。
849
+ </p>
850
+
851
+ <h3>6.2 Logit-normal 与小数据匹配的张力</h3>
852
+ <p>
853
+ logit-normal (<em class="m">μ</em> = −0.8, <em class="m">σ</em> = 0.8) 将
854
+ 𝔼[<em class="m">t</em>] 从 0.5 移到 0.34,即把更多概率质量推向 noisy 端。该选择在
855
+ 大规模数据下使得各 <em class="m">t</em> 段都获得足够样本数;但在 56 张图、6000
856
+ 步、batch=1 的小数据规模下,每个 mid-<em class="m">t</em> bin 实际获得的有效样
857
+ 本数本就稀疏,进一步的密度偏移导致 palette 与 mid-frequency texture 统计学习
858
+ 不足,对应 §5.1 (iii) 报告的整体偏冷。
859
+ </p>
860
+
861
+ <h3>6.3 Condition dropout 的正面作用</h3>
862
+ <p>
863
+ 与上述两项不同,condition dropout 在小数据风格 LoRA 上未观察到风格损伤。其作
864
+ 用机制是把 unconditional branch 与 conditional branch 的相对几何在训练期同时
865
+ 暴露给模型,从而使 inference 时 CFG 的 uncond + <em class="m">s</em> · (cond − uncond)
866
+ 推算成立。在 <em class="m">s</em> = 4.0 这一相对较强的 CFG 下,没有 dropout 训练
867
+ 的 LoRA 经常把输出推到 LoRA 未覆盖的方向,导致结构 artefact。运行 (c) 与 Baseline
868
+ 的视觉对比显示,前者在风格指纹保持不变的前提下,前景结构 (puffball 形状、
869
+ 树干轮廓) 更清晰、稳定。
870
+ </p>
871
+
872
+ <h3>6.4 Prompt 前缀格式的去混淆作用</h3>
873
+ <p>
874
+ §5.3 显示 prompt 中 artist anchor 的位置 (句首 vs. 句末 tag vs. 完全缺失) 与训练
875
+ caption 分布对齐与否直接决定 LoRA delta 是否被正确激活。这一变量与目标 lever
876
+ 正交,但视觉量级与 lever 效果相当;任何不控制该变量的 lever 评估都会被
877
+ prompt-format mismatch 污染。本文 §5.1、§5.2 与 §5.3 结论均在 v2 anchored prompt
878
+ 下复测确认。
879
+ </p>
880
+
881
+ <h2>7. 结论与发布决策</h2>
882
+ <p>
883
+ 在 SenseNova-U1-8B-MoT × 56 张 Hayateluc 风格 LoRA 微调上:
884
+ (1) velocity 损失系统性向高 <em class="m">t</em> 端搬移梯度密度,导致天空与大面积
885
+ 渐变区域出现规则横向条纹 banding 与 painterly→photoreal 退化;
886
+ (2) logit-normal 时间步在小数据下加剧 mid-<em class="m">t</em> 样本稀疏,引发 palette
887
+ 偏冷与大面积平滑色块的 speckle 噪点;
888
+ (3) (1) + (2) 叠加进一步引入远景元素缺失与构图压缩等结构扭曲,劣化超出 style
889
+ drift 范畴;
890
+ (4) condition dropout 是三项中唯一无损反而有益的算法层;
891
+ (5) prompt 前缀格式与训练分布的匹配是评估前置条件,必须先控制。
892
+ </p>
893
+ <p>
894
+ 据此,我们��� <code>configs/default.yaml</code> 保留为本地基线
895
+ (<code>x0 + uniform t + 0 dropout</code> + 显式 style trigger + 忽略 think
896
+ sidecar),并以 <code>configs/official_alignment.yaml</code> 形式发布官方对齐组
897
+ 合,标注其用途为研究复现。该选择不否定公开报告中的算法设计,只反映其与本仓
898
+ 库目标 (小数据风格 LoRA) 的不匹配。
899
+ </p>
900
+
901
+ <h3>默认配置摘要</h3>
902
+ <pre>data:
903
+ use_think_labels: false
904
+
905
+ style:
906
+ trigger: "my style"
907
+ prompt_template: official
908
+
909
+ lora:
910
+ preset: attn_mlp_no_head
911
+
912
+ unfreeze:
913
+ - '^fm_modules\.timestep_embedder\.'
914
+ - '^fm_modules\.noise_scale_embedder\.'
915
+ - '^fm_modules\.vision_model_mot_gen\.'
916
+ - '^fm_modules\.fm_head\.'
917
+
918
+ train:
919
+ loss_type: x0
920
+ t_dist: uniform
921
+ cond_dropout_text: 0.0
922
+ cond_dropout_both: 0.0
923
+ </pre>
924
+
925
+ <h2>参考文献</h2>
926
+ <ol class="refs">
927
+ <li>SenseNova-U1 Technical Report. <em>OpenSenseNova</em>, 2026.
928
+ <span class="small">https://github.com/OpenSenseNova/SenseNova-U1</span></li>
929
+ <li>Lipman, Y. et al. Flow Matching for Generative Modeling. <em>ICLR</em>, 2023.</li>
930
+ <li>Esser, P. et al. Scaling Rectified Flow Transformers for High-Resolution
931
+ Image Synthesis. <em>ICML</em>, 2024 — 提出 logit-normal <em class="m">t</em> 采样.</li>
932
+ <li>Ho, J. &amp; Salimans, T. Classifier-Free Diffusion Guidance.
933
+ <em>NeurIPS Workshop on Deep Generative Models</em>, 2021.</li>
934
+ <li>Hu, E. J. et al. LoRA: Low-Rank Adaptation of Large Language Models.
935
+ <em>ICLR</em>, 2022.</li>
936
+ </ol>
937
+
938
+ </section>
939
+
940
+ </body>
941
+ </html>
docs/small_data_style_ablation.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:446f868e6c5fe4518cebe60baf5d5caed50ff0f1b0bad87b693265550b19daad
3
+ size 9161869
train.sh CHANGED
@@ -2,12 +2,12 @@
2
  # train.sh — quick-launch a LoRA training run from a YAML config.
3
  #
4
  # Usage:
5
- # ./train.sh # uses configs/default.yaml
6
- # ./train.sh configs/v16c.yaml
7
- # ./train.sh configs/v16c.yaml --steps 12000 # extra args forwarded
8
  #
9
  # Long-running training tip:
10
- # setsid nohup ./train.sh configs/v16c.yaml </dev/null >run.log 2>&1 &
11
  # disown
12
  # (a bare `nohup &` can be SIGHUP'd when the SSH/IDE session disconnects).
13
 
 
2
  # train.sh — quick-launch a LoRA training run from a YAML config.
3
  #
4
  # Usage:
5
+ # ./train.sh # uses configs/default.yaml
6
+ # ./train.sh configs/my_style.yaml
7
+ # ./train.sh configs/my_style.yaml --steps 12000 # extra args forwarded
8
  #
9
  # Long-running training tip:
10
+ # setsid nohup ./train.sh configs/my_style.yaml </dev/null >run.log 2>&1 &
11
  # disown
12
  # (a bare `nohup &` can be SIGHUP'd when the SSH/IDE session disconnects).
13
 
train_u1/README.md CHANGED
@@ -49,7 +49,10 @@ train_u1/
49
  ```
50
 
51
  **未实现 / pending**:
52
- - `model/peft_targets.py` — LoRA target 工具(实验 C 时落地)
 
 
 
53
  - `scripts/train_balanced.py` — 48GB 平衡场景(实验 C 后视效果决定)
54
  - `scripts/sample_t2i.py` 完整管线(依赖 wrapper 的 `t2i_generate` 钩子)
55
 
 
49
  ```
50
 
51
  **未实现 / pending**:
52
+ - `model/peft_targets.py` — 独立 LoRA target 工具仍未拆出;当前
53
+ `model/lora.py` 已内置 8B dense target 与实验性 A3B generation MoE
54
+ target grammar(`gen_moe_mlp` / `gen_moe_router`)。A3B/MoE 不是当前
55
+ 主训练路径,必须等可实例化 `mlp_mot_gen.experts.*` 的公开 runtime。
56
  - `scripts/train_balanced.py` — 48GB 平衡场景(实验 C 后视效果决定)
57
  - `scripts/sample_t2i.py` 完整管线(依赖 wrapper 的 `t2i_generate` 钩子)
58
 
train_u1/config.py CHANGED
@@ -15,12 +15,15 @@ A single YAML file describes a complete LoRA training run. Example::
15
  prompt_template: official # or 'plain'
16
 
17
  lora:
18
- preset: default # = attn+mlp+fm_head all at r=64,a=64
19
- # spec: "attn=r64a64;mlp=r64a64;fm_head=r64a64;mlp_mot_gen.down_proj=off"
20
  dropout: 0.0
21
 
22
  unfreeze: # full-finetune (non-LoRA) regex patterns
23
- [] # default: LoRA-only training
 
 
 
24
 
25
  train:
26
  steps: 6000
@@ -29,6 +32,10 @@ A single YAML file describes a complete LoRA training run. Example::
29
  shuffle: true
30
  grad_accum: 1
31
  checkpoint_every: 600
 
 
 
 
32
 
33
  runtime:
34
  keep_kvs_on_gpu: true
@@ -68,6 +75,7 @@ class DataConfig:
68
  snap_bucket: bool = True
69
  n_samples: int | None = None # default: use entire dataset
70
  sample_buckets_file: str | None = None
 
71
 
72
 
73
  @dataclass
@@ -78,7 +86,7 @@ class StyleConfig:
78
 
79
  @dataclass
80
  class LoRAConfig:
81
- preset: str | None = "default" # one of LORA_PRESETS
82
  spec: str | None = None # overrides preset if set
83
  dropout: float = 0.0
84
 
@@ -107,6 +115,25 @@ class TrainConfig:
107
  grad_accum: int = 1
108
  checkpoint_every: int = 600
109
  checkpoint_dir: str | None = None # default: artifacts/{run_name}/checkpoints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  @dataclass
@@ -119,13 +146,22 @@ class RuntimeConfig:
119
  upstream_lora_skip: tuple[str, ...] = ()
120
 
121
 
 
 
 
 
 
 
 
 
 
122
  @dataclass
123
  class TrainRunConfig:
124
  run_name: str = "my_run"
125
  data: DataConfig = field(default_factory=DataConfig)
126
  style: StyleConfig = field(default_factory=StyleConfig)
127
  lora: LoRAConfig = field(default_factory=LoRAConfig)
128
- unfreeze: list[str] = field(default_factory=list)
129
  train: TrainConfig = field(default_factory=TrainConfig)
130
  runtime: RuntimeConfig = field(default_factory=RuntimeConfig)
131
 
 
15
  prompt_template: official # or 'plain'
16
 
17
  lora:
18
+ preset: attn_mlp_no_head # small-data baseline: LoRA attn+mlp only
19
+ # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off"
20
  dropout: 0.0
21
 
22
  unfreeze: # full-finetune (non-LoRA) regex patterns
23
+ - '^fm_modules\\.timestep_embedder\\.'
24
+ - '^fm_modules\\.noise_scale_embedder\\.'
25
+ - '^fm_modules\\.vision_model_mot_gen\\.'
26
+ - '^fm_modules\\.fm_head\\.'
27
 
28
  train:
29
  steps: 6000
 
32
  shuffle: true
33
  grad_accum: 1
34
  checkpoint_every: 600
35
+ loss_type: x0
36
+ t_dist: uniform
37
+ cond_dropout_text: 0.0
38
+ cond_dropout_both: 0.0
39
 
40
  runtime:
41
  keep_kvs_on_gpu: true
 
75
  snap_bucket: bool = True
76
  n_samples: int | None = None # default: use entire dataset
77
  sample_buckets_file: str | None = None
78
+ use_think_labels: bool = False
79
 
80
 
81
  @dataclass
 
86
 
87
  @dataclass
88
  class LoRAConfig:
89
+ preset: str | None = "attn_mlp_no_head" # one of LORA_PRESETS
90
  spec: str | None = None # overrides preset if set
91
  dropout: float = 0.0
92
 
 
115
  grad_accum: int = 1
116
  checkpoint_every: int = 600
117
  checkpoint_dir: str | None = None # default: artifacts/{run_name}/checkpoints
118
+ # FM loss objective. Default is the local small-data baseline (`x0`) because
119
+ # the ablation study showed that official-style v-loss is not a good
120
+ # small-data style-training default. `v` remains available for explicit
121
+ # official alignment experiments.
122
+ # Choose one of `x0` | `v` | `x0_huber` | `v_huber`.
123
+ loss_type: str = "x0"
124
+ huber_delta: float = 1.0
125
+ # FM `t`-sampling distribution. Default is uniform for the same local
126
+ # baseline reason. `logit_normal` is kept for report-alignment ablations.
127
+ t_dist: str = "uniform"
128
+ t_logit_mean: float = -0.8
129
+ t_logit_std: float = 0.8
130
+ # CFG / condition dropout. `cond_dropout_text` drops text condition only;
131
+ # `cond_dropout_both` is the additional unconditional bucket from the
132
+ # report. In the current pure-T2I trainer there is no separate reference
133
+ # image condition, so both modes use the sampler's unconditional prompt
134
+ # prefix while preserving separate log labels.
135
+ cond_dropout_text: float = 0.0
136
+ cond_dropout_both: float = 0.0
137
 
138
 
139
  @dataclass
 
146
  upstream_lora_skip: tuple[str, ...] = ()
147
 
148
 
149
+ def _default_unfreeze_patterns() -> list[str]:
150
+ return [
151
+ r"^fm_modules\.timestep_embedder\.",
152
+ r"^fm_modules\.noise_scale_embedder\.",
153
+ r"^fm_modules\.vision_model_mot_gen\.",
154
+ r"^fm_modules\.fm_head\.",
155
+ ]
156
+
157
+
158
  @dataclass
159
  class TrainRunConfig:
160
  run_name: str = "my_run"
161
  data: DataConfig = field(default_factory=DataConfig)
162
  style: StyleConfig = field(default_factory=StyleConfig)
163
  lora: LoRAConfig = field(default_factory=LoRAConfig)
164
+ unfreeze: list[str] = field(default_factory=_default_unfreeze_patterns)
165
  train: TrainConfig = field(default_factory=TrainConfig)
166
  runtime: RuntimeConfig = field(default_factory=RuntimeConfig)
167
 
train_u1/data/collators.py CHANGED
@@ -21,8 +21,6 @@ from typing import Iterable
21
 
22
  import torch
23
 
24
- import math
25
-
26
  from train_u1.constants import (
27
  FM_OUTPUT_DIM,
28
  NOISE_SCALE_BASE_IMAGE_SEQ_LEN,
@@ -50,7 +48,23 @@ class CollatorConfig:
50
  # then batch=1).
51
  image_hw: tuple[int, int] | None = (512, 512)
52
  t_eps: float = T_EPS_DEFAULT
53
- t_dist: str = "uniform" # uniform on (t_eps, 1] for MVP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  add_noise_scale: bool = True
55
  # Base noise_scale value (config.noise_scale = 1.0). The *effective* per-sample
56
  # noise_scale is computed at collator runtime as
@@ -80,6 +94,12 @@ class CollatorConfig:
80
  # Format used: f"{style_trigger}, {original_caption}".
81
  style_trigger: str = ""
82
 
 
 
 
 
 
 
83
 
84
  class SenseNovaU1Collator:
85
  """Stateful collator: holds tokenizer + config, callable on a list of `T2ISample`.
@@ -94,6 +114,7 @@ class SenseNovaU1Collator:
94
  self.tok = tokenizer
95
  self.cfg = cfg or CollatorConfig()
96
  self._gen = torch.Generator().manual_seed(self.cfg.seed)
 
97
  if self.cfg.prompt_template == "official":
98
  if model is None or not hasattr(model, "_build_t2i_query"):
99
  raise ValueError(
@@ -147,10 +168,87 @@ class SenseNovaU1Collator:
147
  if self.cfg.t_dist == "uniform":
148
  t = torch.rand(batch_size, generator=self._gen)
149
  t = t * (1.0 - self.cfg.t_eps) + self.cfg.t_eps
 
 
 
 
 
 
150
  else:
151
  raise NotImplementedError(f"t_dist={self.cfg.t_dist}")
152
  return t
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  @staticmethod
155
  def _check_image_hw(image_hw: tuple[int, int]) -> None:
156
  H, W = image_hw
@@ -163,7 +261,12 @@ class SenseNovaU1Collator:
163
  # ------------------------------------------------------------------ #
164
  # Main entry #
165
  # ------------------------------------------------------------------ #
166
- def __call__(self, samples: list[T2ISample]) -> dict[str, torch.Tensor]:
 
 
 
 
 
167
  cfg = self.cfg
168
  if cfg.enforce_batch_one and len(samples) != 1:
169
  raise ValueError(
@@ -196,38 +299,16 @@ class SenseNovaU1Collator:
196
  # 1) text → ids + per-sample lengths. With enforce_batch_one we know
197
  # `len(samples) == 1` so no batch padding is applied — `L_text` is
198
  # exactly this prompt's length (matches upstream `_build_t2i_text_inputs`).
199
- # Apply style trigger BEFORE chat-template wrap so the trigger lives
200
- # inside the user-message portion of the chat (not in system or
201
- # assistant). Identical formatting must be replicated at sample time.
202
- if cfg.style_trigger:
203
- raw_prompts = [f"{cfg.style_trigger}, {s.prompt}" for s in samples]
204
  else:
205
- raw_prompts = [s.prompt for s in samples]
206
- if self._build_t2i_query is not None:
207
- prompts = []
208
- for rp, s in zip(raw_prompts, samples):
209
- # Per-sample think injection: when the dataset supplies a
210
- # `think` text, render it INSIDE the otherwise-empty
211
- # `<think></think>` block of the official prompt template.
212
- # This makes training distribution match inference-time
213
- # `--think-mode`, where the model autoregressively fills the
214
- # same window with ~250-400 reasoning tokens. Without this,
215
- # the gen tower sees an unfamiliar prefix length/content
216
- # at inference and the LoRA delta is calibrated against
217
- # the wrong cond-KV distribution.
218
- if s.think:
219
- append_text = f"<think>\n{s.think}\n</think>\n\n<img>"
220
- else:
221
- append_text = self._gen_append
222
- prompts.append(
223
- self._build_t2i_query(
224
- rp,
225
- system_message=self._sys_msg_for_gen,
226
- append_text=append_text,
227
- )
228
  )
229
- else:
230
- prompts = list(raw_prompts)
231
  input_ids, text_lens = self._tokenize(prompts)
232
  B = input_ids.shape[0]
233
  L_text = input_ids.shape[1]
@@ -298,6 +379,12 @@ class SenseNovaU1Collator:
298
  "noisy_pixel_values": noisy_pixel_values, # (B, 3, H, W)
299
  "noisy_grid_hw": noisy_grid_hw, # (B, 2)
300
  "noise_scale": noise_scale, # (B,) or None
 
 
 
 
 
 
301
  "sample_ids": [s.sample_id for s in samples],
302
  "text_lens": text_lens,
303
  "token_hw": (token_h, token_w),
 
21
 
22
  import torch
23
 
 
 
24
  from train_u1.constants import (
25
  FM_OUTPUT_DIM,
26
  NOISE_SCALE_BASE_IMAGE_SEQ_LEN,
 
48
  # then batch=1).
49
  image_hw: tuple[int, int] | None = (512, 512)
50
  t_eps: float = T_EPS_DEFAULT
51
+ # **Default = `logit_normal` (mean=-0.8, std=0.8)** to match SenseNova-U1
52
+ # report Table 2:
53
+ # u ~ Normal(t_logit_mean, t_logit_std); t = sigmoid(u)
54
+ # clamped to [t_eps, 1 - t_eps]. Biases t toward the *low* end (near-clean);
55
+ # combined with v-loss this recovers the official training density.
56
+ # `uniform` on (t_eps, 1] is kept for back-compat / ablation.
57
+ t_dist: str = "logit_normal"
58
+ t_logit_mean: float = -0.8
59
+ t_logit_std: float = 0.8
60
+ # Classifier-free guidance condition dropout. Keep the collator default at
61
+ # zero so eval/smoke/diagnostic callers stay fully conditional unless they
62
+ # opt in. The official training entry point passes 0.10/0.10 from TrainConfig.
63
+ # For pure T2I training there is no separate reference-image condition, so
64
+ # `text_image` uses the same unconditional prompt path as `text` while
65
+ # recording the intended mode.
66
+ cond_dropout_text: float = 0.0
67
+ cond_dropout_both: float = 0.0
68
  add_noise_scale: bool = True
69
  # Base noise_scale value (config.noise_scale = 1.0). The *effective* per-sample
70
  # noise_scale is computed at collator runtime as
 
94
  # Format used: f"{style_trigger}, {original_caption}".
95
  style_trigger: str = ""
96
 
97
+ def __post_init__(self) -> None:
98
+ if self.cond_dropout_text < 0 or self.cond_dropout_both < 0:
99
+ raise ValueError("condition dropout probabilities must be non-negative")
100
+ if self.cond_dropout_text + self.cond_dropout_both > 1.0:
101
+ raise ValueError("cond_dropout_text + cond_dropout_both must be <= 1.0")
102
+
103
 
104
  class SenseNovaU1Collator:
105
  """Stateful collator: holds tokenizer + config, callable on a list of `T2ISample`.
 
114
  self.tok = tokenizer
115
  self.cfg = cfg or CollatorConfig()
116
  self._gen = torch.Generator().manual_seed(self.cfg.seed)
117
+ self._cond_gen = torch.Generator().manual_seed(self.cfg.seed + 10_003)
118
  if self.cfg.prompt_template == "official":
119
  if model is None or not hasattr(model, "_build_t2i_query"):
120
  raise ValueError(
 
168
  if self.cfg.t_dist == "uniform":
169
  t = torch.rand(batch_size, generator=self._gen)
170
  t = t * (1.0 - self.cfg.t_eps) + self.cfg.t_eps
171
+ elif self.cfg.t_dist == "logit_normal":
172
+ # u ~ N(mu, sigma) -> t = sigmoid(u) in (0, 1)
173
+ u = torch.randn(batch_size, generator=self._gen)
174
+ u = u * self.cfg.t_logit_std + self.cfg.t_logit_mean
175
+ t = torch.sigmoid(u)
176
+ t = t.clamp(min=self.cfg.t_eps, max=1.0 - self.cfg.t_eps)
177
  else:
178
  raise NotImplementedError(f"t_dist={self.cfg.t_dist}")
179
  return t
180
 
181
+ def _sample_condition_modes(self, batch_size: int) -> list[str]:
182
+ p_text = float(self.cfg.cond_dropout_text)
183
+ p_both = float(self.cfg.cond_dropout_both)
184
+ if p_text == 0.0 and p_both == 0.0:
185
+ return ["none"] * batch_size
186
+ u = torch.rand(batch_size, generator=self._cond_gen)
187
+ modes: list[str] = []
188
+ for v in u.tolist():
189
+ if v < p_text:
190
+ modes.append("text")
191
+ elif v < p_text + p_both:
192
+ modes.append("text_image")
193
+ else:
194
+ modes.append("none")
195
+ return modes
196
+
197
+ @staticmethod
198
+ def _prefix_cache_key(mode: str) -> str:
199
+ if mode == "none":
200
+ return "cond"
201
+ if mode in ("text", "text_image"):
202
+ return "uncond"
203
+ raise ValueError(f"unknown condition dropout mode {mode!r}")
204
+
205
+ def _render_prompts(
206
+ self,
207
+ samples: list[T2ISample],
208
+ condition_modes: list[str],
209
+ ) -> list[str]:
210
+ prompts: list[str] = []
211
+ for s, mode in zip(samples, condition_modes):
212
+ if mode not in ("none", "text", "text_image"):
213
+ raise ValueError(f"unknown condition dropout mode {mode!r}")
214
+ drop_text = mode in ("text", "text_image")
215
+ if drop_text:
216
+ raw_prompt = ""
217
+ elif self.cfg.style_trigger:
218
+ raw_prompt = f"{self.cfg.style_trigger}, {s.prompt}"
219
+ else:
220
+ raw_prompt = s.prompt
221
+
222
+ if self._build_t2i_query is not None:
223
+ if drop_text:
224
+ # Match the sampler's unconditional CFG prefix exactly:
225
+ # `_build_t2i_query("", append_text="<img>")`.
226
+ prompts.append(self._build_t2i_query("", append_text="<img>"))
227
+ continue
228
+ # Per-sample think injection: when the dataset supplies a
229
+ # `think` text, render it INSIDE the otherwise-empty
230
+ # `<think></think>` block of the official prompt template.
231
+ # This makes training distribution match inference-time
232
+ # `--think-mode`, where the model autoregressively fills the
233
+ # same window with ~250-400 reasoning tokens. Without this,
234
+ # the gen tower sees an unfamiliar prefix length/content
235
+ # at inference and the LoRA delta is calibrated against
236
+ # the wrong cond-KV distribution.
237
+ if s.think:
238
+ append_text = f"<think>\n{s.think}\n</think>\n\n<img>"
239
+ else:
240
+ append_text = self._gen_append
241
+ prompts.append(
242
+ self._build_t2i_query(
243
+ raw_prompt,
244
+ system_message=self._sys_msg_for_gen,
245
+ append_text=append_text,
246
+ )
247
+ )
248
+ else:
249
+ prompts.append(" " if drop_text else raw_prompt)
250
+ return prompts
251
+
252
  @staticmethod
253
  def _check_image_hw(image_hw: tuple[int, int]) -> None:
254
  H, W = image_hw
 
261
  # ------------------------------------------------------------------ #
262
  # Main entry #
263
  # ------------------------------------------------------------------ #
264
+ def __call__(
265
+ self,
266
+ samples: list[T2ISample],
267
+ *,
268
+ condition_modes: list[str] | None = None,
269
+ ) -> dict[str, torch.Tensor]:
270
  cfg = self.cfg
271
  if cfg.enforce_batch_one and len(samples) != 1:
272
  raise ValueError(
 
299
  # 1) text → ids + per-sample lengths. With enforce_batch_one we know
300
  # `len(samples) == 1` so no batch padding is applied — `L_text` is
301
  # exactly this prompt's length (matches upstream `_build_t2i_text_inputs`).
302
+ if condition_modes is None:
303
+ condition_modes = self._sample_condition_modes(len(samples))
 
 
 
304
  else:
305
+ condition_modes = list(condition_modes)
306
+ if len(condition_modes) != len(samples):
307
+ raise ValueError(
308
+ f"condition_modes length {len(condition_modes)} != samples length {len(samples)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
310
+ prompts = self._render_prompts(samples, condition_modes)
311
+ prefix_cache_keys = [self._prefix_cache_key(m) for m in condition_modes]
312
  input_ids, text_lens = self._tokenize(prompts)
313
  B = input_ids.shape[0]
314
  L_text = input_ids.shape[1]
 
379
  "noisy_pixel_values": noisy_pixel_values, # (B, 3, H, W)
380
  "noisy_grid_hw": noisy_grid_hw, # (B, 2)
381
  "noise_scale": noise_scale, # (B,) or None
382
+ "cond_drop_text": torch.tensor(
383
+ [m in ("text", "text_image") for m in condition_modes],
384
+ dtype=torch.bool,
385
+ ),
386
+ "cond_drop_mode": condition_modes,
387
+ "prefix_cache_key": prefix_cache_keys,
388
  "sample_ids": [s.sample_id for s in samples],
389
  "text_lens": text_lens,
390
  "token_hw": (token_h, token_w),
train_u1/data/datasets.py CHANGED
@@ -141,6 +141,7 @@ class PairedFolderT2IDataset(Dataset):
141
  prompt_template: str | None = None,
142
  image_extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp"),
143
  snap_bucket: bool = False,
 
144
  ):
145
  self.folder = Path(folder)
146
  if not self.folder.is_dir():
@@ -148,6 +149,7 @@ class PairedFolderT2IDataset(Dataset):
148
  self.cap_max_pixels = cap_max_pixels
149
  self.prompt_template = prompt_template
150
  self.snap_bucket = snap_bucket
 
151
 
152
  pairs: list[tuple[Path, Path, str]] = []
153
  for ext in image_extensions:
@@ -172,11 +174,13 @@ class PairedFolderT2IDataset(Dataset):
172
  with open(txt_path, encoding="utf-8") as f:
173
  raw = f.read()
174
  caption, think_text = parse_caption_and_think(raw)
 
 
175
  if self.prompt_template:
176
  caption = self.prompt_template.format(caption=caption)
177
  # Legacy fallback: `<id>.think.txt` separate sidecar (deprecated;
178
  # `parse_caption_and_think` is the preferred path).
179
- if think_text is None:
180
  think_path = img_path.with_suffix(".think.txt")
181
  if think_path.is_file():
182
  with open(think_path, encoding="utf-8") as f:
@@ -257,6 +261,7 @@ class ArrowT2IDataset(Dataset):
257
  cap_max_pixels: int | None = None,
258
  prompt_template: str | None = None,
259
  snap_bucket: bool = False,
 
260
  ):
261
  try:
262
  import pyarrow.parquet as pq # noqa: F401
@@ -269,6 +274,7 @@ class ArrowT2IDataset(Dataset):
269
  self.cap_max_pixels = cap_max_pixels
270
  self.prompt_template = prompt_template
271
  self.snap_bucket = snap_bucket
 
272
  self._table = None # lazy-loaded
273
  self._n: int | None = None
274
 
@@ -311,6 +317,8 @@ class ArrowT2IDataset(Dataset):
311
  sample_id = row["sample_id"][0]
312
  caption = row["caption"][0]
313
  think = (row.get("think") or [None])[0] or None
 
 
314
 
315
  if "image" in self._table.column_names and row["image"][0] is not None:
316
  from PIL import Image
 
141
  prompt_template: str | None = None,
142
  image_extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp"),
143
  snap_bucket: bool = False,
144
+ use_think_labels: bool = True,
145
  ):
146
  self.folder = Path(folder)
147
  if not self.folder.is_dir():
 
149
  self.cap_max_pixels = cap_max_pixels
150
  self.prompt_template = prompt_template
151
  self.snap_bucket = snap_bucket
152
+ self.use_think_labels = use_think_labels
153
 
154
  pairs: list[tuple[Path, Path, str]] = []
155
  for ext in image_extensions:
 
174
  with open(txt_path, encoding="utf-8") as f:
175
  raw = f.read()
176
  caption, think_text = parse_caption_and_think(raw)
177
+ if not self.use_think_labels:
178
+ think_text = None
179
  if self.prompt_template:
180
  caption = self.prompt_template.format(caption=caption)
181
  # Legacy fallback: `<id>.think.txt` separate sidecar (deprecated;
182
  # `parse_caption_and_think` is the preferred path).
183
+ if self.use_think_labels and think_text is None:
184
  think_path = img_path.with_suffix(".think.txt")
185
  if think_path.is_file():
186
  with open(think_path, encoding="utf-8") as f:
 
261
  cap_max_pixels: int | None = None,
262
  prompt_template: str | None = None,
263
  snap_bucket: bool = False,
264
+ use_think_labels: bool = True,
265
  ):
266
  try:
267
  import pyarrow.parquet as pq # noqa: F401
 
274
  self.cap_max_pixels = cap_max_pixels
275
  self.prompt_template = prompt_template
276
  self.snap_bucket = snap_bucket
277
+ self.use_think_labels = use_think_labels
278
  self._table = None # lazy-loaded
279
  self._n: int | None = None
280
 
 
317
  sample_id = row["sample_id"][0]
318
  caption = row["caption"][0]
319
  think = (row.get("think") or [None])[0] or None
320
+ if not self.use_think_labels:
321
+ think = None
322
 
323
  if "image" in self._table.column_names and row["image"][0] is not None:
324
  from PIL import Image
train_u1/model/lora.py CHANGED
@@ -17,6 +17,16 @@ Wrapped modules supported (per-module rank/alpha/enable independently):
17
  Patch decoder (×2):
18
  fm_modules.fm_head.0 fm_modules.fm_head.2
19
 
 
 
 
 
 
 
 
 
 
 
20
  The adapter is implemented as `y = base(x) + scaling * lora_up(lora_down(x))`
21
  with `scaling = alpha / r`. Initial state: `lora_down` kaiming uniform,
22
  `lora_up` zeros — so the wrapped module starts at exactly the base output.
@@ -50,17 +60,42 @@ import torch.nn as nn
50
  ATTN_TARGETS = ("q_proj_mot_gen", "k_proj_mot_gen", "v_proj_mot_gen", "o_proj_mot_gen")
51
  MLP_TARGETS = ("mlp_mot_gen.gate_proj", "mlp_mot_gen.up_proj", "mlp_mot_gen.down_proj")
52
  FM_HEAD_TARGETS = ("fm_modules.fm_head.0", "fm_modules.fm_head.2")
 
 
 
 
 
 
 
53
 
54
- ALL_KNOWN_TARGETS = ATTN_TARGETS + MLP_TARGETS + FM_HEAD_TARGETS
 
55
 
56
  # Convenience expansions used by the CLI parser (`attn`, `mlp`, `fm_head`).
57
  TARGET_GROUPS: dict[str, tuple[str, ...]] = {
58
  "attn": ATTN_TARGETS,
59
  "mlp": MLP_TARGETS,
60
  "fm_head": FM_HEAD_TARGETS,
61
- "all": ALL_KNOWN_TARGETS,
 
 
 
 
 
 
 
 
62
  }
63
 
 
 
 
 
 
 
 
 
 
64
 
65
  # --------------------------------------------------------------------------- #
66
  # Spec types #
@@ -71,7 +106,8 @@ TARGET_GROUPS: dict[str, tuple[str, ...]] = {
71
  class LoRASpec:
72
  """Per-target LoRA configuration.
73
 
74
- `target` is one of `ALL_KNOWN_TARGETS` (verbatim module-name suffix).
 
75
  `r` is the LoRA rank. `alpha` is the LoRA alpha; `scaling = alpha / r`.
76
  `dropout` applies to the input before `lora_down`.
77
  `enabled=False` lets a preset entry be turned off without removing it.
@@ -84,7 +120,7 @@ class LoRASpec:
84
  enabled: bool = True
85
 
86
  def __post_init__(self) -> None:
87
- if self.target not in ALL_KNOWN_TARGETS:
88
  raise ValueError(
89
  f"unknown LoRA target {self.target!r}. "
90
  f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
@@ -219,6 +255,50 @@ def _walk_mlp_targets(model: nn.Module, target_name: str):
219
  yield sub, leaf, idx
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def _walk_fm_head_targets(model: nn.Module, target_name: str):
223
  """Yield `(parent, attr, idx)` for each fm_head linear matching target_name.
224
 
@@ -251,6 +331,10 @@ def _resolve_target_walker(target: str):
251
  return _walk_attn_targets
252
  if target in MLP_TARGETS:
253
  return _walk_mlp_targets
 
 
 
 
254
  if target in FM_HEAD_TARGETS:
255
  return _walk_fm_head_targets
256
  raise ValueError(f"no walker for target {target!r}")
@@ -327,10 +411,7 @@ def apply_lora_specs(
327
  # --------------------------------------------------------------------------- #
328
 
329
 
330
- _SPEC_TOK_RE = re.compile(
331
- r"^(?P<target>[A-Za-z0-9_.]+)"
332
- r"(?:=(?P<body>.+))?$"
333
- )
334
  _RA_RE = re.compile(r"^r(?P<r>\d+)(?:a(?P<alpha>\d+(?:\.\d+)?))?$")
335
 
336
 
@@ -343,14 +424,16 @@ def parse_lora_spec_str(s: str) -> list[LoRASpec]:
343
  - `off` disable a target (overrides earlier entries)
344
  - `r=N,a=M` alternative comma form (more readable)
345
 
346
- Group expansions: `attn`, `mlp`, `fm_head`, `all` expand to their member
347
- targets, all sharing the same body.
 
348
 
349
  Examples::
350
 
351
  attn=r64a64;mlp=r64a64
352
  q_proj_mot_gen=r128a128; k_proj_mot_gen=r128a128
353
  all=r64a64; mlp_mot_gen.down_proj=off
 
354
  fm_head=r=128,a=128
355
  """
356
  specs: dict[str, LoRASpec] = {}
@@ -366,7 +449,7 @@ def parse_lora_spec_str(s: str) -> list[LoRASpec]:
366
 
367
  targets = TARGET_GROUPS.get(target, (target,))
368
  for t in targets:
369
- if t not in ALL_KNOWN_TARGETS:
370
  raise ValueError(
371
  f"unknown LoRA target {t!r}. "
372
  f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
@@ -434,8 +517,22 @@ LORA_PRESETS: dict[str, str] = {
434
  # Attn + MLP only (no fm_head); equivalent to our pre-v16c v15a recipe.
435
  "attn_mlp": "attn=r64a64;mlp=r64a64",
436
 
 
 
 
 
 
 
 
 
 
437
  # Exact upstream 8-step distill LoRA shape (rank 128 alpha 128).
438
  "official_r128": "attn=r128a128;mlp=r128a128;fm_head=r128a128",
 
 
 
 
 
439
  }
440
 
441
 
 
17
  Patch decoder (×2):
18
  fm_modules.fm_head.0 fm_modules.fm_head.2
19
 
20
+ Experimental A3B/MoE target grammar (requires an A3B runtime whose modules
21
+ match the public checkpoint names):
22
+
23
+ Generation MoE experts:
24
+ mlp_mot_gen.experts.*.gate_proj
25
+ mlp_mot_gen.experts.*.up_proj
26
+ mlp_mot_gen.experts.*.down_proj
27
+ Generation MoE router:
28
+ mlp_mot_gen.gate
29
+
30
  The adapter is implemented as `y = base(x) + scaling * lora_up(lora_down(x))`
31
  with `scaling = alpha / r`. Initial state: `lora_down` kaiming uniform,
32
  `lora_up` zeros — so the wrapped module starts at exactly the base output.
 
60
  ATTN_TARGETS = ("q_proj_mot_gen", "k_proj_mot_gen", "v_proj_mot_gen", "o_proj_mot_gen")
61
  MLP_TARGETS = ("mlp_mot_gen.gate_proj", "mlp_mot_gen.up_proj", "mlp_mot_gen.down_proj")
62
  FM_HEAD_TARGETS = ("fm_modules.fm_head.0", "fm_modules.fm_head.2")
63
+ GEN_MOE_MLP_TARGETS = (
64
+ "mlp_mot_gen.experts.*.gate_proj",
65
+ "mlp_mot_gen.experts.*.up_proj",
66
+ "mlp_mot_gen.experts.*.down_proj",
67
+ )
68
+ GEN_MOE_ROUTER_TARGETS = ("mlp_mot_gen.gate",)
69
+ GEN_MOE_TARGETS = GEN_MOE_MLP_TARGETS + GEN_MOE_ROUTER_TARGETS
70
 
71
+ DENSE_KNOWN_TARGETS = ATTN_TARGETS + MLP_TARGETS + FM_HEAD_TARGETS
72
+ ALL_KNOWN_TARGETS = DENSE_KNOWN_TARGETS + GEN_MOE_TARGETS
73
 
74
  # Convenience expansions used by the CLI parser (`attn`, `mlp`, `fm_head`).
75
  TARGET_GROUPS: dict[str, tuple[str, ...]] = {
76
  "attn": ATTN_TARGETS,
77
  "mlp": MLP_TARGETS,
78
  "fm_head": FM_HEAD_TARGETS,
79
+ # A3B generation-side MoE aliases. These are deliberately separate from
80
+ # `mlp`/`all` so existing 8B configs remain byte-for-byte semantic matches.
81
+ "gen_moe_mlp": GEN_MOE_MLP_TARGETS,
82
+ "moe_mlp": GEN_MOE_MLP_TARGETS,
83
+ "gen_moe_router": GEN_MOE_ROUTER_TARGETS,
84
+ "moe_router": GEN_MOE_ROUTER_TARGETS,
85
+ "gen_moe_all": ATTN_TARGETS + GEN_MOE_TARGETS + FM_HEAD_TARGETS,
86
+ "moe_all": ATTN_TARGETS + GEN_MOE_TARGETS + FM_HEAD_TARGETS,
87
+ "all": DENSE_KNOWN_TARGETS,
88
  }
89
 
90
+ _GEN_MOE_EXPERT_TARGET_RE = re.compile(
91
+ r"^mlp_mot_gen\.experts\.(?P<expert>\*|\d+)\."
92
+ r"(?P<leaf>gate_proj|up_proj|down_proj)$"
93
+ )
94
+
95
+
96
+ def _is_known_target(target: str) -> bool:
97
+ return target in ALL_KNOWN_TARGETS or _GEN_MOE_EXPERT_TARGET_RE.match(target) is not None
98
+
99
 
100
  # --------------------------------------------------------------------------- #
101
  # Spec types #
 
106
  class LoRASpec:
107
  """Per-target LoRA configuration.
108
 
109
+ `target` is one of `ALL_KNOWN_TARGETS` (verbatim module-name suffix) or
110
+ an A3B MoE expert target like `mlp_mot_gen.experts.0.gate_proj`.
111
  `r` is the LoRA rank. `alpha` is the LoRA alpha; `scaling = alpha / r`.
112
  `dropout` applies to the input before `lora_down`.
113
  `enabled=False` lets a preset entry be turned off without removing it.
 
120
  enabled: bool = True
121
 
122
  def __post_init__(self) -> None:
123
+ if not _is_known_target(self.target):
124
  raise ValueError(
125
  f"unknown LoRA target {self.target!r}. "
126
  f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
 
255
  yield sub, leaf, idx
256
 
257
 
258
+ def _walk_moe_mlp_targets(model: nn.Module, target_name: str):
259
+ """Yield generation-side MoE expert projections for A3B-style modules.
260
+
261
+ `target_name` is `mlp_mot_gen.experts.*.gate_proj` or a single expert
262
+ target such as `mlp_mot_gen.experts.7.down_proj`. The walker is intentionally
263
+ shape/runtime agnostic: if the loaded model has no `experts` ModuleList, it
264
+ yields nothing so 8B dense configs are unaffected.
265
+ """
266
+ m = _GEN_MOE_EXPERT_TARGET_RE.match(target_name)
267
+ if m is None:
268
+ raise ValueError(f"invalid MoE expert target {target_name!r}")
269
+ expert_selector = m.group("expert")
270
+ leaf = m.group("leaf")
271
+
272
+ layers = model.language_model.model.layers
273
+ for layer_idx, layer in enumerate(layers):
274
+ sub = getattr(layer, "mlp_mot_gen", None)
275
+ experts = getattr(sub, "experts", None)
276
+ if experts is None:
277
+ continue
278
+ if expert_selector == "*":
279
+ expert_indices = range(len(experts))
280
+ else:
281
+ expert_idx = int(expert_selector)
282
+ if expert_idx >= len(experts):
283
+ continue
284
+ expert_indices = (expert_idx,)
285
+ for expert_idx in expert_indices:
286
+ expert = experts[expert_idx]
287
+ if hasattr(expert, leaf):
288
+ yield expert, leaf, layer_idx
289
+
290
+
291
+ def _walk_moe_router_targets(model: nn.Module, target_name: str):
292
+ """Yield generation-side MoE router gates (`mlp_mot_gen.gate`) per layer."""
293
+ if target_name != "mlp_mot_gen.gate":
294
+ raise ValueError(f"invalid MoE router target {target_name!r}")
295
+ layers = model.language_model.model.layers
296
+ for idx, layer in enumerate(layers):
297
+ sub = getattr(layer, "mlp_mot_gen", None)
298
+ if sub is not None and hasattr(sub, "gate"):
299
+ yield sub, "gate", idx
300
+
301
+
302
  def _walk_fm_head_targets(model: nn.Module, target_name: str):
303
  """Yield `(parent, attr, idx)` for each fm_head linear matching target_name.
304
 
 
331
  return _walk_attn_targets
332
  if target in MLP_TARGETS:
333
  return _walk_mlp_targets
334
+ if _GEN_MOE_EXPERT_TARGET_RE.match(target):
335
+ return _walk_moe_mlp_targets
336
+ if target in GEN_MOE_ROUTER_TARGETS:
337
+ return _walk_moe_router_targets
338
  if target in FM_HEAD_TARGETS:
339
  return _walk_fm_head_targets
340
  raise ValueError(f"no walker for target {target!r}")
 
411
  # --------------------------------------------------------------------------- #
412
 
413
 
414
+ _SPEC_TOK_RE = re.compile(r"^(?P<target>[A-Za-z0-9_.*]+)(?:=(?P<body>.+))?$")
 
 
 
415
  _RA_RE = re.compile(r"^r(?P<r>\d+)(?:a(?P<alpha>\d+(?:\.\d+)?))?$")
416
 
417
 
 
424
  - `off` disable a target (overrides earlier entries)
425
  - `r=N,a=M` alternative comma form (more readable)
426
 
427
+ Group expansions: `attn`, `mlp`, `fm_head`, `gen_moe_mlp`,
428
+ `gen_moe_router`, `gen_moe_all`, `all` expand to their member targets,
429
+ all sharing the same body.
430
 
431
  Examples::
432
 
433
  attn=r64a64;mlp=r64a64
434
  q_proj_mot_gen=r128a128; k_proj_mot_gen=r128a128
435
  all=r64a64; mlp_mot_gen.down_proj=off
436
+ gen_moe_mlp=r8a8; gen_moe_router=r4a4
437
  fm_head=r=128,a=128
438
  """
439
  specs: dict[str, LoRASpec] = {}
 
449
 
450
  targets = TARGET_GROUPS.get(target, (target,))
451
  for t in targets:
452
+ if not _is_known_target(t):
453
  raise ValueError(
454
  f"unknown LoRA target {t!r}. "
455
  f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
 
517
  # Attn + MLP only (no fm_head); equivalent to our pre-v16c v15a recipe.
518
  "attn_mlp": "attn=r64a64;mlp=r64a64",
519
 
520
+ # **Safe presets** that explicitly drop fm_head from the trained surface.
521
+ # The technical report's grid-artifact discussion attributes artifacts to
522
+ # the final FFN + MLP head independently modelling disjoint 32×32 patches,
523
+ # and notes that the official T2I RL stage freezes the generation-branch
524
+ # MLP head and the last three transformer layers for exactly this reason.
525
+ # Use these when you want to avoid touching the head at all.
526
+ "attn_only_no_head": "attn=r64a64",
527
+ "attn_mlp_no_head": "attn=r64a64;mlp=r64a64",
528
+
529
  # Exact upstream 8-step distill LoRA shape (rank 128 alpha 128).
530
  "official_r128": "attn=r128a128;mlp=r128a128;fm_head=r128a128",
531
+
532
+ # Experimental A3B/MoE coverage. Small ranks are intentional: covering all
533
+ # 48 layers × 32 gen experts × 3 projections gets large quickly.
534
+ "a3b_moe_r8": "attn=r8a8;gen_moe_mlp=r8a8;fm_head=r8a8",
535
+ "a3b_moe_router_r8": "gen_moe_router=r8a8",
536
  }
537
 
538
 
train_u1/model/losses.py CHANGED
@@ -1,11 +1,14 @@
1
  """Training losses for the FM step.
2
 
3
- Two primaries (report §5 / §2.1):
4
- - `fm_loss_x0(x_pred, x0_patch)` — MVP recommended. MSE on clean patches.
5
- - `fm_loss_v(v_pred, v_target)` — velocity-target ablation.
 
 
6
 
7
- Plus optional Huber variants for outlier robustness, and a tiny CE
8
- guardrail for the unified-training scenario (Phase C).
 
9
  """
10
  from __future__ import annotations
11
 
@@ -35,6 +38,65 @@ def fm_loss_v_huber(v_pred: torch.Tensor, v_target: torch.Tensor, delta: float =
35
  return F.huber_loss(v_pred.float(), v_target.float(), delta=delta)
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def text_ce_guardrail(
39
  logits: torch.Tensor,
40
  labels: torch.Tensor,
 
1
  """Training losses for the FM step.
2
 
3
+ Two primaries (report Eq. (5) / Table 2):
4
+ - `fm_loss_x0(x_pred, x0_patch)` — MSE on clean patches (legacy MVP default).
5
+ - `fm_loss_v(v_pred, v_target)` — MSE on velocity (matches the official
6
+ x-predict + v-loss training objective; equivalent to
7
+ `MSE(x_pred - x0) / (1 - t)^2`, i.e. an x0-MSE re-weighted by `(1-t)^-2`).
8
 
9
+ Plus Huber variants and an `fm_loss` dispatcher that selects by `loss_type`.
10
+
11
+ CE guardrail kept for the Phase C unified-training scenario.
12
  """
13
  from __future__ import annotations
14
 
 
38
  return F.huber_loss(v_pred.float(), v_target.float(), delta=delta)
39
 
40
 
41
+ def compute_v_target(
42
+ x0_patch: torch.Tensor,
43
+ z_t: torch.Tensor,
44
+ t: torch.Tensor,
45
+ *,
46
+ t_eps: float = 1e-3,
47
+ ) -> torch.Tensor:
48
+ """Closed-form velocity target for rectified-flow / linear-z_t.
49
+
50
+ Report Eq. (5): `v* = (x0 - z_t) / (1 - t)` with `z_t = t x0 + (1-t) eps`.
51
+ `t` is expected to be a (B,) tensor — we broadcast to (B, 1, 1) to match
52
+ the (B, N, D) patch tensors.
53
+ """
54
+ if x0_patch.shape != z_t.shape:
55
+ raise ValueError(f"shape mismatch x0 {x0_patch.shape} vs z_t {z_t.shape}")
56
+ t = t.to(x0_patch.dtype)
57
+ while t.dim() < x0_patch.dim():
58
+ t = t.unsqueeze(-1)
59
+ denom = (1.0 - t).clamp(min=t_eps)
60
+ return (x0_patch - z_t) / denom
61
+
62
+
63
+ # --------------------------------------------------------------------------- #
64
+ # Dispatcher #
65
+ # --------------------------------------------------------------------------- #
66
+
67
+ VALID_LOSS_TYPES = ("x0", "v", "x0_huber", "v_huber")
68
+
69
+
70
+ def fm_loss(
71
+ *,
72
+ loss_type: str,
73
+ x_pred: torch.Tensor,
74
+ x0_patch: torch.Tensor,
75
+ v_pred: torch.Tensor | None = None,
76
+ v_target: torch.Tensor | None = None,
77
+ huber_delta: float = 1.0,
78
+ ) -> torch.Tensor:
79
+ """Single entry point selecting one of the four FM losses.
80
+
81
+ `x0` / `x0_huber` only need `x_pred` + `x0_patch`.
82
+ `v` / `v_huber` require `v_pred` + `v_target` (caller computes them via
83
+ `compute_v_target` from the same `(x0, z_t, t)` used to build the batch).
84
+ """
85
+ if loss_type == "x0":
86
+ return fm_loss_x0(x_pred, x0_patch)
87
+ if loss_type == "x0_huber":
88
+ return fm_loss_x0_huber(x_pred, x0_patch, delta=huber_delta)
89
+ if loss_type == "v":
90
+ if v_pred is None or v_target is None:
91
+ raise ValueError("loss_type='v' requires v_pred and v_target")
92
+ return fm_loss_v(v_pred, v_target)
93
+ if loss_type == "v_huber":
94
+ if v_pred is None or v_target is None:
95
+ raise ValueError("loss_type='v_huber' requires v_pred and v_target")
96
+ return fm_loss_v_huber(v_pred, v_target, delta=huber_delta)
97
+ raise ValueError(f"unknown loss_type {loss_type!r}; valid: {VALID_LOSS_TYPES}")
98
+
99
+
100
  def text_ce_guardrail(
101
  logits: torch.Tensor,
102
  labels: torch.Tensor,