codec-corruption-classifier

Binary image classifier that predicts whether a video frame contains severe codec corruption โ€” block tearing, frame freezes, or other compression-artifact damage typically seen in low-bandwidth WiFi video streams (e.g. DJI Tello drone telemetry).

Trained on hand-labelled frames from indoor drone-mapping footage. Intended as a preprocessing filter for downstream SfM / 3D reconstruction pipelines, where a single severely-corrupted frame can pollute feature matching.

Architecture

  • Backbone: torchvision.models.mobilenet_v3_small (ImageNet-pretrained, IMAGENET1K_V1)
  • Head: replace the final nn.Linear in classifier with nn.Linear(in_features, 1)
  • Output: single logit; apply torch.sigmoid for P(corrupted)
  • Suggested threshold: 0.5
  • Params: 1.5M

Preprocessing

Frames are letterboxed (preserve aspect, pad with black) to 224ร—224, then normalized with ImageNet statistics.

from PIL import Image

def letterbox(img: Image.Image, size: int = 224) -> Image.Image:
    w, h = img.size
    scale = size / max(w, h)
    new_w, new_h = int(w * scale), int(h * scale)
    img = img.resize((new_w, new_h), Image.BILINEAR)
    padded = Image.new("RGB", (size, size), (0, 0, 0))
    padded.paste(img, ((size - new_w) // 2, (size - new_h) // 2))
    return padded

Usage

import torch
from torch import nn
from torchvision import transforms
from torchvision.models import mobilenet_v3_small
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image

weights_path = hf_hub_download(
    repo_id="callum-sh/codec-corruption-classifier",
    filename="model.safetensors",
)
state = load_file(weights_path)

model = mobilenet_v3_small(weights=None)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
model.load_state_dict(state)
model.eval()

tx = transforms.Compose([
    transforms.Lambda(lambda im: letterbox(im, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

img = Image.open("frame.jpg").convert("RGB")
with torch.no_grad():
    logit = model(tx(img).unsqueeze(0))
    p_corrupted = torch.sigmoid(logit).item()
print(f"P(corrupted) = {p_corrupted:.3f}")

Intended use

Filter out frames before structure-from-motion. A frame with P(corrupted) > 0.5 should be excluded from the SfM input set.

Not intended as a general-purpose image-quality predictor โ€” it specifically targets codec artifacts, not blur, exposure, or motion noise.

Downloads last month
2
Safetensors
Model size
1.53M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support