File size: 7,542 Bytes
c4306d5
 
24dd46b
 
5a773fc
 
 
 
e309e64
24dd46b
 
 
 
 
e0f4290
5870e2c
c4306d5
399daf2
210f83d
2d22374
210f83d
6d1502a
399daf2
e0f4290
b1fba4c
 
e309e64
b1fba4c
 
e309e64
 
b1fba4c
c4306d5
b1fba4c
5870e2c
c4306d5
 
b1fba4c
5870e2c
b1fba4c
c4306d5
 
e309e64
210f83d
5870e2c
422ae83
 
 
 
 
210f83d
 
422ae83
 
6d1502a
49c095a
 
6d1502a
d1fc27c
 
6d1502a
 
 
2be63b8
 
 
 
d50110e
76f7ec9
 
262a1dc
4d9c8e1
6d1502a
76f7ec9
d1fc27c
1385931
2be63b8
d1fc27c
 
 
 
 
 
 
7c2be5c
 
 
 
 
 
 
 
 
399daf2
2be63b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1fc27c
 
 
a2dfad9
5870e2c
2d22374
210f83d
e309e64
d1028ac
 
 
b1fba4c
0c8af5a
 
 
 
 
 
 
 
 
 
 
 
8feafa5
b1fba4c
 
c4306d5
 
b1fba4c
 
 
 
 
 
 
 
 
 
5a773fc
 
 
 
7c2be5c
 
 
 
5a773fc
 
 
 
 
 
 
 
 
 
7c2be5c
5a773fc
7c2be5c
 
 
5a773fc
 
 
 
 
 
 
 
 
 
 
0c8af5a
 
 
 
 
 
 
 
 
 
 
6bb7746
 
5a773fc
 
 
 
 
 
 
 
 
b1fba4c
 
637ed97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import io
import time
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'

import torch
import tqdm.auto as tqdm
from glob import glob
import numpy as np
import pandas as pd

from models import Model
from datasets import load_dataset
from preprocess import preprocess, preprocess_FS

from src.rawnet_model import RawNet
from src.lcnn_model import LCNN
from src.resnet_model import ResNet_LogSpec, ResNet_MelSpec
from src.moe_model import UltimateMOE, MOE_attention, MOE_attention_FS

# Import your model and anything else you want
# You can even install other packages included in your repo
# However, during the evaluation the container will not have access to the internet.
# So you must include everything you need in your model repo. Common python libraries will be installed.
# Feel free to contact us to add dependencies to the requiremnts.txt
# For testing, this is the docker image that will be used https://github.com/huggingface/competitions/blob/main/Dockerfile
# It can be pulled here https://hub.docker.com/r/huggingface/competitions/tags


# load the dataset. dataset will be automatically downloaded to /tmp/data during evaluation
print('Load Dataset')
# DATASET_PATH = "/tmp/data_test"
# dataset_remote = glob(os.path.join(DATASET_PATH, '*'))
DATASET_PATH = "/tmp/data"
dataset_remote = load_dataset(DATASET_PATH,split="test", streaming=True)

device = "cuda:0"
# device = "cpu"

print('Define Model')

# # RAWNET2 MODEL
# config = {
#     "first_conv": 1024, "in_channels": 1, "filts": [20, [20, 20], [20, 128], [128, 128]],
#     "blocks": [2, 4], "nb_fc_node": 1024, "gru_node": 1024, "nb_gru_layer": 3, "nb_classes": 2
# }
# model = RawNet(config, device).to(device)
# model_path = './checkpoints/RAWNET_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
# model.load_state_dict(torch.load(model_path, map_location=device))

# RESNET MODEL
# model = ResNet_LogSpec(sample_rate=24000, return_emb=False).to(device)
# model_path = './checkpoints/RESNET_LOGSPEC_ALL_DATA_FS_24000.pth'

# model = ResNet_MelSpec(sample_rate=24000, return_emb=False).to(device)
# model_path = './checkpoints/RESNET_MELSPEC_ALL_DATA_FS_24000.pth'

## LCNN MODEL
# model = LCNN(return_emb=False, fs=24000).to(device)
# model_path = './checkpoints/LCNN_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
# model_path = './checkpoints/LCNN_ALL_DATA.pth'
# model_path = './checkpoints/LCNN_ALL_DATA_AUG.pth'
# model_path = './checkpoints/LCNN_ALL_DATA_TTS_AUG.pth'
# model_path = './checkpoints/LCNN_ALL_DATA_TTS_MOD.pth'
# model_path = './checkpoints/LCNN_ALL_DATA_HI_FREQ_22050.pth'

# model_path = './checkpoints/LCNN_ALL_DATA_FS_16000.pth'
# model_path = './checkpoints/LCNN_ALL_DATA_FS_22050.pth'
# model_path = './checkpoints/LCNN_ALL_DATA_FS_24000.pth'

# model.load_state_dict(torch.load(model_path, map_location=device))

# # MOE MODEL
expert_1 = LCNN(return_emb=True, fs=24000)
expert_2 = ResNet_LogSpec(return_emb=True, sample_rate=24000)
expert_3 = ResNet_MelSpec(return_emb=True, sample_rate=24000)

model = MOE_attention(experts=[expert_1, expert_2, expert_3], device=device)
model_path = './checkpoints/MOE_TRANSF_3EXP_MODELS_AUG.pth'

# expert_1 = LCNN(return_emb=True, fs=16000).to(device)
# expert_2 = LCNN(return_emb=True, fs=22050).to(device)
# expert_3 = LCNN(return_emb=True, fs=24000).to(device)
# # expert_4 = LCNN(return_emb=True).to(device)
# # expert_5 = LCNN(return_emb=True).to(device)
# # expert_6 = LCNN(return_emb=True).to(device)
#
# model = MOE_attention_FS(experts=[expert_1, expert_2, expert_3], device=device)
# model_path = './checkpoints/MOE_TRANSF_3EXP_FS_AUG_NO_FREEZE.pth'

# # # model = UltimateMOE(experts=[expert_1, expert_2, expert_3, expert_4])
# # # model_path = './checkpoints/MOE_ULTIMATE.pth'
#
# # model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device)
# # # model_path = './checkpoints/MOE_ATTENTION.pth'
# # model_path = './checkpoints/MOE_TRANSF.pth'
#
# expert_7 = LCNN(return_emb=True).to(device)
# expert_8 = LCNN(return_emb=True).to(device)
# model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6, expert_7, expert_8], device=device, freezing=True)
# # model_path = './checkpoints/MOE_TRANSF_7EXP.pth'
# # model_path = './checkpoints/MOE_TRANSF_7EXP_AUG.pth'
# # model_path = './checkpoints/MOE_TRANSF_7EXP_AUG_NO_FREEZE.pth'
# # model_path = './checkpoints/MOE_TRANSF_8EXP_AUG.pth'
# model_path = './checkpoints/MOE_TRANSF_8EXP_AUG_NO_FREEZE.pth'

model = (model).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))

model.eval()

print('Loaded Weights')

# # EVALUATE OLD MODEL
# del model
# model = Model().to(device)

# SAMPLING_RATE_CODES = {
#     8000: 2,
#     16000: 3,
#     22050: 5,
#     24000: 7,
#     32000: 11,
#     44100: 13,
#     48000: 17,
#     "other": 19
# }
#
# seen_frequencies = set()

# iterate over the dataset
out = []
for el in tqdm.tqdm(dataset_remote):
# for el in dataset_remote:

    start_time = time.time()

    # each element is a dict
    # el["id"] id of example and el["audio"] contains the audio file
    # el["audio"]["bytes"] contains bytes from reading the raw audio
    # el["audio"]["path"] containts the filename. This is just for reference and you cant actually load it

    # if you are using libraries that expect a file. You can use BytesIO object

    # try:

    # RUNNING ON HUGGINGFACE
    file_like = io.BytesIO(el["audio"]["bytes"])

    tensor, sr = preprocess(file_like, target_sr=24000)
    # tensor_16, tensor_22, tensor_24 = preprocess_FS(file_like)

    # # RUNNING LOCALLY
    # tensor = preprocess(el)

    with torch.no_grad():
        # soft decision (such as log likelihood score)
        # positive score correspond to synthetic prediction
        # negative score correspond to pristine prediction

        # # OLD MODEL
        # score = model(tensor.to(device)).cpu().item()

        # CUSTOM MODEL
        score = model(tensor.to(device))[:, 1].cpu()
        # score = model(tensor_16.to(device), tensor_22.to(device), tensor_24.to(device))[:, 1].cpu()

        print(f'SCORE OUT: {score}')
        score = score.mean().item()
        print(f'SCORE FINAL: {score}')

        # we require a hard decision to be submited. so you need to pick a threshold
        pred = "generated" if score > model.threshold else "pristine"

    # append your prediction
    # "id" and "pred" are required. "score" will not be used in scoring but we encourage you to include it. We'll use it for analysis of the results

    # RUNNING ON HUGGINGFACE
    total_time = time.time() - start_time

    # freq = sr if sr in SAMPLING_RATE_CODES else "other"
    #
    # # Assegna total_time: codice se è la prima occorrenza, 0 altrimenti
    # if freq not in seen_frequencies:
    #     total_time = SAMPLING_RATE_CODES[freq]
    #     seen_frequencies.add(freq)
    # # else:
    # #     total_time = 0
    # total_time = 1

    out.append(dict(id=el["id"], pred=pred, score=score, time=total_time))
    # # RUNNING LOCALLY
    # out.append(dict(id=el, pred=pred, score=score, time=time.time() - start_time))

    # except Exception as e:
    #     print(e)
    #     print("failed", el["id"])
    #     out.append(dict(id=el["id"], pred="none", score=None))
    #     # print("failed", el)
    #     # out.append(dict(id=el, pred="none", score=None))

# save the final result and that's it
pd.DataFrame(out).to_csv("submission.csv", index=False)