codec-corruption-classifier
Binary image classifier that predicts whether a video frame contains severe codec corruption โ block tearing, frame freezes, or other compression-artifact damage typically seen in low-bandwidth WiFi video streams (e.g. DJI Tello drone telemetry).
Trained on hand-labelled frames from indoor drone-mapping footage. Intended as a preprocessing filter for downstream SfM / 3D reconstruction pipelines, where a single severely-corrupted frame can pollute feature matching.
Architecture
- Backbone:
torchvision.models.mobilenet_v3_small(ImageNet-pretrained, IMAGENET1K_V1) - Head: replace the final
nn.Linearinclassifierwithnn.Linear(in_features, 1) - Output: single logit; apply
torch.sigmoidfor P(corrupted) - Suggested threshold:
0.5 - Params: 1.5M
Preprocessing
Frames are letterboxed (preserve aspect, pad with black) to 224ร224, then normalized with ImageNet statistics.
from PIL import Image
def letterbox(img: Image.Image, size: int = 224) -> Image.Image:
w, h = img.size
scale = size / max(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.BILINEAR)
padded = Image.new("RGB", (size, size), (0, 0, 0))
padded.paste(img, ((size - new_w) // 2, (size - new_h) // 2))
return padded
Usage
import torch
from torch import nn
from torchvision import transforms
from torchvision.models import mobilenet_v3_small
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
weights_path = hf_hub_download(
repo_id="callum-sh/codec-corruption-classifier",
filename="model.safetensors",
)
state = load_file(weights_path)
model = mobilenet_v3_small(weights=None)
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
model.load_state_dict(state)
model.eval()
tx = transforms.Compose([
transforms.Lambda(lambda im: letterbox(im, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = Image.open("frame.jpg").convert("RGB")
with torch.no_grad():
logit = model(tx(img).unsqueeze(0))
p_corrupted = torch.sigmoid(logit).item()
print(f"P(corrupted) = {p_corrupted:.3f}")
Intended use
Filter out frames before structure-from-motion. A frame with P(corrupted) > 0.5 should be excluded from the SfM input set.
Not intended as a general-purpose image-quality predictor โ it specifically targets codec artifacts, not blur, exposure, or motion noise.
- Downloads last month
- 2