Renu05 commited on
Commit
bf1598b
·
verified ·
1 Parent(s): 3a47659

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +216 -216
inference_2.py CHANGED
@@ -1,216 +1,216 @@
1
- import os
2
- import cv2
3
- import onnx
4
- import torch
5
- import argparse
6
- import numpy as np
7
- import torch.nn as nn
8
- from models.TMC import ETMC
9
- from models import image
10
-
11
- from onnx2pytorch import ConvertModel
12
-
13
- onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
- pytorch_model = ConvertModel(onnx_model)
15
-
16
- #Set random seed for reproducibility.
17
- torch.manual_seed(42)
18
-
19
-
20
- # Define the audio_args dictionary
21
- audio_args = {
22
- 'nb_samp': 64600,
23
- 'first_conv': 1024,
24
- 'in_channels': 1,
25
- 'filts': [20, [20, 20], [20, 128], [128, 128]],
26
- 'blocks': [2, 4],
27
- 'nb_fc_node': 1024,
28
- 'gru_node': 1024,
29
- 'nb_gru_layer': 3,
30
- 'nb_classes': 2
31
- }
32
-
33
-
34
- def get_args(parser):
35
- parser.add_argument("--batch_size", type=int, default=8)
36
- parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
37
- parser.add_argument("--LOAD_SIZE", type=int, default=256)
38
- parser.add_argument("--FINE_SIZE", type=int, default=224)
39
- parser.add_argument("--dropout", type=float, default=0.2)
40
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
41
- parser.add_argument("--hidden", nargs="*", type=int, default=[])
42
- parser.add_argument("--hidden_sz", type=int, default=768)
43
- parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
44
- parser.add_argument("--img_hidden_sz", type=int, default=1024)
45
- parser.add_argument("--include_bn", type=int, default=True)
46
- parser.add_argument("--lr", type=float, default=1e-4)
47
- parser.add_argument("--lr_factor", type=float, default=0.3)
48
- parser.add_argument("--lr_patience", type=int, default=10)
49
- parser.add_argument("--max_epochs", type=int, default=500)
50
- parser.add_argument("--n_workers", type=int, default=12)
51
- parser.add_argument("--name", type=str, default="MMDF")
52
- parser.add_argument("--num_image_embeds", type=int, default=1)
53
- parser.add_argument("--patience", type=int, default=20)
54
- parser.add_argument("--savedir", type=str, default="./savepath/")
55
- parser.add_argument("--seed", type=int, default=1)
56
- parser.add_argument("--n_classes", type=int, default=2)
57
- parser.add_argument("--annealing_epoch", type=int, default=10)
58
- parser.add_argument("--device", type=str, default='cpu')
59
- parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
60
- parser.add_argument("--freeze_image_encoder", type=bool, default = False)
61
- parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
62
- parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
63
- parser.add_argument("--augment_dataset", type = bool, default = True)
64
-
65
- for key, value in audio_args.items():
66
- parser.add_argument(f"--{key}", type=type(value), default=value)
67
-
68
- def model_summary(args):
69
- '''Prints the model summary.'''
70
- model = ETMC(args)
71
-
72
- for name, layer in model.named_modules():
73
- print(name, layer)
74
-
75
- def load_multimodal_model(args):
76
- '''Load multimodal model'''
77
- model = ETMC(args)
78
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
79
- model.load_state_dict(ckpt, strict = True)
80
- model.eval()
81
- return model
82
-
83
- def load_img_modality_model(args):
84
- '''Loads image modality model.'''
85
- rgb_encoder = pytorch_model
86
-
87
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
88
- rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
89
- rgb_encoder.eval()
90
- return rgb_encoder
91
-
92
- def load_spec_modality_model(args):
93
- spec_encoder = image.RawNet(args)
94
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
95
- spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
96
- spec_encoder.eval()
97
- return spec_encoder
98
-
99
-
100
- #Load models.
101
- parser = argparse.ArgumentParser(description="Inference models")
102
- get_args(parser)
103
- args, remaining_args = parser.parse_known_args()
104
- assert remaining_args == [], remaining_args
105
-
106
- spec_model = load_spec_modality_model(args)
107
-
108
- img_model = load_img_modality_model(args)
109
-
110
-
111
- def preprocess_img(face):
112
- face = face / 255
113
- face = cv2.resize(face, (256, 256))
114
- # face = face.transpose(2, 0, 1) #(W, H, C) -> (C, W, H)
115
- face_pt = torch.unsqueeze(torch.Tensor(face), dim = 0)
116
- return face_pt
117
-
118
- def preprocess_audio(audio_file):
119
- audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim = 0)
120
- return audio_pt
121
-
122
- def deepfakes_spec_predict(input_audio):
123
- x, _ = input_audio
124
- audio = preprocess_audio(x)
125
- spec_grads = spec_model.forward(audio)
126
- spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
127
-
128
- # multimodal_grads = multimodal.spec_depth[0].forward(spec_grads)
129
-
130
- # out = nn.Softmax()(multimodal_grads)
131
- # max = torch.argmax(out, dim = -1) #Index of the max value in the tensor.
132
- # max_value = out[max] #Actual value of the tensor.
133
- max_value = np.argmax(spec_grads_inv)
134
-
135
- if max_value > 0.5:
136
- preds = round(100 - (max_value*100), 3)
137
- text2 = f"The audio is REAL."
138
-
139
- else:
140
- preds = round(max_value*100, 3)
141
- text2 = f"The audio is FAKE."
142
-
143
- return text2
144
-
145
- def deepfakes_image_predict(input_image):
146
- face = preprocess_img(input_image)
147
- print(f"Face shape is: {face.shape}")
148
- img_grads = img_model.forward(face)
149
- img_grads = img_grads.cpu().detach().numpy()
150
- img_grads_np = np.squeeze(img_grads)
151
-
152
- if img_grads_np[0] > 0.5:
153
- preds = round(img_grads_np[0] * 100, 3)
154
- text2 = f"The image is REAL. \nConfidence score is: {preds}"
155
-
156
- else:
157
- preds = round(img_grads_np[1] * 100, 3)
158
- text2 = f"The image is FAKE. \nConfidence score is: {preds}"
159
-
160
- return text2
161
-
162
-
163
- def preprocess_video(input_video, n_frames = 3):
164
- v_cap = cv2.VideoCapture(input_video)
165
- v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
166
-
167
- # Pick 'n_frames' evenly spaced frames to sample
168
- if n_frames is None:
169
- sample = np.arange(0, v_len)
170
- else:
171
- sample = np.linspace(0, v_len - 1, n_frames).astype(int)
172
-
173
- #Loop through frames.
174
- frames = []
175
- for j in range(v_len):
176
- success = v_cap.grab()
177
- if j in sample:
178
- # Load frame
179
- success, frame = v_cap.retrieve()
180
- if not success:
181
- continue
182
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
183
- frame = preprocess_img(frame)
184
- frames.append(frame)
185
- v_cap.release()
186
- return frames
187
-
188
-
189
- def deepfakes_video_predict(input_video):
190
- '''Perform inference on a video.'''
191
- video_frames = preprocess_video(input_video)
192
- real_faces_list = []
193
- fake_faces_list = []
194
-
195
- for face in video_frames:
196
- # face = preprocess_img(face)
197
-
198
- img_grads = img_model.forward(face)
199
- img_grads = img_grads.cpu().detach().numpy()
200
- img_grads_np = np.squeeze(img_grads)
201
- real_faces_list.append(img_grads_np[0])
202
- fake_faces_list.append(img_grads_np[1])
203
-
204
- real_faces_mean = np.mean(real_faces_list)
205
- fake_faces_mean = np.mean(fake_faces_list)
206
-
207
- if real_faces_mean > 0.5:
208
- preds = round(real_faces_mean * 100, 3)
209
- text2 = f"The video is REAL. \nConfidence score is: {preds}%"
210
-
211
- else:
212
- preds = round(fake_faces_mean * 100, 3)
213
- text2 = f"The video is FAKE. \nConfidence score is: {preds}%"
214
-
215
- return text2
216
-
 
1
+ import os
2
+ import cv2
3
+ import onnx
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ from models.TMC import ETMC
9
+ from models import image
10
+
11
+ from onnx2pytorch import ConvertModel
12
+
13
+ onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
+ pytorch_model = ConvertModel(onnx_model)
15
+
16
+ #Set random seed for reproducibility.
17
+ torch.manual_seed(42)
18
+
19
+
20
+ # Define the audio_args dictionary
21
+ audio_args = {
22
+ 'nb_samp': 64600,
23
+ 'first_conv': 1024,
24
+ 'in_channels': 1,
25
+ 'filts': [20, [20, 20], [20, 128], [128, 128]],
26
+ 'blocks': [2, 4],
27
+ 'nb_fc_node': 1024,
28
+ 'gru_node': 1024,
29
+ 'nb_gru_layer': 3,
30
+ 'nb_classes': 2
31
+ }
32
+
33
+
34
+ def get_args(parser):
35
+ parser.add_argument("--batch_size", type=int, default=8)
36
+ parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
37
+ parser.add_argument("--LOAD_SIZE", type=int, default=256)
38
+ parser.add_argument("--FINE_SIZE", type=int, default=224)
39
+ parser.add_argument("--dropout", type=float, default=0.2)
40
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
41
+ parser.add_argument("--hidden", nargs="*", type=int, default=[])
42
+ parser.add_argument("--hidden_sz", type=int, default=768)
43
+ parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
44
+ parser.add_argument("--img_hidden_sz", type=int, default=1024)
45
+ parser.add_argument("--include_bn", type=int, default=True)
46
+ parser.add_argument("--lr", type=float, default=1e-4)
47
+ parser.add_argument("--lr_factor", type=float, default=0.3)
48
+ parser.add_argument("--lr_patience", type=int, default=10)
49
+ parser.add_argument("--max_epochs", type=int, default=500)
50
+ parser.add_argument("--n_workers", type=int, default=12)
51
+ parser.add_argument("--name", type=str, default="MMDF")
52
+ parser.add_argument("--num_image_embeds", type=int, default=1)
53
+ parser.add_argument("--patience", type=int, default=20)
54
+ parser.add_argument("--savedir", type=str, default="./savepath/")
55
+ parser.add_argument("--seed", type=int, default=1)
56
+ parser.add_argument("--n_classes", type=int, default=2)
57
+ parser.add_argument("--annealing_epoch", type=int, default=10)
58
+ parser.add_argument("--device", type=str, default='cpu')
59
+ parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
60
+ parser.add_argument("--freeze_image_encoder", type=bool, default = False)
61
+ parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
62
+ parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
63
+ parser.add_argument("--augment_dataset", type = bool, default = True)
64
+
65
+ for key, value in audio_args.items():
66
+ parser.add_argument(f"--{key}", type=type(value), default=value)
67
+
68
+ def model_summary(args):
69
+ '''Prints the model summary.'''
70
+ model = ETMC(args)
71
+
72
+ for name, layer in model.named_modules():
73
+ print(name, layer)
74
+
75
+ def load_multimodal_model(args):
76
+ '''Load multimodal model'''
77
+ model = ETMC(args)
78
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
79
+ model.load_state_dict(ckpt, strict = True)
80
+ model.eval()
81
+ return model
82
+
83
+ def load_img_modality_model(args):
84
+ '''Loads image modality model.'''
85
+ rgb_encoder = pytorch_model
86
+
87
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
88
+ rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
89
+ rgb_encoder.eval()
90
+ return rgb_encoder
91
+
92
+ def load_spec_modality_model(args):
93
+ spec_encoder = image.RawNet(args)
94
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
95
+ spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
96
+ spec_encoder.eval()
97
+ return spec_encoder
98
+
99
+
100
+ #Load models.
101
+ parser = argparse.ArgumentParser(description="Inference models")
102
+ get_args(parser)
103
+ args, remaining_args = parser.parse_known_args()
104
+ assert remaining_args == [], remaining_args
105
+
106
+ spec_model = load_spec_modality_model(args)
107
+
108
+ img_model = load_img_modality_model(args)
109
+
110
+
111
+ def preprocess_img(face):
112
+ face = face / 255
113
+ face = cv2.resize(face, (256, 256))
114
+ # face = face.transpose(2, 0, 1) #(W, H, C) -> (C, W, H)
115
+ face_pt = torch.unsqueeze(torch.Tensor(face), dim = 0)
116
+ return face_pt
117
+
118
+ def preprocess_audio(audio_file):
119
+ audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim = 0)
120
+ return audio_pt
121
+
122
+ def deepfakes_spec_predict(input_audio):
123
+ x, _ = input_audio
124
+ audio = preprocess_audio(x)
125
+ spec_grads = spec_model.forward(audio)
126
+ spec_grads_inv = np.exp(spec_grads.cpu().detach().numpy().squeeze())
127
+
128
+ # multimodal_grads = multimodal.spec_depth[0].forward(spec_grads)
129
+
130
+ # out = nn.Softmax()(multimodal_grads)
131
+ # max = torch.argmax(out, dim = -1) #Index of the max value in the tensor.
132
+ # max_value = out[max] #Actual value of the tensor.
133
+ max_value = np.argmax(spec_grads_inv)
134
+
135
+ if max_value > 0.5:
136
+ preds = round(100 - (max_value*100), 3)
137
+ text2 = f"The audio is REAL."
138
+
139
+ else:
140
+ preds = round(max_value*100, 3)
141
+ text2 = f"The audio is FAKE."
142
+
143
+ return text2
144
+
145
+ def deepfakes_image_predict(input_image):
146
+ face = preprocess_img(input_image)
147
+ print(f"Face shape is: {face.shape}")
148
+ img_grads = img_model.forward(face)
149
+ img_grads = img_grads.cpu().detach().numpy()
150
+ img_grads_np = np.squeeze(img_grads)
151
+
152
+ if img_grads_np[0] > 0.5:
153
+ preds = round(img_grads_np[0] * 100, 3)
154
+ text2 = f"The image is REAL."
155
+
156
+ else:
157
+ preds = round(img_grads_np[1] * 100, 3)
158
+ text2 = f"The image is FAKE."
159
+
160
+ return text2
161
+
162
+
163
+ def preprocess_video(input_video, n_frames = 3):
164
+ v_cap = cv2.VideoCapture(input_video)
165
+ v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
166
+
167
+ # Pick 'n_frames' evenly spaced frames to sample
168
+ if n_frames is None:
169
+ sample = np.arange(0, v_len)
170
+ else:
171
+ sample = np.linspace(0, v_len - 1, n_frames).astype(int)
172
+
173
+ #Loop through frames.
174
+ frames = []
175
+ for j in range(v_len):
176
+ success = v_cap.grab()
177
+ if j in sample:
178
+ # Load frame
179
+ success, frame = v_cap.retrieve()
180
+ if not success:
181
+ continue
182
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
183
+ frame = preprocess_img(frame)
184
+ frames.append(frame)
185
+ v_cap.release()
186
+ return frames
187
+
188
+
189
+ def deepfakes_video_predict(input_video):
190
+ '''Perform inference on a video.'''
191
+ video_frames = preprocess_video(input_video)
192
+ real_faces_list = []
193
+ fake_faces_list = []
194
+
195
+ for face in video_frames:
196
+ # face = preprocess_img(face)
197
+
198
+ img_grads = img_model.forward(face)
199
+ img_grads = img_grads.cpu().detach().numpy()
200
+ img_grads_np = np.squeeze(img_grads)
201
+ real_faces_list.append(img_grads_np[0])
202
+ fake_faces_list.append(img_grads_np[1])
203
+
204
+ real_faces_mean = np.mean(real_faces_list)
205
+ fake_faces_mean = np.mean(fake_faces_list)
206
+
207
+ if real_faces_mean > 0.5:
208
+ preds = round(real_faces_mean * 100, 3)
209
+ text2 = f"The video is REAL."
210
+
211
+ else:
212
+ preds = round(fake_faces_mean * 100, 3)
213
+ text2 = f"The video is FAKE."
214
+
215
+ return text2
216
+