deepdetection / lipfd /train.py
akagtag's picture
Prepare Hugging Face Space deployment
de5d6bb
"""
train.py β€” Train LipFDNet on the AVLips v1.0 dataset.
Extracts a frame + audio sample from each video on-the-fly, trains the tiny
LipFDNet, saves ckpt.pth, then uploads to akagtag/LipFD-checkpoint.
Usage:
python lipfd/train.py # full dataset
python lipfd/train.py --max-per-class 200 # quick smoke-test (CPU ~10 min)
python lipfd/train.py --epochs 5 # default 5 epochs
"""
from __future__ import annotations
import argparse
import os
import random
import subprocess
import sys
import tarfile
import tempfile
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from lipfd.model import LipFDNet # noqa: E402
ARCHIVE = Path(__file__).with_name("AVLips v1.0.tar.xz")
CKPT_OUT = Path(__file__).with_name("ckpt.pth")
HF_REPO = "akagtag/LipFD-checkpoint"
# ── helpers ───────────────────────────────────────────────────────────────────
def _extract_frame_and_audio(video_bytes: bytes) -> tuple[np.ndarray, float]:
"""Extract middle frame (H,W,3 uint8) + RMS audio level from raw video bytes."""
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
f.write(video_bytes)
tmp = f.name
try:
# Frame: middle frame as raw RGB
cmd = [
"ffmpeg", "-i", tmp,
"-vf", "select=eq(n\\,15)", # frame 15 (β‰ˆmiddle for short clips)
"-frames:v", "1",
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-loglevel", "error",
"pipe:1",
]
r = subprocess.run(cmd, capture_output=True, timeout=10)
raw = r.stdout
frame: np.ndarray
if len(raw) >= 3:
side = int((len(raw) / 3) ** 0.5)
if side * side * 3 == len(raw):
frame = np.frombuffer(raw, dtype=np.uint8).reshape(side, side, 3)
else:
frame = np.zeros((64, 64, 3), dtype=np.uint8)
else:
frame = np.zeros((64, 64, 3), dtype=np.uint8)
# Audio: RMS level as scalar proxy
cmd2 = [
"ffmpeg", "-i", tmp, "-ac", "1", "-ar", "16000",
"-f", "f32le", "-loglevel", "error", "pipe:1",
]
r2 = subprocess.run(cmd2, capture_output=True, timeout=10)
if r2.stdout:
samples = np.frombuffer(r2.stdout, dtype=np.float32)
rms = float(np.sqrt(np.mean(samples ** 2) + 1e-9))
else:
rms = 0.0
except Exception:
frame = np.zeros((64, 64, 3), dtype=np.uint8)
rms = 0.0
finally:
Path(tmp).unlink(missing_ok=True)
return frame, rms
class AVLipsDataset(Dataset):
def __init__(self, archive: Path, max_per_class: int | None = None):
self.archive = archive
self.samples: list[tuple[str, str | None, int]] = [] # (video, wav, label)
with tarfile.open(archive, "r:xz") as tf:
names = tf.getnames()
real_v = [n for n in names if "/0_real/" in n and n.endswith(".mp4")]
fake_v = [n for n in names if "/1_fake/" in n and n.endswith(".mp4")]
# Build wav lookup: AVLips/wav/0_real/578.wav
wav_lookup: dict[str, str] = {}
for n in names:
if n.endswith(".wav"):
stem = Path(n).stem
wav_lookup[stem] = n
random.shuffle(real_v)
random.shuffle(fake_v)
if max_per_class:
real_v = real_v[:max_per_class]
fake_v = fake_v[:max_per_class]
for v in real_v:
wav = wav_lookup.get(Path(v).stem)
self.samples.append((v, wav, 0))
for v in fake_v:
wav = wav_lookup.get(Path(v).stem)
self.samples.append((v, wav, 1))
random.shuffle(self.samples)
print(f"Dataset: {len(real_v)} real, {len(fake_v)} fake")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
name, wav_name, label = self.samples[idx]
with tarfile.open(self.archive, "r:xz") as tf:
fobj = tf.extractfile(name)
data = fobj.read() if fobj else b""
# Use bundled WAV if available (better audio than ffmpeg extraction)
rms = 0.0
if wav_name:
try:
wobj = tf.extractfile(wav_name)
if wobj:
wav_data = wobj.read()
samples_np = np.frombuffer(wav_data[44:], dtype=np.int16).astype(np.float32) / 32768.0
rms = float(np.sqrt(np.mean(samples_np ** 2) + 1e-9))
except Exception:
pass
frame, rms_fallback = _extract_frame_and_audio(data)
if rms == 0.0:
rms = rms_fallback
# Visual: resize to 32x32, normalise
from PIL import Image # type: ignore
import torchvision.transforms as T # type: ignore
pil = Image.fromarray(frame).resize((32, 32))
vis = T.ToTensor()(pil) # (3, 32, 32)
audio = torch.tensor([rms], dtype=torch.float32)
return vis, audio, torch.tensor(label, dtype=torch.float32)
# ── training ──────────────────────────────────────────────────────────────────
def train(epochs: int = 5, max_per_class: int | None = None, lr: float = 1e-3):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {device}")
dataset = AVLipsDataset(ARCHIVE, max_per_class=max_per_class)
n_val = max(1, int(len(dataset) * 0.1))
train_ds, val_ds = torch.utils.data.random_split(
dataset, [len(dataset) - n_val, n_val]
)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0)
model = LipFDNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_val_acc = 0.0
for epoch in range(1, epochs + 1):
model.train()
total_loss = 0.0
for vis, audio, labels in train_loader:
vis, audio, labels = vis.to(device), audio.to(device), labels.to(device)
opt.zero_grad()
logits = model(vis, audio)
loss = criterion(logits, labels)
loss.backward()
opt.step()
total_loss += loss.item()
# Validation
model.eval()
correct = total = 0
with torch.no_grad():
for vis, audio, labels in val_loader:
vis, audio, labels = vis.to(device), audio.to(device), labels.to(device)
preds = (model(vis, audio) > 0).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / max(total, 1)
scheduler.step()
print(f"Epoch {epoch}/{epochs} loss={total_loss/len(train_loader):.4f} val_acc={val_acc:.3f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), CKPT_OUT)
print(f" βœ“ Saved checkpoint (val_acc={val_acc:.3f})")
print(f"\nTraining complete. Best val_acc={best_val_acc:.3f}")
print(f"Checkpoint saved to: {CKPT_OUT}")
return best_val_acc
def upload():
from huggingface_hub import HfApi # type: ignore
api = HfApi()
api.upload_file(
path_or_fileobj=str(CKPT_OUT),
path_in_repo="ckpt.pth",
repo_id=HF_REPO,
repo_type="model",
)
print(f"Uploaded ckpt.pth to {HF_REPO}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--max-per-class", type=int, default=None,
help="Limit videos per class (e.g. 200 for quick test)")
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--no-upload", action="store_true",
help="Skip HF upload after training")
args = parser.parse_args()
train(epochs=args.epochs, max_per_class=args.max_per_class, lr=args.lr)
if not args.no_upload:
if CKPT_OUT.exists():
upload()
else:
print("No checkpoint found β€” skipping upload")