|
|
import pandas as pd |
|
|
from datasets import load_dataset |
|
|
import numpy as np |
|
|
import tqdm.auto as tqdm |
|
|
import os |
|
|
import io |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
import time |
|
|
import av |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess( |
|
|
file_like: io.BytesIO, crop_size: int = -1, max_memory: int = 50 * 1024 * 1024, device: str = "cpu" |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
This preprocessing function loads videos and reduces their input size if necessary. |
|
|
This is just a guide function; square center cropping may not be the most appropriate, |
|
|
50 MB per video may not be enough, etc. |
|
|
|
|
|
Args: |
|
|
file_like (io.BytesIO): video bytes |
|
|
crop_size (int, optional): center crop adjustment (if frames are too large, this will crop) |
|
|
max_memory (int, optional): maximum memory per video to be saved as a tensor |
|
|
device (str, optional): which device to store the tensors on |
|
|
Returns: |
|
|
torch.Tensor: Tensor of video |
|
|
""" |
|
|
|
|
|
center_crop_transform = None |
|
|
if crop_size > 0: |
|
|
center_crop_transform = transforms.CenterCrop(crop_size) |
|
|
|
|
|
|
|
|
file_like.seek(0) |
|
|
container = av.open(file_like) |
|
|
frames = [] |
|
|
every = 10 |
|
|
current_memory = 0 |
|
|
for i, frame in enumerate(container.decode(video=0)): |
|
|
if i % every == 0: |
|
|
frame_array = frame.to_ndarray(format="rgb24") |
|
|
frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1).float() |
|
|
|
|
|
|
|
|
if center_crop_transform is not None: |
|
|
frame_tensor = center_crop_transform(frame_tensor) |
|
|
|
|
|
|
|
|
frames.append(frame_tensor.to(device)) |
|
|
|
|
|
|
|
|
frame_bytes = frame_tensor.numel() * 4 |
|
|
current_memory += frame_bytes |
|
|
if current_memory >= max_memory: |
|
|
break |
|
|
|
|
|
|
|
|
return torch.stack(frames) |
|
|
|
|
|
|
|
|
class Model(torch.nn.Module): |
|
|
def __init__(self): |
|
|
super(Model, self).__init__() |
|
|
self.fc1 = torch.nn.Linear(10, 5) |
|
|
self.threshold = 0.0 |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
return torch.randn(x.shape[0]).to(x.device) |
|
|
|
|
|
|
|
|
|
|
|
DATASET_PATH = "/tmp/data" |
|
|
dataset_remote = load_dataset(DATASET_PATH, split="test", streaming=True) |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda:0" |
|
|
model = Model().to(device) |
|
|
|
|
|
|
|
|
|
|
|
out = [] |
|
|
for el in tqdm.tqdm(dataset_remote): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
file_like = io.BytesIO(el["video"]["bytes"]) |
|
|
tensor = preprocess(file_like, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
score = model(tensor[None].to(device)).cpu().item() |
|
|
|
|
|
|
|
|
pred = "generated" if score > model.threshold else "real" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out.append(dict(id=el["id"], pred=pred, score=score)) |
|
|
|
|
|
except Exception as e: |
|
|
print(e) |
|
|
print("failed", el["id"]) |
|
|
out.append(dict(id=el["id"])) |
|
|
|
|
|
|
|
|
pd.DataFrame(out).to_csv("submission.csv", index=False) |