""" train.py — Train LipFDNet on the AVLips v1.0 dataset. Extracts a frame + audio sample from each video on-the-fly, trains the tiny LipFDNet, saves ckpt.pth, then uploads to akagtag/LipFD-checkpoint. Usage: python lipfd/train.py # full dataset python lipfd/train.py --max-per-class 200 # quick smoke-test (CPU ~10 min) python lipfd/train.py --epochs 5 # default 5 epochs """ from __future__ import annotations import argparse import os import random import subprocess import sys import tarfile import tempfile from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from lipfd.model import LipFDNet # noqa: E402 ARCHIVE = Path(__file__).with_name("AVLips v1.0.tar.xz") CKPT_OUT = Path(__file__).with_name("ckpt.pth") HF_REPO = "akagtag/LipFD-checkpoint" # ── helpers ─────────────────────────────────────────────────────────────────── def _extract_frame_and_audio(video_bytes: bytes) -> tuple[np.ndarray, float]: """Extract middle frame (H,W,3 uint8) + RMS audio level from raw video bytes.""" with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: f.write(video_bytes) tmp = f.name try: # Frame: middle frame as raw RGB cmd = [ "ffmpeg", "-i", tmp, "-vf", "select=eq(n\\,15)", # frame 15 (≈middle for short clips) "-frames:v", "1", "-f", "rawvideo", "-pix_fmt", "rgb24", "-loglevel", "error", "pipe:1", ] r = subprocess.run(cmd, capture_output=True, timeout=10) raw = r.stdout frame: np.ndarray if len(raw) >= 3: side = int((len(raw) / 3) ** 0.5) if side * side * 3 == len(raw): frame = np.frombuffer(raw, dtype=np.uint8).reshape(side, side, 3) else: frame = np.zeros((64, 64, 3), dtype=np.uint8) else: frame = np.zeros((64, 64, 3), dtype=np.uint8) # Audio: RMS level as scalar proxy cmd2 = [ "ffmpeg", "-i", tmp, "-ac", "1", "-ar", "16000", "-f", "f32le", "-loglevel", "error", "pipe:1", ] r2 = subprocess.run(cmd2, capture_output=True, timeout=10) if r2.stdout: samples = np.frombuffer(r2.stdout, dtype=np.float32) rms = float(np.sqrt(np.mean(samples ** 2) + 1e-9)) else: rms = 0.0 except Exception: frame = np.zeros((64, 64, 3), dtype=np.uint8) rms = 0.0 finally: Path(tmp).unlink(missing_ok=True) return frame, rms class AVLipsDataset(Dataset): def __init__(self, archive: Path, max_per_class: int | None = None): self.archive = archive self.samples: list[tuple[str, str | None, int]] = [] # (video, wav, label) with tarfile.open(archive, "r:xz") as tf: names = tf.getnames() real_v = [n for n in names if "/0_real/" in n and n.endswith(".mp4")] fake_v = [n for n in names if "/1_fake/" in n and n.endswith(".mp4")] # Build wav lookup: AVLips/wav/0_real/578.wav wav_lookup: dict[str, str] = {} for n in names: if n.endswith(".wav"): stem = Path(n).stem wav_lookup[stem] = n random.shuffle(real_v) random.shuffle(fake_v) if max_per_class: real_v = real_v[:max_per_class] fake_v = fake_v[:max_per_class] for v in real_v: wav = wav_lookup.get(Path(v).stem) self.samples.append((v, wav, 0)) for v in fake_v: wav = wav_lookup.get(Path(v).stem) self.samples.append((v, wav, 1)) random.shuffle(self.samples) print(f"Dataset: {len(real_v)} real, {len(fake_v)} fake") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): name, wav_name, label = self.samples[idx] with tarfile.open(self.archive, "r:xz") as tf: fobj = tf.extractfile(name) data = fobj.read() if fobj else b"" # Use bundled WAV if available (better audio than ffmpeg extraction) rms = 0.0 if wav_name: try: wobj = tf.extractfile(wav_name) if wobj: wav_data = wobj.read() samples_np = np.frombuffer(wav_data[44:], dtype=np.int16).astype(np.float32) / 32768.0 rms = float(np.sqrt(np.mean(samples_np ** 2) + 1e-9)) except Exception: pass frame, rms_fallback = _extract_frame_and_audio(data) if rms == 0.0: rms = rms_fallback # Visual: resize to 32x32, normalise from PIL import Image # type: ignore import torchvision.transforms as T # type: ignore pil = Image.fromarray(frame).resize((32, 32)) vis = T.ToTensor()(pil) # (3, 32, 32) audio = torch.tensor([rms], dtype=torch.float32) return vis, audio, torch.tensor(label, dtype=torch.float32) # ── training ────────────────────────────────────────────────────────────────── def train(epochs: int = 5, max_per_class: int | None = None, lr: float = 1e-3): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Training on {device}") dataset = AVLipsDataset(ARCHIVE, max_per_class=max_per_class) n_val = max(1, int(len(dataset) * 0.1)) train_ds, val_ds = torch.utils.data.random_split( dataset, [len(dataset) - n_val, n_val] ) train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0) val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0) model = LipFDNet().to(device) opt = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.BCEWithLogitsLoss() scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best_val_acc = 0.0 for epoch in range(1, epochs + 1): model.train() total_loss = 0.0 for vis, audio, labels in train_loader: vis, audio, labels = vis.to(device), audio.to(device), labels.to(device) opt.zero_grad() logits = model(vis, audio) loss = criterion(logits, labels) loss.backward() opt.step() total_loss += loss.item() # Validation model.eval() correct = total = 0 with torch.no_grad(): for vis, audio, labels in val_loader: vis, audio, labels = vis.to(device), audio.to(device), labels.to(device) preds = (model(vis, audio) > 0).float() correct += (preds == labels).sum().item() total += labels.size(0) val_acc = correct / max(total, 1) scheduler.step() print(f"Epoch {epoch}/{epochs} loss={total_loss/len(train_loader):.4f} val_acc={val_acc:.3f}") if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), CKPT_OUT) print(f" ✓ Saved checkpoint (val_acc={val_acc:.3f})") print(f"\nTraining complete. Best val_acc={best_val_acc:.3f}") print(f"Checkpoint saved to: {CKPT_OUT}") return best_val_acc def upload(): from huggingface_hub import HfApi # type: ignore api = HfApi() api.upload_file( path_or_fileobj=str(CKPT_OUT), path_in_repo="ckpt.pth", repo_id=HF_REPO, repo_type="model", ) print(f"Uploaded ckpt.pth to {HF_REPO}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--max-per-class", type=int, default=None, help="Limit videos per class (e.g. 200 for quick test)") parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--no-upload", action="store_true", help="Skip HF upload after training") args = parser.parse_args() train(epochs=args.epochs, max_per_class=args.max_per_class, lr=args.lr) if not args.no_upload: if CKPT_OUT.exists(): upload() else: print("No checkpoint found — skipping upload")