Spaces:
Configuration error

File size: 2,920 Bytes
e059c3c
 
 
 
 
 
 
95a33ce
 
 
95db469
ce0d4fb
 
95a33ce
95db469
 
 
 
 
 
 
 
e059c3c
 
fc405ab
e059c3c
 
95a33ce
e059c3c
ce0d4fb
 
 
 
e059c3c
ce0d4fb
e059c3c
 
 
 
ce0d4fb
 
e059c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
ce0d4fb
 
 
 
 
e059c3c
ce0d4fb
 
 
 
 
 
 
 
 
e059c3c
 
 
 
ce0d4fb
 
 
 
e059c3c
 
 
ce0d4fb
 
e059c3c
 
 
 
95db469
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import shutil
import zipfile
from os.path import join, isfile, basename

import cv2
import numpy as np
import gradio as gr
import torch

from resnet50 import resnet18
from sampling_util import furthest_neighbours
from video_reader import video_reader

model = resnet18(
    output_dim=0,
    nmb_prototypes=0,
    eval_mode=True,
    hidden_mlp=0,
    normalize=False)
model.load_state_dict(torch.load("model.pt"))
model.eval()
avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))


def predict(input_file, downsample_size):
    downsample_size = int(downsample_size)

    base_directory = os.getcwd()
    selected_directory = os.path.join(base_directory, "selected_images")
    if os.path.isdir(selected_directory):
        shutil.rmtree(selected_directory)
    os.mkdir(selected_directory)

    zip_path = os.path.join(input_file.split('/')[-1][:-4] + ".zip")

    mean = np.asarray([0.3156024, 0.33569682, 0.34337464])
    std = np.asarray([0.16568947, 0.17827448, 0.18925823])

    img_vecs = []
    with torch.no_grad():
        for fp_i, file_path in enumerate([input_file]):
            for i, in_img in enumerate(video_reader(file_path,
                                                    targetFPS=9,
                                                    targetWidth=100,
                                                    to_rgb=True)):
                in_img = (in_img.astype(np.float32) / 255.)
                in_img = (in_img - mean) / std
                in_img = np.transpose(in_img, (0, 3, 1, 2))
                in_img = torch.from_numpy(in_img)
                encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
                img_vecs += [encoded]

    img_vecs = np.asarray(img_vecs)
    rv_indices, _ = furthest_neighbours(
        img_vecs,
        downsample_size,
        seed=0)
    indices = np.zeros((img_vecs.shape[0],))
    indices[np.asarray(rv_indices)] = 1

    global_ctr = 0
    for fp_i, file_path in enumerate([input_file]):
        for i, img in enumerate(video_reader(file_path,
                                             targetFPS=9,
                                             targetWidth=None,
                                             to_rgb=False)):
            if indices[global_ctr] == 1:
                cv2.imwrite(join(selected_directory, str(global_ctr) + ".jpg"), img)
            global_ctr += 1

    all_selected_imgs_path = [join(selected_directory, f) for f in os.listdir(selected_directory) if
                              isfile(join(selected_directory, f))]

    zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED)
    for i, f in enumerate(all_selected_imgs_path):
        zipf.write(f, basename(f))
    zipf.close()

    return zip_path


demo = gr.Interface(
    fn=predict,
    inputs=[gr.inputs.Video(label="Upload Video File"), gr.inputs.Number(Label="Downsample size")],
    outputs=gr.outputs.File(label="Zip"))

demo.launch()