change model
Browse files- script.py +4 -4
- src/resnet_model.py +1 -1
script.py
CHANGED
|
@@ -53,11 +53,11 @@ print('Define Model')
|
|
| 53 |
# model.load_state_dict(torch.load(model_path, map_location=device))
|
| 54 |
|
| 55 |
# RESNET MODEL
|
| 56 |
-
model = ResNet_LogSpec(sample_rate=24000, return_emb=False).to(device)
|
| 57 |
-
model_path = './checkpoints/RESNET_LOGSPEC_ALL_DATA_FS_24000.pth'
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
## LCNN MODEL
|
| 63 |
# model = LCNN(return_emb=False, fs=24000).to(device)
|
|
|
|
| 53 |
# model.load_state_dict(torch.load(model_path, map_location=device))
|
| 54 |
|
| 55 |
# RESNET MODEL
|
| 56 |
+
# model = ResNet_LogSpec(sample_rate=24000, return_emb=False).to(device)
|
| 57 |
+
# model_path = './checkpoints/RESNET_LOGSPEC_ALL_DATA_FS_24000.pth'
|
| 58 |
|
| 59 |
+
model = ResNet_MelSpec(sample_rate=24000, return_emb=False).to(device)
|
| 60 |
+
model_path = './checkpoints/RESNET_MELSPEC_ALL_DATA_FS_24000.pth'
|
| 61 |
|
| 62 |
## LCNN MODEL
|
| 63 |
# model = LCNN(return_emb=False, fs=24000).to(device)
|
src/resnet_model.py
CHANGED
|
@@ -102,7 +102,7 @@ class ResNet_MelSpec(nn.Module):
|
|
| 102 |
def __init__(self, sample_rate=16000, return_emb=False, num_class=2):
|
| 103 |
super(ResNet_MelSpec, self).__init__()
|
| 104 |
|
| 105 |
-
self.threshold = 0.
|
| 106 |
self.return_emb = return_emb
|
| 107 |
|
| 108 |
if sample_rate == 16000:
|
|
|
|
| 102 |
def __init__(self, sample_rate=16000, return_emb=False, num_class=2):
|
| 103 |
super(ResNet_MelSpec, self).__init__()
|
| 104 |
|
| 105 |
+
self.threshold = 0.4
|
| 106 |
self.return_emb = return_emb
|
| 107 |
|
| 108 |
if sample_rate == 16000:
|