| 代码解读(关键设计决策) |
| utils_ursa_inputs.py |
| build_ursa_inputs(transformer, txt_ids, visual_tokens, latents_shape, device) |
| 严格复刻 URSAPipeline.__call__ 的 token 拼接逻辑: |
| img_ids = pad(latents_flat + lm_vocab_size, (1,0), value=bov_token_id)input_ids = cat([txt_ids, img_ids], dim=1)blk_pos = flex_rope.get_pos(latents_shape, L)rope_pos = cat([txt_pos, blk_pos[0]]).unsqueeze(0).expand(B,-1,-1) |
| extract_visual_logits(logits, N, K) |
| 坑 1 防护:z = logits[:, -(N+1):-1](causal slice),然后根据最后一维是否等于 K 决定是否再切 slice。 |
| sample_t_curriculum — 前 10k 步用 t = 1-(1-u)^2 偏大,之后恢复均匀采样。 |
| train_onestep_ursa_dimo.py 训练循环 |
| 每一步的 9 个 stage 对应 DiMO 论文的完整流程: |
| Stage 操作 梯度 |
| 1-2 tokenize + 采样 x_init (80% uniform / 20% corrupt) 无 |
| 3 student 在 x_init 上 1-step forward → x_hat, logp, H ✅ student |
| 4 add_noise(x_hat, t) → x_t 无(离散采样截断) |
| 5 teacher 在 x_t → p_T 无 (no_grad) |
| 6 aux 在 x_t → Jeffrey(p_T, p_A) → backward → aux update ✅ aux only |
| 7 student 在 x_t → KL(p_T ‖ p_S_t) ✅ student |
| 8 REINFORCE: r=-loss_aux, adv=r-EMA, loss_pg=-(adv·logp) ✅ student (via logp) |
| 9 L_s = λ_pg·loss_pg + λ_kd·loss_kd - λ_ent·H → student update ✅ student |
| 运行命令示例 |
| 端到端冒烟测试(单卡,17帧256×256,2000步): |
| python scripts/train_onestep_ursa_dimo.py \ |
| --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ |
| --prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \ |
| --num_frames 17 --height 256 --width 256 \ |
| --batch_size 1 --num_steps 2000 \ |
| --log_every 50 --save_every 500 \ |
| --out_dir ./outputs/dimo_test |
| |
| 评估(1-step student vs 25-step teacher): |
| python scripts/eval_onestep_ursa.py \ |
| --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ |
| --student_ckpt ./outputs/dimo_test/final/student.pt \ |
| --num_frames 17 --height 256 --width 256 \ |
| --teacher_steps 25 \ |
| --out_dir ./outputs/eval |
| |
| 扩展到完整分辨率(49帧 320×512): |
| python scripts/train_onestep_ursa_dimo.py \ |
| --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ |
| --prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \ |
| --num_frames 49 --height 320 --width 512 \ |
| --batch_size 2 --num_steps 50000 \ |
| --lambda_ent 0.01 --t_curriculum_steps 10000 \ |
| --mixed_precision bf16 --out_dir ./outputs/dimo_full |
| |
| 三大稳定性机制(缺一不可) |
| t curriculum — 前 10k 步 t 偏大,teacher 分布更尖锐,KD 信号更强,避免早期 student 随机游走 |
| p_init mixing — 20% batch 用 corrupt(x_hat_prev, r=0.2),让 student 学会"一步修复" |
| 熵正则 λ_ent — 初始 0.01,若检测到 tok_entropy 下降就升到 0.05 |
| |
|
|
| 8 卡启动命令 |
| accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config=./configs/distill_dimo.yaml experiment.output_dir=./experiments/distill_dimo distill.teacher_ckpt=/gfs/space/private/fengzl/World_Model/URSA-1.7B distill.prompt_source=/gfs/space/private/fengzl/World_Model/Koala-36M-v1 distill.batch_size_per_gpu=1 |
|
|
| Smoke Test(50 步,保存 checkpoint) |
| accelerate launch --num_processes 8 --mixed_precision bf16 \ |
| scripts/train_distill_dimo.py \ |
| config="./configs/distill_dimo.yaml" \ |
| experiment.output_dir="./experiments/smoke" \ |
| distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \ |
| distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1" \ |
| training.max_train_steps=50 \ |
| experiment.save_every=50 |
| |
|
|
| 加载 student.pt 做 1-step 推理 |
| from diffnext.pipelines import URSAPipelineimport torchpipe = URSAPipeline.from_pretrained( "/path/to/URSA-1.7B-IBQ1024", torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")# 替换 transformer 权重为 studentstate = torch.load("experiments/distill_dimo/checkpoints/final/student.pt", map_location="cuda")pipe.transformer.load_state_dict(state, strict=True)# 1-step 生成(num_inference_steps=1)frames = pipe( prompt="a dog running on a beach", height=256, width=256, num_frames=17, num_inference_steps=1, guidance_scale=3.0,).frames |
|
|
|
|
| 最新 修改分辨率和cfg后 |
| accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ |
| --machine_rank 0 --num_machines 1 --num_processes 8 \ |
| scripts/train_distill_dimo.py \ |
| config="./configs/distill_dimo.yaml" \ |
| experiment.output_dir="./experiments/distill_dimo" \ |
| distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \ |
| distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1" |