Spaces:
Paused
Paused
| """ | |
| train.py β Train LipFDNet on the AVLips v1.0 dataset. | |
| Extracts a frame + audio sample from each video on-the-fly, trains the tiny | |
| LipFDNet, saves ckpt.pth, then uploads to akagtag/LipFD-checkpoint. | |
| Usage: | |
| python lipfd/train.py # full dataset | |
| python lipfd/train.py --max-per-class 200 # quick smoke-test (CPU ~10 min) | |
| python lipfd/train.py --epochs 5 # default 5 epochs | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import random | |
| import subprocess | |
| import sys | |
| import tarfile | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, Dataset | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from lipfd.model import LipFDNet # noqa: E402 | |
| ARCHIVE = Path(__file__).with_name("AVLips v1.0.tar.xz") | |
| CKPT_OUT = Path(__file__).with_name("ckpt.pth") | |
| HF_REPO = "akagtag/LipFD-checkpoint" | |
| # ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _extract_frame_and_audio(video_bytes: bytes) -> tuple[np.ndarray, float]: | |
| """Extract middle frame (H,W,3 uint8) + RMS audio level from raw video bytes.""" | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: | |
| f.write(video_bytes) | |
| tmp = f.name | |
| try: | |
| # Frame: middle frame as raw RGB | |
| cmd = [ | |
| "ffmpeg", "-i", tmp, | |
| "-vf", "select=eq(n\\,15)", # frame 15 (βmiddle for short clips) | |
| "-frames:v", "1", | |
| "-f", "rawvideo", "-pix_fmt", "rgb24", | |
| "-loglevel", "error", | |
| "pipe:1", | |
| ] | |
| r = subprocess.run(cmd, capture_output=True, timeout=10) | |
| raw = r.stdout | |
| frame: np.ndarray | |
| if len(raw) >= 3: | |
| side = int((len(raw) / 3) ** 0.5) | |
| if side * side * 3 == len(raw): | |
| frame = np.frombuffer(raw, dtype=np.uint8).reshape(side, side, 3) | |
| else: | |
| frame = np.zeros((64, 64, 3), dtype=np.uint8) | |
| else: | |
| frame = np.zeros((64, 64, 3), dtype=np.uint8) | |
| # Audio: RMS level as scalar proxy | |
| cmd2 = [ | |
| "ffmpeg", "-i", tmp, "-ac", "1", "-ar", "16000", | |
| "-f", "f32le", "-loglevel", "error", "pipe:1", | |
| ] | |
| r2 = subprocess.run(cmd2, capture_output=True, timeout=10) | |
| if r2.stdout: | |
| samples = np.frombuffer(r2.stdout, dtype=np.float32) | |
| rms = float(np.sqrt(np.mean(samples ** 2) + 1e-9)) | |
| else: | |
| rms = 0.0 | |
| except Exception: | |
| frame = np.zeros((64, 64, 3), dtype=np.uint8) | |
| rms = 0.0 | |
| finally: | |
| Path(tmp).unlink(missing_ok=True) | |
| return frame, rms | |
| class AVLipsDataset(Dataset): | |
| def __init__(self, archive: Path, max_per_class: int | None = None): | |
| self.archive = archive | |
| self.samples: list[tuple[str, str | None, int]] = [] # (video, wav, label) | |
| with tarfile.open(archive, "r:xz") as tf: | |
| names = tf.getnames() | |
| real_v = [n for n in names if "/0_real/" in n and n.endswith(".mp4")] | |
| fake_v = [n for n in names if "/1_fake/" in n and n.endswith(".mp4")] | |
| # Build wav lookup: AVLips/wav/0_real/578.wav | |
| wav_lookup: dict[str, str] = {} | |
| for n in names: | |
| if n.endswith(".wav"): | |
| stem = Path(n).stem | |
| wav_lookup[stem] = n | |
| random.shuffle(real_v) | |
| random.shuffle(fake_v) | |
| if max_per_class: | |
| real_v = real_v[:max_per_class] | |
| fake_v = fake_v[:max_per_class] | |
| for v in real_v: | |
| wav = wav_lookup.get(Path(v).stem) | |
| self.samples.append((v, wav, 0)) | |
| for v in fake_v: | |
| wav = wav_lookup.get(Path(v).stem) | |
| self.samples.append((v, wav, 1)) | |
| random.shuffle(self.samples) | |
| print(f"Dataset: {len(real_v)} real, {len(fake_v)} fake") | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int): | |
| name, wav_name, label = self.samples[idx] | |
| with tarfile.open(self.archive, "r:xz") as tf: | |
| fobj = tf.extractfile(name) | |
| data = fobj.read() if fobj else b"" | |
| # Use bundled WAV if available (better audio than ffmpeg extraction) | |
| rms = 0.0 | |
| if wav_name: | |
| try: | |
| wobj = tf.extractfile(wav_name) | |
| if wobj: | |
| wav_data = wobj.read() | |
| samples_np = np.frombuffer(wav_data[44:], dtype=np.int16).astype(np.float32) / 32768.0 | |
| rms = float(np.sqrt(np.mean(samples_np ** 2) + 1e-9)) | |
| except Exception: | |
| pass | |
| frame, rms_fallback = _extract_frame_and_audio(data) | |
| if rms == 0.0: | |
| rms = rms_fallback | |
| # Visual: resize to 32x32, normalise | |
| from PIL import Image # type: ignore | |
| import torchvision.transforms as T # type: ignore | |
| pil = Image.fromarray(frame).resize((32, 32)) | |
| vis = T.ToTensor()(pil) # (3, 32, 32) | |
| audio = torch.tensor([rms], dtype=torch.float32) | |
| return vis, audio, torch.tensor(label, dtype=torch.float32) | |
| # ββ training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(epochs: int = 5, max_per_class: int | None = None, lr: float = 1e-3): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Training on {device}") | |
| dataset = AVLipsDataset(ARCHIVE, max_per_class=max_per_class) | |
| n_val = max(1, int(len(dataset) * 0.1)) | |
| train_ds, val_ds = torch.utils.data.random_split( | |
| dataset, [len(dataset) - n_val, n_val] | |
| ) | |
| train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0) | |
| model = LipFDNet().to(device) | |
| opt = torch.optim.Adam(model.parameters(), lr=lr) | |
| criterion = nn.BCEWithLogitsLoss() | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) | |
| best_val_acc = 0.0 | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| total_loss = 0.0 | |
| for vis, audio, labels in train_loader: | |
| vis, audio, labels = vis.to(device), audio.to(device), labels.to(device) | |
| opt.zero_grad() | |
| logits = model(vis, audio) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| opt.step() | |
| total_loss += loss.item() | |
| # Validation | |
| model.eval() | |
| correct = total = 0 | |
| with torch.no_grad(): | |
| for vis, audio, labels in val_loader: | |
| vis, audio, labels = vis.to(device), audio.to(device), labels.to(device) | |
| preds = (model(vis, audio) > 0).float() | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| val_acc = correct / max(total, 1) | |
| scheduler.step() | |
| print(f"Epoch {epoch}/{epochs} loss={total_loss/len(train_loader):.4f} val_acc={val_acc:.3f}") | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save(model.state_dict(), CKPT_OUT) | |
| print(f" β Saved checkpoint (val_acc={val_acc:.3f})") | |
| print(f"\nTraining complete. Best val_acc={best_val_acc:.3f}") | |
| print(f"Checkpoint saved to: {CKPT_OUT}") | |
| return best_val_acc | |
| def upload(): | |
| from huggingface_hub import HfApi # type: ignore | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=str(CKPT_OUT), | |
| path_in_repo="ckpt.pth", | |
| repo_id=HF_REPO, | |
| repo_type="model", | |
| ) | |
| print(f"Uploaded ckpt.pth to {HF_REPO}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--epochs", type=int, default=5) | |
| parser.add_argument("--max-per-class", type=int, default=None, | |
| help="Limit videos per class (e.g. 200 for quick test)") | |
| parser.add_argument("--lr", type=float, default=1e-3) | |
| parser.add_argument("--no-upload", action="store_true", | |
| help="Skip HF upload after training") | |
| args = parser.parse_args() | |
| train(epochs=args.epochs, max_per_class=args.max_per_class, lr=args.lr) | |
| if not args.no_upload: | |
| if CKPT_OUT.exists(): | |
| upload() | |
| else: | |
| print("No checkpoint found β skipping upload") | |