File size: 4,779 Bytes
2ee4cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
代码解读(关键设计决策)
utils_ursa_inputs.py
build_ursa_inputs(transformer, txt_ids, visual_tokens, latents_shape, device)
严格复刻 URSAPipeline.__call__ 的 token 拼接逻辑:
img_ids = pad(latents_flat + lm_vocab_size, (1,0), value=bov_token_id)input_ids = cat([txt_ids, img_ids], dim=1)blk_pos = flex_rope.get_pos(latents_shape, L)rope_pos = cat([txt_pos, blk_pos[0]]).unsqueeze(0).expand(B,-1,-1)
extract_visual_logits(logits, N, K)
坑 1 防护:z = logits[:, -(N+1):-1](causal slice),然后根据最后一维是否等于 K 决定是否再切 slice。
sample_t_curriculum — 前 10k 步用 t = 1-(1-u)^2 偏大,之后恢复均匀采样。
train_onestep_ursa_dimo.py 训练循环
每一步的 9 个 stage 对应 DiMO 论文的完整流程:
Stage	操作	梯度
1-2	tokenize + 采样 x_init (80% uniform / 20% corrupt)	无
3	student 在 x_init 上 1-step forward → x_hat, logp, H	✅ student
4	add_noise(x_hat, t) → x_t	无(离散采样截断)
5	teacher 在 x_t → p_T	无 (no_grad)
6	aux 在 x_t → Jeffrey(p_T, p_A) → backward → aux update	✅ aux only
7	student 在 x_t → KL(p_T ‖ p_S_t)	✅ student
8	REINFORCE: r=-loss_aux, adv=r-EMA, loss_pg=-(adv·logp)	✅ student (via logp)
9	L_s = λ_pg·loss_pg + λ_kd·loss_kd - λ_ent·H → student update	✅ student
运行命令示例
端到端冒烟测试(单卡,17帧256×256,2000步):
python scripts/train_onestep_ursa_dimo.py \
    --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \
    --prompt_file  /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \
    --num_frames 17 --height 256 --width 256 \
    --batch_size 1  --num_steps 2000 \
    --log_every 50  --save_every 500 \
    --out_dir ./outputs/dimo_test

评估(1-step student vs 25-step teacher):
python scripts/eval_onestep_ursa.py \
    --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \
    --student_ckpt ./outputs/dimo_test/final/student.pt \
    --num_frames 17 --height 256 --width 256 \
    --teacher_steps 25 \
    --out_dir ./outputs/eval

扩展到完整分辨率(49帧 320×512):
python scripts/train_onestep_ursa_dimo.py \
    --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \
    --prompt_file  /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \
    --num_frames 49 --height 320 --width 512 \
    --batch_size 2  --num_steps 50000 \
    --lambda_ent 0.01  --t_curriculum_steps 10000 \
    --mixed_precision bf16 --out_dir ./outputs/dimo_full
    
三大稳定性机制(缺一不可)
t curriculum — 前 10k 步 t 偏大,teacher 分布更尖锐,KD 信号更强,避免早期 student 随机游走
p_init mixing — 20% batch 用 corrupt(x_hat_prev, r=0.2),让 student 学会"一步修复"
熵正则 λ_ent — 初始 0.01,若检测到 tok_entropy 下降就升到 0.05


8 卡启动命令
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config=./configs/distill_dimo.yaml experiment.output_dir=./experiments/distill_dimo distill.teacher_ckpt=/gfs/space/private/fengzl/World_Model/URSA-1.7B distill.prompt_source=/gfs/space/private/fengzl/World_Model/Koala-36M-v1 distill.batch_size_per_gpu=1

Smoke Test(50 步,保存 checkpoint)
accelerate launch --num_processes 8 --mixed_precision bf16 \
    scripts/train_distill_dimo.py \
    config="./configs/distill_dimo.yaml" \
    experiment.output_dir="./experiments/smoke" \
    distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \
    distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1" \
    training.max_train_steps=50 \
    experiment.save_every=50


加载 student.pt 做 1-step 推理
from diffnext.pipelines import URSAPipelineimport torchpipe = URSAPipeline.from_pretrained(    "/path/to/URSA-1.7B-IBQ1024", torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")# 替换 transformer 权重为 studentstate = torch.load("experiments/distill_dimo/checkpoints/final/student.pt", map_location="cuda")pipe.transformer.load_state_dict(state, strict=True)# 1-step 生成(num_inference_steps=1)frames = pipe(    prompt="a dog running on a beach",    height=256, width=256, num_frames=17,    num_inference_steps=1,    guidance_scale=3.0,).frames


最新 修改分辨率和cfg后
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \
    --machine_rank 0 --num_machines 1 --num_processes 8 \
    scripts/train_distill_dimo.py \
    config="./configs/distill_dimo.yaml" \
    experiment.output_dir="./experiments/distill_dimo" \
    distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \
    distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"