Official-consistency ablation: code, configs, paper-style writeup
Browse filesMirror of GitHub commit 555d1c0. See https://github.com/Yidhar/sensenova-u1-lora-trainer/commit/555d1c0 for full message.
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +33 -0
- README.md +86 -29
- configs/default.yaml +31 -14
- configs/official_alignment.yaml +57 -0
- docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg +3 -0
- docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg +3 -0
- docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg +3 -0
- docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg +3 -0
- docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg +3 -0
- docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg +3 -0
- docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg +3 -0
- docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg +3 -0
- docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg +3 -0
- docs/assets/small_data_style_ablation/detail00_v18.jpg +3 -0
- docs/assets/small_data_style_ablation/detail00_v19.jpg +3 -0
- docs/assets/small_data_style_ablation/detail00_v19a.jpg +3 -0
- docs/assets/small_data_style_ablation/detail00_v19b.jpg +3 -0
- docs/assets/small_data_style_ablation/detail00_v19c.jpg +3 -0
- docs/assets/small_data_style_ablation/detail01_v18.jpg +3 -0
- docs/assets/small_data_style_ablation/detail01_v19.jpg +3 -0
- docs/assets/small_data_style_ablation/detail01_v19a.jpg +3 -0
- docs/assets/small_data_style_ablation/detail01_v19b.jpg +3 -0
- docs/assets/small_data_style_ablation/detail02_v18.jpg +3 -0
- docs/assets/small_data_style_ablation/detail02_v19.jpg +3 -0
- docs/assets/small_data_style_ablation/detail02_v19a.jpg +3 -0
- docs/assets/small_data_style_ablation/detail02_v19b.jpg +3 -0
- docs/assets/small_data_style_ablation/detail02_v19c.jpg +3 -0
- docs/assets/small_data_style_ablation/detail07_v18.jpg +3 -0
- docs/assets/small_data_style_ablation/detail07_v19.jpg +3 -0
- docs/assets/small_data_style_ablation/detail07_v19a.jpg +3 -0
- docs/assets/small_data_style_ablation/detail07_v19b.jpg +3 -0
- docs/assets/small_data_style_ablation/detail07_v19c.jpg +3 -0
- docs/assets/small_data_style_ablation/tb_active_loss_curves.png +0 -0
- docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png +0 -0
- docs/assets/small_data_style_ablation/tb_t_distribution.png +0 -0
- docs/assets/small_data_style_ablation/tb_v_mse_curves.png +0 -0
- docs/assets/small_data_style_ablation/tb_x0_mse_curves.png +0 -0
- docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg +3 -0
- docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg +3 -0
- docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg +3 -0
- docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg +3 -0
- docs/small_data_style_ablation.html +941 -0
- docs/small_data_style_ablation.pdf +3 -0
- train.sh +4 -4
- train_u1/README.md +4 -1
- train_u1/config.py +41 -5
- train_u1/data/collators.py +121 -34
- train_u1/data/datasets.py +9 -1
- train_u1/model/lora.py +108 -11
- train_u1/model/losses.py +67 -5
.gitattributes
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
docs/assets/small_data_style_ablation/detail00_v18.jpg filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
docs/assets/small_data_style_ablation/detail00_v19.jpg filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
docs/assets/small_data_style_ablation/detail00_v19a.jpg filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
docs/assets/small_data_style_ablation/detail00_v19b.jpg filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
docs/assets/small_data_style_ablation/detail00_v19c.jpg filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
docs/assets/small_data_style_ablation/detail01_v18.jpg filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
docs/assets/small_data_style_ablation/detail01_v19.jpg filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
docs/assets/small_data_style_ablation/detail01_v19a.jpg filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
docs/assets/small_data_style_ablation/detail01_v19b.jpg filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
docs/assets/small_data_style_ablation/detail02_v18.jpg filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
docs/assets/small_data_style_ablation/detail02_v19.jpg filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
docs/assets/small_data_style_ablation/detail02_v19a.jpg filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
docs/assets/small_data_style_ablation/detail02_v19b.jpg filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
docs/assets/small_data_style_ablation/detail02_v19c.jpg filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
docs/assets/small_data_style_ablation/detail07_v18.jpg filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
docs/assets/small_data_style_ablation/detail07_v19.jpg filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
docs/assets/small_data_style_ablation/detail07_v19a.jpg filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
docs/assets/small_data_style_ablation/detail07_v19b.jpg filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
docs/assets/small_data_style_ablation/detail07_v19c.jpg filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
docs/small_data_style_ablation.pdf filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -16,7 +16,7 @@ checkpoints drop straight into the official inference scripts.
|
|
| 16 |
```
|
| 17 |
|
| 18 |
Fits on a 32 GB GPU (RTX 5090 / A100-40 / RTX 6000 Ada). Peak VRAM ~20 GB
|
| 19 |
-
on the
|
| 20 |
|
| 21 |
---
|
| 22 |
|
|
@@ -25,16 +25,26 @@ on the maintainer's 56-image hayateluc dataset at 2048².
|
|
| 25 |
- **Config-first**: every run is one YAML file (`configs/default.yaml`).
|
| 26 |
- **Per-module rank + enable**: each LoRA target (`q_proj_mot_gen`, `mlp_mot_gen.down_proj`,
|
| 27 |
`fm_modules.fm_head.0`, …) takes its own rank / alpha / on-off independently.
|
| 28 |
-
- **
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
- **Upstream-format save**: load straight into `examples/t2i/inference.py`
|
| 33 |
via `--lora_path`, or stack with the official 8-step LoRA.
|
| 34 |
- **bf16 training, not 4/8-bit**. Earlier 4-bit nf4 LoRA training produced
|
| 35 |
grid artefacts and limb collapse on the gen tower; switching the base to
|
| 36 |
bf16 (with offload + static prefix-KV cache) eliminated both.
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
---
|
| 39 |
|
| 40 |
## Hardware
|
|
@@ -48,6 +58,29 @@ on the maintainer's 56-image hayateluc dataset at 2048².
|
|
| 48 |
`bitsandbytes>=0.45` and `torch>=2.9` must be linked against your CUDA
|
| 49 |
runtime. On RTX 5090 (sm_120) you'll likely need the cu128 torch wheel.
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
---
|
| 52 |
|
| 53 |
## Install
|
|
@@ -85,10 +118,10 @@ HF_HOME=$PWD/hf_cache python -m train_u1.scripts.install_modeling_into_snapshot
|
|
| 85 |
└── … └── …
|
| 86 |
```
|
| 87 |
|
| 88 |
-
Each `.txt` is a single-paragraph natural-language caption.
|
| 89 |
-
artist
|
| 90 |
-
|
| 91 |
-
|
| 92 |
|
| 93 |
**Optional**: append a `<think>...</think>` reasoning label inside the
|
| 94 |
same `.txt` after a `---think---` delimiter line:
|
|
@@ -101,12 +134,11 @@ HF_HOME=$PWD/hf_cache python -m train_u1.scripts.install_modeling_into_snapshot
|
|
| 101 |
6. **Explicit Prompt:** ...
|
| 102 |
```
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
`
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
6-section format.
|
| 110 |
|
| 111 |
**Parquet/arrow shards** (recommended for ≥ ~10k images, e.g. 1M
|
| 112 |
scaling):
|
|
@@ -192,6 +224,7 @@ data:
|
|
| 192 |
data_dir: dataset/my_style
|
| 193 |
cap_max_pixels: 4194304 # 2048² hard cap per image
|
| 194 |
snap_bucket: true # snap to upstream bucket grid
|
|
|
|
| 195 |
# n_samples: 56 # cap dataset size; default = use everything
|
| 196 |
|
| 197 |
style:
|
|
@@ -199,13 +232,15 @@ style:
|
|
| 199 |
prompt_template: official # 'official' (recommended) | 'plain'
|
| 200 |
|
| 201 |
lora:
|
| 202 |
-
preset:
|
| 203 |
-
# spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off
|
| 204 |
dropout: 0.0
|
| 205 |
|
| 206 |
unfreeze: # full-FT (non-LoRA) regex patterns
|
| 207 |
- '^fm_modules\.timestep_embedder\.'
|
| 208 |
- '^fm_modules\.noise_scale_embedder\.'
|
|
|
|
|
|
|
| 209 |
|
| 210 |
train:
|
| 211 |
steps: 6000
|
|
@@ -214,6 +249,15 @@ train:
|
|
| 214 |
shuffle: true
|
| 215 |
grad_accum: 1
|
| 216 |
checkpoint_every: 600
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
runtime:
|
| 219 |
keep_kvs_on_gpu: true
|
|
@@ -259,30 +303,41 @@ q_proj_mot_gen=r=128,a=64;k_proj_mot_gen=r=64,a=64 # asymmetric ranks
|
|
| 259 |
|
| 260 |
| Preset | Coverage | Trainable LoRA params | Use when |
|
| 261 |
|---|---|---|---|
|
| 262 |
-
| `default` | 168 attn + 126 mlp + 2 fm_head, all r=64 | ~75 M |
|
| 263 |
| `attn_only` | 168 attn, r=64 | ~50 M | ablation |
|
| 264 |
| `attn_mlp` | attn + mlp (no fm_head), r=64 | ~75 M | when fm_head is full-FT'd separately |
|
|
|
|
|
|
|
| 265 |
| `official_r128` | exact upstream shape (r=128 across all 296 wraps) | ~298 M | parameter-matching upstream's 8-step LoRA |
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
---
|
| 268 |
|
| 269 |
## Stack with the official 8-step distill LoRA
|
| 270 |
|
| 271 |
Upstream released a step-distillation LoRA that brings inference down to 8
|
| 272 |
-
NFE at `cfg_scale=1.0`. You can train your own style LoRA **on top** of it
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
```yaml
|
| 275 |
-
# configs/stack_8step.yaml (already in this repo)
|
| 276 |
runtime:
|
| 277 |
upstream_lora_path: hf_cache/.../SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors
|
| 278 |
upstream_lora_skip: ['fm_modules.fm_head'] # don't clobber our fm_head LoRA
|
| 279 |
```
|
| 280 |
|
| 281 |
-
At sample time,
|
| 282 |
|
| 283 |
```bash
|
| 284 |
-
./sample.sh configs/
|
| 285 |
-
artifacts/
|
| 286 |
--prompt "…" \
|
| 287 |
--upstream-lora-path SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors \
|
| 288 |
--upstream-lora-skip fm_modules.fm_head \
|
|
@@ -302,9 +357,8 @@ At sample time, also pass the same upstream LoRA:
|
|
| 302 |
├── pyproject.toml # package metadata
|
| 303 |
├── LICENSE # Apache-2.0
|
| 304 |
├── configs/
|
| 305 |
-
│ ├── default.yaml #
|
| 306 |
-
│
|
| 307 |
-
│ └── stack_8step.yaml # train on top of 8-step distill LoRA
|
| 308 |
├── train_u1/ # importable package
|
| 309 |
│ ├── config.py # YAML config schema
|
| 310 |
│ ├── constants.py # pinned MODEL_SHA / CODE_COMMIT / arch constants
|
|
@@ -314,7 +368,7 @@ At sample time, also pass the same upstream LoRA:
|
|
| 314 |
│ │ ├── lora_io.py # save/load + upstream merge
|
| 315 |
│ │ ├── loader.py # bf16 base load + tower offload
|
| 316 |
│ │ ├── wrapper.py # forward_t2i_step
|
| 317 |
-
│ │ ├── losses.py # fm_loss_x0
|
| 318 |
│ │ ├── patching.py # patchify/unpatchify
|
| 319 |
│ │ └── …
|
| 320 |
│ ├��─ scripts/
|
|
@@ -325,7 +379,10 @@ At sample time, also pass the same upstream LoRA:
|
|
| 325 |
│ │ └── install_modeling_into_snapshot.py
|
| 326 |
│ └── tests/
|
| 327 |
├── docs/
|
| 328 |
-
│
|
|
|
|
|
|
|
|
|
|
| 329 |
├── artifacts/ # local-only: checkpoints + sweeps (gitignored)
|
| 330 |
├── dataset/ # local-only: image+caption pairs (gitignored)
|
| 331 |
├── hf_cache/ # local-only: HF snapshot (gitignored)
|
|
@@ -346,4 +403,4 @@ At sample time, also pass the same upstream LoRA:
|
|
| 346 |
release; consumed via the `upstream_lora_path` mechanism.
|
| 347 |
- **This trainer** is licensed under Apache-2.0 (see `LICENSE`).
|
| 348 |
|
| 349 |
-
**Thanks to comfy.org for the GPU power support. The open-source community will not forget.**
|
|
|
|
| 16 |
```
|
| 17 |
|
| 18 |
Fits on a 32 GB GPU (RTX 5090 / A100-40 / RTX 6000 Ada). Peak VRAM ~20 GB
|
| 19 |
+
on the train dataset at 2048².
|
| 20 |
|
| 21 |
---
|
| 22 |
|
|
|
|
| 25 |
- **Config-first**: every run is one YAML file (`configs/default.yaml`).
|
| 26 |
- **Per-module rank + enable**: each LoRA target (`q_proj_mot_gen`, `mlp_mot_gen.down_proj`,
|
| 27 |
`fm_modules.fm_head.0`, …) takes its own rank / alpha / on-off independently.
|
| 28 |
+
- **Experimental MoE target grammar**: A3B-style generation experts can be
|
| 29 |
+
addressed explicitly (`gen_moe_mlp`, `gen_moe_router`,
|
| 30 |
+
`mlp_mot_gen.experts.*.gate_proj`) without changing the stable 8B main path.
|
| 31 |
+
- **Default = small-data style baseline**: `configs/default.yaml` uses
|
| 32 |
+
`x0 + uniform t + no condition dropout`, short captions, LoRA on attn+mlp,
|
| 33 |
+
and full fine-tuning of the timestep/noise embedders, gen vision bridge, and
|
| 34 |
+
fm_head.
|
| 35 |
+
- **Official-alignment recipe is optional**: `configs/official_alignment.yaml`
|
| 36 |
+
keeps the public report knobs together for research ablations, but it is not
|
| 37 |
+
the safest first run for small style datasets.
|
| 38 |
- **Upstream-format save**: load straight into `examples/t2i/inference.py`
|
| 39 |
via `--lora_path`, or stack with the official 8-step LoRA.
|
| 40 |
- **bf16 training, not 4/8-bit**. Earlier 4-bit nf4 LoRA training produced
|
| 41 |
grid artefacts and limb collapse on the gen tower; switching the base to
|
| 42 |
bf16 (with offload + static prefix-KV cache) eliminated both.
|
| 43 |
|
| 44 |
+
See the ablation write-up with training curves and sample grids:
|
| 45 |
+
[`docs/small_data_style_ablation.html`](docs/small_data_style_ablation.html)
|
| 46 |
+
or [`docs/small_data_style_ablation.pdf`](docs/small_data_style_ablation.pdf).
|
| 47 |
+
|
| 48 |
---
|
| 49 |
|
| 50 |
## Hardware
|
|
|
|
| 58 |
`bitsandbytes>=0.45` and `torch>=2.9` must be linked against your CUDA
|
| 59 |
runtime. On RTX 5090 (sm_120) you'll likely need the cu128 torch wheel.
|
| 60 |
|
| 61 |
+
### A3B / MoE Status
|
| 62 |
+
|
| 63 |
+
The trainer now has experimental target grammar for future
|
| 64 |
+
`SenseNova-U1-A3B-MoT` generation-side MoE LoRA work:
|
| 65 |
+
|
| 66 |
+
```yaml
|
| 67 |
+
lora:
|
| 68 |
+
spec: "attn=r8a8;gen_moe_mlp=r8a8;gen_moe_router=r8a8"
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
This is a compatibility layer, not the main training path and not an end-to-end
|
| 72 |
+
A3B training claim. The stable release target remains `SenseNova-U1-8B-MoT`;
|
| 73 |
+
A3B training depends on public MoE runtime support that can instantiate the
|
| 74 |
+
`mlp_mot_gen.experts.*` modules.
|
| 75 |
+
|
| 76 |
+
Before training, estimate MoE LoRA size from metadata only:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python -m train_u1.scripts.inspect_lora_targets \
|
| 80 |
+
--model path/to/A3B/config.json \
|
| 81 |
+
--spec "attn=r8a8;gen_moe_mlp=r8a8;fm_head=r8a8"
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
---
|
| 85 |
|
| 86 |
## Install
|
|
|
|
| 118 |
└── … └── …
|
| 119 |
```
|
| 120 |
|
| 121 |
+
Each `.txt` is a single-paragraph natural-language caption. Put the style
|
| 122 |
+
or artist anchor in a stable way and keep `style.trigger` aligned with how
|
| 123 |
+
you will sample later. The default config prepends that trigger to every
|
| 124 |
+
caption.
|
| 125 |
|
| 126 |
**Optional**: append a `<think>...</think>` reasoning label inside the
|
| 127 |
same `.txt` after a `---think---` delimiter line:
|
|
|
|
| 134 |
6. **Explicit Prompt:** ...
|
| 135 |
```
|
| 136 |
|
| 137 |
+
Think labels are **ignored by default** because low-quality or highly
|
| 138 |
+
templated think text can dominate the prefix and hurt style binding. To use
|
| 139 |
+
them, set `data.use_think_labels: true` and evaluate with the same think
|
| 140 |
+
distribution at sample time. Do this only when your think labels are
|
| 141 |
+
curated and repeatable.
|
|
|
|
| 142 |
|
| 143 |
**Parquet/arrow shards** (recommended for ≥ ~10k images, e.g. 1M
|
| 144 |
scaling):
|
|
|
|
| 224 |
data_dir: dataset/my_style
|
| 225 |
cap_max_pixels: 4194304 # 2048² hard cap per image
|
| 226 |
snap_bucket: true # snap to upstream bucket grid
|
| 227 |
+
use_think_labels: false # keep prefixes short by default
|
| 228 |
# n_samples: 56 # cap dataset size; default = use everything
|
| 229 |
|
| 230 |
style:
|
|
|
|
| 232 |
prompt_template: official # 'official' (recommended) | 'plain'
|
| 233 |
|
| 234 |
lora:
|
| 235 |
+
preset: attn_mlp_no_head # attn+mlp LoRA; fm_head is full-FT below
|
| 236 |
+
# spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off"
|
| 237 |
dropout: 0.0
|
| 238 |
|
| 239 |
unfreeze: # full-FT (non-LoRA) regex patterns
|
| 240 |
- '^fm_modules\.timestep_embedder\.'
|
| 241 |
- '^fm_modules\.noise_scale_embedder\.'
|
| 242 |
+
- '^fm_modules\.vision_model_mot_gen\.'
|
| 243 |
+
- '^fm_modules\.fm_head\.'
|
| 244 |
|
| 245 |
train:
|
| 246 |
steps: 6000
|
|
|
|
| 249 |
shuffle: true
|
| 250 |
grad_accum: 1
|
| 251 |
checkpoint_every: 600
|
| 252 |
+
# Small-data style baseline. See docs/small_data_style_ablation.html before
|
| 253 |
+
# switching to the official-alignment recipe.
|
| 254 |
+
loss_type: x0
|
| 255 |
+
t_dist: uniform
|
| 256 |
+
t_logit_mean: -0.8
|
| 257 |
+
t_logit_std: 0.8
|
| 258 |
+
# huber_delta: 1.0 # only used for *_huber
|
| 259 |
+
cond_dropout_text: 0.0
|
| 260 |
+
cond_dropout_both: 0.0
|
| 261 |
|
| 262 |
runtime:
|
| 263 |
keep_kvs_on_gpu: true
|
|
|
|
| 303 |
|
| 304 |
| Preset | Coverage | Trainable LoRA params | Use when |
|
| 305 |
|---|---|---|---|
|
| 306 |
+
| `default` | 168 attn + 126 mlp + 2 fm_head, all r=64 | ~75 M | match upstream 8-step LoRA coverage |
|
| 307 |
| `attn_only` | 168 attn, r=64 | ~50 M | ablation |
|
| 308 |
| `attn_mlp` | attn + mlp (no fm_head), r=64 | ~75 M | when fm_head is full-FT'd separately |
|
| 309 |
+
| `attn_only_no_head` | alias for `attn_only`; explicit no-fm_head intent | ~50 M | conservative small-data style training |
|
| 310 |
+
| `attn_mlp_no_head` | alias for `attn_mlp`; explicit no-fm_head intent | ~75 M | conservative small-data style training |
|
| 311 |
| `official_r128` | exact upstream shape (r=128 across all 296 wraps) | ~298 M | parameter-matching upstream's 8-step LoRA |
|
| 312 |
|
| 313 |
+
The shipped `configs/default.yaml` uses `attn_mlp_no_head` and full-FTs
|
| 314 |
+
`fm_head` separately because that was the most stable small-data baseline in
|
| 315 |
+
our ablations. The `default` preset name inside the LoRA parser still means
|
| 316 |
+
"match upstream 8-step LoRA coverage"; use it only when that exact module
|
| 317 |
+
coverage is what you want. For report-alignment research, start from
|
| 318 |
+
`configs/official_alignment.yaml`.
|
| 319 |
+
|
| 320 |
---
|
| 321 |
|
| 322 |
## Stack with the official 8-step distill LoRA
|
| 323 |
|
| 324 |
Upstream released a step-distillation LoRA that brings inference down to 8
|
| 325 |
+
NFE at `cfg_scale=1.0`. You can train your own style LoRA **on top** of it
|
| 326 |
+
by setting `runtime.upstream_lora_path` in your YAML — at training time we
|
| 327 |
+
bake-in the official 8-step delta into the bf16 base (skipping `fm_head` so
|
| 328 |
+
we don't clobber our own fm_head LoRA), then wrap our LoRA on top.
|
| 329 |
|
| 330 |
```yaml
|
|
|
|
| 331 |
runtime:
|
| 332 |
upstream_lora_path: hf_cache/.../SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors
|
| 333 |
upstream_lora_skip: ['fm_modules.fm_head'] # don't clobber our fm_head LoRA
|
| 334 |
```
|
| 335 |
|
| 336 |
+
At sample time, pass the same upstream LoRA and use 8 steps at cfg=1.0:
|
| 337 |
|
| 338 |
```bash
|
| 339 |
+
./sample.sh configs/my_style.yaml \
|
| 340 |
+
artifacts/my_style/trainable_state.safetensors \
|
| 341 |
--prompt "…" \
|
| 342 |
--upstream-lora-path SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors \
|
| 343 |
--upstream-lora-skip fm_modules.fm_head \
|
|
|
|
| 357 |
├── pyproject.toml # package metadata
|
| 358 |
├── LICENSE # Apache-2.0
|
| 359 |
├── configs/
|
| 360 |
+
│ ├── default.yaml # recommended small-data style baseline
|
| 361 |
+
│ └── official_alignment.yaml # optional report-alignment research config
|
|
|
|
| 362 |
├── train_u1/ # importable package
|
| 363 |
│ ├── config.py # YAML config schema
|
| 364 |
│ ├── constants.py # pinned MODEL_SHA / CODE_COMMIT / arch constants
|
|
|
|
| 368 |
│ │ ├── lora_io.py # save/load + upstream merge
|
| 369 |
│ │ ├── loader.py # bf16 base load + tower offload
|
| 370 |
│ │ ├── wrapper.py # forward_t2i_step
|
| 371 |
+
│ │ ├── losses.py # fm_loss_x0 / fm_loss_v / fm_loss dispatcher
|
| 372 |
│ │ ├── patching.py # patchify/unpatchify
|
| 373 |
│ │ └── …
|
| 374 |
│ ├��─ scripts/
|
|
|
|
| 379 |
│ │ └── install_modeling_into_snapshot.py
|
| 380 |
│ └── tests/
|
| 381 |
├── docs/
|
| 382 |
+
│ ├── SETUP.md # data layout, design rationale, pinned-upstream details
|
| 383 |
+
│ ├── small_data_style_ablation.html
|
| 384 |
+
│ ├── small_data_style_ablation.pdf
|
| 385 |
+
│ └── assets/ # figures used by the ablation document
|
| 386 |
├── artifacts/ # local-only: checkpoints + sweeps (gitignored)
|
| 387 |
├── dataset/ # local-only: image+caption pairs (gitignored)
|
| 388 |
├── hf_cache/ # local-only: HF snapshot (gitignored)
|
|
|
|
| 403 |
release; consumed via the `upstream_lora_path` mechanism.
|
| 404 |
- **This trainer** is licensed under Apache-2.0 (see `LICENSE`).
|
| 405 |
|
| 406 |
+
**Thanks to comfy.org for the GPU power support. The open-source community will not forget.**
|
configs/default.yaml
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
# SenseNova-U1 LoRA trainer — default config
|
| 2 |
#
|
| 3 |
-
# This is the recommended starting point
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
#
|
|
|
|
|
|
|
| 7 |
#
|
| 8 |
# Edit `data.data_dir`, `style.trigger`, and `run_name` and you should be
|
| 9 |
# good to go for a 32 GB single-GPU run.
|
|
@@ -17,6 +19,9 @@ data:
|
|
| 17 |
cap_max_pixels: 4194304
|
| 18 |
# Snap each image to nearest official bucket so train shape == infer shape.
|
| 19 |
snap_bucket: true
|
|
|
|
|
|
|
|
|
|
| 20 |
# n_samples: cap on dataset size (omit / null = use entire data_dir)
|
| 21 |
|
| 22 |
style:
|
|
@@ -24,20 +29,21 @@ style:
|
|
| 24 |
prompt_template: official # 'official' (recommended) | 'plain'
|
| 25 |
|
| 26 |
lora:
|
| 27 |
-
#
|
| 28 |
-
#
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
# Or override per-target:
|
| 31 |
-
# spec: "attn=r64a64;mlp=r64a64;
|
| 32 |
dropout: 0.0
|
| 33 |
|
| 34 |
unfreeze:
|
| 35 |
-
# Full-finetune (non-LoRA) regex patterns
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
[]
|
| 41 |
|
| 42 |
train:
|
| 43 |
steps: 6000
|
|
@@ -47,6 +53,17 @@ train:
|
|
| 47 |
grad_accum: 1
|
| 48 |
checkpoint_every: 600 # 0 = only save final state
|
| 49 |
# checkpoint_dir: null # default: artifacts/{run_name}/checkpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
runtime:
|
| 52 |
keep_kvs_on_gpu: true # 56 samples × ~59 MB = ~3.3 GB; keep on GPU for speed
|
|
|
|
| 1 |
# SenseNova-U1 LoRA trainer — default config
|
| 2 |
#
|
| 3 |
+
# This is the recommended starting point for small/medium style training.
|
| 4 |
+
# It follows the local small-data baseline: x0 loss + uniform t + no train-time
|
| 5 |
+
# condition dropout, explicit style trigger, LoRA on attn+mlp, and full-FT
|
| 6 |
+
# of ts/ns/vision/fm_head. Public official-consistency knobs are available
|
| 7 |
+
# in configs/official_alignment.yaml for research ablations, but they are
|
| 8 |
+
# not the default because the ablation study showed worse small-data sampling.
|
| 9 |
#
|
| 10 |
# Edit `data.data_dir`, `style.trigger`, and `run_name` and you should be
|
| 11 |
# good to go for a 32 GB single-GPU run.
|
|
|
|
| 19 |
cap_max_pixels: 4194304
|
| 20 |
# Snap each image to nearest official bucket so train shape == infer shape.
|
| 21 |
snap_bucket: true
|
| 22 |
+
# Keep the prefix short by default. Set true only when your think labels are
|
| 23 |
+
# curated and your sampling path uses the same think distribution.
|
| 24 |
+
use_think_labels: false
|
| 25 |
# n_samples: cap on dataset size (omit / null = use entire data_dir)
|
| 26 |
|
| 27 |
style:
|
|
|
|
| 29 |
prompt_template: official # 'official' (recommended) | 'plain'
|
| 30 |
|
| 31 |
lora:
|
| 32 |
+
# Small-data baseline: LoRA on attn+mlp only. fm_head is full-FT'd below.
|
| 33 |
+
# 'default' remains available when you explicitly want upstream 8-step
|
| 34 |
+
# coverage including fm_head LoRA.
|
| 35 |
+
# A3B/MoE presets are experimental compatibility helpers, not the main path.
|
| 36 |
+
preset: attn_mlp_no_head
|
| 37 |
# Or override per-target:
|
| 38 |
+
# spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off"
|
| 39 |
dropout: 0.0
|
| 40 |
|
| 41 |
unfreeze:
|
| 42 |
+
# Full-finetune (non-LoRA) regex patterns for the small-data baseline.
|
| 43 |
+
- '^fm_modules\.timestep_embedder\.'
|
| 44 |
+
- '^fm_modules\.noise_scale_embedder\.'
|
| 45 |
+
- '^fm_modules\.vision_model_mot_gen\.'
|
| 46 |
+
- '^fm_modules\.fm_head\.'
|
|
|
|
| 47 |
|
| 48 |
train:
|
| 49 |
steps: 6000
|
|
|
|
| 53 |
grad_accum: 1
|
| 54 |
checkpoint_every: 600 # 0 = only save final state
|
| 55 |
# checkpoint_dir: null # default: artifacts/{run_name}/checkpoints
|
| 56 |
+
# Local small-data baseline. See docs/small_data_style_ablation.html before changing these:
|
| 57 |
+
# official-style v-loss/logit-normal/dropout is an optional ablation path,
|
| 58 |
+
# not the small-data default.
|
| 59 |
+
loss_type: x0 # 'x0' | 'v' | 'x0_huber' | 'v_huber'
|
| 60 |
+
t_dist: uniform # 'uniform' | 'logit_normal'
|
| 61 |
+
t_logit_mean: -0.8 # only used for logit_normal
|
| 62 |
+
t_logit_std: 0.8 # only used for logit_normal
|
| 63 |
+
# huber_delta: 1.0 # only used for *_huber
|
| 64 |
+
# CFG / condition dropout. Keep off by default for small style datasets.
|
| 65 |
+
cond_dropout_text: 0.0
|
| 66 |
+
cond_dropout_both: 0.0
|
| 67 |
|
| 68 |
runtime:
|
| 69 |
keep_kvs_on_gpu: true # 56 samples × ~59 MB = ~3.3 GB; keep on GPU for speed
|
configs/official_alignment.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Optional official-consistency research config.
|
| 2 |
+
#
|
| 3 |
+
# This keeps the public SenseNova-U1 report knobs together:
|
| 4 |
+
# - v-loss
|
| 5 |
+
# - logit-normal t sampler, mean=-0.8 std=0.8
|
| 6 |
+
# - train-time condition dropout, text=0.10 and both=0.10
|
| 7 |
+
# - think labels enabled when the dataset provides them
|
| 8 |
+
#
|
| 9 |
+
# Do not treat this as the small-data style-training default. The ablation
|
| 10 |
+
# study showed weaker sampling than the default baseline on the 56-image
|
| 11 |
+
# Hayateluc task. Use this config when you explicitly want to study report
|
| 12 |
+
# alignment, not when you want the safest first run.
|
| 13 |
+
|
| 14 |
+
run_name: official_alignment
|
| 15 |
+
|
| 16 |
+
data:
|
| 17 |
+
data_dir: dataset/my_style
|
| 18 |
+
cap_max_pixels: 4194304
|
| 19 |
+
snap_bucket: true
|
| 20 |
+
use_think_labels: true
|
| 21 |
+
|
| 22 |
+
style:
|
| 23 |
+
trigger: "" # prefer captions that carry their own style anchor
|
| 24 |
+
prompt_template: official
|
| 25 |
+
|
| 26 |
+
lora:
|
| 27 |
+
# Matches upstream 8-step LoRA coverage at r=64.
|
| 28 |
+
preset: default
|
| 29 |
+
dropout: 0.0
|
| 30 |
+
|
| 31 |
+
unfreeze:
|
| 32 |
+
- '^fm_modules\.timestep_embedder\.'
|
| 33 |
+
- '^fm_modules\.noise_scale_embedder\.'
|
| 34 |
+
- '^fm_modules\.vision_model_mot_gen\.'
|
| 35 |
+
- '^fm_modules\.fm_head\.'
|
| 36 |
+
|
| 37 |
+
train:
|
| 38 |
+
steps: 6000
|
| 39 |
+
lr: 5.0e-5
|
| 40 |
+
seed: 0
|
| 41 |
+
shuffle: true
|
| 42 |
+
grad_accum: 1
|
| 43 |
+
checkpoint_every: 600
|
| 44 |
+
loss_type: v
|
| 45 |
+
t_dist: logit_normal
|
| 46 |
+
t_logit_mean: -0.8
|
| 47 |
+
t_logit_std: 0.8
|
| 48 |
+
cond_dropout_text: 0.10
|
| 49 |
+
cond_dropout_both: 0.10
|
| 50 |
+
|
| 51 |
+
runtime:
|
| 52 |
+
# Long think prefixes and unconditional-prefix precompute can push VRAM up.
|
| 53 |
+
# Offload KVs and checkpoint all layers unless you have confirmed headroom.
|
| 54 |
+
keep_kvs_on_gpu: false
|
| 55 |
+
gc_skip_last: 0
|
| 56 |
+
device: cuda
|
| 57 |
+
cpu_device: cpu
|
docs/assets/small_data_style_ablation/artifact_baseline_meadow.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/artifact_baseline_pampas.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/artifact_v19a_grid_meadow.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/artifact_v19b_noise_pampas.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail00_v18.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail00_v19.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail00_v19a.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail00_v19b.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail00_v19c.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail01_v18.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail01_v19.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail01_v19a.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail01_v19b.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail02_v18.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail02_v19.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail02_v19a.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail02_v19b.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail02_v19c.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail07_v18.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail07_v19.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail07_v19a.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail07_v19b.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/detail07_v19c.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/tb_active_loss_curves.png
ADDED
|
docs/assets/small_data_style_ablation/tb_condition_dropout_counts.png
ADDED
|
docs/assets/small_data_style_ablation/tb_t_distribution.png
ADDED
|
docs/assets/small_data_style_ablation/tb_v_mse_curves.png
ADDED
|
docs/assets/small_data_style_ablation/tb_x0_mse_curves.png
ADDED
|
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg
ADDED
|
Git LFS Details
|
docs/assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg
ADDED
|
Git LFS Details
|
docs/small_data_style_ablation.html
ADDED
|
@@ -0,0 +1,941 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="zh-CN">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8">
|
| 5 |
+
<title>对 SenseNova-U1 官方训练配置在小数据风格 LoRA 场景的消融研究</title>
|
| 6 |
+
<style>
|
| 7 |
+
@page {
|
| 8 |
+
size: A4;
|
| 9 |
+
margin: 22mm 18mm 22mm 18mm;
|
| 10 |
+
}
|
| 11 |
+
:root {
|
| 12 |
+
--ink: #1a1a1a;
|
| 13 |
+
--muted: #555;
|
| 14 |
+
--line: #c8c8c8;
|
| 15 |
+
--rule: #2a2a2a;
|
| 16 |
+
--soft: #f6f6f4;
|
| 17 |
+
--accent: #6b3410;
|
| 18 |
+
}
|
| 19 |
+
* { box-sizing: border-box; }
|
| 20 |
+
body {
|
| 21 |
+
margin: 0;
|
| 22 |
+
color: var(--ink);
|
| 23 |
+
font-family: "Source Han Serif SC", "Noto Serif CJK SC", "STSong",
|
| 24 |
+
"Latin Modern Roman", "Linux Libertine O", "Times New Roman", Georgia, serif;
|
| 25 |
+
font-size: 10.5pt;
|
| 26 |
+
line-height: 1.55;
|
| 27 |
+
background: white;
|
| 28 |
+
text-align: justify;
|
| 29 |
+
hyphens: auto;
|
| 30 |
+
}
|
| 31 |
+
h1, h2, h3, h4 {
|
| 32 |
+
font-family: "Source Han Sans SC", "Noto Sans CJK SC",
|
| 33 |
+
"Latin Modern Sans", "Helvetica Neue", Arial, sans-serif;
|
| 34 |
+
font-weight: 700;
|
| 35 |
+
line-height: 1.25;
|
| 36 |
+
page-break-after: avoid;
|
| 37 |
+
color: var(--ink);
|
| 38 |
+
}
|
| 39 |
+
h1 {
|
| 40 |
+
font-size: 19pt;
|
| 41 |
+
margin: 0 0 6pt;
|
| 42 |
+
text-align: center;
|
| 43 |
+
letter-spacing: 0.5pt;
|
| 44 |
+
}
|
| 45 |
+
h2 {
|
| 46 |
+
font-size: 13pt;
|
| 47 |
+
margin: 24pt 0 8pt;
|
| 48 |
+
padding-bottom: 4pt;
|
| 49 |
+
border-bottom: 0.7pt solid var(--rule);
|
| 50 |
+
}
|
| 51 |
+
h3 {
|
| 52 |
+
font-size: 11pt;
|
| 53 |
+
margin: 14pt 0 4pt;
|
| 54 |
+
}
|
| 55 |
+
h4 {
|
| 56 |
+
font-size: 10pt;
|
| 57 |
+
margin: 10pt 0 3pt;
|
| 58 |
+
font-style: italic;
|
| 59 |
+
font-weight: 600;
|
| 60 |
+
}
|
| 61 |
+
p { margin: 5pt 0; text-indent: 2em; }
|
| 62 |
+
p.noindent, h1 + p, h2 + p, h3 + p, h4 + p,
|
| 63 |
+
.abstract p, .figure p, .table-wrap p, .equation,
|
| 64 |
+
li > p, .meta { text-indent: 0; }
|
| 65 |
+
.meta {
|
| 66 |
+
text-align: center;
|
| 67 |
+
color: var(--muted);
|
| 68 |
+
font-size: 10pt;
|
| 69 |
+
margin: 0 0 14pt;
|
| 70 |
+
}
|
| 71 |
+
.affil {
|
| 72 |
+
text-align: center;
|
| 73 |
+
color: var(--muted);
|
| 74 |
+
font-size: 9.5pt;
|
| 75 |
+
margin: 0 0 18pt;
|
| 76 |
+
}
|
| 77 |
+
.abstract {
|
| 78 |
+
margin: 4pt 8% 18pt;
|
| 79 |
+
padding: 10pt 14pt;
|
| 80 |
+
border: 0.5pt solid var(--line);
|
| 81 |
+
background: var(--soft);
|
| 82 |
+
}
|
| 83 |
+
.abstract h3 {
|
| 84 |
+
margin: 0 0 4pt;
|
| 85 |
+
text-align: center;
|
| 86 |
+
font-size: 10.5pt;
|
| 87 |
+
letter-spacing: 1pt;
|
| 88 |
+
text-transform: uppercase;
|
| 89 |
+
}
|
| 90 |
+
.abstract p { text-indent: 0; margin: 4pt 0; font-size: 10pt; }
|
| 91 |
+
.keywords {
|
| 92 |
+
font-size: 9.5pt;
|
| 93 |
+
margin-top: 8pt;
|
| 94 |
+
}
|
| 95 |
+
.keywords strong { letter-spacing: 0.5pt; }
|
| 96 |
+
code, pre, .mono {
|
| 97 |
+
font-family: "Latin Modern Mono", "DejaVu Sans Mono", "Consolas", monospace;
|
| 98 |
+
}
|
| 99 |
+
code { font-size: 9.5pt; background: #f0f0ec; padding: 0 2pt; border-radius: 2pt; }
|
| 100 |
+
pre {
|
| 101 |
+
background: #f4f4f0;
|
| 102 |
+
color: #1a1a1a;
|
| 103 |
+
padding: 8pt 10pt;
|
| 104 |
+
font-size: 9pt;
|
| 105 |
+
border: 0.5pt solid var(--line);
|
| 106 |
+
border-radius: 3pt;
|
| 107 |
+
overflow: hidden;
|
| 108 |
+
white-space: pre-wrap;
|
| 109 |
+
page-break-inside: avoid;
|
| 110 |
+
}
|
| 111 |
+
table {
|
| 112 |
+
width: 100%;
|
| 113 |
+
border-collapse: collapse;
|
| 114 |
+
margin: 4pt 0 6pt;
|
| 115 |
+
font-size: 9.5pt;
|
| 116 |
+
page-break-inside: avoid;
|
| 117 |
+
}
|
| 118 |
+
th, td {
|
| 119 |
+
border-top: 0.5pt solid var(--rule);
|
| 120 |
+
border-bottom: 0.5pt solid var(--rule);
|
| 121 |
+
padding: 4pt 6pt;
|
| 122 |
+
vertical-align: top;
|
| 123 |
+
text-align: left;
|
| 124 |
+
}
|
| 125 |
+
th {
|
| 126 |
+
background: transparent;
|
| 127 |
+
font-weight: 600;
|
| 128 |
+
border-top: 0.8pt solid var(--rule);
|
| 129 |
+
border-bottom: 0.8pt solid var(--rule);
|
| 130 |
+
}
|
| 131 |
+
tbody tr td { border-top: none; }
|
| 132 |
+
tbody tr:last-child td { border-bottom: 0.8pt solid var(--rule); }
|
| 133 |
+
table.bordered td, table.bordered th { border: 0.5pt solid var(--line); }
|
| 134 |
+
|
| 135 |
+
.figure {
|
| 136 |
+
page-break-inside: avoid;
|
| 137 |
+
margin: 10pt 0 12pt;
|
| 138 |
+
}
|
| 139 |
+
.figure img {
|
| 140 |
+
display: block;
|
| 141 |
+
width: 100%;
|
| 142 |
+
max-width: 100%;
|
| 143 |
+
margin: 0 auto;
|
| 144 |
+
border: 0.5pt solid var(--line);
|
| 145 |
+
background: white;
|
| 146 |
+
}
|
| 147 |
+
.figure.center img { margin: 0 auto; }
|
| 148 |
+
.figure.narrow img { width: 78%; }
|
| 149 |
+
.figure.full img { width: 100%; }
|
| 150 |
+
/* Multi-slice figures: a tall composite split into A4-fitting slices.
|
| 151 |
+
Each slice is its own atomic page-break-inside-avoid; the caption
|
| 152 |
+
sits on the last slice. */
|
| 153 |
+
.figure.sliced {
|
| 154 |
+
page-break-inside: auto;
|
| 155 |
+
}
|
| 156 |
+
.figure.sliced .slice {
|
| 157 |
+
page-break-inside: avoid;
|
| 158 |
+
margin: 0 0 4pt;
|
| 159 |
+
}
|
| 160 |
+
.figure.sliced .slice img {
|
| 161 |
+
width: 100%;
|
| 162 |
+
max-height: 245mm;
|
| 163 |
+
object-fit: contain;
|
| 164 |
+
}
|
| 165 |
+
.figure.sliced.narrow .slice img {
|
| 166 |
+
width: 65%;
|
| 167 |
+
margin: 0 auto;
|
| 168 |
+
display: block;
|
| 169 |
+
}
|
| 170 |
+
/* Side-by-side / stacked closeup with row labels */
|
| 171 |
+
.closeup-row {
|
| 172 |
+
display: grid;
|
| 173 |
+
grid-template-columns: 70pt 1fr;
|
| 174 |
+
align-items: center;
|
| 175 |
+
gap: 8pt;
|
| 176 |
+
margin: 0 0 4pt;
|
| 177 |
+
page-break-inside: avoid;
|
| 178 |
+
}
|
| 179 |
+
.closeup-row .label {
|
| 180 |
+
font-size: 9.5pt;
|
| 181 |
+
font-family: "Source Han Sans SC", "Noto Sans CJK SC", sans-serif;
|
| 182 |
+
color: var(--ink);
|
| 183 |
+
text-align: right;
|
| 184 |
+
padding-right: 4pt;
|
| 185 |
+
}
|
| 186 |
+
.closeup-row .label .tag {
|
| 187 |
+
font-weight: 700;
|
| 188 |
+
font-size: 11pt;
|
| 189 |
+
display: block;
|
| 190 |
+
margin-bottom: 1pt;
|
| 191 |
+
}
|
| 192 |
+
.closeup-row .label .sub {
|
| 193 |
+
color: var(--muted);
|
| 194 |
+
font-size: 8.5pt;
|
| 195 |
+
}
|
| 196 |
+
.closeup-row img {
|
| 197 |
+
width: 100%;
|
| 198 |
+
max-height: 70mm;
|
| 199 |
+
object-fit: cover;
|
| 200 |
+
object-position: center;
|
| 201 |
+
border: 0.5pt solid var(--line);
|
| 202 |
+
display: block;
|
| 203 |
+
}
|
| 204 |
+
.closeup-figure {
|
| 205 |
+
page-break-inside: avoid;
|
| 206 |
+
margin: 10pt 0 14pt;
|
| 207 |
+
}
|
| 208 |
+
.caption {
|
| 209 |
+
color: var(--ink);
|
| 210 |
+
font-size: 9pt;
|
| 211 |
+
margin: 4pt 0 0;
|
| 212 |
+
line-height: 1.45;
|
| 213 |
+
text-align: left;
|
| 214 |
+
text-indent: 0;
|
| 215 |
+
}
|
| 216 |
+
.caption strong { font-size: 9pt; letter-spacing: 0.3pt; }
|
| 217 |
+
|
| 218 |
+
.grid-2 {
|
| 219 |
+
display: grid;
|
| 220 |
+
grid-template-columns: 1fr 1fr;
|
| 221 |
+
gap: 8pt;
|
| 222 |
+
margin: 8pt 0;
|
| 223 |
+
}
|
| 224 |
+
.grid-2 .figure { margin: 0; }
|
| 225 |
+
.table-wrap { page-break-inside: avoid; margin: 6pt 0 14pt; }
|
| 226 |
+
.table-title { font-size: 9.5pt; margin: 0 0 4pt; text-indent: 0; }
|
| 227 |
+
.table-title strong { letter-spacing: 0.3pt; }
|
| 228 |
+
|
| 229 |
+
.equation {
|
| 230 |
+
margin: 8pt 0;
|
| 231 |
+
text-align: center;
|
| 232 |
+
font-size: 10.5pt;
|
| 233 |
+
page-break-inside: avoid;
|
| 234 |
+
}
|
| 235 |
+
.eq-row {
|
| 236 |
+
display: flex;
|
| 237 |
+
align-items: center;
|
| 238 |
+
justify-content: center;
|
| 239 |
+
gap: 14pt;
|
| 240 |
+
}
|
| 241 |
+
.eq-label { font-style: italic; color: var(--muted); font-size: 9.5pt; }
|
| 242 |
+
/* Static math rendering — no MathJax, works in weasyprint PDF */
|
| 243 |
+
.m, em.m { font-style: italic; font-family: "Latin Modern Math", "STIX Two Math", "Cambria Math", "Times New Roman", serif; }
|
| 244 |
+
.m-up { font-style: normal; }
|
| 245 |
+
.m-cal { font-style: italic; font-family: "Latin Modern Math", "STIX Two Math", "Cambria Math", cursive, serif; }
|
| 246 |
+
.frac {
|
| 247 |
+
display: inline-block;
|
| 248 |
+
vertical-align: -0.35em;
|
| 249 |
+
text-align: center;
|
| 250 |
+
font-size: 90%;
|
| 251 |
+
margin: 0 3pt;
|
| 252 |
+
line-height: 1.1;
|
| 253 |
+
white-space: nowrap;
|
| 254 |
+
}
|
| 255 |
+
.frac .num,
|
| 256 |
+
.frac .den {
|
| 257 |
+
display: block;
|
| 258 |
+
white-space: nowrap;
|
| 259 |
+
padding: 0 4pt;
|
| 260 |
+
}
|
| 261 |
+
.frac .num { border-bottom: 0.6pt solid var(--ink); padding-bottom: 1pt; }
|
| 262 |
+
.frac .den { padding-top: 1pt; }
|
| 263 |
+
sub, sup { line-height: 0; font-size: 75%; }
|
| 264 |
+
.eq-bracket {
|
| 265 |
+
font-size: 150%;
|
| 266 |
+
vertical-align: -0.25em;
|
| 267 |
+
font-weight: 400;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
ol.refs { padding-left: 1.2em; font-size: 9.5pt; line-height: 1.5; }
|
| 271 |
+
ol.refs li { margin: 3pt 0; }
|
| 272 |
+
|
| 273 |
+
.page-break { page-break-before: always; }
|
| 274 |
+
.small { font-size: 9pt; color: var(--muted); }
|
| 275 |
+
section.body { padding: 0 4pt; }
|
| 276 |
+
|
| 277 |
+
/* Column layout for compact paragraphs (not used full-width to keep figures full) */
|
| 278 |
+
.two-col {
|
| 279 |
+
column-count: 2;
|
| 280 |
+
column-gap: 14pt;
|
| 281 |
+
column-rule: 0.3pt solid var(--line);
|
| 282 |
+
}
|
| 283 |
+
.two-col p:first-of-type { margin-top: 0; }
|
| 284 |
+
</style>
|
| 285 |
+
</head>
|
| 286 |
+
<body>
|
| 287 |
+
|
| 288 |
+
<header>
|
| 289 |
+
<h1>对 SenseNova-U1 官方训练配置在小数据风格 LoRA 场景的消融研究</h1>
|
| 290 |
+
<p class="meta">
|
| 291 |
+
内部技术报告 · 2026 年 5 月 · train_u1 工具链
|
| 292 |
+
</p>
|
| 293 |
+
<p class="affil">
|
| 294 |
+
针对 SenseNova-U1-8B-MoT 基模的 LoRA / Partial Fine-tuning 训练器维护团队
|
| 295 |
+
</p>
|
| 296 |
+
|
| 297 |
+
<div class="abstract">
|
| 298 |
+
<h3>Abstract</h3>
|
| 299 |
+
<p>
|
| 300 |
+
SenseNova-U1 公开技术报告披露的 text-to-image 流匹配训练在三个关键算法层
|
| 301 |
+
上做出特定选择:以速度 (velocity) 为目标的损失函数、logit-normal 时间步采样、
|
| 302 |
+
以及 condition dropout 形式的无分类器引导校准。在大规模预训练或后训阶段,
|
| 303 |
+
这些选择共同支撑了模型最终行为。但对小数据风格 LoRA 微调而言,是否应当继承
|
| 304 |
+
这一组合并未由现有公开文献回答。本文以 56 张 Hayateluc 风格图像与 8B-MoT 基模
|
| 305 |
+
为受控环境,对三个官方算法层逐一做单变量消融,并与 v18 本地基线 (x<sub>0</sub>-MSE
|
| 306 |
+
损失、uniform 时间步、零 dropout) 进行对比,同时纳入 prompt 前缀格式作为额外
|
| 307 |
+
受控变量。所有训练运行 6000 步,并在统一 prompt 集上做多步采样以做视觉评估。
|
| 308 |
+
结果表明,逐项替换为官方设置后,<strong>训练 loss 不发散,但完整采样图像呈
|
| 309 |
+
现可观察的图像重建质量劣化</strong>:(i) velocity 损失导致天空与大面积渐变上
|
| 310 |
+
规则横向条纹 (banding) 与 photoreal 化;(ii) logit-normal 时间步导致大面积平
|
| 311 |
+
滑色块上高频 speckle 噪点与色阶离散化;(iii) 两者组合时进一步出现远景元素缺
|
| 312 |
+
失、构图被压缩的结构扭曲 (structural distortion),已超出 "style drift" 范畴而
|
| 313 |
+
属重建质量缺陷。condition dropout 是三项中唯一未引入可见劣化的项,反而轻
|
| 314 |
+
微改善图像结构质量。基于该结论,我们将仓库的发布默认值保持为本地基线,并
|
| 315 |
+
把官方对齐组合作为可选的 ablation 配置发布,供研究复现使用。
|
| 316 |
+
</p>
|
| 317 |
+
<p class="keywords">
|
| 318 |
+
<strong>Keywords —</strong>
|
| 319 |
+
流匹配 (flow matching) ·
|
| 320 |
+
LoRA 微调 ·
|
| 321 |
+
消融研究 ·
|
| 322 |
+
classifier-free guidance ·
|
| 323 |
+
风格迁移 ·
|
| 324 |
+
SenseNova-U1
|
| 325 |
+
</p>
|
| 326 |
+
</div>
|
| 327 |
+
</header>
|
| 328 |
+
|
| 329 |
+
<section class="body">
|
| 330 |
+
|
| 331 |
+
<h2>1. 引言</h2>
|
| 332 |
+
<p>
|
| 333 |
+
SenseNova-U1 公开报告 [1] 在 text-to-image 流匹配训练中采用三个具体的算法层:
|
| 334 |
+
速度损失 (velocity loss)、logit-normal 时间步采样、以及在条件 prefix 上的 dropout
|
| 335 |
+
作为 classifier-free guidance (CFG) 的校准信号。这一组合在该报告所对应的大规模
|
| 336 |
+
数据与多阶段训练管线下是合理且互相支撑的设计选择。
|
| 337 |
+
</p>
|
| 338 |
+
<p>
|
| 339 |
+
与之相对,在小数据风格 LoRA 微调场景下,训练目标的属性发生了实质变化:训练数据
|
| 340 |
+
规模通常在 10<sup>1</sup>–10<sup>2</sup> 张量级,目的是从基模注入一个相对集中的
|
| 341 |
+
视觉风格 fingerprint,而非更新一个完整的世界模型。这种情况下,将官方训练配置整
|
| 342 |
+
体迁移过来是否仍最优,并不显然。
|
| 343 |
+
</p>
|
| 344 |
+
<p>
|
| 345 |
+
本文以 SenseNova-U1-8B-MoT 为基模,在 56 张 Hayateluc 风格图像数据上做受控消
|
| 346 |
+
融。具体贡献为三点:
|
| 347 |
+
</p>
|
| 348 |
+
<p>
|
| 349 |
+
<strong>(i)</strong> 提出一组单变量消融配置,分别替换官方三个算法层中的一个,并
|
| 350 |
+
与 v18 本地基线对照,使得每对差异可被归因到单一 lever。
|
| 351 |
+
<strong>(ii)</strong> 指出 prompt 前缀格式 (训练时 caption 中风格 anchor 的位置)
|
| 352 |
+
是一个独立而强的混淆变量;在没有控制该变量前,单纯的 lever 替换会得到误导性的视
|
| 353 |
+
觉差异。
|
| 354 |
+
<strong>(iii)</strong> 基于训练标量曲线与多步采样视觉评估,识别每个官方 lever
|
| 355 |
+
引入的具体图像重建质量缺陷:velocity 损失→横纹 + photoreal 化、logit-normal
|
| 356 |
+
时间步→噪点 + 色阶离散化、两者叠加→结构扭曲。condition dropout 是唯一对小数据
|
| 357 |
+
风格 LoRA 无损甚至有益的官方算法层。
|
| 358 |
+
</p>
|
| 359 |
+
|
| 360 |
+
<h2>2. 背景与符号</h2>
|
| 361 |
+
|
| 362 |
+
<h3>2.1 流匹配训练目标</h3>
|
| 363 |
+
<p>
|
| 364 |
+
设 <em class="m">x</em><sub>0</sub> 为目标图像 patch,<em class="m">ε</em> 为各向同
|
| 365 |
+
性高斯噪声,时间步 <em class="m">t</em> ∈ [<em class="m">t</em><sub class="m">ε</sub>,
|
| 366 |
+
1 − <em class="m">t</em><sub class="m">ε</sub>]。本仓库与上游推理共用如下线性插
|
| 367 |
+
值约定 (linear-<em class="m">z</em> schedule,<em class="m">t</em> 越接近 1 越接
|
| 368 |
+
近干净图):
|
| 369 |
+
</p>
|
| 370 |
+
<div class="equation">
|
| 371 |
+
<span style="white-space: nowrap;">
|
| 372 |
+
<em class="m">z</em><sub class="m">t</sub> =
|
| 373 |
+
<em class="m">t</em> <em class="m">x</em><sub>0</sub> +
|
| 374 |
+
(1−<em class="m">t</em>) <em class="m">ε</em>,
|
| 375 |
+
</span>
|
| 376 |
+
  
|
| 377 |
+
<span style="white-space: nowrap;">
|
| 378 |
+
<em class="m">v</em><sup>★</sup> =
|
| 379 |
+
<span class="frac">
|
| 380 |
+
<span class="num"><em class="m">x</em><sub>0</sub> − <em class="m">z</em><sub class="m">t</sub></span>
|
| 381 |
+
<span class="den">1 − <em class="m">t</em></span>
|
| 382 |
+
</span>.
|
| 383 |
+
</span>
|
| 384 |
+
 <span class="eq-label">(1)</span>
|
| 385 |
+
</div>
|
| 386 |
+
<p>
|
| 387 |
+
在 x<sub>0</sub>-MSE 损失下,训练目标为
|
| 388 |
+
<em class="m-cal">L</em><sub class="m">x</sub><sub>0</sub> =
|
| 389 |
+
𝔼 ‖<em class="m">x</em><sub class="m">θ</sub>(<em class="m">z</em><sub class="m">t</sub>,<em class="m">t</em>)
|
| 390 |
+
− <em class="m">x</em><sub>0</sub>‖<sup>2</sup>;
|
| 391 |
+
在 velocity 损失下,<em class="m-cal">L</em><sub class="m">v</sub> = 𝔼
|
| 392 |
+
‖<em class="m">v</em><sub class="m">θ</sub>(<em class="m">z</em><sub class="m">t</sub>,<em class="m">t</em>)
|
| 393 |
+
− <em class="m">v</em><sup>★</sup>‖<sup>2</sup>。代入
|
| 394 |
+
<em class="m">v</em><sup>★</sup> = (<em class="m">x</em><sub>0</sub> − <em class="m">z</em><sub class="m">t</sub>)/(1 − <em class="m">t</em>)
|
| 395 |
+
与 <em class="m">v</em><sub class="m">θ</sub> = (<em class="m">x</em><sub class="m">θ</sub> − <em class="m">z</em><sub class="m">t</sub>)/(1 − <em class="m">t</em>),可得二者关系为
|
| 396 |
+
</p>
|
| 397 |
+
<div class="equation">
|
| 398 |
+
<span style="white-space: nowrap;">
|
| 399 |
+
<em class="m-cal">L</em><sub class="m">v</sub> =
|
| 400 |
+
𝔼<span class="eq-bracket">[</span> 
|
| 401 |
+
<span class="frac">
|
| 402 |
+
<span class="num">‖<em class="m">x</em><sub class="m">θ</sub> − <em class="m">x</em><sub>0</sub>‖<sup>2</sup></span>
|
| 403 |
+
<span class="den">(1 − <em class="m">t</em>)<sup>2</sup></span>
|
| 404 |
+
</span>
|
| 405 |
+
 <span class="eq-bracket">]</span>.
|
| 406 |
+
</span>
|
| 407 |
+
 <span class="eq-label">(2)</span>
|
| 408 |
+
</div>
|
| 409 |
+
<p>
|
| 410 |
+
即 velocity 损失等价于以 (1 − <em class="m">t</em>)<sup>−2</sup> 重新加权的
|
| 411 |
+
x<sub>0</sub>-MSE。该权重在 <em class="m">t</em> → 1 (近 clean 端) 发散,因此训
|
| 412 |
+
练梯度的有效分布严重偏向高 <em class="m">t</em> 区间。
|
| 413 |
+
</p>
|
| 414 |
+
|
| 415 |
+
<h3>2.2 时间步采样分布</h3>
|
| 416 |
+
<p>
|
| 417 |
+
uniform 采样取 <em class="m">t</em> ∼ <em class="m-cal">U</em>(<em class="m">t</em><sub class="m">ε</sub>,
|
| 418 |
+
1 − <em class="m">t</em><sub class="m">ε</sub>)。logit-normal 采样定义为
|
| 419 |
+
<em class="m">u</em> ∼ <em class="m-cal">N</em>(<em class="m">μ</em>, <em class="m">σ</em><sup>2</sup>),
|
| 420 |
+
<em class="m">t</em> = <em class="m">σ</em>(<em class="m">u</em>),其中
|
| 421 |
+
<em class="m">σ</em> 为 sigmoid。本文 logit-normal 取
|
| 422 |
+
<em class="m">μ</em> = −0.8, <em class="m">σ</em> = 0.8,对应
|
| 423 |
+
𝔼[<em class="m">t</em>] ≈ 0.34,将概率质量偏向较低 <em class="m">t</em>
|
| 424 |
+
(较 noisy 一侧)。
|
| 425 |
+
</p>
|
| 426 |
+
|
| 427 |
+
<h3>2.3 Condition dropout</h3>
|
| 428 |
+
<p>
|
| 429 |
+
设每步训练以独立概率 <em class="m">p</em><sub class="m-up">text</sub> 将文本条件
|
| 430 |
+
替换为空 prompt 对应的 prefix KV;以独立概率 <em class="m">p</em><sub class="m-up">both</sub>
|
| 431 |
+
走 “text + image” 全部 drop 的分支。在纯 T2I 场景下,这两个 drop 模式都会回退到
|
| 432 |
+
统一的 unconditional prefix。本文取
|
| 433 |
+
<em class="m">p</em><sub class="m-up">text</sub> = <em class="m">p</em><sub class="m-up">both</sub> = 0.10,
|
| 434 |
+
即约 20% 步使用无条件 prefix。
|
| 435 |
+
</p>
|
| 436 |
+
|
| 437 |
+
<h2>3. 实验配置</h2>
|
| 438 |
+
|
| 439 |
+
<h3>3.1 数据集与基模</h3>
|
| 440 |
+
<p>
|
| 441 |
+
训练数据为 56 张 Hayateluc 风格自然语言 caption 配对图像,分布在 7 个 aspect-ratio
|
| 442 |
+
bucket 上 (最大像素数 ≤ 2048<sup>2</sup>)。基模为 SenseNova-U1-8B-MoT [1],加载方式为
|
| 443 |
+
bf16 CPU 驻留 + 静态前缀 KV cache 的低显存 LoRA 训练,单卡 32 GB 峰值约 21 GB。
|
| 444 |
+
</p>
|
| 445 |
+
|
| 446 |
+
<h3>3.2 可训练面</h3>
|
| 447 |
+
<p>
|
| 448 |
+
所有实验共享如下可训练面,保证不同运行间差异仅来自表 1 列出的三个 lever。
|
| 449 |
+
表 2 给出训练面分层;表 3 给出训练超参。
|
| 450 |
+
</p>
|
| 451 |
+
|
| 452 |
+
<div class="table-wrap">
|
| 453 |
+
<p class="table-title"><strong>表 2.</strong> 训练面分层。LoRA / partial FT / frozen
|
| 454 |
+
三类合计 286M 可训练参数。所有运行共享此结构。</p>
|
| 455 |
+
<table class="bordered">
|
| 456 |
+
<thead>
|
| 457 |
+
<tr>
|
| 458 |
+
<th style="width: 14%;">类别</th>
|
| 459 |
+
<th>覆盖模块</th>
|
| 460 |
+
<th style="width: 14%;">参数量</th>
|
| 461 |
+
<th style="width: 18%;">备注</th>
|
| 462 |
+
</tr>
|
| 463 |
+
</thead>
|
| 464 |
+
<tbody>
|
| 465 |
+
<tr>
|
| 466 |
+
<td><strong>LoRA wrap</strong></td>
|
| 467 |
+
<td>
|
| 468 |
+
注意力:<code>q_proj_mot_gen</code>, <code>k_proj_mot_gen</code>,
|
| 469 |
+
<code>v_proj_mot_gen</code>, <code>o_proj_mot_gen</code>;<br>
|
| 470 |
+
MLP:<code>mlp_mot_gen.{gate,up,down}_proj</code>
|
| 471 |
+
</td>
|
| 472 |
+
<td>~204M</td>
|
| 473 |
+
<td>共 294 wrap;<em class="m">r</em> = 64, <em class="m">α</em> = 64</td>
|
| 474 |
+
</tr>
|
| 475 |
+
<tr>
|
| 476 |
+
<td><strong>Partial fine-tune</strong></td>
|
| 477 |
+
<td>
|
| 478 |
+
<code>fm_modules.timestep_embedder</code>,<br>
|
| 479 |
+
<code>fm_modules.noise_scale_embedder</code>,<br>
|
| 480 |
+
<code>fm_modules.vision_model_mot_gen</code>,<br>
|
| 481 |
+
<code>fm_modules.fm_head</code>
|
| 482 |
+
</td>
|
| 483 |
+
<td>~82M</td>
|
| 484 |
+
<td>仅 <code>fm_modules</code> 子树</td>
|
| 485 |
+
</tr>
|
| 486 |
+
<tr>
|
| 487 |
+
<td><strong>Frozen</strong></td>
|
| 488 |
+
<td>Understand path 全部模块及其它未列出的所有权重</td>
|
| 489 |
+
<td>—</td>
|
| 490 |
+
<td>不更新</td>
|
| 491 |
+
</tr>
|
| 492 |
+
</tbody>
|
| 493 |
+
</table>
|
| 494 |
+
</div>
|
| 495 |
+
|
| 496 |
+
<div class="table-wrap">
|
| 497 |
+
<p class="table-title"><strong>表 3.</strong> 训练超参。所有运行共享。</p>
|
| 498 |
+
<table class="bordered">
|
| 499 |
+
<tbody>
|
| 500 |
+
<tr>
|
| 501 |
+
<td style="width: 22%;"><strong>训练步数</strong></td>
|
| 502 |
+
<td>6000</td>
|
| 503 |
+
<td style="width: 22%;"><strong>学习率</strong></td>
|
| 504 |
+
<td>5 × 10<sup>−5</sup></td>
|
| 505 |
+
</tr>
|
| 506 |
+
<tr>
|
| 507 |
+
<td><strong>优化器</strong></td>
|
| 508 |
+
<td>PagedAdamW8bit</td>
|
| 509 |
+
<td><strong>Batch size</strong></td>
|
| 510 |
+
<td>1 (native resolution)</td>
|
| 511 |
+
</tr>
|
| 512 |
+
<tr>
|
| 513 |
+
<td><strong>Gradient accumulation</strong></td>
|
| 514 |
+
<td>1</td>
|
| 515 |
+
<td><strong>Seed</strong></td>
|
| 516 |
+
<td>固定 (cross-run)</td>
|
| 517 |
+
</tr>
|
| 518 |
+
</tbody>
|
| 519 |
+
</table>
|
| 520 |
+
</div>
|
| 521 |
+
|
| 522 |
+
<h3>3.3 受控变量</h3>
|
| 523 |
+
<p>
|
| 524 |
+
本文设五组运行做单变量消融,其中四组训练运行共享数据与可训练面 (表 1)。
|
| 525 |
+
</p>
|
| 526 |
+
|
| 527 |
+
<div class="table-wrap">
|
| 528 |
+
<p class="table-title"><strong>表 1.</strong> 五组消融运行的算法配置。<em>baseline</em>
|
| 529 |
+
Baseline 为本仓库默认;运行 (a)–(d) 对应官方算法层逐项与全部替换。</p>
|
| 530 |
+
<table>
|
| 531 |
+
<thead>
|
| 532 |
+
<tr>
|
| 533 |
+
<th>运行</th>
|
| 534 |
+
<th><em class="m-cal">L</em></th>
|
| 535 |
+
<th><em class="m">t</em> 分布</th>
|
| 536 |
+
<th><em class="m">p</em><sub class="m-up">text</sub>, <em class="m">p</em><sub class="m-up">both</sub></th>
|
| 537 |
+
<th>训练 prefix 注释</th>
|
| 538 |
+
</tr>
|
| 539 |
+
</thead>
|
| 540 |
+
<tbody>
|
| 541 |
+
<tr><td><strong>v18 baseline</strong></td><td>x<sub>0</sub></td><td>uniform</td><td>0, 0</td><td>含 think sidecar 长前缀</td></tr>
|
| 542 |
+
<tr><td>(a) +velocity loss</td><td><strong>v</strong></td><td>uniform</td><td>0, 0</td><td>含 think sidecar 长前缀</td></tr>
|
| 543 |
+
<tr><td>(b) +logit-normal <em class="m">t</em></td><td>x<sub>0</sub></td><td><strong>logit-normal</strong></td><td>0, 0</td><td>含 think sidecar 长前缀</td></tr>
|
| 544 |
+
<tr><td>(c) +cond. dropout</td><td>x<sub>0</sub></td><td>uniform</td><td><strong>0.10, 0.10</strong></td><td>含 think sidecar 长前缀</td></tr>
|
| 545 |
+
<tr><td>(d) full official (a + b + c)</td><td><strong>v</strong></td><td><strong>logit-normal</strong></td><td><strong>0.10, 0.10</strong></td><td>含 think sidecar 长前缀</td></tr>
|
| 546 |
+
</tbody>
|
| 547 |
+
</table>
|
| 548 |
+
</div>
|
| 549 |
+
|
| 550 |
+
<h3>3.4 评估协议</h3>
|
| 551 |
+
<p>
|
| 552 |
+
训练时记录每步 active loss、x<sub>0</sub>-MSE、v-MSE、<em class="m">t</em> 统计量、
|
| 553 |
+
以及 dropout 路由计数;121-step 滑动平均后绘制。采样评估使用一组固定的 12 条
|
| 554 |
+
自然风景 prompts,在统一的 7 个 bucket 分辨率下做 50 步 Euler、
|
| 555 |
+
<em class="m-up">cfg_scale</em> = 4.0、<em class="m-up">timestep_shift</em> = 3.0。
|
| 556 |
+
我们额外引入 prompt 前缀格式变量:v1 prompts 保留与 baseline 训练分布一致的
|
| 557 |
+
简短前缀;v2 prompts 在每条句首嵌入 artist anchor,与含 think sidecar 的训练
|
| 558 |
+
caption 分布对齐。
|
| 559 |
+
</p>
|
| 560 |
+
|
| 561 |
+
<h2>4. 训练动力学结果</h2>
|
| 562 |
+
|
| 563 |
+
<p>
|
| 564 |
+
图 1–3 给出 x0-MSE、active loss 与 v-MSE 三条标量曲线。x0-MSE 是唯一可跨损
|
| 565 |
+
失类型公平比较的指标:它是所有运行共同执行的诊断量。可见运行 (b) 与 (d) 在
|
| 566 |
+
x<sub>0</sub>-MSE 上系统性高于 Baseline,差距贯穿整个训练,并非过
|
| 567 |
+
渡期偏差。
|
| 568 |
+
</p>
|
| 569 |
+
|
| 570 |
+
<div class="figure">
|
| 571 |
+
<img src="assets/small_data_style_ablation/tb_x0_mse_curves.png" alt="x0 MSE curves">
|
| 572 |
+
<p class="caption">
|
| 573 |
+
<strong>图 1.</strong> x<sub>0</sub>-MSE 在五组运行中的演化 (121-step 滑动平均)。
|
| 574 |
+
(b) 与 (d) 系统性偏高,表明把 <em class="m">t</em> 密度搬向较 noisy 区间会降低
|
| 575 |
+
模型在更清晰图像状态下学习颜色与构图的机会。
|
| 576 |
+
</p>
|
| 577 |
+
</div>
|
| 578 |
+
|
| 579 |
+
<div class="figure">
|
| 580 |
+
<img src="assets/small_data_style_ablation/tb_active_loss_curves.png" alt="active loss curves">
|
| 581 |
+
<p class="caption">
|
| 582 |
+
<strong>图 2.</strong> active loss 的绝对值不可跨目标比较 (velocity 损失与
|
| 583 |
+
x<sub>0</sub>-MSE 单位不同)。该图仅用于检查每组运行内部是否存在有效下降趋势,
|
| 584 |
+
所有运行均下降稳定,未出现训练发散。
|
| 585 |
+
</p>
|
| 586 |
+
</div>
|
| 587 |
+
|
| 588 |
+
<div class="figure">
|
| 589 |
+
<img src="assets/small_data_style_ablation/tb_v_mse_curves.png" alt="v MSE curves">
|
| 590 |
+
<p class="caption">
|
| 591 |
+
<strong>图 3.</strong> v-MSE 作为诊断量在所有运行中都被计算。可观察到 v-MSE
|
| 592 |
+
在 <em class="m">t</em> → 1 端出现极端尖峰,与 §2.1 中给出的
|
| 593 |
+
(1 − <em class="m">t</em>)<sup>−2</sup> 加权一致。在 velocity 训练的运行 (a) 与 (d)
|
| 594 |
+
中这些尖峰直接进入梯度,被高权重区间支配。
|
| 595 |
+
</p>
|
| 596 |
+
</div>
|
| 597 |
+
|
| 598 |
+
<div class="grid-2">
|
| 599 |
+
<div class="figure">
|
| 600 |
+
<img src="assets/small_data_style_ablation/tb_t_distribution.png" alt="t distribution">
|
| 601 |
+
<p class="caption">
|
| 602 |
+
<strong>图 4.</strong> 实测 <em class="m">t</em> 分布。logit-normal 运行 (b) 与
|
| 603 |
+
(d) 的经验均值约 0.34,与理论值 <em class="m">σ</em>(−0.8) ≈ 0.31 一致;其余运行
|
| 604 |
+
约为 0.50。
|
| 605 |
+
</p>
|
| 606 |
+
</div>
|
| 607 |
+
<div class="figure">
|
| 608 |
+
<img src="assets/small_data_style_ablation/tb_condition_dropout_counts.png" alt="dropout counts">
|
| 609 |
+
<p class="caption">
|
| 610 |
+
<strong>图 5.</strong> 含 condition dropout 的运行 (c) 与 (d) 中每步走
|
| 611 |
+
cond / uncond 分支的累计次数。最终约 21% 步走 unconditional prefix,与设定的
|
| 612 |
+
<em class="m">p</em><sub class="m-up">text</sub> + <em class="m">p</em><sub class="m-up">both</sub>
|
| 613 |
+
= 0.20 在大样本统计上吻合。
|
| 614 |
+
</p>
|
| 615 |
+
</div>
|
| 616 |
+
</div>
|
| 617 |
+
|
| 618 |
+
<h2 class="page-break">5. 视觉评估结果</h2>
|
| 619 |
+
|
| 620 |
+
<p>
|
| 621 |
+
训练标量只反映单步去噪行为,无法替代从纯噪声开始的多步采样。在固定 seed 与
|
| 622 |
+
identical sampling 超参下,我们对每组运行生成 12 张 1024–2048 像素级风景图。
|
| 623 |
+
</p>
|
| 624 |
+
|
| 625 |
+
<h3>5.1 同 prompt 多配方对比</h3>
|
| 626 |
+
|
| 627 |
+
<p>
|
| 628 |
+
图 6 为 forest / wildflower meadow / dandelion field 三个最能体现风格 fingerprint
|
| 629 |
+
的 prompt 在五组运行 (含 Baseline) 下的 contact sheet。关键观察:
|
| 630 |
+
</p>
|
| 631 |
+
<p>
|
| 632 |
+
<strong>(i)</strong> Baseline 与运行 (c) 在所有三个 prompt 上都保持暖橙地平线、
|
| 633 |
+
青蓝高空、清晰前景轮廓的视觉指纹。
|
| 634 |
+
<strong>(ii)</strong> 运行 (a) 在 dandelion field 上系统性向 photoreal 夜景偏移:
|
| 635 |
+
cyan 天空消失、puffball 密度降低、暖金核被压暗。
|
| 636 |
+
<strong>(iii)</strong> 运行 (b) 在所有 prompt 上整体 palette 偏冷,紫蓝山失去层
|
| 637 |
+
次。
|
| 638 |
+
<strong>(iv)</strong> 运行 (d) 同时承袭 (a) 与 (b) 的劣化,并叠加 composition
|
| 639 |
+
cropping:dandelion 图中失去远景 forest mass 与天空。
|
| 640 |
+
</p>
|
| 641 |
+
|
| 642 |
+
<div class="figure sliced">
|
| 643 |
+
<div class="slice"><img src="assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice1.jpg" alt="visual same-prompt contact sheet (1/3)"></div>
|
| 644 |
+
<div class="slice"><img src="assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice2.jpg" alt="visual same-prompt contact sheet (2/3)"></div>
|
| 645 |
+
<div class="slice"><img src="assets/small_data_style_ablation/visual_same_prompts_contactsheet_slice3.jpg" alt="visual same-prompt contact sheet (3/3)">
|
| 646 |
+
<p class="caption">
|
| 647 |
+
<strong>图 6.</strong> 同 prompt × 五组运行的 contact sheet (上→中→下三页连读)。
|
| 648 |
+
每列为一组运行,每行为一条 prompt。列名直接对应表 1 的算法配置差异。
|
| 649 |
+
风格 fingerprint 由暖金地平线、青蓝高空、painterly 厚笔触三要素共同定义;
|
| 650 |
+
运行 (a) 与 (d) 出现 atmospheric drift,运行 (b) 出现 palette 偏冷。
|
| 651 |
+
</p>
|
| 652 |
+
</div>
|
| 653 |
+
</div>
|
| 654 |
+
|
| 655 |
+
<h3>5.2 局部细节:图像重建质量劣化</h3>
|
| 656 |
+
|
| 657 |
+
<p>
|
| 658 |
+
§5.1 给出的是整图尺度上 palette / brushwork / 构图层面的偏移。但在原始
|
| 659 |
+
分辨率下查看局部,三个 v19 系列方案相比 v18 baseline 还呈现三类<strong>图像
|
| 660 |
+
重建质量</strong>层面的劣化,且每一类都与具体 lever 的训练分布偏置直接挂钩:
|
| 661 |
+
</p>
|
| 662 |
+
<ul>
|
| 663 |
+
<li><strong>横纹 (banding):</strong>云形或色温过渡上出现规则横向条纹,主要见于
|
| 664 |
+
运行 (a) 与 (d),与 velocity 损失的高 <em class="m">t</em> 加权 (§6.1) 一致。</li>
|
| 665 |
+
<li><strong>噪点 (speckle noise):</strong>大面积渐变上出现非自然色阶离散化、
|
| 666 |
+
细粒度高频纹理,主要见于运行 (b) 与 (d),与 logit-normal 时间步的低
|
| 667 |
+
<em class="m">t</em> 偏置 (§6.2) 一致。</li>
|
| 668 |
+
<li><strong>结构扭曲 (structural distortion):</strong>远景元素 (forest mass、
|
| 669 |
+
cloud column) 缺失或被压平、构图整体被裁切,集中见于运行 (d),是
|
| 670 |
+
上述两个分布偏置叠加效应。</li>
|
| 671 |
+
</ul>
|
| 672 |
+
<p>
|
| 673 |
+
以下分别从三个 prompt 截取细节区域 (图 7、图 8、图 9)。每图以 v18 baseline 在最
|
| 674 |
+
上,按 (a) (b) (d) 顺序向下排列。
|
| 675 |
+
</p>
|
| 676 |
+
|
| 677 |
+
<h4>5.2.1 Sample 00 森林暗部:竖向 striation 格纹 (运行 (d) 主)</h4>
|
| 678 |
+
<p>
|
| 679 |
+
Sample 00 (dense old-growth forest interior at dawn) 的画面左下、左侧粗树干、
|
| 680 |
+
右侧暗部都是低光强、低 SNR 区域,是 v-loss + logit-normal 组合下竖向格纹伪
|
| 681 |
+
影最先暴露的位置。运行 (d) 的左侧粗树干表面与右下苔藓层均出现可见的等间距
|
| 682 |
+
竖线。
|
| 683 |
+
</p>
|
| 684 |
+
|
| 685 |
+
<div class="closeup-figure">
|
| 686 |
+
<div class="closeup-row">
|
| 687 |
+
<div class="label"><span class="tag">v18</span><span class="sub">baseline</span></div>
|
| 688 |
+
<img src="assets/small_data_style_ablation/detail00_v18.jpg" alt="v18 forest bottom">
|
| 689 |
+
</div>
|
| 690 |
+
<div class="closeup-row">
|
| 691 |
+
<div class="label"><span class="tag">(a)</span><span class="sub">+v-loss</span></div>
|
| 692 |
+
<img src="assets/small_data_style_ablation/detail00_v19a.jpg" alt="v19a forest bottom">
|
| 693 |
+
</div>
|
| 694 |
+
<div class="closeup-row">
|
| 695 |
+
<div class="label"><span class="tag">(b)</span><span class="sub">+logit-normal <em class="m">t</em></span></div>
|
| 696 |
+
<img src="assets/small_data_style_ablation/detail00_v19b.jpg" alt="v19b forest bottom">
|
| 697 |
+
</div>
|
| 698 |
+
<div class="closeup-row">
|
| 699 |
+
<div class="label"><span class="tag">(c)</span><span class="sub">+cond. dropout</span></div>
|
| 700 |
+
<img src="assets/small_data_style_ablation/detail00_v19c.jpg" alt="v19c forest bottom">
|
| 701 |
+
</div>
|
| 702 |
+
<div class="closeup-row">
|
| 703 |
+
<div class="label"><span class="tag">(d)</span><span class="sub">+full official</span></div>
|
| 704 |
+
<img src="assets/small_data_style_ablation/detail00_v19.jpg" alt="v19 forest bottom with vertical striations">
|
| 705 |
+
</div>
|
| 706 |
+
<p class="caption">
|
| 707 |
+
<strong>图 7.</strong> Sample 00 画面下半 (左粗树干 + 中间地面 + 右下苔藓 +
|
| 708 |
+
右侧上层树叶) 五向对比。
|
| 709 |
+
<strong>v18:</strong>painterly chunky 厚笔触树干 + 暖色秋叶碎片 + 多层 silhouette
|
| 710 |
+
清晰可分;
|
| 711 |
+
<strong>(a):</strong><span style="color: #c53030; font-weight: 600;">god-rays 与 mist 大幅放大主导画面</span>,
|
| 712 |
+
painterly chunky 厚笔触被 atmospheric haze 稀释成更软的笔触;树干与
|
| 713 |
+
地面 silhouette 基本保留,主要问题是风格 softening 与雾化过度;
|
| 714 |
+
<strong>(b):</strong>painterly 厚笔触与树形保留较好,但
|
| 715 |
+
<span style="color: #c53030; font-weight: 600;">左下暗部 + 右下暗部仍有未完全恢复的轻度竖向条纹</span>
|
| 716 |
+
(强度远低于 (d));高光 god-rays 中段也出现轻度结构化痕迹;
|
| 717 |
+
<strong>(c):</strong>五行中重建质量最接近 v18 — painterly chunky 树干 + 完整地面
|
| 718 |
+
碎叶都保留,god-rays 强度略偏 (a) 但 painterly 信息完整,无格纹或条纹伪影;
|
| 719 |
+
<strong>(d):</strong><span style="color: #c53030; font-weight: 600;">左侧粗树干表面 + 中景树干 + 右下苔藓暗部出现明显竖向 striation 格纹</span>,
|
| 720 |
+
暗部 banding 最严重,painterly 信息几乎完全缺失。
|
| 721 |
+
</p>
|
| 722 |
+
</div>
|
| 723 |
+
|
| 724 |
+
<h4>5.2.2 Sample 02 湖岸:树 silhouette 与水面倒影竖纹</h4>
|
| 725 |
+
<p>
|
| 726 |
+
Sample 02 (deep mountain lake at dawn) 的两侧针叶林 silhouette 与中央镜面倒影
|
| 727 |
+
是考察树形 painterly 是否退化、倒影是否出现竖纹的核心区域。运行 (d) 在两侧
|
| 728 |
+
树林区 + 水面倒影区均呈现 screen-door 竖向条纹。
|
| 729 |
+
</p>
|
| 730 |
+
|
| 731 |
+
<div class="closeup-figure">
|
| 732 |
+
<div class="closeup-row">
|
| 733 |
+
<div class="label"><span class="tag">v18</span><span class="sub">baseline</span></div>
|
| 734 |
+
<img src="assets/small_data_style_ablation/detail02_v18.jpg" alt="v18 lakeside trees + reflection">
|
| 735 |
+
</div>
|
| 736 |
+
<div class="closeup-row">
|
| 737 |
+
<div class="label"><span class="tag">(a)</span><span class="sub">+v-loss</span></div>
|
| 738 |
+
<img src="assets/small_data_style_ablation/detail02_v19a.jpg" alt="v19a lakeside trees + reflection">
|
| 739 |
+
</div>
|
| 740 |
+
<div class="closeup-row">
|
| 741 |
+
<div class="label"><span class="tag">(b)</span><span class="sub">+logit-normal <em class="m">t</em></span></div>
|
| 742 |
+
<img src="assets/small_data_style_ablation/detail02_v19b.jpg" alt="v19b lakeside trees + reflection">
|
| 743 |
+
</div>
|
| 744 |
+
<div class="closeup-row">
|
| 745 |
+
<div class="label"><span class="tag">(c)</span><span class="sub">+cond. dropout</span></div>
|
| 746 |
+
<img src="assets/small_data_style_ablation/detail02_v19c.jpg" alt="v19c lakeside trees + reflection">
|
| 747 |
+
</div>
|
| 748 |
+
<div class="closeup-row">
|
| 749 |
+
<div class="label"><span class="tag">(d)</span><span class="sub">+full official</span></div>
|
| 750 |
+
<img src="assets/small_data_style_ablation/detail02_v19.jpg" alt="v19 lakeside trees + reflection with vertical striations">
|
| 751 |
+
</div>
|
| 752 |
+
<p class="caption">
|
| 753 |
+
<strong>图 8.</strong> Sample 02 中段 (左右两侧 painted tree silhouette + 中央
|
| 754 |
+
水面倒影 + 远雾) 五向对比。
|
| 755 |
+
<strong>v18:</strong>painterly chunky 树形 + 厚笔触倒影 + 远山三层 silhouette
|
| 756 |
+
清晰可分;
|
| 757 |
+
<strong>(a):</strong>整图 photoreal 化,painterly 笔触退化为 stock-photo 树形,
|
| 758 |
+
倒影保留 photoreal 镜面但失去厚笔触;
|
| 759 |
+
<strong>(b):</strong>painted 风格部分保留,但远雾 + 倒影出现微弱竖向条纹;
|
| 760 |
+
<strong>(c):</strong>painted 风格保留较好,树形+倒影皆有 painterly 触感,未观察
|
| 761 |
+
到明显伪影;
|
| 762 |
+
<strong>(d):</strong><span style="color: #c53030; font-weight: 600;">左侧山体 + 右侧针叶林 silhouette + 整个倒影区出现 screen-door 竖向 striation 格纹</span>,
|
| 763 |
+
覆盖范围广。
|
| 764 |
+
</p>
|
| 765 |
+
</div>
|
| 766 |
+
|
| 767 |
+
<h4>5.2.3 Sample 07 蒲公英主体:grass 区竖向条纹 + bokeh 退化</h4>
|
| 768 |
+
<p>
|
| 769 |
+
Sample 07 (backlit dandelion field) 主体为前景蒲公英 + 草丛 + 背光,是考察
|
| 770 |
+
painterly 笔触是否被替换为 photographic bokeh、草丛暗区是否出现条纹的核心
|
| 771 |
+
区域。
|
| 772 |
+
</p>
|
| 773 |
+
|
| 774 |
+
<div class="closeup-figure">
|
| 775 |
+
<div class="closeup-row">
|
| 776 |
+
<div class="label"><span class="tag">v18</span><span class="sub">baseline</span></div>
|
| 777 |
+
<img src="assets/small_data_style_ablation/detail07_v18.jpg" alt="v18 dandelion main subject">
|
| 778 |
+
</div>
|
| 779 |
+
<div class="closeup-row">
|
| 780 |
+
<div class="label"><span class="tag">(a)</span><span class="sub">+v-loss</span></div>
|
| 781 |
+
<img src="assets/small_data_style_ablation/detail07_v19a.jpg" alt="v19a dandelion main subject">
|
| 782 |
+
</div>
|
| 783 |
+
<div class="closeup-row">
|
| 784 |
+
<div class="label"><span class="tag">(b)</span><span class="sub">+logit-normal <em class="m">t</em></span></div>
|
| 785 |
+
<img src="assets/small_data_style_ablation/detail07_v19b.jpg" alt="v19b dandelion main subject">
|
| 786 |
+
</div>
|
| 787 |
+
<div class="closeup-row">
|
| 788 |
+
<div class="label"><span class="tag">(c)</span><span class="sub">+cond. dropout</span></div>
|
| 789 |
+
<img src="assets/small_data_style_ablation/detail07_v19c.jpg" alt="v19c dandelion main subject">
|
| 790 |
+
</div>
|
| 791 |
+
<div class="closeup-row">
|
| 792 |
+
<div class="label"><span class="tag">(d)</span><span class="sub">+full official</span></div>
|
| 793 |
+
<img src="assets/small_data_style_ablation/detail07_v19.jpg" alt="v19 dandelion main subject with cross-hatch grid">
|
| 794 |
+
</div>
|
| 795 |
+
<p class="caption">
|
| 796 |
+
<strong>图 9.</strong> Sample 07 主体区 (前景蒲公英 + 草丛暗部 + 暖背光) 五向
|
| 797 |
+
对比。
|
| 798 |
+
<strong>v18:</strong>painterly chunky 笔触草叶 + painted 蒲公英 + 远景树线
|
| 799 |
+
可见;
|
| 800 |
+
<strong>(a):</strong><span style="color: #c53030; font-weight: 600;">painterly 完全退化为 photographic bokeh</span>,
|
| 801 |
+
前景蒲公英变疏散小,背景暗化呈相机散景;
|
| 802 |
+
<strong>(b):</strong>painterly 较为完整 — chunky 草叶 + painted 蒲公英都还在,
|
| 803 |
+
主要变化是视野相对收紧 (前景密度减少);草丛暗部存在轻度的竖向 streak,密度
|
| 804 |
+
低于 (d);
|
| 805 |
+
<strong>(c):</strong>painterly fingerprint 接近 v18 — 蒲公英密度高、chunky
|
| 806 |
+
草叶 + 浮散种子全保留,无观察到格纹伪影;
|
| 807 |
+
<strong>(d):</strong><span style="color: #c53030; font-weight: 600;">草丛暗部出现密集 cross-hatch 格纹</span>
|
| 808 |
+
(竖向 striation 为主,叠加微横向 banding),painterly 信息几乎完全缺失,整图
|
| 809 |
+
带 halftone 质感。
|
| 810 |
+
</p>
|
| 811 |
+
</div>
|
| 812 |
+
|
| 813 |
+
<h3>5.3 Prompt 前缀格式作为混淆变量</h3>
|
| 814 |
+
|
| 815 |
+
<p>
|
| 816 |
+
我们在受控分析中发现,若 inference prompt 的 artist anchor 位置与训练 caption
|
| 817 |
+
不一致,会引入与 lever 无关但视觉量级相当的差异。因此对 Baseline 与运行 (c) 额外
|
| 818 |
+
做两组 prompt 集对比,结果见图 10。该结果说明把 prompt 写法与训练分布对齐是
|
| 819 |
+
正确归因 lever 效果的必要前置条件,并解释了我们对 README 默认配置中保留
|
| 820 |
+
显式 <code>style.trigger</code> 与忽略 think sidecar 的选择。
|
| 821 |
+
</p>
|
| 822 |
+
|
| 823 |
+
<div class="figure sliced narrow">
|
| 824 |
+
<div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice1.jpg" alt="baseline prompt-set contact sheet (1/4)"></div>
|
| 825 |
+
<div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice2.jpg" alt="baseline prompt-set contact sheet (2/4)"></div>
|
| 826 |
+
<div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice3.jpg" alt="baseline prompt-set contact sheet (3/4)"></div>
|
| 827 |
+
<div class="slice"><img src="assets/small_data_style_ablation/baseline_prompt_sets_contactsheet_slice4.jpg" alt="baseline prompt-set contact sheet (4/4)">
|
| 828 |
+
<p class="caption">
|
| 829 |
+
<strong>图 10.</strong> Baseline 在两组 prompt 集下的采样 (上→下四页连读)。
|
| 830 |
+
左:与训练同分布的短前缀 (含显式 style trigger);右:分布外的长 prompt。
|
| 831 |
+
完整采样在分布内稳定;这表明 Baseline 的强 fingerprint 不依赖额外 prefix 文本。
|
| 832 |
+
</p>
|
| 833 |
+
</div>
|
| 834 |
+
</div>
|
| 835 |
+
|
| 836 |
+
<h2>6. 讨论</h2>
|
| 837 |
+
|
| 838 |
+
<h3>6.1 Velocity 损失的隐式重加权</h3>
|
| 839 |
+
<p>
|
| 840 |
+
由式 (2),<em class="m-cal">L</em><sub class="m">v</sub> =
|
| 841 |
+
𝔼 [‖<em class="m">x</em><sub class="m">θ</sub> − <em class="m">x</em><sub>0</sub>‖<sup>2</sup>
|
| 842 |
+
· (1 − <em class="m">t</em>)<sup>−2</sup>],velocity 损失等价于在 x<sub>0</sub>-MSE
|
| 843 |
+
上施加一个 (1 − <em class="m">t</em>)<sup>−2</sup> 的权重函数。在 uniform
|
| 844 |
+
<em class="m">t</em> 下,该权重在 <em class="m">t</em> ∈ [0.9, 1.0] 区间承担约 64%
|
| 845 |
+
的总积分质量。意即在 velocity 训练中,模型有大半的梯度信号被分配到 "接近 clean
|
| 846 |
+
图但仍残留少量噪声" 的窗口。对于风格 LoRA 而言,颜色调性、构图、笔触特征均在
|
| 847 |
+
<em class="m">t</em> ∈ [0.3, 0.7] 这一中频窗口形成;velocity 损失显著削减了该窗口的
|
| 848 |
+
有效更新量,因而出现了 §5.1 (ii) 报告的 photoreal 夜景偏移。
|
| 849 |
+
</p>
|
| 850 |
+
|
| 851 |
+
<h3>6.2 Logit-normal 与小数据匹配的张力</h3>
|
| 852 |
+
<p>
|
| 853 |
+
logit-normal (<em class="m">μ</em> = −0.8, <em class="m">σ</em> = 0.8) 将
|
| 854 |
+
𝔼[<em class="m">t</em>] 从 0.5 移到 0.34,即把更多概率质量推向 noisy 端。该选择在
|
| 855 |
+
大规模数据下使得各 <em class="m">t</em> 段都获得足够样本数;但在 56 张图、6000
|
| 856 |
+
步、batch=1 的小数据规模下,每个 mid-<em class="m">t</em> bin 实际获得的有效样
|
| 857 |
+
本数本就稀疏,进一步的密度偏移导致 palette 与 mid-frequency texture 统计学习
|
| 858 |
+
不足,对应 §5.1 (iii) 报告的整体偏冷。
|
| 859 |
+
</p>
|
| 860 |
+
|
| 861 |
+
<h3>6.3 Condition dropout 的正面作用</h3>
|
| 862 |
+
<p>
|
| 863 |
+
与上述两项不同,condition dropout 在小数据风格 LoRA 上未观察到风格损伤。其作
|
| 864 |
+
用机制是把 unconditional branch 与 conditional branch 的相对几何在训练期同时
|
| 865 |
+
暴露给模型,从而使 inference 时 CFG 的 uncond + <em class="m">s</em> · (cond − uncond)
|
| 866 |
+
推算成立。在 <em class="m">s</em> = 4.0 这一相对较强的 CFG 下,没有 dropout 训练
|
| 867 |
+
的 LoRA 经常把输出推到 LoRA 未覆盖的方向,导致结构 artefact。运行 (c) 与 Baseline
|
| 868 |
+
的视觉对比显示,前者在风格指纹保持不变的前提下,前景结构 (puffball 形状、
|
| 869 |
+
树干轮廓) 更清晰、稳定。
|
| 870 |
+
</p>
|
| 871 |
+
|
| 872 |
+
<h3>6.4 Prompt 前缀格式的去混淆作用</h3>
|
| 873 |
+
<p>
|
| 874 |
+
§5.3 显示 prompt 中 artist anchor 的位置 (句首 vs. 句末 tag vs. 完全缺失) 与训练
|
| 875 |
+
caption 分布对齐与否直接决定 LoRA delta 是否被正确激活。这一变量与目标 lever
|
| 876 |
+
正交,但视觉量级与 lever 效果相当;任何不控制该变量的 lever 评估都会被
|
| 877 |
+
prompt-format mismatch 污染。本文 §5.1、§5.2 与 §5.3 结论均在 v2 anchored prompt
|
| 878 |
+
下复测确认。
|
| 879 |
+
</p>
|
| 880 |
+
|
| 881 |
+
<h2>7. 结论与发布决策</h2>
|
| 882 |
+
<p>
|
| 883 |
+
在 SenseNova-U1-8B-MoT × 56 张 Hayateluc 风格 LoRA 微调上:
|
| 884 |
+
(1) velocity 损失系统性向高 <em class="m">t</em> 端搬移梯度密度,导致天空与大面积
|
| 885 |
+
渐变区域出现规则横向条纹 banding 与 painterly→photoreal 退化;
|
| 886 |
+
(2) logit-normal 时间步在小数据下加剧 mid-<em class="m">t</em> 样本稀疏,引发 palette
|
| 887 |
+
偏冷与大面积平滑色块的 speckle 噪点;
|
| 888 |
+
(3) (1) + (2) 叠加进一步引入远景元素缺失与构图压缩等结构扭曲,劣化超出 style
|
| 889 |
+
drift 范畴;
|
| 890 |
+
(4) condition dropout 是三项中唯一无损反而有益的算法层;
|
| 891 |
+
(5) prompt 前缀格式与训练分布的匹配是评估前置条件,必须先控制。
|
| 892 |
+
</p>
|
| 893 |
+
<p>
|
| 894 |
+
据此,我们��� <code>configs/default.yaml</code> 保留为本地基线
|
| 895 |
+
(<code>x0 + uniform t + 0 dropout</code> + 显式 style trigger + 忽略 think
|
| 896 |
+
sidecar),并以 <code>configs/official_alignment.yaml</code> 形式发布官方对齐组
|
| 897 |
+
合,标注其用途为研究复现。该选择不否定公开报告中的算法设计,只反映其与本仓
|
| 898 |
+
库目标 (小数据风格 LoRA) 的不匹配。
|
| 899 |
+
</p>
|
| 900 |
+
|
| 901 |
+
<h3>默认配置摘要</h3>
|
| 902 |
+
<pre>data:
|
| 903 |
+
use_think_labels: false
|
| 904 |
+
|
| 905 |
+
style:
|
| 906 |
+
trigger: "my style"
|
| 907 |
+
prompt_template: official
|
| 908 |
+
|
| 909 |
+
lora:
|
| 910 |
+
preset: attn_mlp_no_head
|
| 911 |
+
|
| 912 |
+
unfreeze:
|
| 913 |
+
- '^fm_modules\.timestep_embedder\.'
|
| 914 |
+
- '^fm_modules\.noise_scale_embedder\.'
|
| 915 |
+
- '^fm_modules\.vision_model_mot_gen\.'
|
| 916 |
+
- '^fm_modules\.fm_head\.'
|
| 917 |
+
|
| 918 |
+
train:
|
| 919 |
+
loss_type: x0
|
| 920 |
+
t_dist: uniform
|
| 921 |
+
cond_dropout_text: 0.0
|
| 922 |
+
cond_dropout_both: 0.0
|
| 923 |
+
</pre>
|
| 924 |
+
|
| 925 |
+
<h2>参考文献</h2>
|
| 926 |
+
<ol class="refs">
|
| 927 |
+
<li>SenseNova-U1 Technical Report. <em>OpenSenseNova</em>, 2026.
|
| 928 |
+
<span class="small">https://github.com/OpenSenseNova/SenseNova-U1</span></li>
|
| 929 |
+
<li>Lipman, Y. et al. Flow Matching for Generative Modeling. <em>ICLR</em>, 2023.</li>
|
| 930 |
+
<li>Esser, P. et al. Scaling Rectified Flow Transformers for High-Resolution
|
| 931 |
+
Image Synthesis. <em>ICML</em>, 2024 — 提出 logit-normal <em class="m">t</em> 采样.</li>
|
| 932 |
+
<li>Ho, J. & Salimans, T. Classifier-Free Diffusion Guidance.
|
| 933 |
+
<em>NeurIPS Workshop on Deep Generative Models</em>, 2021.</li>
|
| 934 |
+
<li>Hu, E. J. et al. LoRA: Low-Rank Adaptation of Large Language Models.
|
| 935 |
+
<em>ICLR</em>, 2022.</li>
|
| 936 |
+
</ol>
|
| 937 |
+
|
| 938 |
+
</section>
|
| 939 |
+
|
| 940 |
+
</body>
|
| 941 |
+
</html>
|
docs/small_data_style_ablation.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:446f868e6c5fe4518cebe60baf5d5caed50ff0f1b0bad87b693265550b19daad
|
| 3 |
+
size 9161869
|
train.sh
CHANGED
|
@@ -2,12 +2,12 @@
|
|
| 2 |
# train.sh — quick-launch a LoRA training run from a YAML config.
|
| 3 |
#
|
| 4 |
# Usage:
|
| 5 |
-
# ./train.sh
|
| 6 |
-
# ./train.sh configs/
|
| 7 |
-
# ./train.sh configs/
|
| 8 |
#
|
| 9 |
# Long-running training tip:
|
| 10 |
-
# setsid nohup ./train.sh configs/
|
| 11 |
# disown
|
| 12 |
# (a bare `nohup &` can be SIGHUP'd when the SSH/IDE session disconnects).
|
| 13 |
|
|
|
|
| 2 |
# train.sh — quick-launch a LoRA training run from a YAML config.
|
| 3 |
#
|
| 4 |
# Usage:
|
| 5 |
+
# ./train.sh # uses configs/default.yaml
|
| 6 |
+
# ./train.sh configs/my_style.yaml
|
| 7 |
+
# ./train.sh configs/my_style.yaml --steps 12000 # extra args forwarded
|
| 8 |
#
|
| 9 |
# Long-running training tip:
|
| 10 |
+
# setsid nohup ./train.sh configs/my_style.yaml </dev/null >run.log 2>&1 &
|
| 11 |
# disown
|
| 12 |
# (a bare `nohup &` can be SIGHUP'd when the SSH/IDE session disconnects).
|
| 13 |
|
train_u1/README.md
CHANGED
|
@@ -49,7 +49,10 @@ train_u1/
|
|
| 49 |
```
|
| 50 |
|
| 51 |
**未实现 / pending**:
|
| 52 |
-
- `model/peft_targets.py` — LoRA target 工具
|
|
|
|
|
|
|
|
|
|
| 53 |
- `scripts/train_balanced.py` — 48GB 平衡场景(实验 C 后视效果决定)
|
| 54 |
- `scripts/sample_t2i.py` 完整管线(依赖 wrapper 的 `t2i_generate` 钩子)
|
| 55 |
|
|
|
|
| 49 |
```
|
| 50 |
|
| 51 |
**未实现 / pending**:
|
| 52 |
+
- `model/peft_targets.py` — 独立 LoRA target 工具仍未拆出;当前
|
| 53 |
+
`model/lora.py` 已内置 8B dense target 与实验性 A3B generation MoE
|
| 54 |
+
target grammar(`gen_moe_mlp` / `gen_moe_router`)。A3B/MoE 不是当前
|
| 55 |
+
主训练路径,必须等可实例化 `mlp_mot_gen.experts.*` 的公开 runtime。
|
| 56 |
- `scripts/train_balanced.py` — 48GB 平衡场景(实验 C 后视效果决定)
|
| 57 |
- `scripts/sample_t2i.py` 完整管线(依赖 wrapper 的 `t2i_generate` 钩子)
|
| 58 |
|
train_u1/config.py
CHANGED
|
@@ -15,12 +15,15 @@ A single YAML file describes a complete LoRA training run. Example::
|
|
| 15 |
prompt_template: official # or 'plain'
|
| 16 |
|
| 17 |
lora:
|
| 18 |
-
preset:
|
| 19 |
-
# spec: "attn=r64a64;mlp=r64a64;
|
| 20 |
dropout: 0.0
|
| 21 |
|
| 22 |
unfreeze: # full-finetune (non-LoRA) regex patterns
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
train:
|
| 26 |
steps: 6000
|
|
@@ -29,6 +32,10 @@ A single YAML file describes a complete LoRA training run. Example::
|
|
| 29 |
shuffle: true
|
| 30 |
grad_accum: 1
|
| 31 |
checkpoint_every: 600
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
runtime:
|
| 34 |
keep_kvs_on_gpu: true
|
|
@@ -68,6 +75,7 @@ class DataConfig:
|
|
| 68 |
snap_bucket: bool = True
|
| 69 |
n_samples: int | None = None # default: use entire dataset
|
| 70 |
sample_buckets_file: str | None = None
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
@dataclass
|
|
@@ -78,7 +86,7 @@ class StyleConfig:
|
|
| 78 |
|
| 79 |
@dataclass
|
| 80 |
class LoRAConfig:
|
| 81 |
-
preset: str | None = "
|
| 82 |
spec: str | None = None # overrides preset if set
|
| 83 |
dropout: float = 0.0
|
| 84 |
|
|
@@ -107,6 +115,25 @@ class TrainConfig:
|
|
| 107 |
grad_accum: int = 1
|
| 108 |
checkpoint_every: int = 600
|
| 109 |
checkpoint_dir: str | None = None # default: artifacts/{run_name}/checkpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
@dataclass
|
|
@@ -119,13 +146,22 @@ class RuntimeConfig:
|
|
| 119 |
upstream_lora_skip: tuple[str, ...] = ()
|
| 120 |
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
@dataclass
|
| 123 |
class TrainRunConfig:
|
| 124 |
run_name: str = "my_run"
|
| 125 |
data: DataConfig = field(default_factory=DataConfig)
|
| 126 |
style: StyleConfig = field(default_factory=StyleConfig)
|
| 127 |
lora: LoRAConfig = field(default_factory=LoRAConfig)
|
| 128 |
-
unfreeze: list[str] = field(default_factory=
|
| 129 |
train: TrainConfig = field(default_factory=TrainConfig)
|
| 130 |
runtime: RuntimeConfig = field(default_factory=RuntimeConfig)
|
| 131 |
|
|
|
|
| 15 |
prompt_template: official # or 'plain'
|
| 16 |
|
| 17 |
lora:
|
| 18 |
+
preset: attn_mlp_no_head # small-data baseline: LoRA attn+mlp only
|
| 19 |
+
# spec: "attn=r64a64;mlp=r64a64;mlp_mot_gen.down_proj=off"
|
| 20 |
dropout: 0.0
|
| 21 |
|
| 22 |
unfreeze: # full-finetune (non-LoRA) regex patterns
|
| 23 |
+
- '^fm_modules\\.timestep_embedder\\.'
|
| 24 |
+
- '^fm_modules\\.noise_scale_embedder\\.'
|
| 25 |
+
- '^fm_modules\\.vision_model_mot_gen\\.'
|
| 26 |
+
- '^fm_modules\\.fm_head\\.'
|
| 27 |
|
| 28 |
train:
|
| 29 |
steps: 6000
|
|
|
|
| 32 |
shuffle: true
|
| 33 |
grad_accum: 1
|
| 34 |
checkpoint_every: 600
|
| 35 |
+
loss_type: x0
|
| 36 |
+
t_dist: uniform
|
| 37 |
+
cond_dropout_text: 0.0
|
| 38 |
+
cond_dropout_both: 0.0
|
| 39 |
|
| 40 |
runtime:
|
| 41 |
keep_kvs_on_gpu: true
|
|
|
|
| 75 |
snap_bucket: bool = True
|
| 76 |
n_samples: int | None = None # default: use entire dataset
|
| 77 |
sample_buckets_file: str | None = None
|
| 78 |
+
use_think_labels: bool = False
|
| 79 |
|
| 80 |
|
| 81 |
@dataclass
|
|
|
|
| 86 |
|
| 87 |
@dataclass
|
| 88 |
class LoRAConfig:
|
| 89 |
+
preset: str | None = "attn_mlp_no_head" # one of LORA_PRESETS
|
| 90 |
spec: str | None = None # overrides preset if set
|
| 91 |
dropout: float = 0.0
|
| 92 |
|
|
|
|
| 115 |
grad_accum: int = 1
|
| 116 |
checkpoint_every: int = 600
|
| 117 |
checkpoint_dir: str | None = None # default: artifacts/{run_name}/checkpoints
|
| 118 |
+
# FM loss objective. Default is the local small-data baseline (`x0`) because
|
| 119 |
+
# the ablation study showed that official-style v-loss is not a good
|
| 120 |
+
# small-data style-training default. `v` remains available for explicit
|
| 121 |
+
# official alignment experiments.
|
| 122 |
+
# Choose one of `x0` | `v` | `x0_huber` | `v_huber`.
|
| 123 |
+
loss_type: str = "x0"
|
| 124 |
+
huber_delta: float = 1.0
|
| 125 |
+
# FM `t`-sampling distribution. Default is uniform for the same local
|
| 126 |
+
# baseline reason. `logit_normal` is kept for report-alignment ablations.
|
| 127 |
+
t_dist: str = "uniform"
|
| 128 |
+
t_logit_mean: float = -0.8
|
| 129 |
+
t_logit_std: float = 0.8
|
| 130 |
+
# CFG / condition dropout. `cond_dropout_text` drops text condition only;
|
| 131 |
+
# `cond_dropout_both` is the additional unconditional bucket from the
|
| 132 |
+
# report. In the current pure-T2I trainer there is no separate reference
|
| 133 |
+
# image condition, so both modes use the sampler's unconditional prompt
|
| 134 |
+
# prefix while preserving separate log labels.
|
| 135 |
+
cond_dropout_text: float = 0.0
|
| 136 |
+
cond_dropout_both: float = 0.0
|
| 137 |
|
| 138 |
|
| 139 |
@dataclass
|
|
|
|
| 146 |
upstream_lora_skip: tuple[str, ...] = ()
|
| 147 |
|
| 148 |
|
| 149 |
+
def _default_unfreeze_patterns() -> list[str]:
|
| 150 |
+
return [
|
| 151 |
+
r"^fm_modules\.timestep_embedder\.",
|
| 152 |
+
r"^fm_modules\.noise_scale_embedder\.",
|
| 153 |
+
r"^fm_modules\.vision_model_mot_gen\.",
|
| 154 |
+
r"^fm_modules\.fm_head\.",
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
@dataclass
|
| 159 |
class TrainRunConfig:
|
| 160 |
run_name: str = "my_run"
|
| 161 |
data: DataConfig = field(default_factory=DataConfig)
|
| 162 |
style: StyleConfig = field(default_factory=StyleConfig)
|
| 163 |
lora: LoRAConfig = field(default_factory=LoRAConfig)
|
| 164 |
+
unfreeze: list[str] = field(default_factory=_default_unfreeze_patterns)
|
| 165 |
train: TrainConfig = field(default_factory=TrainConfig)
|
| 166 |
runtime: RuntimeConfig = field(default_factory=RuntimeConfig)
|
| 167 |
|
train_u1/data/collators.py
CHANGED
|
@@ -21,8 +21,6 @@ from typing import Iterable
|
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
| 24 |
-
import math
|
| 25 |
-
|
| 26 |
from train_u1.constants import (
|
| 27 |
FM_OUTPUT_DIM,
|
| 28 |
NOISE_SCALE_BASE_IMAGE_SEQ_LEN,
|
|
@@ -50,7 +48,23 @@ class CollatorConfig:
|
|
| 50 |
# then batch=1).
|
| 51 |
image_hw: tuple[int, int] | None = (512, 512)
|
| 52 |
t_eps: float = T_EPS_DEFAULT
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
add_noise_scale: bool = True
|
| 55 |
# Base noise_scale value (config.noise_scale = 1.0). The *effective* per-sample
|
| 56 |
# noise_scale is computed at collator runtime as
|
|
@@ -80,6 +94,12 @@ class CollatorConfig:
|
|
| 80 |
# Format used: f"{style_trigger}, {original_caption}".
|
| 81 |
style_trigger: str = ""
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
class SenseNovaU1Collator:
|
| 85 |
"""Stateful collator: holds tokenizer + config, callable on a list of `T2ISample`.
|
|
@@ -94,6 +114,7 @@ class SenseNovaU1Collator:
|
|
| 94 |
self.tok = tokenizer
|
| 95 |
self.cfg = cfg or CollatorConfig()
|
| 96 |
self._gen = torch.Generator().manual_seed(self.cfg.seed)
|
|
|
|
| 97 |
if self.cfg.prompt_template == "official":
|
| 98 |
if model is None or not hasattr(model, "_build_t2i_query"):
|
| 99 |
raise ValueError(
|
|
@@ -147,10 +168,87 @@ class SenseNovaU1Collator:
|
|
| 147 |
if self.cfg.t_dist == "uniform":
|
| 148 |
t = torch.rand(batch_size, generator=self._gen)
|
| 149 |
t = t * (1.0 - self.cfg.t_eps) + self.cfg.t_eps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
else:
|
| 151 |
raise NotImplementedError(f"t_dist={self.cfg.t_dist}")
|
| 152 |
return t
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
@staticmethod
|
| 155 |
def _check_image_hw(image_hw: tuple[int, int]) -> None:
|
| 156 |
H, W = image_hw
|
|
@@ -163,7 +261,12 @@ class SenseNovaU1Collator:
|
|
| 163 |
# ------------------------------------------------------------------ #
|
| 164 |
# Main entry #
|
| 165 |
# ------------------------------------------------------------------ #
|
| 166 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
cfg = self.cfg
|
| 168 |
if cfg.enforce_batch_one and len(samples) != 1:
|
| 169 |
raise ValueError(
|
|
@@ -196,38 +299,16 @@ class SenseNovaU1Collator:
|
|
| 196 |
# 1) text → ids + per-sample lengths. With enforce_batch_one we know
|
| 197 |
# `len(samples) == 1` so no batch padding is applied — `L_text` is
|
| 198 |
# exactly this prompt's length (matches upstream `_build_t2i_text_inputs`).
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
# assistant). Identical formatting must be replicated at sample time.
|
| 202 |
-
if cfg.style_trigger:
|
| 203 |
-
raw_prompts = [f"{cfg.style_trigger}, {s.prompt}" for s in samples]
|
| 204 |
else:
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
# Per-sample think injection: when the dataset supplies a
|
| 210 |
-
# `think` text, render it INSIDE the otherwise-empty
|
| 211 |
-
# `<think></think>` block of the official prompt template.
|
| 212 |
-
# This makes training distribution match inference-time
|
| 213 |
-
# `--think-mode`, where the model autoregressively fills the
|
| 214 |
-
# same window with ~250-400 reasoning tokens. Without this,
|
| 215 |
-
# the gen tower sees an unfamiliar prefix length/content
|
| 216 |
-
# at inference and the LoRA delta is calibrated against
|
| 217 |
-
# the wrong cond-KV distribution.
|
| 218 |
-
if s.think:
|
| 219 |
-
append_text = f"<think>\n{s.think}\n</think>\n\n<img>"
|
| 220 |
-
else:
|
| 221 |
-
append_text = self._gen_append
|
| 222 |
-
prompts.append(
|
| 223 |
-
self._build_t2i_query(
|
| 224 |
-
rp,
|
| 225 |
-
system_message=self._sys_msg_for_gen,
|
| 226 |
-
append_text=append_text,
|
| 227 |
-
)
|
| 228 |
)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
input_ids, text_lens = self._tokenize(prompts)
|
| 232 |
B = input_ids.shape[0]
|
| 233 |
L_text = input_ids.shape[1]
|
|
@@ -298,6 +379,12 @@ class SenseNovaU1Collator:
|
|
| 298 |
"noisy_pixel_values": noisy_pixel_values, # (B, 3, H, W)
|
| 299 |
"noisy_grid_hw": noisy_grid_hw, # (B, 2)
|
| 300 |
"noise_scale": noise_scale, # (B,) or None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
"sample_ids": [s.sample_id for s in samples],
|
| 302 |
"text_lens": text_lens,
|
| 303 |
"token_hw": (token_h, token_w),
|
|
|
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
|
|
|
|
|
|
| 24 |
from train_u1.constants import (
|
| 25 |
FM_OUTPUT_DIM,
|
| 26 |
NOISE_SCALE_BASE_IMAGE_SEQ_LEN,
|
|
|
|
| 48 |
# then batch=1).
|
| 49 |
image_hw: tuple[int, int] | None = (512, 512)
|
| 50 |
t_eps: float = T_EPS_DEFAULT
|
| 51 |
+
# **Default = `logit_normal` (mean=-0.8, std=0.8)** to match SenseNova-U1
|
| 52 |
+
# report Table 2:
|
| 53 |
+
# u ~ Normal(t_logit_mean, t_logit_std); t = sigmoid(u)
|
| 54 |
+
# clamped to [t_eps, 1 - t_eps]. Biases t toward the *low* end (near-clean);
|
| 55 |
+
# combined with v-loss this recovers the official training density.
|
| 56 |
+
# `uniform` on (t_eps, 1] is kept for back-compat / ablation.
|
| 57 |
+
t_dist: str = "logit_normal"
|
| 58 |
+
t_logit_mean: float = -0.8
|
| 59 |
+
t_logit_std: float = 0.8
|
| 60 |
+
# Classifier-free guidance condition dropout. Keep the collator default at
|
| 61 |
+
# zero so eval/smoke/diagnostic callers stay fully conditional unless they
|
| 62 |
+
# opt in. The official training entry point passes 0.10/0.10 from TrainConfig.
|
| 63 |
+
# For pure T2I training there is no separate reference-image condition, so
|
| 64 |
+
# `text_image` uses the same unconditional prompt path as `text` while
|
| 65 |
+
# recording the intended mode.
|
| 66 |
+
cond_dropout_text: float = 0.0
|
| 67 |
+
cond_dropout_both: float = 0.0
|
| 68 |
add_noise_scale: bool = True
|
| 69 |
# Base noise_scale value (config.noise_scale = 1.0). The *effective* per-sample
|
| 70 |
# noise_scale is computed at collator runtime as
|
|
|
|
| 94 |
# Format used: f"{style_trigger}, {original_caption}".
|
| 95 |
style_trigger: str = ""
|
| 96 |
|
| 97 |
+
def __post_init__(self) -> None:
|
| 98 |
+
if self.cond_dropout_text < 0 or self.cond_dropout_both < 0:
|
| 99 |
+
raise ValueError("condition dropout probabilities must be non-negative")
|
| 100 |
+
if self.cond_dropout_text + self.cond_dropout_both > 1.0:
|
| 101 |
+
raise ValueError("cond_dropout_text + cond_dropout_both must be <= 1.0")
|
| 102 |
+
|
| 103 |
|
| 104 |
class SenseNovaU1Collator:
|
| 105 |
"""Stateful collator: holds tokenizer + config, callable on a list of `T2ISample`.
|
|
|
|
| 114 |
self.tok = tokenizer
|
| 115 |
self.cfg = cfg or CollatorConfig()
|
| 116 |
self._gen = torch.Generator().manual_seed(self.cfg.seed)
|
| 117 |
+
self._cond_gen = torch.Generator().manual_seed(self.cfg.seed + 10_003)
|
| 118 |
if self.cfg.prompt_template == "official":
|
| 119 |
if model is None or not hasattr(model, "_build_t2i_query"):
|
| 120 |
raise ValueError(
|
|
|
|
| 168 |
if self.cfg.t_dist == "uniform":
|
| 169 |
t = torch.rand(batch_size, generator=self._gen)
|
| 170 |
t = t * (1.0 - self.cfg.t_eps) + self.cfg.t_eps
|
| 171 |
+
elif self.cfg.t_dist == "logit_normal":
|
| 172 |
+
# u ~ N(mu, sigma) -> t = sigmoid(u) in (0, 1)
|
| 173 |
+
u = torch.randn(batch_size, generator=self._gen)
|
| 174 |
+
u = u * self.cfg.t_logit_std + self.cfg.t_logit_mean
|
| 175 |
+
t = torch.sigmoid(u)
|
| 176 |
+
t = t.clamp(min=self.cfg.t_eps, max=1.0 - self.cfg.t_eps)
|
| 177 |
else:
|
| 178 |
raise NotImplementedError(f"t_dist={self.cfg.t_dist}")
|
| 179 |
return t
|
| 180 |
|
| 181 |
+
def _sample_condition_modes(self, batch_size: int) -> list[str]:
|
| 182 |
+
p_text = float(self.cfg.cond_dropout_text)
|
| 183 |
+
p_both = float(self.cfg.cond_dropout_both)
|
| 184 |
+
if p_text == 0.0 and p_both == 0.0:
|
| 185 |
+
return ["none"] * batch_size
|
| 186 |
+
u = torch.rand(batch_size, generator=self._cond_gen)
|
| 187 |
+
modes: list[str] = []
|
| 188 |
+
for v in u.tolist():
|
| 189 |
+
if v < p_text:
|
| 190 |
+
modes.append("text")
|
| 191 |
+
elif v < p_text + p_both:
|
| 192 |
+
modes.append("text_image")
|
| 193 |
+
else:
|
| 194 |
+
modes.append("none")
|
| 195 |
+
return modes
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def _prefix_cache_key(mode: str) -> str:
|
| 199 |
+
if mode == "none":
|
| 200 |
+
return "cond"
|
| 201 |
+
if mode in ("text", "text_image"):
|
| 202 |
+
return "uncond"
|
| 203 |
+
raise ValueError(f"unknown condition dropout mode {mode!r}")
|
| 204 |
+
|
| 205 |
+
def _render_prompts(
|
| 206 |
+
self,
|
| 207 |
+
samples: list[T2ISample],
|
| 208 |
+
condition_modes: list[str],
|
| 209 |
+
) -> list[str]:
|
| 210 |
+
prompts: list[str] = []
|
| 211 |
+
for s, mode in zip(samples, condition_modes):
|
| 212 |
+
if mode not in ("none", "text", "text_image"):
|
| 213 |
+
raise ValueError(f"unknown condition dropout mode {mode!r}")
|
| 214 |
+
drop_text = mode in ("text", "text_image")
|
| 215 |
+
if drop_text:
|
| 216 |
+
raw_prompt = ""
|
| 217 |
+
elif self.cfg.style_trigger:
|
| 218 |
+
raw_prompt = f"{self.cfg.style_trigger}, {s.prompt}"
|
| 219 |
+
else:
|
| 220 |
+
raw_prompt = s.prompt
|
| 221 |
+
|
| 222 |
+
if self._build_t2i_query is not None:
|
| 223 |
+
if drop_text:
|
| 224 |
+
# Match the sampler's unconditional CFG prefix exactly:
|
| 225 |
+
# `_build_t2i_query("", append_text="<img>")`.
|
| 226 |
+
prompts.append(self._build_t2i_query("", append_text="<img>"))
|
| 227 |
+
continue
|
| 228 |
+
# Per-sample think injection: when the dataset supplies a
|
| 229 |
+
# `think` text, render it INSIDE the otherwise-empty
|
| 230 |
+
# `<think></think>` block of the official prompt template.
|
| 231 |
+
# This makes training distribution match inference-time
|
| 232 |
+
# `--think-mode`, where the model autoregressively fills the
|
| 233 |
+
# same window with ~250-400 reasoning tokens. Without this,
|
| 234 |
+
# the gen tower sees an unfamiliar prefix length/content
|
| 235 |
+
# at inference and the LoRA delta is calibrated against
|
| 236 |
+
# the wrong cond-KV distribution.
|
| 237 |
+
if s.think:
|
| 238 |
+
append_text = f"<think>\n{s.think}\n</think>\n\n<img>"
|
| 239 |
+
else:
|
| 240 |
+
append_text = self._gen_append
|
| 241 |
+
prompts.append(
|
| 242 |
+
self._build_t2i_query(
|
| 243 |
+
raw_prompt,
|
| 244 |
+
system_message=self._sys_msg_for_gen,
|
| 245 |
+
append_text=append_text,
|
| 246 |
+
)
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
prompts.append(" " if drop_text else raw_prompt)
|
| 250 |
+
return prompts
|
| 251 |
+
|
| 252 |
@staticmethod
|
| 253 |
def _check_image_hw(image_hw: tuple[int, int]) -> None:
|
| 254 |
H, W = image_hw
|
|
|
|
| 261 |
# ------------------------------------------------------------------ #
|
| 262 |
# Main entry #
|
| 263 |
# ------------------------------------------------------------------ #
|
| 264 |
+
def __call__(
|
| 265 |
+
self,
|
| 266 |
+
samples: list[T2ISample],
|
| 267 |
+
*,
|
| 268 |
+
condition_modes: list[str] | None = None,
|
| 269 |
+
) -> dict[str, torch.Tensor]:
|
| 270 |
cfg = self.cfg
|
| 271 |
if cfg.enforce_batch_one and len(samples) != 1:
|
| 272 |
raise ValueError(
|
|
|
|
| 299 |
# 1) text → ids + per-sample lengths. With enforce_batch_one we know
|
| 300 |
# `len(samples) == 1` so no batch padding is applied — `L_text` is
|
| 301 |
# exactly this prompt's length (matches upstream `_build_t2i_text_inputs`).
|
| 302 |
+
if condition_modes is None:
|
| 303 |
+
condition_modes = self._sample_condition_modes(len(samples))
|
|
|
|
|
|
|
|
|
|
| 304 |
else:
|
| 305 |
+
condition_modes = list(condition_modes)
|
| 306 |
+
if len(condition_modes) != len(samples):
|
| 307 |
+
raise ValueError(
|
| 308 |
+
f"condition_modes length {len(condition_modes)} != samples length {len(samples)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
)
|
| 310 |
+
prompts = self._render_prompts(samples, condition_modes)
|
| 311 |
+
prefix_cache_keys = [self._prefix_cache_key(m) for m in condition_modes]
|
| 312 |
input_ids, text_lens = self._tokenize(prompts)
|
| 313 |
B = input_ids.shape[0]
|
| 314 |
L_text = input_ids.shape[1]
|
|
|
|
| 379 |
"noisy_pixel_values": noisy_pixel_values, # (B, 3, H, W)
|
| 380 |
"noisy_grid_hw": noisy_grid_hw, # (B, 2)
|
| 381 |
"noise_scale": noise_scale, # (B,) or None
|
| 382 |
+
"cond_drop_text": torch.tensor(
|
| 383 |
+
[m in ("text", "text_image") for m in condition_modes],
|
| 384 |
+
dtype=torch.bool,
|
| 385 |
+
),
|
| 386 |
+
"cond_drop_mode": condition_modes,
|
| 387 |
+
"prefix_cache_key": prefix_cache_keys,
|
| 388 |
"sample_ids": [s.sample_id for s in samples],
|
| 389 |
"text_lens": text_lens,
|
| 390 |
"token_hw": (token_h, token_w),
|
train_u1/data/datasets.py
CHANGED
|
@@ -141,6 +141,7 @@ class PairedFolderT2IDataset(Dataset):
|
|
| 141 |
prompt_template: str | None = None,
|
| 142 |
image_extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp"),
|
| 143 |
snap_bucket: bool = False,
|
|
|
|
| 144 |
):
|
| 145 |
self.folder = Path(folder)
|
| 146 |
if not self.folder.is_dir():
|
|
@@ -148,6 +149,7 @@ class PairedFolderT2IDataset(Dataset):
|
|
| 148 |
self.cap_max_pixels = cap_max_pixels
|
| 149 |
self.prompt_template = prompt_template
|
| 150 |
self.snap_bucket = snap_bucket
|
|
|
|
| 151 |
|
| 152 |
pairs: list[tuple[Path, Path, str]] = []
|
| 153 |
for ext in image_extensions:
|
|
@@ -172,11 +174,13 @@ class PairedFolderT2IDataset(Dataset):
|
|
| 172 |
with open(txt_path, encoding="utf-8") as f:
|
| 173 |
raw = f.read()
|
| 174 |
caption, think_text = parse_caption_and_think(raw)
|
|
|
|
|
|
|
| 175 |
if self.prompt_template:
|
| 176 |
caption = self.prompt_template.format(caption=caption)
|
| 177 |
# Legacy fallback: `<id>.think.txt` separate sidecar (deprecated;
|
| 178 |
# `parse_caption_and_think` is the preferred path).
|
| 179 |
-
if think_text is None:
|
| 180 |
think_path = img_path.with_suffix(".think.txt")
|
| 181 |
if think_path.is_file():
|
| 182 |
with open(think_path, encoding="utf-8") as f:
|
|
@@ -257,6 +261,7 @@ class ArrowT2IDataset(Dataset):
|
|
| 257 |
cap_max_pixels: int | None = None,
|
| 258 |
prompt_template: str | None = None,
|
| 259 |
snap_bucket: bool = False,
|
|
|
|
| 260 |
):
|
| 261 |
try:
|
| 262 |
import pyarrow.parquet as pq # noqa: F401
|
|
@@ -269,6 +274,7 @@ class ArrowT2IDataset(Dataset):
|
|
| 269 |
self.cap_max_pixels = cap_max_pixels
|
| 270 |
self.prompt_template = prompt_template
|
| 271 |
self.snap_bucket = snap_bucket
|
|
|
|
| 272 |
self._table = None # lazy-loaded
|
| 273 |
self._n: int | None = None
|
| 274 |
|
|
@@ -311,6 +317,8 @@ class ArrowT2IDataset(Dataset):
|
|
| 311 |
sample_id = row["sample_id"][0]
|
| 312 |
caption = row["caption"][0]
|
| 313 |
think = (row.get("think") or [None])[0] or None
|
|
|
|
|
|
|
| 314 |
|
| 315 |
if "image" in self._table.column_names and row["image"][0] is not None:
|
| 316 |
from PIL import Image
|
|
|
|
| 141 |
prompt_template: str | None = None,
|
| 142 |
image_extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp"),
|
| 143 |
snap_bucket: bool = False,
|
| 144 |
+
use_think_labels: bool = True,
|
| 145 |
):
|
| 146 |
self.folder = Path(folder)
|
| 147 |
if not self.folder.is_dir():
|
|
|
|
| 149 |
self.cap_max_pixels = cap_max_pixels
|
| 150 |
self.prompt_template = prompt_template
|
| 151 |
self.snap_bucket = snap_bucket
|
| 152 |
+
self.use_think_labels = use_think_labels
|
| 153 |
|
| 154 |
pairs: list[tuple[Path, Path, str]] = []
|
| 155 |
for ext in image_extensions:
|
|
|
|
| 174 |
with open(txt_path, encoding="utf-8") as f:
|
| 175 |
raw = f.read()
|
| 176 |
caption, think_text = parse_caption_and_think(raw)
|
| 177 |
+
if not self.use_think_labels:
|
| 178 |
+
think_text = None
|
| 179 |
if self.prompt_template:
|
| 180 |
caption = self.prompt_template.format(caption=caption)
|
| 181 |
# Legacy fallback: `<id>.think.txt` separate sidecar (deprecated;
|
| 182 |
# `parse_caption_and_think` is the preferred path).
|
| 183 |
+
if self.use_think_labels and think_text is None:
|
| 184 |
think_path = img_path.with_suffix(".think.txt")
|
| 185 |
if think_path.is_file():
|
| 186 |
with open(think_path, encoding="utf-8") as f:
|
|
|
|
| 261 |
cap_max_pixels: int | None = None,
|
| 262 |
prompt_template: str | None = None,
|
| 263 |
snap_bucket: bool = False,
|
| 264 |
+
use_think_labels: bool = True,
|
| 265 |
):
|
| 266 |
try:
|
| 267 |
import pyarrow.parquet as pq # noqa: F401
|
|
|
|
| 274 |
self.cap_max_pixels = cap_max_pixels
|
| 275 |
self.prompt_template = prompt_template
|
| 276 |
self.snap_bucket = snap_bucket
|
| 277 |
+
self.use_think_labels = use_think_labels
|
| 278 |
self._table = None # lazy-loaded
|
| 279 |
self._n: int | None = None
|
| 280 |
|
|
|
|
| 317 |
sample_id = row["sample_id"][0]
|
| 318 |
caption = row["caption"][0]
|
| 319 |
think = (row.get("think") or [None])[0] or None
|
| 320 |
+
if not self.use_think_labels:
|
| 321 |
+
think = None
|
| 322 |
|
| 323 |
if "image" in self._table.column_names and row["image"][0] is not None:
|
| 324 |
from PIL import Image
|
train_u1/model/lora.py
CHANGED
|
@@ -17,6 +17,16 @@ Wrapped modules supported (per-module rank/alpha/enable independently):
|
|
| 17 |
Patch decoder (×2):
|
| 18 |
fm_modules.fm_head.0 fm_modules.fm_head.2
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
The adapter is implemented as `y = base(x) + scaling * lora_up(lora_down(x))`
|
| 21 |
with `scaling = alpha / r`. Initial state: `lora_down` kaiming uniform,
|
| 22 |
`lora_up` zeros — so the wrapped module starts at exactly the base output.
|
|
@@ -50,17 +60,42 @@ import torch.nn as nn
|
|
| 50 |
ATTN_TARGETS = ("q_proj_mot_gen", "k_proj_mot_gen", "v_proj_mot_gen", "o_proj_mot_gen")
|
| 51 |
MLP_TARGETS = ("mlp_mot_gen.gate_proj", "mlp_mot_gen.up_proj", "mlp_mot_gen.down_proj")
|
| 52 |
FM_HEAD_TARGETS = ("fm_modules.fm_head.0", "fm_modules.fm_head.2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
# Convenience expansions used by the CLI parser (`attn`, `mlp`, `fm_head`).
|
| 57 |
TARGET_GROUPS: dict[str, tuple[str, ...]] = {
|
| 58 |
"attn": ATTN_TARGETS,
|
| 59 |
"mlp": MLP_TARGETS,
|
| 60 |
"fm_head": FM_HEAD_TARGETS,
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
}
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# --------------------------------------------------------------------------- #
|
| 66 |
# Spec types #
|
|
@@ -71,7 +106,8 @@ TARGET_GROUPS: dict[str, tuple[str, ...]] = {
|
|
| 71 |
class LoRASpec:
|
| 72 |
"""Per-target LoRA configuration.
|
| 73 |
|
| 74 |
-
`target` is one of `ALL_KNOWN_TARGETS` (verbatim module-name suffix)
|
|
|
|
| 75 |
`r` is the LoRA rank. `alpha` is the LoRA alpha; `scaling = alpha / r`.
|
| 76 |
`dropout` applies to the input before `lora_down`.
|
| 77 |
`enabled=False` lets a preset entry be turned off without removing it.
|
|
@@ -84,7 +120,7 @@ class LoRASpec:
|
|
| 84 |
enabled: bool = True
|
| 85 |
|
| 86 |
def __post_init__(self) -> None:
|
| 87 |
-
if self.target
|
| 88 |
raise ValueError(
|
| 89 |
f"unknown LoRA target {self.target!r}. "
|
| 90 |
f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
|
|
@@ -219,6 +255,50 @@ def _walk_mlp_targets(model: nn.Module, target_name: str):
|
|
| 219 |
yield sub, leaf, idx
|
| 220 |
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
def _walk_fm_head_targets(model: nn.Module, target_name: str):
|
| 223 |
"""Yield `(parent, attr, idx)` for each fm_head linear matching target_name.
|
| 224 |
|
|
@@ -251,6 +331,10 @@ def _resolve_target_walker(target: str):
|
|
| 251 |
return _walk_attn_targets
|
| 252 |
if target in MLP_TARGETS:
|
| 253 |
return _walk_mlp_targets
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
if target in FM_HEAD_TARGETS:
|
| 255 |
return _walk_fm_head_targets
|
| 256 |
raise ValueError(f"no walker for target {target!r}")
|
|
@@ -327,10 +411,7 @@ def apply_lora_specs(
|
|
| 327 |
# --------------------------------------------------------------------------- #
|
| 328 |
|
| 329 |
|
| 330 |
-
_SPEC_TOK_RE = re.compile(
|
| 331 |
-
r"^(?P<target>[A-Za-z0-9_.]+)"
|
| 332 |
-
r"(?:=(?P<body>.+))?$"
|
| 333 |
-
)
|
| 334 |
_RA_RE = re.compile(r"^r(?P<r>\d+)(?:a(?P<alpha>\d+(?:\.\d+)?))?$")
|
| 335 |
|
| 336 |
|
|
@@ -343,14 +424,16 @@ def parse_lora_spec_str(s: str) -> list[LoRASpec]:
|
|
| 343 |
- `off` disable a target (overrides earlier entries)
|
| 344 |
- `r=N,a=M` alternative comma form (more readable)
|
| 345 |
|
| 346 |
-
Group expansions: `attn`, `mlp`, `fm_head`, `
|
| 347 |
-
|
|
|
|
| 348 |
|
| 349 |
Examples::
|
| 350 |
|
| 351 |
attn=r64a64;mlp=r64a64
|
| 352 |
q_proj_mot_gen=r128a128; k_proj_mot_gen=r128a128
|
| 353 |
all=r64a64; mlp_mot_gen.down_proj=off
|
|
|
|
| 354 |
fm_head=r=128,a=128
|
| 355 |
"""
|
| 356 |
specs: dict[str, LoRASpec] = {}
|
|
@@ -366,7 +449,7 @@ def parse_lora_spec_str(s: str) -> list[LoRASpec]:
|
|
| 366 |
|
| 367 |
targets = TARGET_GROUPS.get(target, (target,))
|
| 368 |
for t in targets:
|
| 369 |
-
if
|
| 370 |
raise ValueError(
|
| 371 |
f"unknown LoRA target {t!r}. "
|
| 372 |
f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
|
|
@@ -434,8 +517,22 @@ LORA_PRESETS: dict[str, str] = {
|
|
| 434 |
# Attn + MLP only (no fm_head); equivalent to our pre-v16c v15a recipe.
|
| 435 |
"attn_mlp": "attn=r64a64;mlp=r64a64",
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
# Exact upstream 8-step distill LoRA shape (rank 128 alpha 128).
|
| 438 |
"official_r128": "attn=r128a128;mlp=r128a128;fm_head=r128a128",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
}
|
| 440 |
|
| 441 |
|
|
|
|
| 17 |
Patch decoder (×2):
|
| 18 |
fm_modules.fm_head.0 fm_modules.fm_head.2
|
| 19 |
|
| 20 |
+
Experimental A3B/MoE target grammar (requires an A3B runtime whose modules
|
| 21 |
+
match the public checkpoint names):
|
| 22 |
+
|
| 23 |
+
Generation MoE experts:
|
| 24 |
+
mlp_mot_gen.experts.*.gate_proj
|
| 25 |
+
mlp_mot_gen.experts.*.up_proj
|
| 26 |
+
mlp_mot_gen.experts.*.down_proj
|
| 27 |
+
Generation MoE router:
|
| 28 |
+
mlp_mot_gen.gate
|
| 29 |
+
|
| 30 |
The adapter is implemented as `y = base(x) + scaling * lora_up(lora_down(x))`
|
| 31 |
with `scaling = alpha / r`. Initial state: `lora_down` kaiming uniform,
|
| 32 |
`lora_up` zeros — so the wrapped module starts at exactly the base output.
|
|
|
|
| 60 |
ATTN_TARGETS = ("q_proj_mot_gen", "k_proj_mot_gen", "v_proj_mot_gen", "o_proj_mot_gen")
|
| 61 |
MLP_TARGETS = ("mlp_mot_gen.gate_proj", "mlp_mot_gen.up_proj", "mlp_mot_gen.down_proj")
|
| 62 |
FM_HEAD_TARGETS = ("fm_modules.fm_head.0", "fm_modules.fm_head.2")
|
| 63 |
+
GEN_MOE_MLP_TARGETS = (
|
| 64 |
+
"mlp_mot_gen.experts.*.gate_proj",
|
| 65 |
+
"mlp_mot_gen.experts.*.up_proj",
|
| 66 |
+
"mlp_mot_gen.experts.*.down_proj",
|
| 67 |
+
)
|
| 68 |
+
GEN_MOE_ROUTER_TARGETS = ("mlp_mot_gen.gate",)
|
| 69 |
+
GEN_MOE_TARGETS = GEN_MOE_MLP_TARGETS + GEN_MOE_ROUTER_TARGETS
|
| 70 |
|
| 71 |
+
DENSE_KNOWN_TARGETS = ATTN_TARGETS + MLP_TARGETS + FM_HEAD_TARGETS
|
| 72 |
+
ALL_KNOWN_TARGETS = DENSE_KNOWN_TARGETS + GEN_MOE_TARGETS
|
| 73 |
|
| 74 |
# Convenience expansions used by the CLI parser (`attn`, `mlp`, `fm_head`).
|
| 75 |
TARGET_GROUPS: dict[str, tuple[str, ...]] = {
|
| 76 |
"attn": ATTN_TARGETS,
|
| 77 |
"mlp": MLP_TARGETS,
|
| 78 |
"fm_head": FM_HEAD_TARGETS,
|
| 79 |
+
# A3B generation-side MoE aliases. These are deliberately separate from
|
| 80 |
+
# `mlp`/`all` so existing 8B configs remain byte-for-byte semantic matches.
|
| 81 |
+
"gen_moe_mlp": GEN_MOE_MLP_TARGETS,
|
| 82 |
+
"moe_mlp": GEN_MOE_MLP_TARGETS,
|
| 83 |
+
"gen_moe_router": GEN_MOE_ROUTER_TARGETS,
|
| 84 |
+
"moe_router": GEN_MOE_ROUTER_TARGETS,
|
| 85 |
+
"gen_moe_all": ATTN_TARGETS + GEN_MOE_TARGETS + FM_HEAD_TARGETS,
|
| 86 |
+
"moe_all": ATTN_TARGETS + GEN_MOE_TARGETS + FM_HEAD_TARGETS,
|
| 87 |
+
"all": DENSE_KNOWN_TARGETS,
|
| 88 |
}
|
| 89 |
|
| 90 |
+
_GEN_MOE_EXPERT_TARGET_RE = re.compile(
|
| 91 |
+
r"^mlp_mot_gen\.experts\.(?P<expert>\*|\d+)\."
|
| 92 |
+
r"(?P<leaf>gate_proj|up_proj|down_proj)$"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _is_known_target(target: str) -> bool:
|
| 97 |
+
return target in ALL_KNOWN_TARGETS or _GEN_MOE_EXPERT_TARGET_RE.match(target) is not None
|
| 98 |
+
|
| 99 |
|
| 100 |
# --------------------------------------------------------------------------- #
|
| 101 |
# Spec types #
|
|
|
|
| 106 |
class LoRASpec:
|
| 107 |
"""Per-target LoRA configuration.
|
| 108 |
|
| 109 |
+
`target` is one of `ALL_KNOWN_TARGETS` (verbatim module-name suffix) or
|
| 110 |
+
an A3B MoE expert target like `mlp_mot_gen.experts.0.gate_proj`.
|
| 111 |
`r` is the LoRA rank. `alpha` is the LoRA alpha; `scaling = alpha / r`.
|
| 112 |
`dropout` applies to the input before `lora_down`.
|
| 113 |
`enabled=False` lets a preset entry be turned off without removing it.
|
|
|
|
| 120 |
enabled: bool = True
|
| 121 |
|
| 122 |
def __post_init__(self) -> None:
|
| 123 |
+
if not _is_known_target(self.target):
|
| 124 |
raise ValueError(
|
| 125 |
f"unknown LoRA target {self.target!r}. "
|
| 126 |
f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
|
|
|
|
| 255 |
yield sub, leaf, idx
|
| 256 |
|
| 257 |
|
| 258 |
+
def _walk_moe_mlp_targets(model: nn.Module, target_name: str):
|
| 259 |
+
"""Yield generation-side MoE expert projections for A3B-style modules.
|
| 260 |
+
|
| 261 |
+
`target_name` is `mlp_mot_gen.experts.*.gate_proj` or a single expert
|
| 262 |
+
target such as `mlp_mot_gen.experts.7.down_proj`. The walker is intentionally
|
| 263 |
+
shape/runtime agnostic: if the loaded model has no `experts` ModuleList, it
|
| 264 |
+
yields nothing so 8B dense configs are unaffected.
|
| 265 |
+
"""
|
| 266 |
+
m = _GEN_MOE_EXPERT_TARGET_RE.match(target_name)
|
| 267 |
+
if m is None:
|
| 268 |
+
raise ValueError(f"invalid MoE expert target {target_name!r}")
|
| 269 |
+
expert_selector = m.group("expert")
|
| 270 |
+
leaf = m.group("leaf")
|
| 271 |
+
|
| 272 |
+
layers = model.language_model.model.layers
|
| 273 |
+
for layer_idx, layer in enumerate(layers):
|
| 274 |
+
sub = getattr(layer, "mlp_mot_gen", None)
|
| 275 |
+
experts = getattr(sub, "experts", None)
|
| 276 |
+
if experts is None:
|
| 277 |
+
continue
|
| 278 |
+
if expert_selector == "*":
|
| 279 |
+
expert_indices = range(len(experts))
|
| 280 |
+
else:
|
| 281 |
+
expert_idx = int(expert_selector)
|
| 282 |
+
if expert_idx >= len(experts):
|
| 283 |
+
continue
|
| 284 |
+
expert_indices = (expert_idx,)
|
| 285 |
+
for expert_idx in expert_indices:
|
| 286 |
+
expert = experts[expert_idx]
|
| 287 |
+
if hasattr(expert, leaf):
|
| 288 |
+
yield expert, leaf, layer_idx
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def _walk_moe_router_targets(model: nn.Module, target_name: str):
|
| 292 |
+
"""Yield generation-side MoE router gates (`mlp_mot_gen.gate`) per layer."""
|
| 293 |
+
if target_name != "mlp_mot_gen.gate":
|
| 294 |
+
raise ValueError(f"invalid MoE router target {target_name!r}")
|
| 295 |
+
layers = model.language_model.model.layers
|
| 296 |
+
for idx, layer in enumerate(layers):
|
| 297 |
+
sub = getattr(layer, "mlp_mot_gen", None)
|
| 298 |
+
if sub is not None and hasattr(sub, "gate"):
|
| 299 |
+
yield sub, "gate", idx
|
| 300 |
+
|
| 301 |
+
|
| 302 |
def _walk_fm_head_targets(model: nn.Module, target_name: str):
|
| 303 |
"""Yield `(parent, attr, idx)` for each fm_head linear matching target_name.
|
| 304 |
|
|
|
|
| 331 |
return _walk_attn_targets
|
| 332 |
if target in MLP_TARGETS:
|
| 333 |
return _walk_mlp_targets
|
| 334 |
+
if _GEN_MOE_EXPERT_TARGET_RE.match(target):
|
| 335 |
+
return _walk_moe_mlp_targets
|
| 336 |
+
if target in GEN_MOE_ROUTER_TARGETS:
|
| 337 |
+
return _walk_moe_router_targets
|
| 338 |
if target in FM_HEAD_TARGETS:
|
| 339 |
return _walk_fm_head_targets
|
| 340 |
raise ValueError(f"no walker for target {target!r}")
|
|
|
|
| 411 |
# --------------------------------------------------------------------------- #
|
| 412 |
|
| 413 |
|
| 414 |
+
_SPEC_TOK_RE = re.compile(r"^(?P<target>[A-Za-z0-9_.*]+)(?:=(?P<body>.+))?$")
|
|
|
|
|
|
|
|
|
|
| 415 |
_RA_RE = re.compile(r"^r(?P<r>\d+)(?:a(?P<alpha>\d+(?:\.\d+)?))?$")
|
| 416 |
|
| 417 |
|
|
|
|
| 424 |
- `off` disable a target (overrides earlier entries)
|
| 425 |
- `r=N,a=M` alternative comma form (more readable)
|
| 426 |
|
| 427 |
+
Group expansions: `attn`, `mlp`, `fm_head`, `gen_moe_mlp`,
|
| 428 |
+
`gen_moe_router`, `gen_moe_all`, `all` expand to their member targets,
|
| 429 |
+
all sharing the same body.
|
| 430 |
|
| 431 |
Examples::
|
| 432 |
|
| 433 |
attn=r64a64;mlp=r64a64
|
| 434 |
q_proj_mot_gen=r128a128; k_proj_mot_gen=r128a128
|
| 435 |
all=r64a64; mlp_mot_gen.down_proj=off
|
| 436 |
+
gen_moe_mlp=r8a8; gen_moe_router=r4a4
|
| 437 |
fm_head=r=128,a=128
|
| 438 |
"""
|
| 439 |
specs: dict[str, LoRASpec] = {}
|
|
|
|
| 449 |
|
| 450 |
targets = TARGET_GROUPS.get(target, (target,))
|
| 451 |
for t in targets:
|
| 452 |
+
if not _is_known_target(t):
|
| 453 |
raise ValueError(
|
| 454 |
f"unknown LoRA target {t!r}. "
|
| 455 |
f"valid: {ALL_KNOWN_TARGETS} or groups {list(TARGET_GROUPS)}"
|
|
|
|
| 517 |
# Attn + MLP only (no fm_head); equivalent to our pre-v16c v15a recipe.
|
| 518 |
"attn_mlp": "attn=r64a64;mlp=r64a64",
|
| 519 |
|
| 520 |
+
# **Safe presets** that explicitly drop fm_head from the trained surface.
|
| 521 |
+
# The technical report's grid-artifact discussion attributes artifacts to
|
| 522 |
+
# the final FFN + MLP head independently modelling disjoint 32×32 patches,
|
| 523 |
+
# and notes that the official T2I RL stage freezes the generation-branch
|
| 524 |
+
# MLP head and the last three transformer layers for exactly this reason.
|
| 525 |
+
# Use these when you want to avoid touching the head at all.
|
| 526 |
+
"attn_only_no_head": "attn=r64a64",
|
| 527 |
+
"attn_mlp_no_head": "attn=r64a64;mlp=r64a64",
|
| 528 |
+
|
| 529 |
# Exact upstream 8-step distill LoRA shape (rank 128 alpha 128).
|
| 530 |
"official_r128": "attn=r128a128;mlp=r128a128;fm_head=r128a128",
|
| 531 |
+
|
| 532 |
+
# Experimental A3B/MoE coverage. Small ranks are intentional: covering all
|
| 533 |
+
# 48 layers × 32 gen experts × 3 projections gets large quickly.
|
| 534 |
+
"a3b_moe_r8": "attn=r8a8;gen_moe_mlp=r8a8;fm_head=r8a8",
|
| 535 |
+
"a3b_moe_router_r8": "gen_moe_router=r8a8",
|
| 536 |
}
|
| 537 |
|
| 538 |
|
train_u1/model/losses.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
"""Training losses for the FM step.
|
| 2 |
|
| 3 |
-
Two primaries (report
|
| 4 |
-
- `fm_loss_x0(x_pred, x0_patch)` —
|
| 5 |
-
- `fm_loss_v(v_pred, v_target)` — velocity
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
Plus
|
| 8 |
-
|
|
|
|
| 9 |
"""
|
| 10 |
from __future__ import annotations
|
| 11 |
|
|
@@ -35,6 +38,65 @@ def fm_loss_v_huber(v_pred: torch.Tensor, v_target: torch.Tensor, delta: float =
|
|
| 35 |
return F.huber_loss(v_pred.float(), v_target.float(), delta=delta)
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def text_ce_guardrail(
|
| 39 |
logits: torch.Tensor,
|
| 40 |
labels: torch.Tensor,
|
|
|
|
| 1 |
"""Training losses for the FM step.
|
| 2 |
|
| 3 |
+
Two primaries (report Eq. (5) / Table 2):
|
| 4 |
+
- `fm_loss_x0(x_pred, x0_patch)` — MSE on clean patches (legacy MVP default).
|
| 5 |
+
- `fm_loss_v(v_pred, v_target)` — MSE on velocity (matches the official
|
| 6 |
+
x-predict + v-loss training objective; equivalent to
|
| 7 |
+
`MSE(x_pred - x0) / (1 - t)^2`, i.e. an x0-MSE re-weighted by `(1-t)^-2`).
|
| 8 |
|
| 9 |
+
Plus Huber variants and an `fm_loss` dispatcher that selects by `loss_type`.
|
| 10 |
+
|
| 11 |
+
CE guardrail kept for the Phase C unified-training scenario.
|
| 12 |
"""
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 38 |
return F.huber_loss(v_pred.float(), v_target.float(), delta=delta)
|
| 39 |
|
| 40 |
|
| 41 |
+
def compute_v_target(
|
| 42 |
+
x0_patch: torch.Tensor,
|
| 43 |
+
z_t: torch.Tensor,
|
| 44 |
+
t: torch.Tensor,
|
| 45 |
+
*,
|
| 46 |
+
t_eps: float = 1e-3,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
"""Closed-form velocity target for rectified-flow / linear-z_t.
|
| 49 |
+
|
| 50 |
+
Report Eq. (5): `v* = (x0 - z_t) / (1 - t)` with `z_t = t x0 + (1-t) eps`.
|
| 51 |
+
`t` is expected to be a (B,) tensor — we broadcast to (B, 1, 1) to match
|
| 52 |
+
the (B, N, D) patch tensors.
|
| 53 |
+
"""
|
| 54 |
+
if x0_patch.shape != z_t.shape:
|
| 55 |
+
raise ValueError(f"shape mismatch x0 {x0_patch.shape} vs z_t {z_t.shape}")
|
| 56 |
+
t = t.to(x0_patch.dtype)
|
| 57 |
+
while t.dim() < x0_patch.dim():
|
| 58 |
+
t = t.unsqueeze(-1)
|
| 59 |
+
denom = (1.0 - t).clamp(min=t_eps)
|
| 60 |
+
return (x0_patch - z_t) / denom
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# --------------------------------------------------------------------------- #
|
| 64 |
+
# Dispatcher #
|
| 65 |
+
# --------------------------------------------------------------------------- #
|
| 66 |
+
|
| 67 |
+
VALID_LOSS_TYPES = ("x0", "v", "x0_huber", "v_huber")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def fm_loss(
|
| 71 |
+
*,
|
| 72 |
+
loss_type: str,
|
| 73 |
+
x_pred: torch.Tensor,
|
| 74 |
+
x0_patch: torch.Tensor,
|
| 75 |
+
v_pred: torch.Tensor | None = None,
|
| 76 |
+
v_target: torch.Tensor | None = None,
|
| 77 |
+
huber_delta: float = 1.0,
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
"""Single entry point selecting one of the four FM losses.
|
| 80 |
+
|
| 81 |
+
`x0` / `x0_huber` only need `x_pred` + `x0_patch`.
|
| 82 |
+
`v` / `v_huber` require `v_pred` + `v_target` (caller computes them via
|
| 83 |
+
`compute_v_target` from the same `(x0, z_t, t)` used to build the batch).
|
| 84 |
+
"""
|
| 85 |
+
if loss_type == "x0":
|
| 86 |
+
return fm_loss_x0(x_pred, x0_patch)
|
| 87 |
+
if loss_type == "x0_huber":
|
| 88 |
+
return fm_loss_x0_huber(x_pred, x0_patch, delta=huber_delta)
|
| 89 |
+
if loss_type == "v":
|
| 90 |
+
if v_pred is None or v_target is None:
|
| 91 |
+
raise ValueError("loss_type='v' requires v_pred and v_target")
|
| 92 |
+
return fm_loss_v(v_pred, v_target)
|
| 93 |
+
if loss_type == "v_huber":
|
| 94 |
+
if v_pred is None or v_target is None:
|
| 95 |
+
raise ValueError("loss_type='v_huber' requires v_pred and v_target")
|
| 96 |
+
return fm_loss_v_huber(v_pred, v_target, delta=huber_delta)
|
| 97 |
+
raise ValueError(f"unknown loss_type {loss_type!r}; valid: {VALID_LOSS_TYPES}")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
def text_ce_guardrail(
|
| 101 |
logits: torch.Tensor,
|
| 102 |
labels: torch.Tensor,
|