Superxixixi commited on
Commit
d828569
·
1 Parent(s): 849d787

Delete legacy

Browse files
legacy/talkNet_multi_multicard.py DELETED
@@ -1,124 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- import sys, time, numpy, os, subprocess, pandas, tqdm
6
-
7
- from loss_multi import lossAV, lossA, lossV
8
- from model.talkNetModel import talkNetModel
9
-
10
- import pytorch_lightning as pl
11
- from torch import distributed as dist
12
-
13
-
14
- class talkNet(pl.LightningModule):
15
-
16
- def __init__(self, cfg):
17
- super(talkNet, self).__init__()
18
- self.model = talkNetModel().cuda()
19
- self.cfg = cfg
20
- self.lossAV = lossAV().cuda()
21
- self.lossA = lossA().cuda()
22
- self.lossV = lossV().cuda()
23
- print(
24
- time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
25
- (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
26
-
27
- def configure_optimizers(self):
28
- optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.SOLVER.BASE_LR)
29
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
30
- step_size=1,
31
- gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
32
- return {"optimizer": optimizer, "lr_scheduler": scheduler}
33
-
34
- def training_step(self, batch, batch_idx):
35
- audioFeature, visualFeature, labels, masks = batch
36
- b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2]
37
- audioFeature = audioFeature.repeat(1, s, 1, 1)
38
- audioFeature = audioFeature.view(b * s, *audioFeature.shape[2:])
39
- visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
40
- labels = labels.view(b * s, *labels.shape[2:])
41
- masks = masks.view(b * s, *masks.shape[2:])
42
-
43
- audioEmbed = self.model.forward_audio_frontend(audioFeature) # feedForward
44
- visualEmbed = self.model.forward_visual_frontend(visualFeature)
45
- audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
46
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
47
- outsA = self.model.forward_audio_backend(audioEmbed)
48
- outsV = self.model.forward_visual_backend(visualEmbed)
49
- labels = labels.reshape((-1))
50
- nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
51
- nlossA = self.lossA.forward(outsA, labels, masks)
52
- nlossV = self.lossV.forward(outsV, labels, masks)
53
- loss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
54
- self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
55
- return loss
56
-
57
- def training_epoch_end(self, training_step_outputs):
58
- self.saveParameters(
59
- os.path.join(self.cfg.WORKSPACE, "model", "{}.pth".format(self.current_epoch)))
60
-
61
- def evaluate_network(self, loader):
62
- self.eval()
63
- predScores = []
64
- self.model = self.model.cuda()
65
- self.lossAV = self.lossAV.cuda()
66
- self.lossA = self.lossA.cuda()
67
- self.lossV = self.lossV.cuda()
68
- evalCsvSave = self.cfg.evalCsvSave
69
- evalOrig = self.cfg.evalOrig
70
- for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader):
71
- with torch.no_grad():
72
- b, s = visualFeature.shape[0], visualFeature.shape[1]
73
- t = visualFeature.shape[2]
74
- audioFeature = audioFeature.repeat(1, s, 1, 1)
75
- audioFeature = audioFeature.view(b * s, *audioFeature.shape[2:])
76
- visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
77
- labels = labels.view(b * s, *labels.shape[2:])
78
- masks = masks.view(b * s, *masks.shape[2:])
79
- audioEmbed = self.model.forward_audio_frontend(audioFeature.cuda())
80
- visualEmbed = self.model.forward_visual_frontend(visualFeature.cuda())
81
- audioEmbed, visualEmbed = self.model.forward_cross_attention(
82
- audioEmbed, visualEmbed)
83
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
84
- labels = labels.reshape((-1)).cuda()
85
- outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1)
86
- labels = labels.view(b, s, t)[:, 0, :].view(b * t)
87
- masks = masks.view(b, s, t)[:, 0, :].view(b * t)
88
- _, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks)
89
- predScore = predScore.detach().cpu().numpy()
90
- predScores.extend(predScore)
91
- evalLines = open(evalOrig).read().splitlines()[1:]
92
- labels = []
93
- labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
94
- scores = pandas.Series(predScores)
95
- evalRes = pandas.read_csv(evalOrig)
96
- evalRes['score'] = scores
97
- evalRes['label'] = labels
98
- evalRes.drop(['label_id'], axis=1, inplace=True)
99
- evalRes.drop(['instance_id'], axis=1, inplace=True)
100
- evalRes.to_csv(evalCsvSave, index=False)
101
- cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
102
- evalCsvSave)
103
- mAP = float(
104
- str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
105
- return mAP
106
-
107
- def saveParameters(self, path):
108
- torch.save(self.state_dict(), path)
109
-
110
- def loadParameters(self, path):
111
- selfState = self.state_dict()
112
- loadedState = torch.load(path)
113
- for name, param in loadedState.items():
114
- origName = name
115
- if name not in selfState:
116
- name = name.replace("module.", "")
117
- if name not in selfState:
118
- print("%s is not in the model." % origName)
119
- continue
120
- if selfState[name].size() != loadedState[origName].size():
121
- sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
122
- (origName, selfState[name].size(), loadedState[origName].size()))
123
- continue
124
- selfState[name].copy_(param)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/talkNet_multicard.py DELETED
@@ -1,146 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- import sys, time, numpy, os, subprocess, pandas, tqdm
6
-
7
- from loss import lossAV, lossA, lossV
8
- from model.talkNetModel import talkNetModel
9
-
10
- import pytorch_lightning as pl
11
- from torch import distributed as dist
12
-
13
-
14
- class talkNet(pl.LightningModule):
15
-
16
- def __init__(self, cfg):
17
- super(talkNet, self).__init__()
18
- self.cfg = cfg
19
- self.model = talkNetModel()
20
- self.lossAV = lossAV()
21
- self.lossA = lossA()
22
- self.lossV = lossV()
23
- print(
24
- time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
25
- (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
26
-
27
- def configure_optimizers(self):
28
- optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.SOLVER.BASE_LR)
29
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
30
- step_size=1,
31
- gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
32
- return {"optimizer": optimizer, "lr_scheduler": scheduler}
33
-
34
- def training_step(self, batch, batch_idx):
35
- audioFeature, visualFeature, labels = batch
36
- audioEmbed = self.model.forward_audio_frontend(audioFeature[0]) # feedForward
37
- visualEmbed = self.model.forward_visual_frontend(visualFeature[0])
38
- audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
39
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
40
- outsA = self.model.forward_audio_backend(audioEmbed)
41
- outsV = self.model.forward_visual_backend(visualEmbed)
42
- labels = labels[0].reshape((-1))
43
- nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels)
44
- nlossA = self.lossA.forward(outsA, labels)
45
- nlossV = self.lossV.forward(outsV, labels)
46
- loss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
47
- self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
48
-
49
- return loss
50
-
51
- def training_epoch_end(self, training_step_outputs):
52
- self.saveParameters(
53
- os.path.join(self.cfg.WORKSPACE, "model", "{}.pth".format(self.current_epoch)))
54
-
55
- def validation_step(self, batch, batch_idx):
56
- audioFeature, visualFeature, labels, indices = batch
57
- audioEmbed = self.model.forward_audio_frontend(audioFeature[0])
58
- visualEmbed = self.model.forward_visual_frontend(visualFeature[0])
59
- audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
60
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
61
- labels = labels[0].reshape((-1))
62
- loss, predScore, _, _ = self.lossAV.forward(outsAV, labels)
63
- predScore = predScore[:, -1:].detach().cpu().numpy()
64
- # self.log("val_loss", loss)
65
-
66
- return predScore
67
-
68
- def validation_epoch_end(self, validation_step_outputs):
69
- evalCsvSave = self.cfg.evalCsvSave
70
- evalOrig = self.cfg.evalOrig
71
- predScores = []
72
-
73
- for out in validation_step_outputs: # batch size =1
74
- predScores.extend(out)
75
-
76
- evalLines = open(evalOrig).read().splitlines()[1:]
77
- labels = []
78
- labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
79
- scores = pandas.Series(predScores)
80
- evalRes = pandas.read_csv(evalOrig)
81
- print(len(evalRes), len(predScores), len(evalLines))
82
- evalRes['score'] = scores
83
- evalRes['label'] = labels
84
- evalRes.drop(['label_id'], axis=1, inplace=True)
85
- evalRes.drop(['instance_id'], axis=1, inplace=True)
86
- evalRes.to_csv(evalCsvSave, index=False)
87
- cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
88
- evalCsvSave)
89
- mAP = float(
90
- str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
91
- print("validation mAP: {}".format(mAP))
92
-
93
- def saveParameters(self, path):
94
- torch.save(self.state_dict(), path)
95
-
96
- def loadParameters(self, path):
97
- selfState = self.state_dict()
98
- loadedState = torch.load(path, map_location='cpu')
99
- for name, param in loadedState.items():
100
- origName = name
101
- if name not in selfState:
102
- name = name.replace("module.", "")
103
- if name not in selfState:
104
- print("%s is not in the model." % origName)
105
- continue
106
- if selfState[name].size() != loadedState[origName].size():
107
- sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
108
- (origName, selfState[name].size(), loadedState[origName].size()))
109
- continue
110
- selfState[name].copy_(param)
111
-
112
- def evaluate_network(self, loader):
113
- self.eval()
114
- self.model = self.model.cuda()
115
- self.lossAV = self.lossAV.cuda()
116
- self.lossA = self.lossA.cuda()
117
- self.lossV = self.lossV.cuda()
118
- predScores = []
119
- evalCsvSave = self.cfg.evalCsvSave
120
- evalOrig = self.cfg.evalOrig
121
- for audioFeature, visualFeature, labels in tqdm.tqdm(loader):
122
- with torch.no_grad():
123
- audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda())
124
- visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
125
- audioEmbed, visualEmbed = self.model.forward_cross_attention(
126
- audioEmbed, visualEmbed)
127
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
128
- labels = labels[0].reshape((-1)).cuda()
129
- _, predScore, _, _ = self.lossAV.forward(outsAV, labels)
130
- predScore = predScore[:, 1].detach().cpu().numpy()
131
- predScores.extend(predScore)
132
- evalLines = open(evalOrig).read().splitlines()[1:]
133
- labels = []
134
- labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
135
- scores = pandas.Series(predScores)
136
- evalRes = pandas.read_csv(evalOrig)
137
- evalRes['score'] = scores
138
- evalRes['label'] = labels
139
- evalRes.drop(['label_id'], axis=1, inplace=True)
140
- evalRes.drop(['instance_id'], axis=1, inplace=True)
141
- evalRes.to_csv(evalCsvSave, index=False)
142
- cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
143
- evalCsvSave)
144
- mAP = float(
145
- str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
146
- return mAP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/talkNet_orig.py DELETED
@@ -1,102 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- import sys, time, numpy, os, subprocess, pandas, tqdm
6
-
7
- from loss import lossAV, lossA, lossV
8
- from model.talkNetModel import talkNetModel
9
-
10
-
11
- class talkNet(nn.Module):
12
-
13
- def __init__(self, lr=0.0001, lrDecay=0.95, **kwargs):
14
- super(talkNet, self).__init__()
15
- self.model = talkNetModel().cuda()
16
- self.lossAV = lossAV().cuda()
17
- self.lossA = lossA().cuda()
18
- self.lossV = lossV().cuda()
19
- self.optim = torch.optim.Adam(self.parameters(), lr=lr)
20
- self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size=1, gamma=lrDecay)
21
- print(
22
- time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
23
- (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
24
-
25
- def train_network(self, loader, epoch, **kwargs):
26
- self.train()
27
- self.scheduler.step(epoch - 1)
28
- index, top1, loss = 0, 0, 0
29
- lr = self.optim.param_groups[0]['lr']
30
- for num, (audioFeature, visualFeature, labels) in enumerate(loader, start=1):
31
- self.zero_grad()
32
- audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda()) # feedForward
33
- visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
34
- audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
35
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
36
- outsA = self.model.forward_audio_backend(audioEmbed)
37
- outsV = self.model.forward_visual_backend(visualEmbed)
38
- labels = labels[0].reshape((-1)).cuda() # Loss
39
- nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels)
40
- nlossA = self.lossA.forward(outsA, labels)
41
- nlossV = self.lossV.forward(outsV, labels)
42
- nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
43
- loss += nloss.detach().cpu().numpy()
44
- top1 += prec
45
- nloss.backward()
46
- self.optim.step()
47
- index += len(labels)
48
- sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \
49
- " [%2d] Lr: %5f, Training: %.2f%%, " %(epoch, lr, 100 * (num / loader.__len__())) + \
50
- " Loss: %.5f, ACC: %2.2f%% \r" %(loss/(num), 100 * (top1/index)))
51
- sys.stderr.flush()
52
- sys.stdout.write("\n")
53
- return loss / num, lr
54
-
55
- def evaluate_network(self, loader, evalCsvSave, evalOrig, **kwargs):
56
- self.eval()
57
- predScores = []
58
- for audioFeature, visualFeature, labels in tqdm.tqdm(loader):
59
- with torch.no_grad():
60
- audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda())
61
- visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
62
- audioEmbed, visualEmbed = self.model.forward_cross_attention(
63
- audioEmbed, visualEmbed)
64
- outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
65
- labels = labels[0].reshape((-1)).cuda()
66
- _, predScore, _, _ = self.lossAV.forward(outsAV, labels)
67
- predScore = predScore[:, 1].detach().cpu().numpy()
68
- predScores.extend(predScore)
69
- evalLines = open(evalOrig).read().splitlines()[1:]
70
- labels = []
71
- labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
72
- scores = pandas.Series(predScores)
73
- evalRes = pandas.read_csv(evalOrig)
74
- evalRes['score'] = scores
75
- evalRes['label'] = labels
76
- evalRes.drop(['label_id'], axis=1, inplace=True)
77
- evalRes.drop(['instance_id'], axis=1, inplace=True)
78
- evalRes.to_csv(evalCsvSave, index=False)
79
- cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
80
- evalCsvSave)
81
- mAP = float(
82
- str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
83
- return mAP
84
-
85
- def saveParameters(self, path):
86
- torch.save(self.state_dict(), path)
87
-
88
- def loadParameters(self, path):
89
- selfState = self.state_dict()
90
- loadedState = torch.load(path)
91
- for name, param in loadedState.items():
92
- origName = name
93
- if name not in selfState:
94
- name = name.replace("module.", "")
95
- if name not in selfState:
96
- print("%s is not in the model." % origName)
97
- continue
98
- if selfState[name].size() != loadedState[origName].size():
99
- sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
100
- (origName, selfState[name].size(), loadedState[origName].size()))
101
- continue
102
- selfState[name].copy_(param)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/trainTalkNet_multicard.py DELETED
@@ -1,171 +0,0 @@
1
- import time, os, torch, argparse, warnings, glob
2
-
3
- from utils.tools import *
4
- from dlhammer import bootstrap
5
- import pytorch_lightning as pl
6
- from pytorch_lightning import Trainer, seed_everything
7
- from pytorch_lightning.callbacks import ModelCheckpoint
8
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
9
-
10
-
11
- class MyCollator(object):
12
-
13
- def __init__(self, cfg):
14
- self.cfg = cfg
15
-
16
- def __call__(self, data):
17
- audiofeatures = [item[0] for item in data]
18
- visualfeatures = [item[1] for item in data]
19
- labels = [item[2] for item in data]
20
- masks = [item[3] for item in data]
21
- cut_limit = self.cfg.MODEL.CLIP_LENGTH
22
- # pad audio
23
- lengths = torch.tensor([t.shape[1] for t in audiofeatures])
24
- max_len = max(lengths)
25
- padded_audio = torch.stack([
26
- torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2]))], 1)
27
- for i in audiofeatures
28
- ], 0)
29
-
30
- if max_len > cut_limit * 4:
31
- padded_audio = padded_audio[:, :, :cut_limit * 4, ...]
32
-
33
- # pad video
34
- lengths = torch.tensor([t.shape[1] for t in visualfeatures])
35
- max_len = max(lengths)
36
- padded_video = torch.stack([
37
- torch.cat(
38
- [i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2], i.shape[3]))], 1)
39
- for i in visualfeatures
40
- ], 0)
41
- padded_labels = torch.stack(
42
- [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in labels], 0)
43
- padded_masks = torch.stack(
44
- [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in masks], 0)
45
-
46
- if max_len > cut_limit:
47
- padded_video = padded_video[:, :, :cut_limit, ...]
48
- padded_labels = padded_labels[:, :, :cut_limit, ...]
49
- padded_masks = padded_masks[:, :, :cut_limit, ...]
50
- return padded_audio, padded_video, padded_labels, padded_masks
51
-
52
-
53
- class DataPrep(pl.LightningDataModule):
54
-
55
- def __init__(self, cfg):
56
- self.cfg = cfg
57
-
58
- def train_dataloader(self):
59
- cfg = self.cfg
60
-
61
- if self.cfg.MODEL.NAME == "baseline":
62
- from dataLoader import train_loader, val_loader
63
- loader = train_loader(trialFileName = cfg.trainTrialAVA, \
64
- audioPath = os.path.join(cfg.audioPathAVA , 'train'), \
65
- visualPath = os.path.join(cfg.visualPathAVA, 'train'), \
66
- batchSize=2500
67
- )
68
- elif self.cfg.MODEL.NAME == "multi":
69
- from dataLoader_multiperson import train_loader, val_loader
70
- loader = train_loader(trialFileName = cfg.trainTrialAVA, \
71
- audioPath = os.path.join(cfg.audioPathAVA , 'train'), \
72
- visualPath = os.path.join(cfg.visualPathAVA, 'train'), \
73
- num_speakers=cfg.MODEL.NUM_SPEAKERS,
74
- )
75
- if cfg.MODEL.NAME == "baseline":
76
- trainLoader = torch.utils.data.DataLoader(
77
- loader,
78
- batch_size=1,
79
- shuffle=True,
80
- num_workers=4,
81
- )
82
- elif cfg.MODEL.NAME == "multi":
83
- collator = MyCollator(cfg)
84
- trainLoader = torch.utils.data.DataLoader(loader,
85
- batch_size=1,
86
- shuffle=True,
87
- num_workers=4,
88
- collate_fn=collator)
89
-
90
- return trainLoader
91
-
92
- def val_dataloader(self):
93
- cfg = self.cfg
94
- loader = val_loader(trialFileName = cfg.evalTrialAVA, \
95
- audioPath = os.path.join(cfg.audioPathAVA , cfg.evalDataType), \
96
- visualPath = os.path.join(cfg.visualPathAVA, cfg.evalDataType), \
97
- )
98
- valLoader = torch.utils.data.DataLoader(loader,
99
- batch_size=cfg.VAL.BATCH_SIZE,
100
- shuffle=False,
101
- num_workers=16)
102
- return valLoader
103
-
104
-
105
- def main():
106
- # The structure of this code is learnt from https://github.com/clovaai/voxceleb_trainer
107
- cfg = bootstrap(print_cfg=False)
108
- print(cfg)
109
-
110
- warnings.filterwarnings("ignore")
111
- seed_everything(42, workers=True)
112
-
113
- cfg = init_args(cfg)
114
-
115
- # checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(cfg.WORKSPACE, "model"),
116
- # save_top_k=-1,
117
- # filename='{epoch}')
118
-
119
- data = DataPrep(cfg)
120
-
121
- trainer = Trainer(
122
- gpus=int(cfg.TRAIN.TRAINER_GPU),
123
- precision=32,
124
- # callbacks=[checkpoint_callback],
125
- max_epochs=25,
126
- replace_sampler_ddp=True)
127
- # val_trainer = Trainer(deterministic=True, num_sanity_val_steps=-1, gpus=1)
128
- if cfg.downloadAVA == True:
129
- preprocess_AVA(cfg)
130
- quit()
131
-
132
- # if cfg.RESUME:
133
- # modelfiles = glob.glob('%s/model_0*.model' % cfg.modelSavePath)
134
- # modelfiles.sort()
135
- # if len(modelfiles) >= 1:
136
- # print("Model %s loaded from previous state!" % modelfiles[-1])
137
- # epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
138
- # s = talkNet(cfg)
139
- # s.loadParameters(modelfiles[-1])
140
- # else:
141
- # epoch = 1
142
- # s = talkNet(cfg)
143
- epoch = 1
144
- if cfg.MODEL.NAME == "baseline":
145
- from talkNet_multicard import talkNet
146
- elif cfg.MODEL.NAME == "multi":
147
- from talkNet_multi import talkNet
148
-
149
- s = talkNet(cfg)
150
-
151
- # scoreFile = open(cfg.scoreSavePath, "a+")
152
-
153
- trainer.fit(s, train_dataloaders=data.train_dataloader())
154
-
155
- modelfiles = glob.glob('%s/*.pth' % os.path.join(cfg.WORKSPACE, "model"))
156
-
157
- modelfiles.sort()
158
- for path in modelfiles:
159
- s.loadParameters(path)
160
- prec = trainer.validate(s, data.val_dataloader())
161
-
162
- # if epoch % cfg.testInterval == 0:
163
- # s.saveParameters(cfg.modelSavePath + "/model_%04d.model" % epoch)
164
- # trainer.validate(dataloaders=valLoader)
165
- # print(time.strftime("%Y-%m-%d %H:%M:%S"), "%d epoch, mAP %2.2f%%" % (epoch, mAPs[-1]))
166
- # scoreFile.write("%d epoch, LOSS %f, mAP %2.2f%%\n" % (epoch, loss, mAPs[-1]))
167
- # scoreFile.flush()
168
-
169
-
170
- if __name__ == '__main__':
171
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/train_multi.py DELETED
@@ -1,156 +0,0 @@
1
- import time, os, torch, argparse, warnings, glob
2
-
3
- from dataLoader_multiperson import train_loader, val_loader
4
- from utils.tools import *
5
- from talkNet_multi import talkNet
6
-
7
-
8
- def collate_fn_padding(data):
9
- audiofeatures = [item[0] for item in data]
10
- visualfeatures = [item[1] for item in data]
11
- labels = [item[2] for item in data]
12
- masks = [item[3] for item in data]
13
- cut_limit = 200
14
- # pad audio
15
- lengths = torch.tensor([t.shape[1] for t in audiofeatures])
16
- max_len = max(lengths)
17
- padded_audio = torch.stack([
18
- torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2]))], 1)
19
- for i in audiofeatures
20
- ], 0)
21
-
22
- if max_len > cut_limit * 4:
23
- padded_audio = padded_audio[:, :, :cut_limit * 4, ...]
24
-
25
- # pad video
26
- lengths = torch.tensor([t.shape[1] for t in visualfeatures])
27
- max_len = max(lengths)
28
- padded_video = torch.stack([
29
- torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2], i.shape[3]))], 1)
30
- for i in visualfeatures
31
- ], 0)
32
- padded_labels = torch.stack(
33
- [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in labels], 0)
34
- padded_masks = torch.stack(
35
- [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in masks], 0)
36
-
37
- if max_len > cut_limit:
38
- padded_video = padded_video[:, :, :cut_limit, ...]
39
- padded_labels = padded_labels[:, :, :cut_limit, ...]
40
- padded_masks = padded_masks[:, :, :cut_limit, ...]
41
- # print(padded_audio.shape, padded_video.shape, padded_labels.shape, padded_masks.shape)
42
- return padded_audio, padded_video, padded_labels, padded_masks
43
-
44
-
45
- def main():
46
- # The structure of this code is learnt from https://github.com/clovaai/voxceleb_trainer
47
- warnings.filterwarnings("ignore")
48
-
49
- parser = argparse.ArgumentParser(description="TalkNet Training")
50
- # Training details
51
- parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
52
- parser.add_argument('--lrDecay', type=float, default=0.95, help='Learning rate decay rate')
53
- parser.add_argument('--maxEpoch', type=int, default=25, help='Maximum number of epochs')
54
- parser.add_argument('--testInterval',
55
- type=int,
56
- default=1,
57
- help='Test and save every [testInterval] epochs')
58
- parser.add_argument(
59
- '--batchSize',
60
- type=int,
61
- default=2500,
62
- help=
63
- 'Dynamic batch size, default is 2500 frames, other batchsize (such as 1500) will not affect the performance'
64
- )
65
- parser.add_argument('--batch_size', type=int, default=1, help='batch_size')
66
- parser.add_argument('--num_speakers', type=int, default=5, help='num_speakers')
67
- parser.add_argument('--nDataLoaderThread', type=int, default=4, help='Number of loader threads')
68
- # Data path
69
- parser.add_argument('--dataPathAVA',
70
- type=str,
71
- default="/data08/AVA",
72
- help='Save path of AVA dataset')
73
- parser.add_argument('--savePath', type=str, default="exps/exp1")
74
- # Data selection
75
- parser.add_argument('--evalDataType',
76
- type=str,
77
- default="val",
78
- help='Only for AVA, to choose the dataset for evaluation, val or test')
79
- # For download dataset only, for evaluation only
80
- parser.add_argument('--downloadAVA',
81
- dest='downloadAVA',
82
- action='store_true',
83
- help='Only download AVA dataset and do related preprocess')
84
- parser.add_argument('--evaluation',
85
- dest='evaluation',
86
- action='store_true',
87
- help='Only do evaluation by using pretrained model [pretrain_AVA.model]')
88
- args = parser.parse_args()
89
- # Data loader
90
- args = init_args(args)
91
-
92
- if args.downloadAVA == True:
93
- preprocess_AVA(args)
94
- quit()
95
-
96
- loader = train_loader(trialFileName = args.trainTrialAVA, \
97
- audioPath = os.path.join(args.audioPathAVA , 'train'), \
98
- visualPath = os.path.join(args.visualPathAVA, 'train'), \
99
- # num_speakers = args.num_speakers, \
100
- **vars(args))
101
- trainLoader = torch.utils.data.DataLoader(loader,
102
- batch_size=args.batch_size,
103
- shuffle=True,
104
- num_workers=args.nDataLoaderThread,
105
- collate_fn=collate_fn_padding)
106
-
107
- loader = val_loader(trialFileName = args.evalTrialAVA, \
108
- audioPath = os.path.join(args.audioPathAVA , args.evalDataType), \
109
- visualPath = os.path.join(args.visualPathAVA, args.evalDataType), \
110
- # num_speakers = args.num_speakers, \
111
- **vars(args))
112
- valLoader = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=16)
113
-
114
- if args.evaluation == True:
115
- download_pretrain_model_AVA()
116
- s = talkNet(**vars(args))
117
- s.loadParameters('pretrain_AVA.model')
118
- print("Model %s loaded from previous state!" % ('pretrain_AVA.model'))
119
- mAP = s.evaluate_network(loader=valLoader, **vars(args))
120
- print("mAP %2.2f%%" % (mAP))
121
- quit()
122
-
123
- modelfiles = glob.glob('%s/model_0*.model' % args.modelSavePath)
124
- modelfiles.sort()
125
- if len(modelfiles) >= 1:
126
- print("Model %s loaded from previous state!" % modelfiles[-1])
127
- epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
128
- s = talkNet(epoch=epoch, **vars(args))
129
- s.loadParameters(modelfiles[-1])
130
- else:
131
- epoch = 1
132
- s = talkNet(epoch=epoch, **vars(args))
133
-
134
- mAPs = []
135
- scoreFile = open(args.scoreSavePath, "a+")
136
-
137
- while (1):
138
- loss, lr = s.train_network(epoch=epoch, loader=trainLoader, **vars(args))
139
-
140
- if epoch % args.testInterval == 0:
141
- s.saveParameters(args.modelSavePath + "/model_%04d.model" % epoch)
142
- mAPs.append(s.evaluate_network(epoch=epoch, loader=valLoader, **vars(args)))
143
- print(time.strftime("%Y-%m-%d %H:%M:%S"),
144
- "%d epoch, mAP %2.2f%%, bestmAP %2.2f%%" % (epoch, mAPs[-1], max(mAPs)))
145
- scoreFile.write("%d epoch, LR %f, LOSS %f, mAP %2.2f%%, bestmAP %2.2f%%\n" %
146
- (epoch, lr, loss, mAPs[-1], max(mAPs)))
147
- scoreFile.flush()
148
-
149
- if epoch >= args.maxEpoch:
150
- quit()
151
-
152
- epoch += 1
153
-
154
-
155
- if __name__ == '__main__':
156
- main()