--- license: mit tags: - image-classification - quality-assessment - codec-corruption - mobilenet library_name: pytorch pipeline_tag: image-classification --- # codec-corruption-classifier Binary image classifier that predicts whether a video frame contains severe **codec corruption** — block tearing, frame freezes, or other compression-artifact damage typically seen in low-bandwidth WiFi video streams (e.g. DJI Tello drone telemetry). Trained on hand-labelled frames from indoor drone-mapping footage. Intended as a preprocessing filter for downstream SfM / 3D reconstruction pipelines, where a single severely-corrupted frame can pollute feature matching. ## Architecture - Backbone: `torchvision.models.mobilenet_v3_small` (ImageNet-pretrained, IMAGENET1K_V1) - Head: replace the final `nn.Linear` in `classifier` with `nn.Linear(in_features, 1)` - Output: single logit; apply `torch.sigmoid` for P(corrupted) - Suggested threshold: `0.5` - Params: 1.5M ## Preprocessing Frames are letterboxed (preserve aspect, pad with black) to 224×224, then normalized with ImageNet statistics. ```python from PIL import Image def letterbox(img: Image.Image, size: int = 224) -> Image.Image: w, h = img.size scale = size / max(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.BILINEAR) padded = Image.new("RGB", (size, size), (0, 0, 0)) padded.paste(img, ((size - new_w) // 2, (size - new_h) // 2)) return padded ``` ## Usage ```python import torch from torch import nn from torchvision import transforms from torchvision.models import mobilenet_v3_small from huggingface_hub import hf_hub_download from safetensors.torch import load_file from PIL import Image weights_path = hf_hub_download( repo_id="callum-sh/codec-corruption-classifier", filename="model.safetensors", ) state = load_file(weights_path) model = mobilenet_v3_small(weights=None) in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, 1) model.load_state_dict(state) model.eval() tx = transforms.Compose([ transforms.Lambda(lambda im: letterbox(im, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) img = Image.open("frame.jpg").convert("RGB") with torch.no_grad(): logit = model(tx(img).unsqueeze(0)) p_corrupted = torch.sigmoid(logit).item() print(f"P(corrupted) = {p_corrupted:.3f}") ``` ## Intended use Filter out frames before structure-from-motion. A frame with `P(corrupted) > 0.5` should be excluded from the SfM input set. Not intended as a general-purpose image-quality predictor — it specifically targets *codec* artifacts, not blur, exposure, or motion noise.