diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..2d2b24700fe20a632017a1fe9cc17eb6a9a18a27 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,33 @@ +docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail00_v18.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail00_v19.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail00_v19a.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail00_v19b.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail00_v19c.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail01_v18.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail01_v19.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail01_v19a.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail01_v19b.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail02_v18.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail02_v19.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail02_v19a.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail02_v19b.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail02_v19c.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail07_v18.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail07_v19.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail07_v19a.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail07_v19b.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/detail07_v19c.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg filter=lfs diff=lfs merge=lfs -text +docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg filter=lfs diff=lfs merge=lfs -text +docs/small_data_style_ablation.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 7eb6a4b906c5c21a3836c7edf02a96a1976134e2..73d9717ee42f5d1771b1288cac748d908d3e848a 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ checkpoints drop straight into the official inference scripts. ``` Fits on a 32 GB GPU (RTX 5090 / A100-40 / RTX 6000 Ada). Peak VRAM ~20 GB -on the maintainer's 56-image hayateluc dataset at 2048². +on the train dataset at 2048². --- @@ -25,16 +25,26 @@ on the maintainer's 56-image hayateluc dataset at 2048². - **Config-first**: every run is one YAML file (`configs/default.yaml`). - **Per-module rank + enable**: each LoRA target (`q_proj_mot_gen`, `mlp_mot_gen.down_proj`, `fm_modules.fm_head.0`, …) takes its own rank / alpha / on-off independently. -- **Default = official coverage at rank 64**: the same 296 module wraps as - upstream's 8-step distill LoRA (168 attn + 126 mlp + 2 fm_head), but at - rank 64 instead of 128 — half the trainable params, half the on-disk size, - retains full module surface. +- **Experimental MoE target grammar**: A3B-style generation experts can be + addressed explicitly (`gen_moe_mlp`, `gen_moe_router`, + `mlp_mot_gen.experts.*.gate_proj`) without changing the stable 8B main path. +- **Default = small-data style baseline**: `configs/default.yaml` uses + `x0 + uniform t + no condition dropout`, short captions, LoRA on attn+mlp, + and full fine-tuning of the timestep/noise embedders, gen vision bridge, and + fm_head. +- **Official-alignment recipe is optional**: `configs/official_alignment.yaml` + keeps the public report knobs together for research ablations, but it is not + the safest first run for small style datasets. - **Upstream-format save**: load straight into `examples/t2i/inference.py` via `--lora_path`, or stack with the official 8-step LoRA. - **bf16 training, not 4/8-bit**. Earlier 4-bit nf4 LoRA training produced grid artefacts and limb collapse on the gen tower; switching the base to bf16 (with offload + static prefix-KV cache) eliminated both. +See the ablation write-up with training curves and sample grids: +[`docs/small_data_style_ablation.html`](docs/small_data_style_ablation.html) +or [`docs/small_data_style_ablation.pdf`](docs/small_data_style_ablation.pdf). + --- ## Hardware @@ -48,6 +58,29 @@ on the maintainer's 56-image hayateluc dataset at 2048². `bitsandbytes>=0.45` and `torch>=2.9` must be linked against your CUDA runtime. On RTX 5090 (sm_120) you'll likely need the cu128 torch wheel. +### A3B / MoE Status + +The trainer now has experimental target grammar for future +`SenseNova-U1-A3B-MoT` generation-side MoE LoRA work: + +```yaml +lora: + spec: "attn=r8a8;gen_moe_mlp=r8a8;gen_moe_router=r8a8" +``` + +This is a compatibility layer, not the main training path and not an end-to-end +A3B training claim. The stable release target remains `SenseNova-U1-8B-MoT`; +A3B training depends on public MoE runtime support that can instantiate the +`mlp_mot_gen.experts.*` modules. + +Before training, estimate MoE LoRA size from metadata only: + +```bash +python -m train_u1.scripts.inspect_lora_targets \ + --model path/to/A3B/config.json \ + --spec "attn=r8a8;gen_moe_mlp=r8a8;fm_head=r8a8" +``` + --- ## Install @@ -85,10 +118,10 @@ HF_HOME=$PWD/hf_cache python -m train_u1.scripts.install_modeling_into_snapshot └── … └── … ``` - Each `.txt` is a single-paragraph natural-language caption. Embed the - artist credit / style anchor inside the description naturally — don't - rely on a hard-coded trigger prepend (`style.trigger` in the YAML is - for backward compat only; the v18 recipe uses an empty trigger). + Each `.txt` is a single-paragraph natural-language caption. Put the style + or artist anchor in a stable way and keep `style.trigger` aligned with how + you will sample later. The default config prepends that trigger to every + caption. **Optional**: append a `...` reasoning label inside the same `.txt` after a `---think---` delimiter line: @@ -101,12 +134,11 @@ HF_HOME=$PWD/hf_cache python -m train_u1.scripts.install_modeling_into_snapshot 6. **Explicit Prompt:** ... ``` - When present, the trainer renders this into the prompt template's - `` window so train-time distribution matches inference - `--think-mode` (avoids prefix-distribution shift on long autoregressive - think). For batch generation of think labels see Agent B's prompt in - the v18 commit history; or write them yourself in the upstream - 6-section format. + Think labels are **ignored by default** because low-quality or highly + templated think text can dominate the prefix and hurt style binding. To use + them, set `data.use_think_labels: true` and evaluate with the same think + distribution at sample time. Do this only when your think labels are + curated and repeatable. **Parquet/arrow shards** (recommended for ≥ ~10k images, e.g. 1M scaling): @@ -192,6 +224,7 @@ data: data_dir: dataset/my_style cap_max_pixels: 4194304 # 2048² hard cap per image snap_bucket: true # snap to upstream bucket grid + use_think_labels: false # keep prefixes short by default # n_samples: 56 # cap dataset size; default = use everything style: @@ -199,13 +232,15 @@ style: prompt_template: official # 'official' (recommended) | 'plain' lora: - preset: default # = attn+mlp+fm_head, all r=64 a=64 - # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off;fm_head=r128a128" + preset: attn_mlp_no_head # attn+mlp LoRA; fm_head is full-FT below + # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off" dropout: 0.0 unfreeze: # full-FT (non-LoRA) regex patterns - '^fm_modules\.timestep_embedder\.' - '^fm_modules\.noise_scale_embedder\.' + - '^fm_modules\.vision_model_mot_gen\.' + - '^fm_modules\.fm_head\.' train: steps: 6000 @@ -214,6 +249,15 @@ train: shuffle: true grad_accum: 1 checkpoint_every: 600 + # Small-data style baseline. See docs/small_data_style_ablation.html before + # switching to the official-alignment recipe. + loss_type: x0 + t_dist: uniform + t_logit_mean: -0.8 + t_logit_std: 0.8 + # huber_delta: 1.0 # only used for *_huber + cond_dropout_text: 0.0 + cond_dropout_both: 0.0 runtime: keep_kvs_on_gpu: true @@ -259,30 +303,41 @@ q_proj_mot_gen=r=128,a=64;k_proj_mot_gen=r=64,a=64 # asymmetric ranks | Preset | Coverage | Trainable LoRA params | Use when | |---|---|---|---| -| `default` | 168 attn + 126 mlp + 2 fm_head, all r=64 | ~75 M | first try / production | +| `default` | 168 attn + 126 mlp + 2 fm_head, all r=64 | ~75 M | match upstream 8-step LoRA coverage | | `attn_only` | 168 attn, r=64 | ~50 M | ablation | | `attn_mlp` | attn + mlp (no fm_head), r=64 | ~75 M | when fm_head is full-FT'd separately | +| `attn_only_no_head` | alias for `attn_only`; explicit no-fm_head intent | ~50 M | conservative small-data style training | +| `attn_mlp_no_head` | alias for `attn_mlp`; explicit no-fm_head intent | ~75 M | conservative small-data style training | | `official_r128` | exact upstream shape (r=128 across all 296 wraps) | ~298 M | parameter-matching upstream's 8-step LoRA | +The shipped `configs/default.yaml` uses `attn_mlp_no_head` and full-FTs +`fm_head` separately because that was the most stable small-data baseline in +our ablations. The `default` preset name inside the LoRA parser still means +"match upstream 8-step LoRA coverage"; use it only when that exact module +coverage is what you want. For report-alignment research, start from +`configs/official_alignment.yaml`. + --- ## Stack with the official 8-step distill LoRA Upstream released a step-distillation LoRA that brings inference down to 8 -NFE at `cfg_scale=1.0`. You can train your own style LoRA **on top** of it. +NFE at `cfg_scale=1.0`. You can train your own style LoRA **on top** of it +by setting `runtime.upstream_lora_path` in your YAML — at training time we +bake-in the official 8-step delta into the bf16 base (skipping `fm_head` so +we don't clobber our own fm_head LoRA), then wrap our LoRA on top. ```yaml -# configs/stack_8step.yaml (already in this repo) runtime: upstream_lora_path: hf_cache/.../SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors upstream_lora_skip: ['fm_modules.fm_head'] # don't clobber our fm_head LoRA ``` -At sample time, also pass the same upstream LoRA: +At sample time, pass the same upstream LoRA and use 8 steps at cfg=1.0: ```bash -./sample.sh configs/stack_8step.yaml \ - artifacts/my_style_8step/trainable_state.safetensors \ +./sample.sh configs/my_style.yaml \ + artifacts/my_style/trainable_state.safetensors \ --prompt "…" \ --upstream-lora-path SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors \ --upstream-lora-skip fm_modules.fm_head \ @@ -302,9 +357,8 @@ At sample time, also pass the same upstream LoRA: ├── pyproject.toml # package metadata ├── LICENSE # Apache-2.0 ├── configs/ -│ ├── default.yaml # opinionated starting point -│ ├── v16c.yaml # production recipe (LoRA + ts/ns/vision/fm_head full-FT) -│ └── stack_8step.yaml # train on top of 8-step distill LoRA +│ ├── default.yaml # recommended small-data style baseline +│ └── official_alignment.yaml # optional report-alignment research config ├── train_u1/ # importable package │ ├── config.py # YAML config schema │ ├── constants.py # pinned MODEL_SHA / CODE_COMMIT / arch constants @@ -314,7 +368,7 @@ At sample time, also pass the same upstream LoRA: │ │ ├── lora_io.py # save/load + upstream merge │ │ ├── loader.py # bf16 base load + tower offload │ │ ├── wrapper.py # forward_t2i_step -│ │ ├── losses.py # fm_loss_x0 +│ │ ├── losses.py # fm_loss_x0 / fm_loss_v / fm_loss dispatcher │ │ ├── patching.py # patchify/unpatchify │ │ └── … │ ├── scripts/ @@ -325,7 +379,10 @@ At sample time, also pass the same upstream LoRA: │ │ └── install_modeling_into_snapshot.py │ └── tests/ ├── docs/ -│ └── SETUP.md # data layout, design rationale, pinned-upstream details +│ ├── SETUP.md # data layout, design rationale, pinned-upstream details +│ ├── small_data_style_ablation.html +│ ├── small_data_style_ablation.pdf +│ └── assets/ # figures used by the ablation document ├── artifacts/ # local-only: checkpoints + sweeps (gitignored) ├── dataset/ # local-only: image+caption pairs (gitignored) ├── hf_cache/ # local-only: HF snapshot (gitignored) @@ -346,4 +403,4 @@ At sample time, also pass the same upstream LoRA: release; consumed via the `upstream_lora_path` mechanism. - **This trainer** is licensed under Apache-2.0 (see `LICENSE`). -**Thanks to comfy.org for the GPU power support. The open-source community will not forget.** \ No newline at end of file +**Thanks to comfy.org for the GPU power support. The open-source community will not forget.** diff --git a/configs/default.yaml b/configs/default.yaml index fdfa24e8532032529b0ea6dd188990f857ea7a37..7ea84d7d450788e3620487ed4a418ff8388257f5 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,9 +1,11 @@ # SenseNova-U1 LoRA trainer — default config # -# This is the recommended starting point. Trains a LoRA whose module -# coverage matches the official 8-step distill LoRA (168 attn + 126 mlp + -# 2 fm_head = 296 wraps), but at rank 64 instead of upstream's rank 128. -# Roughly 75 M trainable LoRA params, ~750 MB on disk in fp32. +# This is the recommended starting point for small/medium style training. +# It follows the local small-data baseline: x0 loss + uniform t + no train-time +# condition dropout, explicit style trigger, LoRA on attn+mlp, and full-FT +# of ts/ns/vision/fm_head. Public official-consistency knobs are available +# in configs/official_alignment.yaml for research ablations, but they are +# not the default because the ablation study showed worse small-data sampling. # # Edit `data.data_dir`, `style.trigger`, and `run_name` and you should be # good to go for a 32 GB single-GPU run. @@ -17,6 +19,9 @@ data: cap_max_pixels: 4194304 # Snap each image to nearest official bucket so train shape == infer shape. snap_bucket: true + # Keep the prefix short by default. Set true only when your think labels are + # curated and your sampling path uses the same think distribution. + use_think_labels: false # n_samples: cap on dataset size (omit / null = use entire data_dir) style: @@ -24,20 +29,21 @@ style: prompt_template: official # 'official' (recommended) | 'plain' lora: - # Named preset: 'default' = official coverage at rank 64. - # Other presets: 'attn_only', 'attn_mlp', 'official_r128'. - preset: default + # Small-data baseline: LoRA on attn+mlp only. fm_head is full-FT'd below. + # 'default' remains available when you explicitly want upstream 8-step + # coverage including fm_head LoRA. + # A3B/MoE presets are experimental compatibility helpers, not the main path. + preset: attn_mlp_no_head # Or override per-target: - # spec: "attn=r64a64;mlp=r64a64;fm_head=r64a64;mlp_mot_gen.down_proj=off" + # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off" dropout: 0.0 unfreeze: - # Full-finetune (non-LoRA) regex patterns. Empty = LoRA-only training. - # Examples for our v16c-style recipe: - # - '^fm_modules\.timestep_embedder\.' - # - '^fm_modules\.noise_scale_embedder\.' - # - '^fm_modules\.vision_model_mot_gen\.' - [] + # Full-finetune (non-LoRA) regex patterns for the small-data baseline. + - '^fm_modules\.timestep_embedder\.' + - '^fm_modules\.noise_scale_embedder\.' + - '^fm_modules\.vision_model_mot_gen\.' + - '^fm_modules\.fm_head\.' train: steps: 6000 @@ -47,6 +53,17 @@ train: grad_accum: 1 checkpoint_every: 600 # 0 = only save final state # checkpoint_dir: null # default: artifacts/{run_name}/checkpoints + # Local small-data baseline. See docs/small_data_style_ablation.html before changing these: + # official-style v-loss/logit-normal/dropout is an optional ablation path, + # not the small-data default. + loss_type: x0 # 'x0' | 'v' | 'x0_huber' | 'v_huber' + t_dist: uniform # 'uniform' | 'logit_normal' + t_logit_mean: -0.8 # only used for logit_normal + t_logit_std: 0.8 # only used for logit_normal + # huber_delta: 1.0 # only used for *_huber + # CFG / condition dropout. Keep off by default for small style datasets. + cond_dropout_text: 0.0 + cond_dropout_both: 0.0 runtime: keep_kvs_on_gpu: true # 56 samples × ~59 MB = ~3.3 GB; keep on GPU for speed diff --git a/configs/official_alignment.yaml b/configs/official_alignment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85be23880d6f4536cfa8ed3fc83d7f4f443c52a6 --- /dev/null +++ b/configs/official_alignment.yaml @@ -0,0 +1,57 @@ +# Optional official-consistency research config. +# +# This keeps the public SenseNova-U1 report knobs together: +# - v-loss +# - logit-normal t sampler, mean=-0.8 std=0.8 +# - train-time condition dropout, text=0.10 and both=0.10 +# - think labels enabled when the dataset provides them +# +# Do not treat this as the small-data style-training default. The ablation +# study showed weaker sampling than the default baseline on the 56-image +# Hayateluc task. Use this config when you explicitly want to study report +# alignment, not when you want the safest first run. + +run_name: official_alignment + +data: + data_dir: dataset/my_style + cap_max_pixels: 4194304 + snap_bucket: true + use_think_labels: true + +style: + trigger: "" # prefer captions that carry their own style anchor + prompt_template: official + +lora: + # Matches upstream 8-step LoRA coverage at r=64. + preset: default + dropout: 0.0 + +unfreeze: + - '^fm_modules\.timestep_embedder\.' + - '^fm_modules\.noise_scale_embedder\.' + - '^fm_modules\.vision_model_mot_gen\.' + - '^fm_modules\.fm_head\.' + +train: + steps: 6000 + lr: 5.0e-5 + seed: 0 + shuffle: true + grad_accum: 1 + checkpoint_every: 600 + loss_type: v + t_dist: logit_normal + t_logit_mean: -0.8 + t_logit_std: 0.8 + cond_dropout_text: 0.10 + cond_dropout_both: 0.10 + +runtime: + # Long think prefixes and unconditional-prefix precompute can push VRAM up. + # Offload KVs and checkpoint all layers unless you have confirmed headroom. + keep_kvs_on_gpu: false + gc_skip_last: 0 + device: cuda + cpu_device: cpu diff --git a/docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg b/docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg new file mode 100644 index 0000000000000000000000000000000000000000..90206584e766cd13fca9fda45f6aa12f66a50750 --- /dev/null +++ b/docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad42ff152cf0b9b60c5fad6c1248477268b9b0fcf982e4d2c6fa8c4e4ad84e24 +size 203053 diff --git a/docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg b/docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a7346f4027bb88f62920bb1df1e16252d6988105 --- /dev/null +++ b/docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8865dc230203718b62f2738db1dee0526146be58b2e0c37cec7d8ec1490011c +size 130937 diff --git a/docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg b/docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg new file mode 100644 index 0000000000000000000000000000000000000000..211f1fe2682e1b118648b55c7453032dc872756e --- /dev/null +++ b/docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43e8cd49b524ca6d63b8bd5ab3200bce8def0f5ac0bee8cb4062b061cc7efce4 +size 199443 diff --git a/docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg b/docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad3dae510c3c1cb22303d5f8f553425f20872e41 --- /dev/null +++ b/docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0560e2bed21e4f87c36642c284226987756d9982690df4dd30d941856a6555a +size 128918 diff --git a/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9d3d485bd4894cd4cc5d7bdbd040fb419e0a4bc8 --- /dev/null +++ b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bac586857be3f55de3474ede157710b41a601b2cc92957a82f46748fde99437 +size 443833 diff --git a/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ced68a08ceaec9f20a89784e398b4b61a8ffda0 --- /dev/null +++ b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c57298303e7bee189eb88ca23700cabd19309f0bf1b34f5004503dfe551958c +size 112023 diff --git a/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e41c13206e763ad8d5eadce61de3e446041f0034 --- /dev/null +++ b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:579a2dd4e839b515beaf3c8519a537ebe18234a193719445959e2e09953bde87 +size 129075 diff --git a/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26475c8a2efa0c12c8679c5899576fe50765c250 --- /dev/null +++ b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9744f109e7b7ec6686954e0c79e82c18a0ac3007148f38760bcdeb57b058e4c3 +size 124917 diff --git a/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ee35b3fe6fc1e6f89d5c6bf0daae797e815ad6c3 --- /dev/null +++ b/docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7535ce9e3a7a11d91098b8ab1a33a48a404711f9e8826341cfc009836a7e808a +size 107888 diff --git a/docs/assets/small_data_style_ablation/detail00_v18.jpg b/docs/assets/small_data_style_ablation/detail00_v18.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a1981194e6c08424814b187b325f25ca6dc35d0b --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail00_v18.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d170e9029386c69037819b3d7f1612ec653b2f645cee4ee6ac4219f4abed073a +size 554475 diff --git a/docs/assets/small_data_style_ablation/detail00_v19.jpg b/docs/assets/small_data_style_ablation/detail00_v19.jpg new file mode 100644 index 0000000000000000000000000000000000000000..850ce38d537348dbf5a1c0abeeb69c2340647d57 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail00_v19.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d9421b5a44907af560906ca5cb28d54536308229eb21307e73a917145304b48 +size 445764 diff --git a/docs/assets/small_data_style_ablation/detail00_v19a.jpg b/docs/assets/small_data_style_ablation/detail00_v19a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6dbc855a63511602e45c51e08fdb604100356df --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail00_v19a.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5202b4f7774bde04f8e4ef589c8c925f44a9875f8b8bc5964a4d49c047f0562 +size 459699 diff --git a/docs/assets/small_data_style_ablation/detail00_v19b.jpg b/docs/assets/small_data_style_ablation/detail00_v19b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6a07c0274e6c3fba244f1674ea8f587b7325778 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail00_v19b.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d5406077a9074310ff30bd98242b9f3af4e15ba5f2f992961583129bb73c167 +size 551653 diff --git a/docs/assets/small_data_style_ablation/detail00_v19c.jpg b/docs/assets/small_data_style_ablation/detail00_v19c.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ee095cd98bef75afeb83f92017e56c538217280b --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail00_v19c.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2e38bfc550d83bfd0f0ca34ade521630b68499561997cfa144d2dc874f6a650 +size 461099 diff --git a/docs/assets/small_data_style_ablation/detail01_v18.jpg b/docs/assets/small_data_style_ablation/detail01_v18.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25a180032301872d6124d6426b0a5239a075bebd --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail01_v18.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09e26f8cfd7a15b7983c45e8121de43bd2018921206f9b149024900224eeebee +size 227997 diff --git a/docs/assets/small_data_style_ablation/detail01_v19.jpg b/docs/assets/small_data_style_ablation/detail01_v19.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad7a5d328439a2b9de62a0615187456647ac7b92 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail01_v19.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa0147b231dd67646980751cdf3b17c25f74b32b975b757cae535be79dbbb7c2 +size 144362 diff --git a/docs/assets/small_data_style_ablation/detail01_v19a.jpg b/docs/assets/small_data_style_ablation/detail01_v19a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..995afa5498230197b50f7fd0cb79cc71d8f5b2f5 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail01_v19a.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da197305288ed7b139e1fbe2d16dbd91f6a44db1d8c83fa16dc25e8fc0f1d1d5 +size 212424 diff --git a/docs/assets/small_data_style_ablation/detail01_v19b.jpg b/docs/assets/small_data_style_ablation/detail01_v19b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b45fa0a934ac56abc88329620da4510bfe60d60 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail01_v19b.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a25ee998774d952c30f500488eb7c502d0ec1c86580d8b4c28e84ab84f32427c +size 192881 diff --git a/docs/assets/small_data_style_ablation/detail02_v18.jpg b/docs/assets/small_data_style_ablation/detail02_v18.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb050e0dbf6e7bdb1b3bf0cf9e49edf882118ea4 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail02_v18.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6a91aff5fbe8485c8aee5551574e96bfdc0135821406191dda4bc67507e00a2 +size 359106 diff --git a/docs/assets/small_data_style_ablation/detail02_v19.jpg b/docs/assets/small_data_style_ablation/detail02_v19.jpg new file mode 100644 index 0000000000000000000000000000000000000000..190c0b6c656ec50ecde3e1874a86ef55b71b9849 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail02_v19.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31a20bc3c7cd598240aceb6e5ef86e76931da4f8e30ccc5aa3398fe8d0c05715 +size 288446 diff --git a/docs/assets/small_data_style_ablation/detail02_v19a.jpg b/docs/assets/small_data_style_ablation/detail02_v19a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..699f6ddc31a0ebd96941a8d5c3ac3c606fe8e465 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail02_v19a.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98ff82f6d3e5f03a8a05f74048ecbdf3f9f86d2bd748b7f3891871a39b61e8bd +size 289487 diff --git a/docs/assets/small_data_style_ablation/detail02_v19b.jpg b/docs/assets/small_data_style_ablation/detail02_v19b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..696acda5c28af7bdb733257a7497510be87e027b --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail02_v19b.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8de9b562cb11b3fd5ab77e127bc0ebb0f9b157d4a9e9f1461cdadbdc5bd1855a +size 344034 diff --git a/docs/assets/small_data_style_ablation/detail02_v19c.jpg b/docs/assets/small_data_style_ablation/detail02_v19c.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f26c8b07a11cdb9f2a989cd2e5939696664a991f --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail02_v19c.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ea2874cf2939cf68327fcde2571c460c45facf9dbfd41b6996fd54c16c474a8 +size 318605 diff --git a/docs/assets/small_data_style_ablation/detail07_v18.jpg b/docs/assets/small_data_style_ablation/detail07_v18.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9cbfc94b0f3faed6085169d4ec44b66f4a0898d3 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail07_v18.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d32a35bc3a20486cbc9e88434889773a1263d5e0098b49a4bf23ab81ba8eee0f +size 465753 diff --git a/docs/assets/small_data_style_ablation/detail07_v19.jpg b/docs/assets/small_data_style_ablation/detail07_v19.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d95d01fe809ab6f1ac24185a50e68d1df31b8135 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail07_v19.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d5f3b96288f70bc5eb7d601cf58771a5416956dafe62969470a31e170839bd8 +size 458036 diff --git a/docs/assets/small_data_style_ablation/detail07_v19a.jpg b/docs/assets/small_data_style_ablation/detail07_v19a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..23ee27074561798b8b3d52025583f7d5eef0ff2a --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail07_v19a.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcca9a543c59209fd792465a93a2f6374fce778770d4b004526ffb244bb1f075 +size 586349 diff --git a/docs/assets/small_data_style_ablation/detail07_v19b.jpg b/docs/assets/small_data_style_ablation/detail07_v19b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e11080a8d90b7c8e8e5d9b3341ebcea2fcfab4ff --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail07_v19b.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c5c4dcd0b646817443f71479c8ffcd0430e8d596c844b56a7f49c9aff7d23dd +size 446102 diff --git a/docs/assets/small_data_style_ablation/detail07_v19c.jpg b/docs/assets/small_data_style_ablation/detail07_v19c.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c8790ddf40b8180a92b4874003009ad3d6a5f0a4 --- /dev/null +++ b/docs/assets/small_data_style_ablation/detail07_v19c.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3080f6442f76e1a522c6e776a738899562d918c004f5d45a676990441b08b57c +size 463662 diff --git a/docs/assets/small_data_style_ablation/tb_active_loss_curves.png b/docs/assets/small_data_style_ablation/tb_active_loss_curves.png new file mode 100644 index 0000000000000000000000000000000000000000..d3c47984a5ed4d4d4f4d7efcac099335a3ca89ae Binary files /dev/null and b/docs/assets/small_data_style_ablation/tb_active_loss_curves.png differ diff --git a/docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png b/docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png new file mode 100644 index 0000000000000000000000000000000000000000..352f5b070b7dc446bac6ddb0039379bee6d3c571 Binary files /dev/null and b/docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png differ diff --git a/docs/assets/small_data_style_ablation/tb_t_distribution.png b/docs/assets/small_data_style_ablation/tb_t_distribution.png new file mode 100644 index 0000000000000000000000000000000000000000..d0e891cc984f83dcf8e8a1a84c0d144a7a4e4971 Binary files /dev/null and b/docs/assets/small_data_style_ablation/tb_t_distribution.png differ diff --git a/docs/assets/small_data_style_ablation/tb_v_mse_curves.png b/docs/assets/small_data_style_ablation/tb_v_mse_curves.png new file mode 100644 index 0000000000000000000000000000000000000000..50dad12af2aca722f72d781838956efcdb29661b Binary files /dev/null and b/docs/assets/small_data_style_ablation/tb_v_mse_curves.png differ diff --git a/docs/assets/small_data_style_ablation/tb_x0_mse_curves.png b/docs/assets/small_data_style_ablation/tb_x0_mse_curves.png new file mode 100644 index 0000000000000000000000000000000000000000..83548b4852c1a8cc3faf0c14c34179b0395b6609 Binary files /dev/null and b/docs/assets/small_data_style_ablation/tb_x0_mse_curves.png differ diff --git a/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..66aee3c6b6f54d1e3164205e6e2bf690d47614fd --- /dev/null +++ b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd484d8b36543be0ce1f3b348f28fb7fe4a1d21f750c412377f1333296442fc4 +size 854885 diff --git a/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c816f7e1b9994d4a128285ffb6474c801730223 --- /dev/null +++ b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5441764becf0a4154547831c3e6bf81fee014e0fccb84ae09c40d2992e3ecb8 +size 295112 diff --git a/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c1f0516021a6e0c6ef62e9e1cf42477ac695a965 --- /dev/null +++ b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47876dcf4cf7ec68604f7f0e229d72d350ac781f2fc5e19e3453688e523b13ee +size 308151 diff --git a/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6a31edb9d8426136833cfb5ecdac0fe64b5f524 --- /dev/null +++ b/docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8476298d8dd7c75a4ee97a916e38af85b69d489630554f12de94178c3bb93891 +size 315307 diff --git a/docs/small_data_style_ablation.html b/docs/small_data_style_ablation.html new file mode 100644 index 0000000000000000000000000000000000000000..065deba5af580595179159b2cdd925361efe5298 --- /dev/null +++ b/docs/small_data_style_ablation.html @@ -0,0 +1,941 @@ + + + + + 对 SenseNova-U1 官方训练配置在小数据风格 LoRA 场景的消融研究 + + + + +
+

对 SenseNova-U1 官方训练配置在小数据风格 LoRA 场景的消融研究

+

+ 内部技术报告 · 2026 年 5 月 · train_u1 工具链 +

+

+ 针对 SenseNova-U1-8B-MoT 基模的 LoRA / Partial Fine-tuning 训练器维护团队 +

+ +
+

Abstract

+

+ SenseNova-U1 公开技术报告披露的 text-to-image 流匹配训练在三个关键算法层 + 上做出特定选择:以速度 (velocity) 为目标的损失函数、logit-normal 时间步采样、 + 以及 condition dropout 形式的无分类器引导校准。在大规模预训练或后训阶段, + 这些选择共同支撑了模型最终行为。但对小数据风格 LoRA 微调而言,是否应当继承 + 这一组合并未由现有公开文献回答。本文以 56 张 Hayateluc 风格图像与 8B-MoT 基模 + 为受控环境,对三个官方算法层逐一做单变量消融,并与 v18 本地基线 (x0-MSE + 损失、uniform 时间步、零 dropout) 进行对比,同时纳入 prompt 前缀格式作为额外 + 受控变量。所有训练运行 6000 步,并在统一 prompt 集上做多步采样以做视觉评估。 + 结果表明,逐项替换为官方设置后,训练 loss 不发散,但完整采样图像呈 + 现可观察的图像重建质量劣化:(i) velocity 损失导致天空与大面积渐变上 + 规则横向条纹 (banding) 与 photoreal 化;(ii) logit-normal 时间步导致大面积平 + 滑色块上高频 speckle 噪点与色阶离散化;(iii) 两者组合时进一步出现远景元素缺 + 失、构图被压缩的结构扭曲 (structural distortion),已超出 "style drift" 范畴而 + 属重建质量缺陷。condition dropout 是三项中唯一未引入可见劣化的项,反而轻 + 微改善图像结构质量。基于该结论,我们将仓库的发布默认值保持为本地基线,并 + 把官方对齐组合作为可选的 ablation 配置发布,供研究复现使用。 +

+

+ Keywords — + 流匹配 (flow matching) · + LoRA 微调 · + 消融研究 · + classifier-free guidance · + 风格迁移 · + SenseNova-U1 +

+
+
+ +
+ +

1. 引言

+

+ SenseNova-U1 公开报告 [1] 在 text-to-image 流匹配训练中采用三个具体的算法层: + 速度损失 (velocity loss)、logit-normal 时间步采样、以及在条件 prefix 上的 dropout + 作为 classifier-free guidance (CFG) 的校准信号。这一组合在该报告所对应的大规模 + 数据与多阶段训练管线下是合理且互相支撑的设计选择。 +

+

+ 与之相对,在小数据风格 LoRA 微调场景下,训练目标的属性发生了实质变化:训练数据 + 规模通常在 101–102 张量级,目的是从基模注入一个相对集中的 + 视觉风格 fingerprint,而非更新一个完整的世界模型。这种情况下,将官方训练配置整 + 体迁移过来是否仍最优,并不显然。 +

+

+ 本文以 SenseNova-U1-8B-MoT 为基模,在 56 张 Hayateluc 风格图像数据上做受控消 + 融。具体贡献为三点: +

+

+ (i) 提出一组单变量消融配置,分别替换官方三个算法层中的一个,并 + 与 v18 本地基线对照,使得每对差异可被归因到单一 lever。 + (ii) 指出 prompt 前缀格式 (训练时 caption 中风格 anchor 的位置) + 是一个独立而强的混淆变量;在没有控制该变量前,单纯的 lever 替换会得到误导性的视 + 觉差异。 + (iii) 基于训练标量曲线与多步采样视觉评估,识别每个官方 lever + 引入的具体图像重建质量缺陷:velocity 损失→横纹 + photoreal 化、logit-normal + 时间步→噪点 + 色阶离散化、两者叠加→结构扭曲。condition dropout 是唯一对小数据 + 风格 LoRA 无损甚至有益的官方算法层。 +

+ +

2. 背景与符号

+ +

2.1 流匹配训练目标

+

+ 设 x0 为目标图像 patch,ε 为各向同 + 性高斯噪声,时间步 t ∈ [tε, + 1 − tε]。本仓库与上游推理共用如下线性插 + 值约定 (linear-z schedule,t 越接近 1 越接 + 近干净图): +

+
+ + zt = + tx0 + + (1−t) ε, + +    + + v = + + x0zt + 1 − t + . + +  (1) +
+

+ 在 x0-MSE 损失下,训练目标为 + Lx0 = + 𝔼 ‖xθ(zt,t) + − x02; + 在 velocity 损失下,Lv = 𝔼 + ‖vθ(zt,t) + − v2。代入 + v = (x0zt)/(1 − t) + 与 vθ = (xθzt)/(1 − t),可得二者关系为 +

+
+ + Lv = + 𝔼[  + + xθx02 + (1 − t)2 + +  ]. + +  (2) +
+

+ 即 velocity 损失等价于以 (1 − t)−2 重新加权的 + x0-MSE。该权重在 t → 1 (近 clean 端) 发散,因此训 + 练梯度的有效分布严重偏向高 t 区间。 +

+ +

2.2 时间步采样分布

+

+ uniform 采样取 tU(tε, + 1 − tε)。logit-normal 采样定义为 + uN(μ, σ2), + t = σ(u),其中 + σ 为 sigmoid。本文 logit-normal 取 + μ = −0.8, σ = 0.8,对应 + 𝔼[t] ≈ 0.34,将概率质量偏向较低 t + (较 noisy 一侧)。 +

+ +

2.3 Condition dropout

+

+ 设每步训练以独立概率 ptext 将文本条件 + 替换为空 prompt 对应的 prefix KV;以独立概率 pboth + 走 “text + image” 全部 drop 的分支。在纯 T2I 场景下,这两个 drop 模式都会回退到 + 统一的 unconditional prefix。本文取 + ptext = pboth = 0.10, + 即约 20% 步使用无条件 prefix。 +

+ +

3. 实验配置

+ +

3.1 数据集与基模

+

+ 训练数据为 56 张 Hayateluc 风格自然语言 caption 配对图像,分布在 7 个 aspect-ratio + bucket 上 (最大像素数 ≤ 20482)。基模为 SenseNova-U1-8B-MoT [1],加载方式为 + bf16 CPU 驻留 + 静态前缀 KV cache 的低显存 LoRA 训练,单卡 32 GB 峰值约 21 GB。 +

+ +

3.2 可训练面

+

+ 所有实验共享如下可训练面,保证不同运行间差异仅来自表 1 列出的三个 lever。 + 表 2 给出训练面分层;表 3 给出训练超参。 +

+ +
+

表 2. 训练面分层。LoRA / partial FT / frozen + 三类合计 286M 可训练参数。所有运行共享此结构。

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
类别覆盖模块参数量备注
LoRA wrap + 注意力:q_proj_mot_gen, k_proj_mot_gen, + v_proj_mot_gen, o_proj_mot_gen
+ MLP:mlp_mot_gen.{gate,up,down}_proj +
~204M共 294 wrap;r = 64, α = 64
Partial fine-tune + fm_modules.timestep_embedder,
+ fm_modules.noise_scale_embedder,
+ fm_modules.vision_model_mot_gen,
+ fm_modules.fm_head +
~82Mfm_modules 子树
FrozenUnderstand path 全部模块及其它未列出的所有权重不更新
+
+ +
+

表 3. 训练超参。所有运行共享。

+ + + + + + + + + + + + + + + + + + + + + +
训练步数6000学习率5 × 10−5
优化器PagedAdamW8bitBatch size1 (native resolution)
Gradient accumulation1Seed固定 (cross-run)
+
+ +

3.3 受控变量

+

+ 本文设五组运行做单变量消融,其中四组训练运行共享数据与可训练面 (表 1)。 +

+ +
+

表 1. 五组消融运行的算法配置。baseline + Baseline 为本仓库默认;运行 (a)–(d) 对应官方算法层逐项与全部替换。

+ + + + + + + + + + + + + + + + + +
运行Lt 分布ptext, pboth训练 prefix 注释
v18 baselinex0uniform0, 0含 think sidecar 长前缀
(a) +velocity lossvuniform0, 0含 think sidecar 长前缀
(b) +logit-normal tx0logit-normal0, 0含 think sidecar 长前缀
(c) +cond. dropoutx0uniform0.10, 0.10含 think sidecar 长前缀
(d) full official (a + b + c)vlogit-normal0.10, 0.10含 think sidecar 长前缀
+
+ +

3.4 评估协议

+

+ 训练时记录每步 active loss、x0-MSE、v-MSE、t 统计量、 + 以及 dropout 路由计数;121-step 滑动平均后绘制。采样评估使用一组固定的 12 条 + 自然风景 prompts,在统一的 7 个 bucket 分辨率下做 50 步 Euler、 + cfg_scale = 4.0、timestep_shift = 3.0。 + 我们额外引入 prompt 前缀格式变量:v1 prompts 保留与 baseline 训练分布一致的 + 简短前缀;v2 prompts 在每条句首嵌入 artist anchor,与含 think sidecar 的训练 + caption 分布对齐。 +

+ +

4. 训练动力学结果

+ +

+ 图 1–3 给出 x0-MSE、active loss 与 v-MSE 三条标量曲线。x0-MSE 是唯一可跨损 + 失类型公平比较的指标:它是所有运行共同执行的诊断量。可见运行 (b) 与 (d) 在 + x0-MSE 上系统性高于 Baseline,差距贯穿整个训练,并非过 + 渡期偏差。 +

+ +
+ x0 MSE curves +

+ 图 1.  x0-MSE 在五组运行中的演化 (121-step 滑动平均)。 + (b) 与 (d) 系统性偏高,表明把 t 密度搬向较 noisy 区间会降低 + 模型在更清晰图像状态下学习颜色与构图的机会。 +

+
+ +
+ active loss curves +

+ 图 2.  active loss 的绝对值不可跨目标比较 (velocity 损失与 + x0-MSE 单位不同)。该图仅用于检查每组运行内部是否存在有效下降趋势, + 所有运行均下降稳定,未出现训练发散。 +

+
+ +
+ v MSE curves +

+ 图 3.  v-MSE 作为诊断量在所有运行中都被计算。可观察到 v-MSE + 在 t → 1 端出现极端尖峰,与 §2.1 中给出的 + (1 − t)−2 加权一致。在 velocity 训练的运行 (a) 与 (d) + 中这些尖峰直接进入梯度,被高权重区间支配。 +

+
+ +
+
+ t distribution +

+ 图 4.  实测 t 分布。logit-normal 运行 (b) 与 + (d) 的经验均值约 0.34,与理论值 σ(−0.8) ≈ 0.31 一致;其余运行 + 约为 0.50。 +

+
+
+ dropout counts +

+ 图 5.  含 condition dropout 的运行 (c) 与 (d) 中每步走 + cond / uncond 分支的累计次数。最终约 21% 步走 unconditional prefix,与设定的 + ptext + pboth + = 0.20 在大样本统计上吻合。 +

+
+
+ +

5. 视觉评估结果

+ +

+ 训练标量只反映单步去噪行为,无法替代从纯噪声开始的多步采样。在固定 seed 与 + identical sampling 超参下,我们对每组运行生成 12 张 1024–2048 像素级风景图。 +

+ +

5.1 同 prompt 多配方对比

+ +

+ 图 6 为 forest / wildflower meadow / dandelion field 三个最能体现风格 fingerprint + 的 prompt 在五组运行 (含 Baseline) 下的 contact sheet。关键观察: +

+

+ (i) Baseline 与运行 (c) 在所有三个 prompt 上都保持暖橙地平线、 + 青蓝高空、清晰前景轮廓的视觉指纹。 + (ii) 运行 (a) 在 dandelion field 上系统性向 photoreal 夜景偏移: + cyan 天空消失、puffball 密度降低、暖金核被压暗。 + (iii) 运行 (b) 在所有 prompt 上整体 palette 偏冷,紫蓝山失去层 + 次。 + (iv) 运行 (d) 同时承袭 (a) 与 (b) 的劣化,并叠加 composition + cropping:dandelion 图中失去远景 forest mass 与天空。 +

+ +
+
visual same-prompt contact sheet (1/3)
+
visual same-prompt contact sheet (2/3)
+
visual same-prompt contact sheet (3/3) +

+ 图 6.  同 prompt × 五组运行的 contact sheet (上→中→下三页连读)。 + 每列为一组运行,每行为一条 prompt。列名直接对应表 1 的算法配置差异。 + 风格 fingerprint 由暖金地平线、青蓝高空、painterly 厚笔触三要素共同定义; + 运行 (a) 与 (d) 出现 atmospheric drift,运行 (b) 出现 palette 偏冷。 +

+
+
+ +

5.2 局部细节:图像重建质量劣化

+ +

+ §5.1 给出的是整图尺度上 palette / brushwork / 构图层面的偏移。但在原始 + 分辨率下查看局部,三个 v19 系列方案相比 v18 baseline 还呈现三类图像 + 重建质量层面的劣化,且每一类都与具体 lever 的训练分布偏置直接挂钩: +

+ +

+ 以下分别从三个 prompt 截取细节区域 (图 7、图 8、图 9)。每图以 v18 baseline 在最 + 上,按 (a) (b) (d) 顺序向下排列。 +

+ +

5.2.1 Sample 00 森林暗部:竖向 striation 格纹 (运行 (d) 主)

+

+ Sample 00 (dense old-growth forest interior at dawn) 的画面左下、左侧粗树干、 + 右侧暗部都是低光强、低 SNR 区域,是 v-loss + logit-normal 组合下竖向格纹伪 + 影最先暴露的位置。运行 (d) 的左侧粗树干表面与右下苔藓层均出现可见的等间距 + 竖线。 +

+ +
+
+
v18baseline
+ v18 forest bottom +
+
+
(a)+v-loss
+ v19a forest bottom +
+
+
(b)+logit-normal t
+ v19b forest bottom +
+
+
(c)+cond. dropout
+ v19c forest bottom +
+
+
(d)+full official
+ v19 forest bottom with vertical striations +
+

+ 图 7.  Sample 00 画面下半 (左粗树干 + 中间地面 + 右下苔藓 + + 右侧上层树叶) 五向对比。 + v18:painterly chunky 厚笔触树干 + 暖色秋叶碎片 + 多层 silhouette + 清晰可分; + (a):god-rays 与 mist 大幅放大主导画面, + painterly chunky 厚笔触被 atmospheric haze 稀释成更软的笔触;树干与 + 地面 silhouette 基本保留,主要问题是风格 softening 与雾化过度; + (b):painterly 厚笔触与树形保留较好,但 + 左下暗部 + 右下暗部仍有未完全恢复的轻度竖向条纹 + (强度远低于 (d));高光 god-rays 中段也出现轻度结构化痕迹; + (c):五行中重建质量最接近 v18 — painterly chunky 树干 + 完整地面 + 碎叶都保留,god-rays 强度略偏 (a) 但 painterly 信息完整,无格纹或条纹伪影; + (d):左侧粗树干表面 + 中景树干 + 右下苔藓暗部出现明显竖向 striation 格纹, + 暗部 banding 最严重,painterly 信息几乎完全缺失。 +

+
+ +

5.2.2 Sample 02 湖岸:树 silhouette 与水面倒影竖纹

+

+ Sample 02 (deep mountain lake at dawn) 的两侧针叶林 silhouette 与中央镜面倒影 + 是考察树形 painterly 是否退化、倒影是否出现竖纹的核心区域。运行 (d) 在两侧 + 树林区 + 水面倒影区均呈现 screen-door 竖向条纹。 +

+ +
+
+
v18baseline
+ v18 lakeside trees + reflection +
+
+
(a)+v-loss
+ v19a lakeside trees + reflection +
+
+
(b)+logit-normal t
+ v19b lakeside trees + reflection +
+
+
(c)+cond. dropout
+ v19c lakeside trees + reflection +
+
+
(d)+full official
+ v19 lakeside trees + reflection with vertical striations +
+

+ 图 8.  Sample 02 中段 (左右两侧 painted tree silhouette + 中央 + 水面倒影 + 远雾) 五向对比。 + v18:painterly chunky 树形 + 厚笔触倒影 + 远山三层 silhouette + 清晰可分; + (a):整图 photoreal 化,painterly 笔触退化为 stock-photo 树形, + 倒影保留 photoreal 镜面但失去厚笔触; + (b):painted 风格部分保留,但远雾 + 倒影出现微弱竖向条纹; + (c):painted 风格保留较好,树形+倒影皆有 painterly 触感,未观察 + 到明显伪影; + (d):左侧山体 + 右侧针叶林 silhouette + 整个倒影区出现 screen-door 竖向 striation 格纹, + 覆盖范围广。 +

+
+ +

5.2.3 Sample 07 蒲公英主体:grass 区竖向条纹 + bokeh 退化

+

+ Sample 07 (backlit dandelion field) 主体为前景蒲公英 + 草丛 + 背光,是考察 + painterly 笔触是否被替换为 photographic bokeh、草丛暗区是否出现条纹的核心 + 区域。 +

+ +
+
+
v18baseline
+ v18 dandelion main subject +
+
+
(a)+v-loss
+ v19a dandelion main subject +
+
+
(b)+logit-normal t
+ v19b dandelion main subject +
+
+
(c)+cond. dropout
+ v19c dandelion main subject +
+
+
(d)+full official
+ v19 dandelion main subject with cross-hatch grid +
+

+ 图 9.  Sample 07 主体区 (前景蒲公英 + 草丛暗部 + 暖背光) 五向 + 对比。 + v18:painterly chunky 笔触草叶 + painted 蒲公英 + 远景树线 + 可见; + (a):painterly 完全退化为 photographic bokeh, + 前景蒲公英变疏散小,背景暗化呈相机散景; + (b):painterly 较为完整 — chunky 草叶 + painted 蒲公英都还在, + 主要变化是视野相对收紧 (前景密度减少);草丛暗部存在轻度的竖向 streak,密度 + 低于 (d); + (c):painterly fingerprint 接近 v18 — 蒲公英密度高、chunky + 草叶 + 浮散种子全保留,无观察到格纹伪影; + (d):草丛暗部出现密集 cross-hatch 格纹 + (竖向 striation 为主,叠加微横向 banding),painterly 信息几乎完全缺失,整图 + 带 halftone 质感。 +

+
+ +

5.3 Prompt 前缀格式作为混淆变量

+ +

+ 我们在受控分析中发现,若 inference prompt 的 artist anchor 位置与训练 caption + 不一致,会引入与 lever 无关但视觉量级相当的差异。因此对 Baseline 与运行 (c) 额外 + 做两组 prompt 集对比,结果见图 10。该结果说明把 prompt 写法与训练分布对齐是 + 正确归因 lever 效果的必要前置条件,并解释了我们对 README 默认配置中保留 + 显式 style.trigger 与忽略 think sidecar 的选择。 +

+ +
+
baseline prompt-set contact sheet (1/4)
+
baseline prompt-set contact sheet (2/4)
+
baseline prompt-set contact sheet (3/4)
+
baseline prompt-set contact sheet (4/4) +

+ 图 10.  Baseline 在两组 prompt 集下的采样 (上→下四页连读)。 + 左:与训练同分布的短前缀 (含显式 style trigger);右:分布外的长 prompt。 + 完整采样在分布内稳定;这表明 Baseline 的强 fingerprint 不依赖额外 prefix 文本。 +

+
+
+ +

6. 讨论

+ +

6.1 Velocity 损失的隐式重加权

+

+ 由式 (2),Lv = + 𝔼 [‖xθx02 + · (1 − t)−2],velocity 损失等价于在 x0-MSE + 上施加一个 (1 − t)−2 的权重函数。在 uniform + t 下,该权重在 t ∈ [0.9, 1.0] 区间承担约 64% + 的总积分质量。意即在 velocity 训练中,模型有大半的梯度信号被分配到 "接近 clean + 图但仍残留少量噪声" 的窗口。对于风格 LoRA 而言,颜色调性、构图、笔触特征均在 + t ∈ [0.3, 0.7] 这一中频窗口形成;velocity 损失显著削减了该窗口的 + 有效更新量,因而出现了 §5.1 (ii) 报告的 photoreal 夜景偏移。 +

+ +

6.2 Logit-normal 与小数据匹配的张力

+

+ logit-normal (μ = −0.8, σ = 0.8) 将 + 𝔼[t] 从 0.5 移到 0.34,即把更多概率质量推向 noisy 端。该选择在 + 大规模数据下使得各 t 段都获得足够样本数;但在 56 张图、6000 + 步、batch=1 的小数据规模下,每个 mid-t bin 实际获得的有效样 + 本数本就稀疏,进一步的密度偏移导致 palette 与 mid-frequency texture 统计学习 + 不足,对应 §5.1 (iii) 报告的整体偏冷。 +

+ +

6.3 Condition dropout 的正面作用

+

+ 与上述两项不同,condition dropout 在小数据风格 LoRA 上未观察到风格损伤。其作 + 用机制是把 unconditional branch 与 conditional branch 的相对几何在训练期同时 + 暴露给模型,从而使 inference 时 CFG 的 uncond + s · (cond − uncond) + 推算成立。在 s = 4.0 这一相对较强的 CFG 下,没有 dropout 训练 + 的 LoRA 经常把输出推到 LoRA 未覆盖的方向,导致结构 artefact。运行 (c) 与 Baseline + 的视觉对比显示,前者在风格指纹保持不变的前提下,前景结构 (puffball 形状、 + 树干轮廓) 更清晰、稳定。 +

+ +

6.4 Prompt 前缀格式的去混淆作用

+

+ §5.3 显示 prompt 中 artist anchor 的位置 (句首 vs. 句末 tag vs. 完全缺失) 与训练 + caption 分布对齐与否直接决定 LoRA delta 是否被正确激活。这一变量与目标 lever + 正交,但视觉量级与 lever 效果相当;任何不控制该变量的 lever 评估都会被 + prompt-format mismatch 污染。本文 §5.1、§5.2 与 §5.3 结论均在 v2 anchored prompt + 下复测确认。 +

+ +

7. 结论与发布决策

+

+ 在 SenseNova-U1-8B-MoT × 56 张 Hayateluc 风格 LoRA 微调上: + (1) velocity 损失系统性向高 t 端搬移梯度密度,导致天空与大面积 + 渐变区域出现规则横向条纹 banding 与 painterly→photoreal 退化; + (2) logit-normal 时间步在小数据下加剧 mid-t 样本稀疏,引发 palette + 偏冷与大面积平滑色块的 speckle 噪点; + (3) (1) + (2) 叠加进一步引入远景元素缺失与构图压缩等结构扭曲,劣化超出 style + drift 范畴; + (4) condition dropout 是三项中唯一无损反而有益的算法层; + (5) prompt 前缀格式与训练分布的匹配是评估前置条件,必须先控制。 +

+

+ 据此,我们将 configs/default.yaml 保留为本地基线 + (x0 + uniform t + 0 dropout + 显式 style trigger + 忽略 think + sidecar),并以 configs/official_alignment.yaml 形式发布官方对齐组 + 合,标注其用途为研究复现。该选择不否定公开报告中的算法设计,只反映其与本仓 + 库目标 (小数据风格 LoRA) 的不匹配。 +

+ +

默认配置摘要

+
data:
+  use_think_labels: false
+
+style:
+  trigger: "my style"
+  prompt_template: official
+
+lora:
+  preset: attn_mlp_no_head
+
+unfreeze:
+  - '^fm_modules\.timestep_embedder\.'
+  - '^fm_modules\.noise_scale_embedder\.'
+  - '^fm_modules\.vision_model_mot_gen\.'
+  - '^fm_modules\.fm_head\.'
+
+train:
+  loss_type: x0
+  t_dist: uniform
+  cond_dropout_text: 0.0
+  cond_dropout_both: 0.0
+
+ +

参考文献

+
    +
  1. SenseNova-U1 Technical Report. OpenSenseNova, 2026. + https://github.com/OpenSenseNova/SenseNova-U1
  2. +
  3. Lipman, Y. et al. Flow Matching for Generative Modeling. ICLR, 2023.
  4. +
  5. Esser, P. et al. Scaling Rectified Flow Transformers for High-Resolution + Image Synthesis. ICML, 2024 — 提出 logit-normal t 采样.
  6. +
  7. Ho, J. & Salimans, T. Classifier-Free Diffusion Guidance. + NeurIPS Workshop on Deep Generative Models, 2021.
  8. +
  9. Hu, E. J. et al. LoRA: Low-Rank Adaptation of Large Language Models. + ICLR, 2022.
  10. +
+ +
+ + + diff --git a/docs/small_data_style_ablation.pdf b/docs/small_data_style_ablation.pdf new file mode 100644 index 0000000000000000000000000000000000000000..1d198a65e725e612c420e76fc24dfabef7cde4c6 --- /dev/null +++ b/docs/small_data_style_ablation.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:446f868e6c5fe4518cebe60baf5d5caed50ff0f1b0bad87b693265550b19daad +size 9161869 diff --git a/train.sh b/train.sh index 5a0a0f8e7b33fbfea6d214d4fd01dcb9fc05357e..2e73db31e4d309ebd0608b2db6fba52007d2844d 100755 --- a/train.sh +++ b/train.sh @@ -2,12 +2,12 @@ # train.sh — quick-launch a LoRA training run from a YAML config. # # Usage: -# ./train.sh # uses configs/default.yaml -# ./train.sh configs/v16c.yaml -# ./train.sh configs/v16c.yaml --steps 12000 # extra args forwarded +# ./train.sh # uses configs/default.yaml +# ./train.sh configs/my_style.yaml +# ./train.sh configs/my_style.yaml --steps 12000 # extra args forwarded # # Long-running training tip: -# setsid nohup ./train.sh configs/v16c.yaml run.log 2>&1 & +# setsid nohup ./train.sh configs/my_style.yaml run.log 2>&1 & # disown # (a bare `nohup &` can be SIGHUP'd when the SSH/IDE session disconnects). diff --git a/train_u1/README.md b/train_u1/README.md index 9aee6cde1825ba0f5613f82f03b464aed8dffc4a..24572dce6679e4daf5740cf0fdd4ed88473b306d 100644 --- a/train_u1/README.md +++ b/train_u1/README.md @@ -49,7 +49,10 @@ train_u1/ ``` **未实现 / pending**: -- `model/peft_targets.py` — LoRA target 工具(实验 C 时落地) +- `model/peft_targets.py` — 独立 LoRA target 工具仍未拆出;当前 + `model/lora.py` 已内置 8B dense target 与实验性 A3B generation MoE + target grammar(`gen_moe_mlp` / `gen_moe_router`)。A3B/MoE 不是当前 + 主训练路径,必须等可实例化 `mlp_mot_gen.experts.*` 的公开 runtime。 - `scripts/train_balanced.py` — 48GB 平衡场景(实验 C 后视效果决定) - `scripts/sample_t2i.py` 完整管线(依赖 wrapper 的 `t2i_generate` 钩子) diff --git a/train_u1/config.py b/train_u1/config.py index 83f647daf38708619b4315f3e9f91b018caa5d6a..eccd841e8fc007d1d1457eeace8fc0f0bb176e09 100644 --- a/train_u1/config.py +++ b/train_u1/config.py @@ -15,12 +15,15 @@ A single YAML file describes a complete LoRA training run. Example:: prompt_template: official # or 'plain' lora: - preset: default # = attn+mlp+fm_head all at r=64,a=64 - # spec: "attn=r64a64;mlp=r64a64;fm_head=r64a64;mlp_mot_gen.down_proj=off" + preset: attn_mlp_no_head # small-data baseline: LoRA attn+mlp only + # spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off" dropout: 0.0 unfreeze: # full-finetune (non-LoRA) regex patterns - [] # default: LoRA-only training + - '^fm_modules\\.timestep_embedder\\.' + - '^fm_modules\\.noise_scale_embedder\\.' + - '^fm_modules\\.vision_model_mot_gen\\.' + - '^fm_modules\\.fm_head\\.' train: steps: 6000 @@ -29,6 +32,10 @@ A single YAML file describes a complete LoRA training run. Example:: shuffle: true grad_accum: 1 checkpoint_every: 600 + loss_type: x0 + t_dist: uniform + cond_dropout_text: 0.0 + cond_dropout_both: 0.0 runtime: keep_kvs_on_gpu: true @@ -68,6 +75,7 @@ class DataConfig: snap_bucket: bool = True n_samples: int | None = None # default: use entire dataset sample_buckets_file: str | None = None + use_think_labels: bool = False @dataclass @@ -78,7 +86,7 @@ class StyleConfig: @dataclass class LoRAConfig: - preset: str | None = "default" # one of LORA_PRESETS + preset: str | None = "attn_mlp_no_head" # one of LORA_PRESETS spec: str | None = None # overrides preset if set dropout: float = 0.0 @@ -107,6 +115,25 @@ class TrainConfig: grad_accum: int = 1 checkpoint_every: int = 600 checkpoint_dir: str | None = None # default: artifacts/{run_name}/checkpoints + # FM loss objective. Default is the local small-data baseline (`x0`) because + # the ablation study showed that official-style v-loss is not a good + # small-data style-training default. `v` remains available for explicit + # official alignment experiments. + # Choose one of `x0` | `v` | `x0_huber` | `v_huber`. + loss_type: str = "x0" + huber_delta: float = 1.0 + # FM `t`-sampling distribution. Default is uniform for the same local + # baseline reason. `logit_normal` is kept for report-alignment ablations. + t_dist: str = "uniform" + t_logit_mean: float = -0.8 + t_logit_std: float = 0.8 + # CFG / condition dropout. `cond_dropout_text` drops text condition only; + # `cond_dropout_both` is the additional unconditional bucket from the + # report. In the current pure-T2I trainer there is no separate reference + # image condition, so both modes use the sampler's unconditional prompt + # prefix while preserving separate log labels. + cond_dropout_text: float = 0.0 + cond_dropout_both: float = 0.0 @dataclass @@ -119,13 +146,22 @@ class RuntimeConfig: upstream_lora_skip: tuple[str, ...] = () +def _default_unfreeze_patterns() -> list[str]: + return [ + r"^fm_modules\.timestep_embedder\.", + r"^fm_modules\.noise_scale_embedder\.", + r"^fm_modules\.vision_model_mot_gen\.", + r"^fm_modules\.fm_head\.", + ] + + @dataclass class TrainRunConfig: run_name: str = "my_run" data: DataConfig = field(default_factory=DataConfig) style: StyleConfig = field(default_factory=StyleConfig) lora: LoRAConfig = field(default_factory=LoRAConfig) - unfreeze: list[str] = field(default_factory=list) + unfreeze: list[str] = field(default_factory=_default_unfreeze_patterns) train: TrainConfig = field(default_factory=TrainConfig) runtime: RuntimeConfig = field(default_factory=RuntimeConfig) diff --git a/train_u1/data/collators.py b/train_u1/data/collators.py index cdf06c866fbc086a402eda389d64a09a39feca3f..4058a01160bec53048f14f42ebd05c56135a8eb0 100644 --- a/train_u1/data/collators.py +++ b/train_u1/data/collators.py @@ -21,8 +21,6 @@ from typing import Iterable import torch -import math - from train_u1.constants import ( FM_OUTPUT_DIM, NOISE_SCALE_BASE_IMAGE_SEQ_LEN, @@ -50,7 +48,23 @@ class CollatorConfig: # then batch=1). image_hw: tuple[int, int] | None = (512, 512) t_eps: float = T_EPS_DEFAULT - t_dist: str = "uniform" # uniform on (t_eps, 1] for MVP + # **Default = `logit_normal` (mean=-0.8, std=0.8)** to match SenseNova-U1 + # report Table 2: + # u ~ Normal(t_logit_mean, t_logit_std); t = sigmoid(u) + # clamped to [t_eps, 1 - t_eps]. Biases t toward the *low* end (near-clean); + # combined with v-loss this recovers the official training density. + # `uniform` on (t_eps, 1] is kept for back-compat / ablation. + t_dist: str = "logit_normal" + t_logit_mean: float = -0.8 + t_logit_std: float = 0.8 + # Classifier-free guidance condition dropout. Keep the collator default at + # zero so eval/smoke/diagnostic callers stay fully conditional unless they + # opt in. The official training entry point passes 0.10/0.10 from TrainConfig. + # For pure T2I training there is no separate reference-image condition, so + # `text_image` uses the same unconditional prompt path as `text` while + # recording the intended mode. + cond_dropout_text: float = 0.0 + cond_dropout_both: float = 0.0 add_noise_scale: bool = True # Base noise_scale value (config.noise_scale = 1.0). The *effective* per-sample # noise_scale is computed at collator runtime as @@ -80,6 +94,12 @@ class CollatorConfig: # Format used: f"{style_trigger}, {original_caption}". style_trigger: str = "" + def __post_init__(self) -> None: + if self.cond_dropout_text < 0 or self.cond_dropout_both < 0: + raise ValueError("condition dropout probabilities must be non-negative") + if self.cond_dropout_text + self.cond_dropout_both > 1.0: + raise ValueError("cond_dropout_text + cond_dropout_both must be <= 1.0") + class SenseNovaU1Collator: """Stateful collator: holds tokenizer + config, callable on a list of `T2ISample`. @@ -94,6 +114,7 @@ class SenseNovaU1Collator: self.tok = tokenizer self.cfg = cfg or CollatorConfig() self._gen = torch.Generator().manual_seed(self.cfg.seed) + self._cond_gen = torch.Generator().manual_seed(self.cfg.seed + 10_003) if self.cfg.prompt_template == "official": if model is None or not hasattr(model, "_build_t2i_query"): raise ValueError( @@ -147,10 +168,87 @@ class SenseNovaU1Collator: if self.cfg.t_dist == "uniform": t = torch.rand(batch_size, generator=self._gen) t = t * (1.0 - self.cfg.t_eps) + self.cfg.t_eps + elif self.cfg.t_dist == "logit_normal": + # u ~ N(mu, sigma) -> t = sigmoid(u) in (0, 1) + u = torch.randn(batch_size, generator=self._gen) + u = u * self.cfg.t_logit_std + self.cfg.t_logit_mean + t = torch.sigmoid(u) + t = t.clamp(min=self.cfg.t_eps, max=1.0 - self.cfg.t_eps) else: raise NotImplementedError(f"t_dist={self.cfg.t_dist}") return t + def _sample_condition_modes(self, batch_size: int) -> list[str]: + p_text = float(self.cfg.cond_dropout_text) + p_both = float(self.cfg.cond_dropout_both) + if p_text == 0.0 and p_both == 0.0: + return ["none"] * batch_size + u = torch.rand(batch_size, generator=self._cond_gen) + modes: list[str] = [] + for v in u.tolist(): + if v < p_text: + modes.append("text") + elif v < p_text + p_both: + modes.append("text_image") + else: + modes.append("none") + return modes + + @staticmethod + def _prefix_cache_key(mode: str) -> str: + if mode == "none": + return "cond" + if mode in ("text", "text_image"): + return "uncond" + raise ValueError(f"unknown condition dropout mode {mode!r}") + + def _render_prompts( + self, + samples: list[T2ISample], + condition_modes: list[str], + ) -> list[str]: + prompts: list[str] = [] + for s, mode in zip(samples, condition_modes): + if mode not in ("none", "text", "text_image"): + raise ValueError(f"unknown condition dropout mode {mode!r}") + drop_text = mode in ("text", "text_image") + if drop_text: + raw_prompt = "" + elif self.cfg.style_trigger: + raw_prompt = f"{self.cfg.style_trigger}, {s.prompt}" + else: + raw_prompt = s.prompt + + if self._build_t2i_query is not None: + if drop_text: + # Match the sampler's unconditional CFG prefix exactly: + # `_build_t2i_query("", append_text="")`. + prompts.append(self._build_t2i_query("", append_text="")) + continue + # Per-sample think injection: when the dataset supplies a + # `think` text, render it INSIDE the otherwise-empty + # `` block of the official prompt template. + # This makes training distribution match inference-time + # `--think-mode`, where the model autoregressively fills the + # same window with ~250-400 reasoning tokens. Without this, + # the gen tower sees an unfamiliar prefix length/content + # at inference and the LoRA delta is calibrated against + # the wrong cond-KV distribution. + if s.think: + append_text = f"\n{s.think}\n\n\n" + else: + append_text = self._gen_append + prompts.append( + self._build_t2i_query( + raw_prompt, + system_message=self._sys_msg_for_gen, + append_text=append_text, + ) + ) + else: + prompts.append(" " if drop_text else raw_prompt) + return prompts + @staticmethod def _check_image_hw(image_hw: tuple[int, int]) -> None: H, W = image_hw @@ -163,7 +261,12 @@ class SenseNovaU1Collator: # ------------------------------------------------------------------ # # Main entry # # ------------------------------------------------------------------ # - def __call__(self, samples: list[T2ISample]) -> dict[str, torch.Tensor]: + def __call__( + self, + samples: list[T2ISample], + *, + condition_modes: list[str] | None = None, + ) -> dict[str, torch.Tensor]: cfg = self.cfg if cfg.enforce_batch_one and len(samples) != 1: raise ValueError( @@ -196,38 +299,16 @@ class SenseNovaU1Collator: # 1) text → ids + per-sample lengths. With enforce_batch_one we know # `len(samples) == 1` so no batch padding is applied — `L_text` is # exactly this prompt's length (matches upstream `_build_t2i_text_inputs`). - # Apply style trigger BEFORE chat-template wrap so the trigger lives - # inside the user-message portion of the chat (not in system or - # assistant). Identical formatting must be replicated at sample time. - if cfg.style_trigger: - raw_prompts = [f"{cfg.style_trigger}, {s.prompt}" for s in samples] + if condition_modes is None: + condition_modes = self._sample_condition_modes(len(samples)) else: - raw_prompts = [s.prompt for s in samples] - if self._build_t2i_query is not None: - prompts = [] - for rp, s in zip(raw_prompts, samples): - # Per-sample think injection: when the dataset supplies a - # `think` text, render it INSIDE the otherwise-empty - # `` block of the official prompt template. - # This makes training distribution match inference-time - # `--think-mode`, where the model autoregressively fills the - # same window with ~250-400 reasoning tokens. Without this, - # the gen tower sees an unfamiliar prefix length/content - # at inference and the LoRA delta is calibrated against - # the wrong cond-KV distribution. - if s.think: - append_text = f"\n{s.think}\n\n\n" - else: - append_text = self._gen_append - prompts.append( - self._build_t2i_query( - rp, - system_message=self._sys_msg_for_gen, - append_text=append_text, - ) + condition_modes = list(condition_modes) + if len(condition_modes) != len(samples): + raise ValueError( + f"condition_modes length {len(condition_modes)} != samples length {len(samples)}" ) - else: - prompts = list(raw_prompts) + prompts = self._render_prompts(samples, condition_modes) + prefix_cache_keys = [self._prefix_cache_key(m) for m in condition_modes] input_ids, text_lens = self._tokenize(prompts) B = input_ids.shape[0] L_text = input_ids.shape[1] @@ -298,6 +379,12 @@ class SenseNovaU1Collator: "noisy_pixel_values": noisy_pixel_values, # (B, 3, H, W) "noisy_grid_hw": noisy_grid_hw, # (B, 2) "noise_scale": noise_scale, # (B,) or None + "cond_drop_text": torch.tensor( + [m in ("text", "text_image") for m in condition_modes], + dtype=torch.bool, + ), + "cond_drop_mode": condition_modes, + "prefix_cache_key": prefix_cache_keys, "sample_ids": [s.sample_id for s in samples], "text_lens": text_lens, "token_hw": (token_h, token_w), diff --git a/train_u1/data/datasets.py b/train_u1/data/datasets.py index 32144ddd5d72a9abb74c52ddff2dc08cc450c2b5..6791813f6085e489fef7fa208d10becde9ff8d1a 100644 --- a/train_u1/data/datasets.py +++ b/train_u1/data/datasets.py @@ -141,6 +141,7 @@ class PairedFolderT2IDataset(Dataset): prompt_template: str | None = None, image_extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp"), snap_bucket: bool = False, + use_think_labels: bool = True, ): self.folder = Path(folder) if not self.folder.is_dir(): @@ -148,6 +149,7 @@ class PairedFolderT2IDataset(Dataset): self.cap_max_pixels = cap_max_pixels self.prompt_template = prompt_template self.snap_bucket = snap_bucket + self.use_think_labels = use_think_labels pairs: list[tuple[Path, Path, str]] = [] for ext in image_extensions: @@ -172,11 +174,13 @@ class PairedFolderT2IDataset(Dataset): with open(txt_path, encoding="utf-8") as f: raw = f.read() caption, think_text = parse_caption_and_think(raw) + if not self.use_think_labels: + think_text = None if self.prompt_template: caption = self.prompt_template.format(caption=caption) # Legacy fallback: `.think.txt` separate sidecar (deprecated; # `parse_caption_and_think` is the preferred path). - if think_text is None: + if self.use_think_labels and think_text is None: think_path = img_path.with_suffix(".think.txt") if think_path.is_file(): with open(think_path, encoding="utf-8") as f: @@ -257,6 +261,7 @@ class ArrowT2IDataset(Dataset): cap_max_pixels: int | None = None, prompt_template: str | None = None, snap_bucket: bool = False, + use_think_labels: bool = True, ): try: import pyarrow.parquet as pq # noqa: F401 @@ -269,6 +274,7 @@ class ArrowT2IDataset(Dataset): self.cap_max_pixels = cap_max_pixels self.prompt_template = prompt_template self.snap_bucket = snap_bucket + self.use_think_labels = use_think_labels self._table = None # lazy-loaded self._n: int | None = None @@ -311,6 +317,8 @@ class ArrowT2IDataset(Dataset): sample_id = row["sample_id"][0] caption = row["caption"][0] think = (row.get("think") or [None])[0] or None + if not self.use_think_labels: + think = None if "image" in self._table.column_names and row["image"][0] is not None: from PIL import Image diff --git a/train_u1/model/lora.py b/train_u1/model/lora.py index 1ab27ae0b67e7c456f04a1d06acce32b28a2e463..1917daca5b2eec4119e423974958bd220bf06253 100644 --- a/train_u1/model/lora.py +++ b/train_u1/model/lora.py @@ -17,6 +17,16 @@ Wrapped modules supported (per-module rank/alpha/enable independently): Patch decoder (×2): fm_modules.fm_head.0 fm_modules.fm_head.2 +Experimental A3B/MoE target grammar (requires an A3B runtime whose modules +match the public checkpoint names): + + Generation MoE experts: + mlp_mot_gen.experts.*.gate_proj + mlp_mot_gen.experts.*.up_proj + mlp_mot_gen.experts.*.down_proj + Generation MoE router: + mlp_mot_gen.gate + The adapter is implemented as `y = base(x) + scaling * lora_up(lora_down(x))` with `scaling = alpha / r`. Initial state: `lora_down` kaiming uniform, `lora_up` zeros — so the wrapped module starts at exactly the base output. @@ -50,17 +60,42 @@ import torch.nn as nn ATTN_TARGETS = ("q_proj_mot_gen", "k_proj_mot_gen", "v_proj_mot_gen", "o_proj_mot_gen") MLP_TARGETS = ("mlp_mot_gen.gate_proj", "mlp_mot_gen.up_proj", "mlp_mot_gen.down_proj") FM_HEAD_TARGETS = ("fm_modules.fm_head.0", "fm_modules.fm_head.2") +GEN_MOE_MLP_TARGETS = ( + "mlp_mot_gen.experts.*.gate_proj", + "mlp_mot_gen.experts.*.up_proj", + "mlp_mot_gen.experts.*.down_proj", +) +GEN_MOE_ROUTER_TARGETS = ("mlp_mot_gen.gate",) +GEN_MOE_TARGETS = GEN_MOE_MLP_TARGETS + GEN_MOE_ROUTER_TARGETS -ALL_KNOWN_TARGETS = ATTN_TARGETS + MLP_TARGETS + FM_HEAD_TARGETS +DENSE_KNOWN_TARGETS = ATTN_TARGETS + MLP_TARGETS + FM_HEAD_TARGETS +ALL_KNOWN_TARGETS = DENSE_KNOWN_TARGETS + GEN_MOE_TARGETS # Convenience expansions used by the CLI parser (`attn`, `mlp`, `fm_head`). TARGET_GROUPS: dict[str, tuple[str, ...]] = { "attn": ATTN_TARGETS, "mlp": MLP_TARGETS, "fm_head": FM_HEAD_TARGETS, - "all": ALL_KNOWN_TARGETS, + # A3B generation-side MoE aliases. These are deliberately separate from + # `mlp`/`all` so existing 8B configs remain byte-for-byte semantic matches. + "gen_moe_mlp": GEN_MOE_MLP_TARGETS, + "moe_mlp": GEN_MOE_MLP_TARGETS, + "gen_moe_router": GEN_MOE_ROUTER_TARGETS, + "moe_router": GEN_MOE_ROUTER_TARGETS, + "gen_moe_all": ATTN_TARGETS + GEN_MOE_TARGETS + FM_HEAD_TARGETS, + "moe_all": ATTN_TARGETS + GEN_MOE_TARGETS + FM_HEAD_TARGETS, + "all": DENSE_KNOWN_TARGETS, } +_GEN_MOE_EXPERT_TARGET_RE = re.compile( + r"^mlp_mot_gen\.experts\.(?P\*|\d+)\." + r"(?Pgate_proj|up_proj|down_proj)$" +) + + +def _is_known_target(target: str) -> bool: + return target in ALL_KNOWN_TARGETS or _GEN_MOE_EXPERT_TARGET_RE.match(target) is not None + # --------------------------------------------------------------------------- # # Spec types # @@ -71,7 +106,8 @@ TARGET_GROUPS: dict[str, tuple[str, ...]] = { class LoRASpec: """Per-target LoRA configuration. - `target` is one of `ALL_KNOWN_TARGETS` (verbatim module-name suffix). + `target` is one of `ALL_KNOWN_TARGETS` (verbatim module-name suffix) or + an A3B MoE expert target like `mlp_mot_gen.experts.0.gate_proj`. `r` is the LoRA rank. `alpha` is the LoRA alpha; `scaling = alpha / r`. `dropout` applies to the input before `lora_down`. `enabled=False` lets a preset entry be turned off without removing it. @@ -84,7 +120,7 @@ class LoRASpec: enabled: bool = True def __post_init__(self) -> None: - if self.target not in ALL_KNOWN_TARGETS: + if not _is_known_target(self.target): raise ValueError( f"unknown LoRA target {self.target!r}. " f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}" @@ -219,6 +255,50 @@ def _walk_mlp_targets(model: nn.Module, target_name: str): yield sub, leaf, idx +def _walk_moe_mlp_targets(model: nn.Module, target_name: str): + """Yield generation-side MoE expert projections for A3B-style modules. + + `target_name` is `mlp_mot_gen.experts.*.gate_proj` or a single expert + target such as `mlp_mot_gen.experts.7.down_proj`. The walker is intentionally + shape/runtime agnostic: if the loaded model has no `experts` ModuleList, it + yields nothing so 8B dense configs are unaffected. + """ + m = _GEN_MOE_EXPERT_TARGET_RE.match(target_name) + if m is None: + raise ValueError(f"invalid MoE expert target {target_name!r}") + expert_selector = m.group("expert") + leaf = m.group("leaf") + + layers = model.language_model.model.layers + for layer_idx, layer in enumerate(layers): + sub = getattr(layer, "mlp_mot_gen", None) + experts = getattr(sub, "experts", None) + if experts is None: + continue + if expert_selector == "*": + expert_indices = range(len(experts)) + else: + expert_idx = int(expert_selector) + if expert_idx >= len(experts): + continue + expert_indices = (expert_idx,) + for expert_idx in expert_indices: + expert = experts[expert_idx] + if hasattr(expert, leaf): + yield expert, leaf, layer_idx + + +def _walk_moe_router_targets(model: nn.Module, target_name: str): + """Yield generation-side MoE router gates (`mlp_mot_gen.gate`) per layer.""" + if target_name != "mlp_mot_gen.gate": + raise ValueError(f"invalid MoE router target {target_name!r}") + layers = model.language_model.model.layers + for idx, layer in enumerate(layers): + sub = getattr(layer, "mlp_mot_gen", None) + if sub is not None and hasattr(sub, "gate"): + yield sub, "gate", idx + + def _walk_fm_head_targets(model: nn.Module, target_name: str): """Yield `(parent, attr, idx)` for each fm_head linear matching target_name. @@ -251,6 +331,10 @@ def _resolve_target_walker(target: str): return _walk_attn_targets if target in MLP_TARGETS: return _walk_mlp_targets + if _GEN_MOE_EXPERT_TARGET_RE.match(target): + return _walk_moe_mlp_targets + if target in GEN_MOE_ROUTER_TARGETS: + return _walk_moe_router_targets if target in FM_HEAD_TARGETS: return _walk_fm_head_targets raise ValueError(f"no walker for target {target!r}") @@ -327,10 +411,7 @@ def apply_lora_specs( # --------------------------------------------------------------------------- # -_SPEC_TOK_RE = re.compile( - r"^(?P[A-Za-z0-9_.]+)" - r"(?:=(?P.+))?$" -) +_SPEC_TOK_RE = re.compile(r"^(?P[A-Za-z0-9_.*]+)(?:=(?P.+))?$") _RA_RE = re.compile(r"^r(?P\d+)(?:a(?P\d+(?:\.\d+)?))?$") @@ -343,14 +424,16 @@ def parse_lora_spec_str(s: str) -> list[LoRASpec]: - `off` disable a target (overrides earlier entries) - `r=N,a=M` alternative comma form (more readable) - Group expansions: `attn`, `mlp`, `fm_head`, `all` expand to their member - targets, all sharing the same body. + Group expansions: `attn`, `mlp`, `fm_head`, `gen_moe_mlp`, + `gen_moe_router`, `gen_moe_all`, `all` expand to their member targets, + all sharing the same body. Examples:: attn=r64a64;mlp=r64a64 q_proj_mot_gen=r128a128; k_proj_mot_gen=r128a128 all=r64a64; mlp_mot_gen.down_proj=off + gen_moe_mlp=r8a8; gen_moe_router=r4a4 fm_head=r=128,a=128 """ specs: dict[str, LoRASpec] = {} @@ -366,7 +449,7 @@ def parse_lora_spec_str(s: str) -> list[LoRASpec]: targets = TARGET_GROUPS.get(target, (target,)) for t in targets: - if t not in ALL_KNOWN_TARGETS: + if not _is_known_target(t): raise ValueError( f"unknown LoRA target {t!r}. " f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}" @@ -434,8 +517,22 @@ LORA_PRESETS: dict[str, str] = { # Attn + MLP only (no fm_head); equivalent to our pre-v16c v15a recipe. "attn_mlp": "attn=r64a64;mlp=r64a64", + # **Safe presets** that explicitly drop fm_head from the trained surface. + # The technical report's grid-artifact discussion attributes artifacts to + # the final FFN + MLP head independently modelling disjoint 32×32 patches, + # and notes that the official T2I RL stage freezes the generation-branch + # MLP head and the last three transformer layers for exactly this reason. + # Use these when you want to avoid touching the head at all. + "attn_only_no_head": "attn=r64a64", + "attn_mlp_no_head": "attn=r64a64;mlp=r64a64", + # Exact upstream 8-step distill LoRA shape (rank 128 alpha 128). "official_r128": "attn=r128a128;mlp=r128a128;fm_head=r128a128", + + # Experimental A3B/MoE coverage. Small ranks are intentional: covering all + # 48 layers × 32 gen experts × 3 projections gets large quickly. + "a3b_moe_r8": "attn=r8a8;gen_moe_mlp=r8a8;fm_head=r8a8", + "a3b_moe_router_r8": "gen_moe_router=r8a8", } diff --git a/train_u1/model/losses.py b/train_u1/model/losses.py index 2e9c3ecb7d339ed2e9997c089673541bb1537e2d..9683fe2dc542d7c1d434cf383fc6bdf65e44efde 100644 --- a/train_u1/model/losses.py +++ b/train_u1/model/losses.py @@ -1,11 +1,14 @@ """Training losses for the FM step. -Two primaries (report §5 / §2.1): -- `fm_loss_x0(x_pred, x0_patch)` — MVP recommended. MSE on clean patches. -- `fm_loss_v(v_pred, v_target)` — velocity-target ablation. +Two primaries (report Eq. (5) / Table 2): +- `fm_loss_x0(x_pred, x0_patch)` — MSE on clean patches (legacy MVP default). +- `fm_loss_v(v_pred, v_target)` — MSE on velocity (matches the official + x-predict + v-loss training objective; equivalent to + `MSE(x_pred - x0) / (1 - t)^2`, i.e. an x0-MSE re-weighted by `(1-t)^-2`). -Plus optional Huber variants for outlier robustness, and a tiny CE -guardrail for the unified-training scenario (Phase C). +Plus Huber variants and an `fm_loss` dispatcher that selects by `loss_type`. + +CE guardrail kept for the Phase C unified-training scenario. """ from __future__ import annotations @@ -35,6 +38,65 @@ def fm_loss_v_huber(v_pred: torch.Tensor, v_target: torch.Tensor, delta: float = return F.huber_loss(v_pred.float(), v_target.float(), delta=delta) +def compute_v_target( + x0_patch: torch.Tensor, + z_t: torch.Tensor, + t: torch.Tensor, + *, + t_eps: float = 1e-3, +) -> torch.Tensor: + """Closed-form velocity target for rectified-flow / linear-z_t. + + Report Eq. (5): `v* = (x0 - z_t) / (1 - t)` with `z_t = t x0 + (1-t) eps`. + `t` is expected to be a (B,) tensor — we broadcast to (B, 1, 1) to match + the (B, N, D) patch tensors. + """ + if x0_patch.shape != z_t.shape: + raise ValueError(f"shape mismatch x0 {x0_patch.shape} vs z_t {z_t.shape}") + t = t.to(x0_patch.dtype) + while t.dim() < x0_patch.dim(): + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp(min=t_eps) + return (x0_patch - z_t) / denom + + +# --------------------------------------------------------------------------- # +# Dispatcher # +# --------------------------------------------------------------------------- # + +VALID_LOSS_TYPES = ("x0", "v", "x0_huber", "v_huber") + + +def fm_loss( + *, + loss_type: str, + x_pred: torch.Tensor, + x0_patch: torch.Tensor, + v_pred: torch.Tensor | None = None, + v_target: torch.Tensor | None = None, + huber_delta: float = 1.0, +) -> torch.Tensor: + """Single entry point selecting one of the four FM losses. + + `x0` / `x0_huber` only need `x_pred` + `x0_patch`. + `v` / `v_huber` require `v_pred` + `v_target` (caller computes them via + `compute_v_target` from the same `(x0, z_t, t)` used to build the batch). + """ + if loss_type == "x0": + return fm_loss_x0(x_pred, x0_patch) + if loss_type == "x0_huber": + return fm_loss_x0_huber(x_pred, x0_patch, delta=huber_delta) + if loss_type == "v": + if v_pred is None or v_target is None: + raise ValueError("loss_type='v' requires v_pred and v_target") + return fm_loss_v(v_pred, v_target) + if loss_type == "v_huber": + if v_pred is None or v_target is None: + raise ValueError("loss_type='v_huber' requires v_pred and v_target") + return fm_loss_v_huber(v_pred, v_target, delta=huber_delta) + raise ValueError(f"unknown loss_type {loss_type!r}; valid: {VALID_LOSS_TYPES}") + + def text_ce_guardrail( logits: torch.Tensor, labels: torch.Tensor, diff --git a/train_u1/model/patching.py b/train_u1/model/patching.py index 29fd847742f0fcd46e79808abf6b51c849b5327f..3d316edf5ae34cea240be2b138e9c74fb911c9f8 100644 --- a/train_u1/model/patching.py +++ b/train_u1/model/patching.py @@ -96,7 +96,7 @@ def linear_z_t(x0_patch: torch.Tensor, eps: torch.Tensor, t: torch.Tensor) -> to """Linear-flow interpolation `z_t = t * x0 + (1 - t) * eps`. 合理推断 — `_t2i_predict_v` divides `(x_pred - z_t) / (1 - t)` to recover - velocity, which is the standard linear-flow form. See report §0.1 (5). + velocity, which is the standard linear-flow form. See report Eq. (5). Shapes: x0_patch : (B, N, FM_OUTPUT_DIM) eps : same as x0_patch diff --git a/train_u1/model/wrapper.py b/train_u1/model/wrapper.py index 54e204f660b00078d176bde5f6c461d958c4ff2f..159951e2394d8dc0c0558a0c3f4481937ed9d227 100644 --- a/train_u1/model/wrapper.py +++ b/train_u1/model/wrapper.py @@ -17,7 +17,7 @@ References (commit df86ca90): - modeling_qwen3.py L152-164, L739-1001 公开证据显示 — every step references public source. -合理推断 — only the `MSE(x_pred, x0)` loss head, per report §0.1 (5). +合理推断 — only the `MSE(x_pred, x0)` loss head, per report Eq. (5). """ from __future__ import annotations diff --git a/train_u1/scripts/inspect_lora_targets.py b/train_u1/scripts/inspect_lora_targets.py new file mode 100644 index 0000000000000000000000000000000000000000..4365efa7d4faad0779da9835477de7c6f71d833b --- /dev/null +++ b/train_u1/scripts/inspect_lora_targets.py @@ -0,0 +1,192 @@ +"""Metadata-only LoRA target estimator for dense 8B and A3B/MoE configs. + +This script intentionally reads only ``config.json``. It does not instantiate +the model and does not download safetensors shards, so it is safe to run against +large A3B repositories before the public MoE runtime is usable locally. +""" +from __future__ import annotations + +import argparse +import json +import re +from dataclasses import dataclass +from pathlib import Path + +from train_u1.model.lora import parse_lora_spec_str + + +_MOE_EXPERT_TARGET_RE = re.compile( + r"^mlp_mot_gen\.experts\.(?P\*|\d+)\." + r"(?Pgate_proj|up_proj|down_proj)$" +) + + +@dataclass(frozen=True) +class ShapeEstimate: + modules: int + params: int + note: str = "" + + +def _load_config(model: str, cache_dir: str | None = None) -> dict: + path = Path(model) + if path.is_dir(): + return json.loads((path / "config.json").read_text()) + if path.is_file(): + return json.loads(path.read_text()) + + from huggingface_hub import hf_hub_download + + cfg_path = hf_hub_download(model, "config.json", cache_dir=cache_dir) + return json.loads(Path(cfg_path).read_text()) + + +def _linear_lora_params(in_features: int, out_features: int, rank: int) -> int: + return rank * (in_features + out_features) + + +def _llm_config(cfg: dict) -> dict: + return cfg.get("llm_config", cfg) + + +def _fm_output_dim(cfg: dict) -> int: + patch_size = int(cfg.get("patch_size", 16)) + downsample_ratio = float(cfg.get("downsample_ratio", 0.5)) + merge_size = int(1 / downsample_ratio) + return 3 * (patch_size * merge_size) ** 2 + + +def _estimate_target(cfg: dict, target: str, rank: int) -> ShapeEstimate: + llm = _llm_config(cfg) + layers = int(llm.get("num_hidden_layers", 0)) + hidden = int(llm.get("hidden_size", 0)) + head_dim = int(llm.get("head_dim", hidden // max(int(llm.get("num_attention_heads", 1)), 1))) + n_heads = int(llm.get("num_attention_heads", 0)) + n_kv = int(llm.get("num_key_value_heads", n_heads)) + q_out = n_heads * head_dim + kv_out = n_kv * head_dim + intermediate = int(llm.get("intermediate_size", 0)) + moe_intermediate = int(llm.get("moe_intermediate_size", 0)) + gen_experts = int(llm.get("gen_num_experts", 0) or 0) + + if target == "q_proj_mot_gen": + return ShapeEstimate(layers, layers * _linear_lora_params(hidden, q_out, rank)) + if target in {"k_proj_mot_gen", "v_proj_mot_gen"}: + return ShapeEstimate(layers, layers * _linear_lora_params(hidden, kv_out, rank)) + if target == "o_proj_mot_gen": + return ShapeEstimate(layers, layers * _linear_lora_params(q_out, hidden, rank)) + + if target in {"mlp_mot_gen.gate_proj", "mlp_mot_gen.up_proj"}: + return ShapeEstimate(layers, layers * _linear_lora_params(hidden, intermediate, rank)) + if target == "mlp_mot_gen.down_proj": + return ShapeEstimate(layers, layers * _linear_lora_params(intermediate, hidden, rank)) + + m = _MOE_EXPERT_TARGET_RE.match(target) + if m is not None: + if gen_experts <= 0 or moe_intermediate <= 0: + return ShapeEstimate(0, 0, "config has no generation MoE experts") + expert_selector = m.group("expert") + if expert_selector == "*": + n_experts = gen_experts + else: + expert_idx = int(expert_selector) + n_experts = 1 if expert_idx < gen_experts else 0 + if n_experts == 0: + return ShapeEstimate(0, 0, "selected expert is outside gen_num_experts") + leaf = m.group("leaf") + if leaf in {"gate_proj", "up_proj"}: + per_module = _linear_lora_params(hidden, moe_intermediate, rank) + else: + per_module = _linear_lora_params(moe_intermediate, hidden, rank) + modules = layers * n_experts + return ShapeEstimate(modules, modules * per_module) + + if target == "mlp_mot_gen.gate": + if gen_experts <= 0: + return ShapeEstimate(0, 0, "config has no generation MoE router") + return ShapeEstimate(layers, layers * _linear_lora_params(hidden, gen_experts, rank)) + + if target == "fm_modules.fm_head.0": + return ShapeEstimate(1, _linear_lora_params(hidden, 4096, rank)) + if target == "fm_modules.fm_head.2": + return ShapeEstimate(1, _linear_lora_params(4096, _fm_output_dim(cfg), rank)) + + return ShapeEstimate(0, 0, "unknown estimator target") + + +def _model_kind(cfg: dict) -> str: + llm = _llm_config(cfg) + arch = ",".join(llm.get("architectures") or []) + if "Moe" in arch or llm.get("gen_num_experts"): + return "a3b_moe" + return "dense" + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--model", required=True, help="HF repo id, local model dir, or config.json path") + ap.add_argument("--spec", required=True, help="LoRA spec string, e.g. 'attn=r8a8;gen_moe_mlp=r8a8'") + ap.add_argument("--cache-dir", default="hf_cache") + ap.add_argument("--json", action="store_true", help="Emit machine-readable JSON") + args = ap.parse_args() + + cfg = _load_config(args.model, cache_dir=args.cache_dir) + specs = parse_lora_spec_str(args.spec) + + rows = [] + total_modules = 0 + total_params = 0 + for spec in specs: + if not spec.enabled: + continue + est = _estimate_target(cfg, spec.target, spec.r) + row = { + "target": spec.target, + "rank": spec.r, + "alpha": spec.alpha, + "modules": est.modules, + "lora_params": est.params, + "note": est.note, + } + rows.append(row) + total_modules += est.modules + total_params += est.params + + out = { + "model": args.model, + "model_kind": _model_kind(cfg), + "llm_architectures": _llm_config(cfg).get("architectures"), + "spec": args.spec, + "targets": rows, + "total_modules": total_modules, + "total_lora_params": total_params, + "approx_checkpoint_mb_fp32": total_params * 4 / 1e6, + "runtime_note": ( + "metadata estimate only; end-to-end A3B requires a runtime with " + "mlp_mot_gen.experts.* modules" + ), + } + + if args.json: + print(json.dumps(out, indent=2, ensure_ascii=False)) + return + + print(f"model: {out['model']}") + print(f"kind: {out['model_kind']} llm_arch={out['llm_architectures']}") + print(f"spec: {out['spec']}") + print() + print(f"{'target':45s} {'r':>4s} {'modules':>8s} {'lora params':>14s} note") + print("-" * 90) + for row in rows: + print( + f"{row['target']:45s} {row['rank']:4d} {row['modules']:8d} " + f"{row['lora_params']:14,d} {row['note']}" + ) + print("-" * 90) + print(f"{'total':45s} {'':4s} {total_modules:8d} {total_params:14,d}") + print(f"approx checkpoint size if saved fp32: {out['approx_checkpoint_mb_fp32']:.1f} MB") + print(out["runtime_note"]) + + +if __name__ == "__main__": + main() diff --git a/train_u1/scripts/train_bf16_offload.py b/train_u1/scripts/train_bf16_offload.py index 564d25d749ae5903f8c3381556f717e1404251f2..9ccc2be333d007851895826da1b8e995e019d013 100644 --- a/train_u1/scripts/train_bf16_offload.py +++ b/train_u1/scripts/train_bf16_offload.py @@ -35,6 +35,7 @@ import argparse import gc import json import os +import re import sys import time from pathlib import Path @@ -46,7 +47,7 @@ from train_u1.constants import MODEL_ID, MODEL_SHA from train_u1.data.collators import CollatorConfig, SenseNovaU1Collator, to_device from train_u1.data.datasets import ArrowT2IDataset, PairedFolderT2IDataset from train_u1.model.loader import _resolve_local_snapshot, load_neo_chat -from train_u1.model.losses import fm_loss_x0 +from train_u1.model.losses import VALID_LOSS_TYPES, compute_v_target, fm_loss, fm_loss_x0, fm_loss_v from train_u1.model.lora import ( LORA_PRESETS, apply_lora_specs, @@ -81,18 +82,20 @@ def _trainable_params(model): def precompute_prefix_kvs( model, samples, - tokenizer, collator, *, cuda_device: str, cpu_device: str, classify: dict, + include_uncond: bool, verbose: bool = True, -) -> list: - """Run prefix forward for each sample with ordinary tower on GPU. +) -> dict[str, object]: + """Run prefix forward for conditional and optional unconditional CFG prefixes. - Returns list[ DynamicCache (with K/V tensors on `cpu_device`) ] aligned - with `samples` order. + Returns a dict with: + - ``cond``: per-sample DynamicCache list aligned with ``samples`` order + - ``uncond``: one shared unconditional DynamicCache when ``include_uncond`` + is true, otherwise ``None`` """ if verbose: sz_prefix = _bytes_for(model, classify["prefix"]) / 1e9 @@ -107,9 +110,15 @@ def precompute_prefix_kvs( flush=True, ) - kvs = [] + # Prefix precompute only needs text inputs; collator still builds full FM + # batches, so preserve its training RNG streams around these calls. + collator_gen_state = collator._gen.get_state() + collator_cond_gen_state = collator._cond_gen.get_state() + + cond_kvs = [] + uncond_kv = None for i, sample in enumerate(samples): - batch = collator([sample]) + batch = collator([sample], condition_modes=["none"]) # Force prefix-relevant inputs to GPU for the no_grad forward. input_ids = batch["input_ids"].to(cuda_device) text_indexes = batch["text_indexes"].to(cuda_device) @@ -125,7 +134,7 @@ def precompute_prefix_kvs( kv = prefix_out.past_key_values # Move KV to CPU (small — ~70 MB per sample at L_text~400) _cache_to(kv, cpu_device) - kvs.append(kv) + cond_kvs.append(kv) if verbose: print( f"[bf16-offload] precomputed prefix KV for sample {i+1}/{len(samples)} " @@ -133,9 +142,65 @@ def precompute_prefix_kvs( flush=True, ) + if include_uncond and samples: + # The unconditional CFG prefix is prompt-independent, so cache it once + # and share it across all samples/steps. + batch = collator([samples[0]], condition_modes=["text"]) + input_ids = batch["input_ids"].to(cuda_device) + text_indexes = batch["text_indexes"].to(cuda_device) + attn_mask_prefix = batch["attn_mask_prefix"].to(cuda_device) + with torch.no_grad(): + prefix_out = model.language_model.model( + input_ids=input_ids, + indexes=text_indexes, + attention_mask={"full_attention": attn_mask_prefix}, + use_cache=True, + ) + uncond_kv = prefix_out.past_key_values + _cache_to(uncond_kv, cpu_device) + if verbose: + print( + f"[bf16-offload] precomputed shared uncond prefix KV " + f"(L_text={input_ids.shape[1]})", + flush=True, + ) + if verbose: print(f"[bf16-offload] precompute total: {time.time()-t0:.1f}s", flush=True) - return kvs + collator._gen.set_state(collator_gen_state) + collator._cond_gen.set_state(collator_cond_gen_state) + return {"cond": cond_kvs, "uncond": uncond_kv} + + +def _select_prefix_kv(prefix_kvs: dict[str, object], idx: int, prefix_key: str): + if prefix_key == "cond": + return prefix_kvs["cond"][idx] + if prefix_key == "uncond": + kv = prefix_kvs.get("uncond") + if kv is None: + raise RuntimeError("missing unconditional prefix KV; condition dropout requires include_uncond=True") + return kv + raise ValueError(f"unknown prefix_cache_key {prefix_key!r}") + + +def _guard_static_prefix_unfreeze(unfreeze_patterns: list[str], classify: dict[str, list[str]]) -> None: + """Static prefix KVs make prefix/unused params effectively non-trainable.""" + if not unfreeze_patterns: + return + compiled = [(pat, re.compile(pat)) for pat in unfreeze_patterns] + blocked_names = classify.get("prefix", []) + classify.get("unused", []) + hits: list[tuple[str, str]] = [] + for pat, cre in compiled: + for name in blocked_names: + if cre.search(name): + hits.append((pat, name)) + break + if hits: + details = "; ".join(f"{pat!r} matched {name}" for pat, name in hits[:5]) + raise ValueError( + "unfreeze patterns cannot target prefix/unused tower params while static prefix KV cache is enabled; " + f"{details}. Restrict unfreeze to gen-side modules or use a non-cached training path." + ) def evict_ordinary_load_gen(model, classify, cuda_device, cpu_device, verbose=True): @@ -224,6 +289,11 @@ def main() -> int: ap.add_argument("--cap-max-pixels", type=int, default=None) ap.add_argument("--n-samples", type=int, default=None, help="Cap on dataset size (default: use entire data_dir).") + ap.add_argument("--use-think-labels", action="store_true", default=None, + help="Use embedded/sidecar think labels during training.") + ap.add_argument("--no-use-think-labels", dest="use_think_labels", + action="store_false", default=None, + help="Ignore embedded/sidecar think labels during training.") ap.add_argument("--steps", type=int, default=None) ap.add_argument("--lr", type=float, default=None) ap.add_argument("--grad-accum", type=int, default=None) @@ -273,6 +343,24 @@ def main() -> int: ap.add_argument("--upstream-lora-path", default=None, help="Bake-in merge an upstream-format LoRA into the bf16 base before training. " "Useful for stacking: train new style on top of the 8-step distill LoRA.") + # ---- Official-consistency knobs (report Eq. (5) / Table 2). ---- + ap.add_argument("--loss-type", default=None, + choices=(None, *VALID_LOSS_TYPES), + help="FM loss objective: x0 (legacy) | v (official) | x0_huber | v_huber.") + ap.add_argument("--huber-delta", type=float, default=None, + help="Delta for x0_huber / v_huber.") + ap.add_argument("--t-dist", default=None, + choices=(None, "uniform", "logit_normal"), + help="t-sampling distribution. Official: logit_normal (mean=-0.8, std=0.8).") + ap.add_argument("--t-logit-mean", type=float, default=None, + help="Mean for logit_normal t-sampling (default: -0.8 per report).") + ap.add_argument("--t-logit-std", type=float, default=None, + help="Std for logit_normal t-sampling (default: 0.8 per report).") + ap.add_argument("--cond-dropout-text", type=float, default=None, + help="Train-time CFG dropout probability for text-only condition drop. Official: 0.10.") + ap.add_argument("--cond-dropout-both", type=float, default=None, + help="Additional unconditional bucket for text+image condition drop. Official: 0.10; " + "pure T2I maps this to the unconditional prompt prefix.") args = ap.parse_args() # ---- Resolve config: load YAML if given, then overlay CLI flags ---- @@ -291,6 +379,7 @@ def main() -> int: _override(cfg.data, "cap_max_pixels", args.cap_max_pixels) _override(cfg.data, "snap_bucket", args.snap_bucket) _override(cfg.data, "n_samples", args.n_samples) + _override(cfg.data, "use_think_labels", args.use_think_labels) _override(cfg.style, "trigger", args.style_trigger) _override(cfg.style, "prompt_template", args.prompt_template) _override(cfg.lora, "preset", args.lora_preset) @@ -310,6 +399,13 @@ def main() -> int: _override(cfg.runtime, "device", args.device) _override(cfg.runtime, "cpu_device", args.cpu_device) _override(cfg.runtime, "upstream_lora_path", args.upstream_lora_path) + _override(cfg.train, "loss_type", args.loss_type) + _override(cfg.train, "huber_delta", args.huber_delta) + _override(cfg.train, "t_dist", args.t_dist) + _override(cfg.train, "t_logit_mean", args.t_logit_mean) + _override(cfg.train, "t_logit_std", args.t_logit_std) + _override(cfg.train, "cond_dropout_text", args.cond_dropout_text) + _override(cfg.train, "cond_dropout_both", args.cond_dropout_both) # `data_dir` is the only truly required field. if not cfg.data.data_dir or cfg.data.data_dir == "dataset/my_style": @@ -339,6 +435,7 @@ def main() -> int: a.data_dir = cfg.data.data_dir a.cap_max_pixels = cfg.data.cap_max_pixels a.n_samples = cfg.data.n_samples + a.use_think_labels = cfg.data.use_think_labels a.steps = cfg.train.steps a.lr = cfg.train.lr a.grad_accum = cfg.train.grad_accum @@ -357,8 +454,29 @@ def main() -> int: a.upstream_lora_path = cfg.runtime.upstream_lora_path a.lora_specs = cfg.lora.resolved_specs() a.unfreeze_patterns = list(cfg.unfreeze) + a.loss_type = cfg.train.loss_type + a.huber_delta = cfg.train.huber_delta + a.t_dist = cfg.train.t_dist + a.t_logit_mean = cfg.train.t_logit_mean + a.t_logit_std = cfg.train.t_logit_std + a.cond_dropout_text = cfg.train.cond_dropout_text + a.cond_dropout_both = cfg.train.cond_dropout_both args = a # rebind so the body below uses the resolved config + if args.loss_type not in VALID_LOSS_TYPES: + ap.error(f"loss_type must be one of {VALID_LOSS_TYPES}, got {args.loss_type!r}") + if args.t_dist not in ("uniform", "logit_normal"): + ap.error(f"t_dist must be 'uniform' or 'logit_normal', got {args.t_dist!r}") + if args.cond_dropout_text < 0 or args.cond_dropout_both < 0: + ap.error("condition dropout probabilities must be non-negative") + if args.cond_dropout_text + args.cond_dropout_both > 1.0: + ap.error("cond_dropout_text + cond_dropout_both must be <= 1.0") + print(f"[bf16-offload] loss_type={args.loss_type} t_dist={args.t_dist}" + + (f" (mean={args.t_logit_mean}, std={args.t_logit_std})" + if args.t_dist == "logit_normal" else ""), flush=True) + print(f"[bf16-offload] cond_dropout: text={args.cond_dropout_text:.3f} " + f"both={args.cond_dropout_both:.3f}", flush=True) + torch.manual_seed(args.seed) # ---- Phase 0: load bf16 model to CPU ---- @@ -389,12 +507,14 @@ def main() -> int: _data_path, cap_max_pixels=args.cap_max_pixels, snap_bucket=args.snap_bucket, + use_think_labels=args.use_think_labels, ) else: ds = PairedFolderT2IDataset( args.data_dir, cap_max_pixels=args.cap_max_pixels, snap_bucket=args.snap_bucket, + use_think_labels=args.use_think_labels, ) n_use = len(ds) if args.n_samples is None else min(args.n_samples, len(ds)) samples = [ds[i] for i in range(n_use)] @@ -406,11 +526,17 @@ def main() -> int: enforce_batch_one=True, prompt_template=args.prompt_template, style_trigger=args.style_trigger, + t_dist=args.t_dist, + t_logit_mean=args.t_logit_mean, + t_logit_std=args.t_logit_std, + cond_dropout_text=args.cond_dropout_text, + cond_dropout_both=args.cond_dropout_both, ), model=model if args.prompt_template == "official" else None, ) print(f"[bf16-offload] dataset: {len(ds)} pairs (using {n_use}) " - f"snap_bucket={args.snap_bucket} style_trigger={args.style_trigger!r}", flush=True) + f"snap_bucket={args.snap_bucket} use_think_labels={args.use_think_labels} " + f"style_trigger={args.style_trigger!r}", flush=True) # ---- Optional: bake-in merge an upstream-format LoRA into the base # before our wrap, so we train on top of (e.g.) the 8-step distill LoRA. ---- @@ -438,14 +564,21 @@ def main() -> int: # tower classification classify = classify_module_paths(model) + legacy_static_scenarios = {"mvp", "mvp_aux", "gen_vision", "aux_no_head"} + if args.scenario not in legacy_static_scenarios: + try: + _guard_static_prefix_unfreeze(args.unfreeze_patterns, classify) + except ValueError as e: + ap.error(str(e)) move_param_set(model, classify["unused"], args.cpu_device) move_param_set(model, classify["shared"], args.device) # ---- Phase 1: precompute prefix KVs (ordinary tower briefly on GPU) ---- prefix_kvs_cpu = precompute_prefix_kvs( - model, samples, tok, collator, + model, samples, collator, cuda_device=args.device, cpu_device=args.cpu_device, classify=classify, + include_uncond=(args.cond_dropout_text + args.cond_dropout_both) > 0.0, ) # ---- Phase 2: ordinary → CPU permanently, gen → GPU permanently ---- @@ -488,9 +621,12 @@ def main() -> int: # If requested, move all prefix KVs to GPU permanently if args.keep_kvs_on_gpu: - for kv in prefix_kvs_cpu: + for kv in prefix_kvs_cpu["cond"]: _cache_to(kv, args.device) - print(f"[bf16-offload] all {len(prefix_kvs_cpu)} prefix KVs kept on GPU") + if prefix_kvs_cpu["uncond"] is not None: + _cache_to(prefix_kvs_cpu["uncond"], args.device) + n_prefix_kv = len(prefix_kvs_cpu["cond"]) + (1 if prefix_kvs_cpu["uncond"] is not None else 0) + print(f"[bf16-offload] all {n_prefix_kv} prefix KVs kept on GPU") # ---- Phase 5: optimizer + training loop ---- wrapper = TrainingWrapper(model) @@ -520,7 +656,7 @@ def main() -> int: print(f"\n[bf16-offload] starting training ({args.steps} steps, n_samples={n_use})", flush=True) t0 = time.time() losses: list[float] = [] - loss_tensors_buf: list[torch.Tensor] = [] # deferred .item() — flushed at log boundaries + loss_tensors_buf: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] log_records_buf: list[dict] = [] # paired JSON record per buffered loss json_lines_buf: list[str] = [] # serialized JSON awaiting batched file write rng = torch.Generator().manual_seed(args.seed) @@ -587,14 +723,32 @@ def main() -> int: sample = samples[idx] token_h, token_w = batch["token_hw"] - # Fetch this sample's pre-computed prefix KV - kv = prefix_kvs_cpu[idx] + # Fetch this sample's pre-computed prefix KV. Condition dropout uses a + # shared unconditional prefix; normal samples use their per-sample cond KV. + prefix_key = batch.get("prefix_cache_key", ["cond"])[0] + cond_drop_mode = batch.get("cond_drop_mode", ["none"])[0] + kv = _select_prefix_kv(prefix_kvs_cpu, idx, prefix_key) if not args.keep_kvs_on_gpu: _cache_to(kv, args.device) out = wrapper.forward_t2i_step(batch, prefix_kv=kv) - loss = fm_loss_x0(out.x_pred, batch["x0_patch"]) / args.grad_accum + # v target for both training and diagnostics. Cheap (one /). + v_target = compute_v_target(batch["x0_patch"], out.z_t, batch["t"], t_eps=wrapper.t_eps) + loss = fm_loss( + loss_type=args.loss_type, + x_pred=out.x_pred, + x0_patch=batch["x0_patch"], + v_pred=out.v_pred, + v_target=v_target, + huber_delta=args.huber_delta, + ) / args.grad_accum loss.backward() + # Diagnostic: always compute the *other* MSE (no grad) so the log + # carries both `x0_mse` and `v_mse` regardless of which one we train. + with torch.no_grad(): + x0_mse_t = fm_loss_x0(out.x_pred, batch["x0_patch"]).detach() + v_mse_t = fm_loss_v(out.v_pred, v_target).detach() + t_vec = batch["t"].detach() do_step = ((step + 1) % args.grad_accum) == 0 if do_step: @@ -606,19 +760,43 @@ def main() -> int: # Defer loss.item() (forces GPU→CPU sync) — keep tensor on GPU, # call .item() only at log-print boundaries. Tensor list is tiny. - loss_tensors_buf.append(loss.detach() * args.grad_accum) # un-scaled + loss_tensors_buf.append(( + loss.detach() * args.grad_accum, # un-scaled active loss + x0_mse_t, v_mse_t, t_vec, + )) log_records_buf.append({ "step": step, "sample_idx": idx, "sample_id": sample.sample_id, "token_h": int(token_h), "token_w": int(token_w), + "cond_drop_mode": cond_drop_mode, + "prefix_cache_key": prefix_key, }) is_log_boundary = (step < 5 or step % 10 == 0 or step == args.steps - 1) if is_log_boundary: - # Sync materialize buffered losses (cheap; PyTorch batches the H2D) - new_losses = [t.item() for t in loss_tensors_buf] + # Sync materialize buffered losses + diagnostics in one shot. + new_losses: list[float] = [] + new_x0: list[float] = [] + new_v: list[float] = [] + new_t_stats: list[tuple[float, float, float, float]] = [] + for loss_t, x0_t, vt, t_vec_t in loss_tensors_buf: + new_losses.append(loss_t.item()) + new_x0.append(x0_t.item()) + new_v.append(vt.item()) + tv = t_vec_t.float() + new_t_stats.append(( + tv.mean().item(), + tv.std(unbiased=False).item() if tv.numel() > 1 else 0.0, + tv.min().item(), + tv.max().item(), + )) losses.extend(new_losses) - for rec, lv in zip(log_records_buf, new_losses): + for rec, lv, x0v, vv, tstat in zip( + log_records_buf, new_losses, new_x0, new_v, new_t_stats + ): rec["loss"] = lv + rec["x0_mse"] = x0v + rec["v_mse"] = vv + rec["t_mean"], rec["t_std"], rec["t_min"], rec["t_max"] = tstat json_lines_buf.append(json.dumps(rec)) loss_tensors_buf.clear() log_records_buf.clear() @@ -628,6 +806,8 @@ def main() -> int: cur_w = batch['noisy_pixel_values'].shape[3] print( f"[bf16-offload] step={step:4d} loss={losses[-1]:.4f} " + f"x0={new_x0[-1]:.4f} v={new_v[-1]:.4f} " + f"t̄={new_t_stats[-1][0]:.3f} " f"sample={sample.sample_id} hw=({cur_h},{cur_w}) tokens={token_h*token_w} " f"elapsed={elapsed:.1f}s", flush=True, diff --git a/train_u1/tests/test_collator.py b/train_u1/tests/test_collator.py index bb8267ef6310fd7f362b96cb25c0216d0f2fcd7d..bdd0368b5a19cc567a06b853723fe28acc50db7b 100644 --- a/train_u1/tests/test_collator.py +++ b/train_u1/tests/test_collator.py @@ -59,6 +59,8 @@ def test_collator_shapes_match_forward_contract(): assert batch["position_indexes"].shape == (3, L_text + N) assert batch["attn_mask"].shape == (1, 1, L_text + N, L_text + N) assert batch["attn_mask_prefix"].shape == (1, 1, L_text, L_text) + assert batch["cond_drop_mode"] == ["none"] + assert batch["prefix_cache_key"] == ["cond"] def test_collator_rejects_batch_gt_one_in_native_mode(): @@ -129,3 +131,54 @@ def test_collator_attn_mask_block_invariants(): # text rows are blind to image rows (image t > text t at higher arange) assert not can_attend[:L_text, L_text:].any() + + +def test_condition_dropout_text_mode_uses_uncond_prefix_key(): + ds = SyntheticT2ITinyDataset(n=1, image_hw=(64, 64)) + cfg = CollatorConfig( + image_hw=(64, 64), + seed=0, + cond_dropout_text=1.0, + cond_dropout_both=0.0, + ) + collator = SenseNovaU1Collator(_MockTokenizer(), cfg=cfg) + + batch = collator([ds[0]]) + + assert batch["cond_drop_mode"] == ["text"] + assert batch["prefix_cache_key"] == ["uncond"] + assert batch["cond_drop_text"].tolist() == [True] + + +def test_condition_dropout_both_mode_is_logged_separately(): + ds = SyntheticT2ITinyDataset(n=1, image_hw=(64, 64)) + cfg = CollatorConfig( + image_hw=(64, 64), + seed=0, + cond_dropout_text=0.0, + cond_dropout_both=1.0, + ) + collator = SenseNovaU1Collator(_MockTokenizer(), cfg=cfg) + + batch = collator([ds[0]]) + + assert batch["cond_drop_mode"] == ["text_image"] + assert batch["prefix_cache_key"] == ["uncond"] + assert batch["cond_drop_text"].tolist() == [True] + + +def test_condition_dropout_can_be_forced_off(): + ds = SyntheticT2ITinyDataset(n=1, image_hw=(64, 64)) + cfg = CollatorConfig( + image_hw=(64, 64), + seed=0, + cond_dropout_text=0.0, + cond_dropout_both=0.0, + ) + collator = SenseNovaU1Collator(_MockTokenizer(), cfg=cfg) + + batch = collator([ds[0]]) + + assert batch["cond_drop_mode"] == ["none"] + assert batch["prefix_cache_key"] == ["cond"] + assert batch["cond_drop_text"].tolist() == [False] diff --git a/train_u1/tests/test_dataset_format.py b/train_u1/tests/test_dataset_format.py index 12fff1a6929bf5368d78d7c906a00cd9ceeb1045..94d12f287ce410100025f9540ceb4d391ad2ce2b 100644 --- a/train_u1/tests/test_dataset_format.py +++ b/train_u1/tests/test_dataset_format.py @@ -94,6 +94,23 @@ def test_paired_folder_dataset_legacy_split(tmp_path: Path) -> None: assert s.think == "legacy think text" +def test_paired_folder_can_ignore_think_labels(tmp_path: Path) -> None: + """Training defaults can keep prefixes short even if sidecar think files exist.""" + from PIL import Image + + from train_u1.data.datasets import PairedFolderT2IDataset + + img = Image.new("RGB", (64, 64), color=(0, 0, 0)) + img.save(tmp_path / "ignore_001.jpg") + (tmp_path / "ignore_001.txt").write_text("caption\n---think---\nembedded think") + (tmp_path / "ignore_001.think.txt").write_text("legacy think") + + ds = PairedFolderT2IDataset(tmp_path, use_think_labels=False) + s = ds[0] + assert s.prompt == "caption" + assert s.think is None + + def test_arrow_dataset_roundtrip(tmp_path: Path) -> None: """ArrowT2IDataset reads back what `dataset_tools pack-arrow` writes.""" pa = pytest.importorskip("pyarrow") @@ -123,6 +140,9 @@ def test_arrow_dataset_roundtrip(tmp_path: Path) -> None: assert s.think == "think a" assert s.image.shape[0] == 3 # CHW + ds_no_think = ArrowT2IDataset(out, use_think_labels=False) + assert ds_no_think[0].think is None + def test_think_delimiter_regex_compiled() -> None: """Sanity: regex matches the canonical marker.""" diff --git a/train_u1/tests/test_lora_spec.py b/train_u1/tests/test_lora_spec.py index c232ef2a37c50bf881510243e28160f255505e6d..bb2e3a41a8847fa93a68719b76e589314bbbb637 100644 --- a/train_u1/tests/test_lora_spec.py +++ b/train_u1/tests/test_lora_spec.py @@ -4,13 +4,19 @@ from __future__ import annotations from pathlib import Path import pytest +import torch.nn as nn from train_u1.config import load_train_config from train_u1.model.lora import ( ALL_KNOWN_TARGETS, ATTN_TARGETS, + DENSE_KNOWN_TARGETS, + GEN_MOE_MLP_TARGETS, + GEN_MOE_ROUTER_TARGETS, LORA_PRESETS, LoRASpec, + LoraAdapter, + apply_lora_specs, parse_lora_spec_str, resolve_preset, ) @@ -47,7 +53,7 @@ def test_off_disables() -> None: def test_group_expansion_all_three() -> None: specs = parse_lora_spec_str("all=r64a64") targets = {s.target for s in specs} - assert targets == set(ALL_KNOWN_TARGETS) + assert targets == set(DENSE_KNOWN_TARGETS) def test_unknown_target_rejected() -> None: @@ -59,8 +65,8 @@ def test_preset_default_matches_official_coverage() -> None: """Default preset must match the official 8-step LoRA's module coverage.""" specs = resolve_preset("default") targets = {s.target for s in specs} - # 168 attn + 126 mlp + 2 fm_head ≡ ALL_KNOWN_TARGETS at the per-target level. - assert targets == set(ALL_KNOWN_TARGETS) + # 168 attn + 126 mlp + 2 fm_head: stable 8B dense coverage. + assert targets == set(DENSE_KNOWN_TARGETS) # All at rank 64 / alpha 64 (our reduction from upstream's r=128). for s in specs: assert s.r == 64 @@ -76,24 +82,12 @@ def test_official_r128_preset() -> None: def test_yaml_default_config() -> None: cfg = load_train_config(Path(__file__).parent.parent.parent / "configs" / "default.yaml") - assert cfg.lora.preset == "default" + assert cfg.lora.preset == "attn_mlp_no_head" assert cfg.style.prompt_template == "official" specs = cfg.lora.resolved_specs() - assert len(specs) == 9 # 4 attn + 3 mlp + 2 fm_head + assert len(specs) == 7 # 4 attn + 3 mlp; fm_head is full-FT'd separately -def test_yaml_v16c_config() -> None: - cfg = load_train_config(Path(__file__).parent.parent.parent / "configs" / "v16c.yaml") - # v16c uses an explicit spec (attn+mlp only — fm_head is full-FT) - assert cfg.lora.spec is not None - specs = cfg.lora.resolved_specs() - targets = {s.target for s in specs} - assert "fm_modules.fm_head.0" not in targets - assert "q_proj_mot_gen" in targets - assert "mlp_mot_gen.gate_proj" in targets - # And vision_model + fm_head are in the full-FT regex list - assert any("vision_model_mot_gen" in p for p in cfg.unfreeze) - assert any("fm_head" in p for p in cfg.unfreeze) def test_dropout_propagates_to_specs() -> None: @@ -115,3 +109,86 @@ def test_preset_list_includes_default() -> None: assert "attn_only" in LORA_PRESETS assert "attn_mlp" in LORA_PRESETS assert "official_r128" in LORA_PRESETS + assert "a3b_moe_r8" in LORA_PRESETS + + +class _DummyExpert(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gate_proj = nn.Linear(8, 4, bias=False) + self.up_proj = nn.Linear(8, 4, bias=False) + self.down_proj = nn.Linear(4, 8, bias=False) + + +class _DummyMoEMLP(nn.Module): + def __init__(self, n_experts: int = 3) -> None: + super().__init__() + self.experts = nn.ModuleList([_DummyExpert() for _ in range(n_experts)]) + self.gate = nn.Linear(8, n_experts, bias=False) + + +class _DummyAttn(nn.Module): + def __init__(self) -> None: + super().__init__() + self.q_proj_mot_gen = nn.Linear(8, 8, bias=False) + self.k_proj_mot_gen = nn.Linear(8, 2, bias=False) + self.v_proj_mot_gen = nn.Linear(8, 2, bias=False) + self.o_proj_mot_gen = nn.Linear(8, 8, bias=False) + + +class _DummyLayer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.self_attn = _DummyAttn() + self.mlp_mot_gen = _DummyMoEMLP() + + +class _DummyInnerModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = nn.ModuleList([_DummyLayer(), _DummyLayer()]) + + +class _DummyLanguageModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = _DummyInnerModel() + + +class _DummyA3BModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.language_model = _DummyLanguageModel() + + +def test_moe_target_groups_parse() -> None: + specs = parse_lora_spec_str("gen_moe_mlp=r8a8;gen_moe_router=r4a4") + targets = {s.target for s in specs} + assert targets == set(GEN_MOE_MLP_TARGETS + GEN_MOE_ROUTER_TARGETS) + assert all(s.r == 8 for s in specs if s.target in GEN_MOE_MLP_TARGETS) + assert all(s.r == 4 for s in specs if s.target in GEN_MOE_ROUTER_TARGETS) + + +def test_moe_specific_expert_target_parse() -> None: + [spec] = parse_lora_spec_str("mlp_mot_gen.experts.0.gate_proj=r2a2") + assert spec.target == "mlp_mot_gen.experts.0.gate_proj" + assert spec.r == 2 + assert spec.alpha == 2.0 + + +def test_apply_moe_lora_specs_on_dummy_model() -> None: + model = _DummyA3BModel() + specs = parse_lora_spec_str( + "mlp_mot_gen.experts.*.gate_proj=r2a2;" + "mlp_mot_gen.experts.0.down_proj=r2a2;" + "gen_moe_router=r2a2" + ) + + report = apply_lora_specs(model, specs) + + # 2 layers × 3 experts for gate_proj, plus 2 layers × expert 0 down_proj, + # plus 2 layer routers. + assert report.n_wrapped == 10 + assert isinstance(model.language_model.model.layers[0].mlp_mot_gen.experts[0].gate_proj, LoraAdapter) + assert isinstance(model.language_model.model.layers[0].mlp_mot_gen.experts[0].down_proj, LoraAdapter) + assert isinstance(model.language_model.model.layers[0].mlp_mot_gen.gate, LoraAdapter) diff --git a/train_u1/tests/test_official_consistency.py b/train_u1/tests/test_official_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..59e4bfaaf33274bfb5d9cb50df7366cb5150918b --- /dev/null +++ b/train_u1/tests/test_official_consistency.py @@ -0,0 +1,289 @@ +"""Tests for the SenseNova-U1 official-consistency knobs: +- v-loss / Huber dispatcher +- logit-normal t sampler +- `attn_mlp_no_head` / `attn_only_no_head` LoRA presets +- YAML round-trip of the new TrainConfig fields +""" +from __future__ import annotations + +import math +from pathlib import Path + +import pytest +import torch + +from train_u1.config import ( + TrainRunConfig, + dump_train_config, + load_train_config, +) +from train_u1.data.collators import CollatorConfig, SenseNovaU1Collator +from train_u1.model.lora import ( + FM_HEAD_TARGETS, + LORA_PRESETS, + resolve_preset, +) +from train_u1.model.losses import ( + VALID_LOSS_TYPES, + compute_v_target, + fm_loss, + fm_loss_v, + fm_loss_x0, +) +from train_u1.model.patching import linear_z_t + + +# --------------------------------------------------------------------------- # +# Loss dispatcher # +# --------------------------------------------------------------------------- # + +def _toy_batch(B: int = 1, N: int = 8, D: int = 4, *, seed: int = 0): + g = torch.Generator().manual_seed(seed) + x0 = torch.randn(B, N, D, generator=g) + eps = torch.randn(B, N, D, generator=g) + t = torch.rand(B, generator=g).clamp(min=0.05, max=0.95) + t_b = t.view(B, 1, 1) + z_t = linear_z_t(x0, eps, t) + x_pred = x0 + 0.1 * torch.randn(B, N, D, generator=g) + v_target = compute_v_target(x0, z_t, t) + v_pred = (x_pred - z_t) / (1.0 - t_b).clamp(min=1e-3) + return x0, z_t, t, x_pred, v_pred, v_target + + +def test_fm_loss_dispatcher_x0() -> None: + x0, z_t, t, x_pred, v_pred, v_target = _toy_batch() + l = fm_loss(loss_type="x0", x_pred=x_pred, x0_patch=x0, v_pred=v_pred, v_target=v_target) + expected = fm_loss_x0(x_pred, x0) + assert torch.isclose(l, expected) + + +def test_fm_loss_dispatcher_v() -> None: + x0, z_t, t, x_pred, v_pred, v_target = _toy_batch() + l = fm_loss(loss_type="v", x_pred=x_pred, x0_patch=x0, v_pred=v_pred, v_target=v_target) + expected = fm_loss_v(v_pred, v_target) + assert torch.isclose(l, expected) + + +def test_fm_loss_v_equiv_to_reweighted_x0() -> None: + """`MSE(v) == MSE((x_pred - x0)/(1-t))` — the (1-t)^-2 re-weight identity.""" + x0, z_t, t, x_pred, _v_pred, v_target = _toy_batch(B=1, N=128, D=8, seed=42) + t_b = t.view(1, 1, 1) + # Reconstruct v_pred from x_pred and z_t in the exact same way wrapper does: + v_pred = (x_pred - z_t) / (1.0 - t_b).clamp(min=1e-3) + lhs = fm_loss_v(v_pred, v_target) + rhs = ((x_pred - x0) / (1.0 - t_b).clamp(min=1e-3)).pow(2).mean() + assert torch.isclose(lhs, rhs, atol=1e-6, rtol=1e-5) + + +def test_v_target_uses_same_t_eps_as_v_pred_near_t_one() -> None: + """High-t uniform ablations must clamp target and prediction identically.""" + x0 = torch.tensor([[[1.0, -1.0]]]) + z_t = torch.tensor([[[0.5, -0.25]]]) + t = torch.tensor([0.99]) + x_pred = torch.tensor([[[0.75, -0.5]]]) + t_eps = 0.05 + denom = torch.tensor([[[t_eps]]]) + + v_target = compute_v_target(x0, z_t, t, t_eps=t_eps) + v_pred = (x_pred - z_t) / denom + + assert torch.allclose(v_target, (x0 - z_t) / denom) + assert torch.isfinite(fm_loss_v(v_pred, v_target)) + + +def test_fm_loss_huber_variants() -> None: + x0, z_t, t, x_pred, v_pred, v_target = _toy_batch() + lh_x0 = fm_loss(loss_type="x0_huber", x_pred=x_pred, x0_patch=x0, + v_pred=v_pred, v_target=v_target, huber_delta=0.5) + lh_v = fm_loss(loss_type="v_huber", x_pred=x_pred, x0_patch=x0, + v_pred=v_pred, v_target=v_target, huber_delta=0.5) + assert lh_x0.item() >= 0 + assert lh_v.item() >= 0 + + +def test_fm_loss_unknown_raises() -> None: + with pytest.raises(ValueError, match="unknown loss_type"): + fm_loss(loss_type="bogus", x_pred=torch.zeros(1, 1, 1), + x0_patch=torch.zeros(1, 1, 1)) + + +def test_fm_loss_v_requires_v_args() -> None: + x0 = torch.zeros(1, 1, 1) + with pytest.raises(ValueError, match="requires v_pred"): + fm_loss(loss_type="v", x_pred=x0, x0_patch=x0) + + +def test_compute_v_target_formula() -> None: + """Spot-check `v* = (x0 - z_t)/(1-t)` is what compute_v_target returns.""" + x0 = torch.tensor([[[1.0, 2.0]]]) + z_t = torch.tensor([[[0.3, 0.5]]]) + t = torch.tensor([0.25]) + v = compute_v_target(x0, z_t, t) + # (1-0.25) = 0.75 + expected = torch.tensor([[[(1.0 - 0.3) / 0.75, (2.0 - 0.5) / 0.75]]]) + assert torch.allclose(v, expected) + + +def test_valid_loss_types_constant() -> None: + assert "x0" in VALID_LOSS_TYPES + assert "v" in VALID_LOSS_TYPES + assert "x0_huber" in VALID_LOSS_TYPES + assert "v_huber" in VALID_LOSS_TYPES + + +# --------------------------------------------------------------------------- # +# Collator t sampler # +# --------------------------------------------------------------------------- # + +class _FakeTokenizer: + pad_token_id = 0 + + def __call__(self, prompts, return_tensors="pt", padding=True): + # Return a dummy ids tensor of shape (B, L). Lengths don't matter for + # _sample_t — the test only inspects t. + return {"input_ids": torch.zeros(len(prompts), 1, dtype=torch.long)} + + +def test_t_sampler_uniform_range() -> None: + cfg = CollatorConfig(t_dist="uniform", seed=123, t_eps=0.01) + c = SenseNovaU1Collator(_FakeTokenizer(), cfg=cfg) + t = c._sample_t(2000) + assert t.min().item() >= cfg.t_eps - 1e-6 + assert t.max().item() <= 1.0 + 1e-6 + # uniform mean ~ (eps + 1) / 2 = 0.505 ; allow a wide tolerance + assert 0.45 < t.mean().item() < 0.55 + + +def test_t_sampler_logit_normal_stats() -> None: + cfg = CollatorConfig( + t_dist="logit_normal", t_logit_mean=-0.8, t_logit_std=0.8, seed=7, t_eps=1e-3, + ) + c = SenseNovaU1Collator(_FakeTokenizer(), cfg=cfg) + t = c._sample_t(5000) + # All in (eps, 1-eps) + assert t.min().item() >= cfg.t_eps - 1e-6 + assert t.max().item() <= 1.0 - cfg.t_eps + 1e-6 + # Median(sigmoid(N(-0.8, 0.8))) = sigmoid(-0.8) ≈ 0.310 — strongly biased low. + # With 5000 samples, empirical mean should land near 0.34 (Monte Carlo of + # E[sigmoid(N(-0.8, 0.8))]); accept anything in [0.28, 0.40]. + assert 0.28 < t.mean().item() < 0.40 + # Median much closer to 0.31 than to 0.5 (uniform mean). + assert t.median().item() < 0.42 + + +def test_t_sampler_logit_normal_eps_clamp() -> None: + """Even tiny tail of sigmoid is clipped to [t_eps, 1 - t_eps].""" + cfg = CollatorConfig(t_dist="logit_normal", t_logit_mean=0.0, t_logit_std=10.0, + seed=0, t_eps=0.05) + c = SenseNovaU1Collator(_FakeTokenizer(), cfg=cfg) + t = c._sample_t(2000) + assert t.min().item() >= 0.05 - 1e-6 + assert t.max().item() <= 0.95 + 1e-6 + + +def test_t_sampler_invalid_dist() -> None: + cfg = CollatorConfig(t_dist="bogus") + c = SenseNovaU1Collator(_FakeTokenizer(), cfg=cfg) + with pytest.raises(NotImplementedError): + c._sample_t(4) + + +def test_collator_defaults_do_not_enable_condition_dropout() -> None: + """Diagnostics using CollatorConfig directly stay fully conditional by default.""" + cfg = CollatorConfig() + assert cfg.cond_dropout_text == 0.0 + assert cfg.cond_dropout_both == 0.0 + + +# --------------------------------------------------------------------------- # +# LoRA presets # +# --------------------------------------------------------------------------- # + +def test_no_head_presets_registered() -> None: + assert "attn_only_no_head" in LORA_PRESETS + assert "attn_mlp_no_head" in LORA_PRESETS + + +def test_no_head_presets_exclude_fm_head() -> None: + for name in ("attn_only_no_head", "attn_mlp_no_head"): + specs = resolve_preset(name) + targets = {s.target for s in specs} + for fmt in FM_HEAD_TARGETS: + assert fmt not in targets, f"{name} unexpectedly includes {fmt}" + + +def test_attn_mlp_no_head_targets() -> None: + specs = resolve_preset("attn_mlp_no_head") + targets = {s.target for s in specs} + assert "q_proj_mot_gen" in targets + assert "mlp_mot_gen.gate_proj" in targets + assert all(s.r == 64 and s.alpha == 64.0 for s in specs) + + +# --------------------------------------------------------------------------- # +# YAML round-trip of new fields # +# --------------------------------------------------------------------------- # + +def test_yaml_roundtrip_loss_and_t(tmp_path: Path) -> None: + cfg = TrainRunConfig() + cfg.train.loss_type = "v" + cfg.train.huber_delta = 1.5 + cfg.train.t_dist = "logit_normal" + cfg.train.t_logit_mean = -0.8 + cfg.train.t_logit_std = 0.8 + cfg.train.cond_dropout_text = 0.2 + cfg.train.cond_dropout_both = 0.05 + + path = tmp_path / "rt.yaml" + dump_train_config(cfg, path) + cfg2 = load_train_config(path) + assert cfg2.train.loss_type == "v" + assert cfg2.train.huber_delta == 1.5 + assert cfg2.train.t_dist == "logit_normal" + assert math.isclose(cfg2.train.t_logit_mean, -0.8) + assert math.isclose(cfg2.train.t_logit_std, 0.8) + assert math.isclose(cfg2.train.cond_dropout_text, 0.2) + assert math.isclose(cfg2.train.cond_dropout_both, 0.05) + + +def test_yaml_defaults_use_local_baseline() -> None: + """A config without explicit FM knobs uses the v16c local baseline.""" + cfg = TrainRunConfig() + assert cfg.train.loss_type == "x0" + assert cfg.train.t_dist == "uniform" + assert math.isclose(cfg.train.t_logit_mean, -0.8) + assert math.isclose(cfg.train.t_logit_std, 0.8) + assert math.isclose(cfg.train.cond_dropout_text, 0.0) + assert math.isclose(cfg.train.cond_dropout_both, 0.0) + assert cfg.data.use_think_labels is False + + +def test_default_yaml_uses_local_baseline() -> None: + """The shipped `configs/default.yaml` uses the v16c local baseline.""" + repo_root = Path(__file__).resolve().parents[2] + p = repo_root / "configs" / "default.yaml" + if not p.exists(): # source checkout might not include configs/ + pytest.skip(f"{p} not present") + cfg = load_train_config(p) + assert cfg.lora.preset == "attn_mlp_no_head" + assert cfg.data.use_think_labels is False + assert cfg.train.loss_type == "x0" + assert cfg.train.t_dist == "uniform" + assert math.isclose(cfg.train.cond_dropout_text, 0.0) + assert math.isclose(cfg.train.cond_dropout_both, 0.0) + + +def test_official_alignment_yaml_uses_official_knobs() -> None: + """Official report knobs are still available as an explicit optional config.""" + repo_root = Path(__file__).resolve().parents[2] + p = repo_root / "configs" / "official_alignment.yaml" + if not p.exists(): + pytest.skip(f"{p} not present") + cfg = load_train_config(p) + assert cfg.data.use_think_labels is True + assert cfg.train.loss_type == "v" + assert cfg.train.t_dist == "logit_normal" + assert math.isclose(cfg.train.t_logit_mean, -0.8) + assert math.isclose(cfg.train.t_logit_std, 0.8) + assert math.isclose(cfg.train.cond_dropout_text, 0.10) + assert math.isclose(cfg.train.cond_dropout_both, 0.10) diff --git a/train_u1/tests/test_train_bf16_offload_helpers.py b/train_u1/tests/test_train_bf16_offload_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..343518240c153deb088ecdc516f9ffc2f46495be --- /dev/null +++ b/train_u1/tests/test_train_bf16_offload_helpers.py @@ -0,0 +1,59 @@ +"""Pure helper tests for the bf16 offload trainer.""" +from __future__ import annotations + +import pytest + +from train_u1.scripts.train_bf16_offload import ( + _guard_static_prefix_unfreeze, + _select_prefix_kv, +) + + +def test_select_prefix_kv_routes_cond_and_uncond() -> None: + cond0 = object() + cond1 = object() + uncond = object() + prefix_kvs = {"cond": [cond0, cond1], "uncond": uncond} + + assert _select_prefix_kv(prefix_kvs, 1, "cond") is cond1 + assert _select_prefix_kv(prefix_kvs, 0, "uncond") is uncond + + +def test_select_prefix_kv_requires_uncond_cache() -> None: + with pytest.raises(RuntimeError, match="missing unconditional prefix KV"): + _select_prefix_kv({"cond": [object()], "uncond": None}, 0, "uncond") + + +def test_select_prefix_kv_rejects_unknown_key() -> None: + with pytest.raises(ValueError, match="unknown prefix_cache_key"): + _select_prefix_kv({"cond": [object()], "uncond": object()}, 0, "bogus") + + +def test_static_prefix_guard_rejects_prefix_unfreeze() -> None: + classify = { + "prefix": ["language_model.model.layers.0.self_attn.q_proj.weight"], + "unused": ["vision_model.patch_embed.proj.weight"], + "gen": ["fm_modules.fm_head.0.weight"], + } + with pytest.raises(ValueError, match="static prefix KV cache"): + _guard_static_prefix_unfreeze([r"language_model\.model\.layers"], classify) + + +def test_static_prefix_guard_rejects_unused_unfreeze() -> None: + classify = { + "prefix": ["language_model.model.layers.0.self_attn.q_proj.weight"], + "unused": ["vision_model.patch_embed.proj.weight"], + "gen": ["fm_modules.fm_head.0.weight"], + } + with pytest.raises(ValueError, match="static prefix KV cache"): + _guard_static_prefix_unfreeze([r"^vision_model\."], classify) + + +def test_static_prefix_guard_allows_gen_unfreeze() -> None: + classify = { + "prefix": ["language_model.model.layers.0.self_attn.q_proj.weight"], + "unused": ["vision_model.patch_embed.proj.weight"], + "gen": ["fm_modules.fm_head.0.weight"], + } + _guard_static_prefix_unfreeze([r"^fm_modules\.fm_head\."], classify) +