File size: 9,449 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Fine-tune ARB model on video/latent diffusion tasks using LoRA.

Freezes text/audio pipelines, adapts VideoHead + core MoE for
latent video diffusion fine-tuning. Uses pig-vae to encode training targets.

Designed for 8GB VRAM with batch_size=1.

Usage:
    python training/finetuning/diffusion.py \\
        --video-dir ./videos --steps 2000 --batch 1 \\
        --lora-rank 16 --run diffusion-finetune

Data format: directory of .mp4 files (will be encoded to latents via pig-vae).
"""
import os, sys, time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from torch.utils.tensorboard import SummaryWriter


def load_model(lora_rank=16, lora_alpha=32.0, max_moe_iters=1):
    """Build ARB model with VideoHead + LoRA, freeze text/audio."""
    from arbitor import ARBModel
    from training.finetuning.lora import apply_lora_to_model, count_lora_params

    model = ARBModel(
        enable_image=False, enable_audio=False,
        enable_vq=False, enable_graph=False,
        enable_memory_modules=False, enable_moe=True,
        max_moe_iters=max_moe_iters,
    ).cuda()

    target_modules = ['W_gate', 'W_transform', 'byte_head', 'router',
                      'shared_up', 'shared_expert_gate', 'shared_expert_up',
                      'video_head', 'diffusion_step', 'cross_attn',
                      'halt_unit', 'noise_embed']
    lora_layers = apply_lora_to_model(model, rank=lora_rank, alpha=lora_alpha,
                                       target_modules=target_modules)
    lora_p, total_p = count_lora_params(model)
    print(f"  LoRA trainable: {lora_p:,} params ({lora_p/1e6:.2f}M)", flush=True)
    return model, lora_layers


def load_video_data(video_dir, max_samples=100, frames=16, res=256):
    """Load video files from directory and encode to VAE latents.

    Uses pig-vae to convert video frames to latent space for training targets.
    Falls back to random latents if pig-vae is not available.
    """
    import glob, torchvision.io
    from arbitor.config import SPECIAL_VOCAB

    files = glob.glob(os.path.join(video_dir, "*.mp4")) + \
            glob.glob(os.path.join(video_dir, "*.avi"))

    if not files:
        print(f"  No video files found in {video_dir}", flush=True)
        print(f"  Using synthetic random latents for smoke testing", flush=True)
        return _generate_synthetic(frames, res, max_samples)

    print(f"  Found {len(files)} video files", flush=True)
    files = files[:max_samples]

    # Try loading pig-vae
    vae = None
    try:
        from arbitor.encoders.pig_vae import load_vae
        vae = load_vae(device='cuda', quantize='int8')
        print(f"  pig-vae loaded for encoding", flush=True)
    except Exception as e:
        print(f"  pig-vae not available: {e}", flush=True)
        print(f"  Using random latents (no video encoding)", flush=True)
        return _generate_synthetic(frames, res, min(max_samples, 50))

    data = []
    for f in files:
        try:
            video, _, _ = torchvision.io.read_video(f, pts_unit='sec')
            video = video.permute(3, 0, 1, 2).float() / 255.0
            video = video[:, :frames, :res, :res]

            if video.shape[1] < frames:
                continue

            video = video.unsqueeze(0).cuda()
            with torch.no_grad():
                latents = vae.encode(video).cpu()
            data.append(latents)
        except Exception as e:
            continue

    if not data:
        return _generate_synthetic(frames, res, 50)

    print(f"  Encoded {len(data)} videos to latent space", flush=True)
    return data


def _generate_synthetic(frames, res, count):
    """Fallback: generate random latent targets for testing."""
    data = []
    for _ in range(count):
        latents = torch.randn(1, 16, 1, 32, 32)
        data.append(latents)
    print(f"  Generated {count} synthetic latent targets", flush=True)
    return data


def _match_latents(target, pred):
    """Resize or pad target latents to the current VideoHead output shape."""
    if target.shape[0] == 1 and pred.shape[0] > 1:
        target = target.expand(pred.shape[0], -1, -1, -1, -1).contiguous()
    if target.shape[1] != pred.shape[1]:
        if target.shape[1] > pred.shape[1]:
            target = target[:, :pred.shape[1]]
        else:
            pad = target.new_zeros(target.shape[0], pred.shape[1] - target.shape[1], *target.shape[2:])
            target = torch.cat([target, pad], dim=1)
    if target.shape[2:] != pred.shape[2:]:
        target = torch.nn.functional.interpolate(
            target, size=pred.shape[2:], mode="trilinear", align_corners=False
        )
    return target


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="ARB video/diffusion fine-tuning")
    parser.add_argument("--video-dir", type=str, default=None,
                        help="Directory with .mp4/.avi files")
    parser.add_argument("--steps", type=int, default=2000)
    parser.add_argument("--batch", type=int, default=1)
    parser.add_argument("--accum", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--lora-alpha", type=float, default=32.0)
    parser.add_argument("--max-moe-iters", type=int, default=1)
    parser.add_argument("--run", type=str, default="diffusion-finetune")
    parser.add_argument("--eval-interval", type=int, default=100)
    parser.add_argument("--frames", type=int, default=8)
    parser.add_argument("--res", type=int, default=128)
    parser.add_argument("--max-samples", type=int, default=100)
    args = parser.parse_args()

    print("Building model with VideoHead + LoRA...", flush=True)
    model, lora_layers = load_model(args.lora_rank, args.lora_alpha, args.max_moe_iters)

    opt = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.lr, weight_decay=0.01
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.steps)

    if args.video_dir:
        data = load_video_data(args.video_dir, args.max_samples, args.frames, args.res)
    else:
        data = _generate_synthetic(args.frames, args.res, 100)

    n = int(0.8 * len(data))
    if len(data) > 1:
        n = min(max(1, n), len(data) - 1)
    train_data = data[:n] if n > 0 else data
    val_data = data[n:] if n < len(data) else data[:1]

    run_dir = f"models/checkpoints/{args.run}"
    os.makedirs(run_dir, exist_ok=True)
    writer = SummaryWriter(run_dir)
    step = 0
    best_val = float('inf')
    model.train()

    while step < args.steps:
        opt.zero_grad()
        accum_loss = 0.0

        for _ in range(args.accum):
            # Generate random text context for VideoHead conditioning
            text = torch.randint(0, 256, (args.batch, 10)).cuda()

            idx = torch.randint(0, len(train_data), (1,)).item()
            target_latents = train_data[idx].cuda()
            if target_latents.shape[0] == 1 and args.batch > 1:
                target_latents = target_latents.expand(args.batch, -1, -1, -1, -1).contiguous()

            # Forward through model → relational tokens → VideoHead → latents
            embedded = model.embedding(text)
            seq_out = model.multimodal_sequencer({'text': embedded})
            rel = seq_out['text']

            pred_latents = model.video_head(rel)
            target_latents = _match_latents(target_latents, pred_latents)

            # MSE loss on latents
            loss_val = torch.nn.functional.mse_loss(pred_latents, target_latents)
            loss = loss_val / args.accum
            loss.backward()
            accum_loss += loss_val.item()

        torch.nn.utils.clip_grad_norm_(
            [p for p in model.parameters() if p.requires_grad], 1.0
        )
        opt.step()
        scheduler.step()
        step += 1

        if step % args.eval_interval == 0:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                text_v = torch.randint(0, 256, (args.batch, 10)).cuda()
                embedded_v = model.embedding(text_v)
                seq_v = model.multimodal_sequencer({'text': embedded_v})
                rel_v = seq_v['text']

                for idx in range(min(10, len(val_data))):
                    target = val_data[idx].cuda()
                    if target.shape[0] == 1 and args.batch > 1:
                        target = target.expand(args.batch, -1, -1, -1, -1).contiguous()
                    pred = model.video_head(rel_v)
                    target = _match_latents(target, pred)
                    val_loss += torch.nn.functional.mse_loss(pred, target).item()
            val_loss /= min(10, len(val_data))

            writer.add_scalar("loss/train", accum_loss, step)
            writer.add_scalar("loss/eval", val_loss, step)

            if val_loss < best_val:
                best_val = val_loss
                from training.finetuning.lora import save_lora
                save_lora(lora_layers, f"{run_dir}/best_lora.pt")

            print(f"step {step:>5d}/{args.steps}  train={accum_loss:.6f}  "
                  f"eval={val_loss:.6f}  best={best_val:.6f}", flush=True)
            model.train()

    from training.finetuning.lora import save_lora
    save_lora(lora_layers, f"{run_dir}/final_lora.pt")
    print(f"Done. LoRA saved to {run_dir}/", flush=True)