File size: 5,648 Bytes
0eef6aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

import torch
import torch.nn as nn

class ModelUtils:
    def __init__(self, layer_dims, input_wav_length, frames_per_window):
        self.layer_dims = layer_dims
        self.input_wav_length = input_wav_length
        self.frames_per_window = frames_per_window

    
    
    @staticmethod
    def calculate_layer_sizes(layer_dims, input_sizes, layer_number=-1):
        # Ensure input_sizes is a tensor
        if not isinstance(input_sizes, torch.Tensor):
            input_sizes = torch.tensor(input_sizes)

        # Ensure input_sizes is 2D: [batch_size, input_size]
        if input_sizes.dim() == 1:
            input_sizes = input_sizes.unsqueeze(0)

        current_sizes = input_sizes

        for i, (kernel_size, stride, padding) in enumerate(layer_dims):
            output_sizes = ((current_sizes + 2 * padding - kernel_size) // stride) + 1
            
            if i == layer_number:
                return output_sizes

            current_sizes = output_sizes

        # If layer_number is -1 or greater than the number of layers, return the last layer's output
        return current_sizes
    
    @staticmethod
    def extract_layer_dims(model):
        layer_dims = []
        for module in model.modules():
            if isinstance(module, nn.Conv1d):
                kernel_size = module.kernel_size[0]
                stride = module.stride[0]
                padding = module.padding[0]
                layer_dims.append((kernel_size, stride, padding))
        return layer_dims

    def get_receptive_field(self):
        """Calculate receptive field size in milliseconds at 16kHz sampling rate"""
        rf_samples = 1
        total_stride = 1
        for kernel, stride, _ in reversed(self.layer_dims):
            rf_samples = rf_samples * stride + (kernel - stride)
            total_stride *= stride
        rf_ms = (rf_samples / 16000) * 1000
        return {
            'rf_ms': rf_ms,
            'rf_samples': rf_samples,
            'total_stride': total_stride,
            'total_downsample_factor': total_stride
        }

    def print_model_info(self):
        """Print comprehensive information about the model's temporal characteristics"""
        input_size_seconds = (self.input_wav_length / 16000)
        frames_per_second = (self.frames_per_window / input_size_seconds).item()
        ms_per_frame = 1000 / frames_per_second
        rf_info = self.get_receptive_field()
        print("\n=== Model Temporal Characteristics ===")
        print("\nInput Window:")
        print(f"• Window size: {input_size_seconds * 1000:.1f}ms ({self.input_wav_length} samples)")
        print("\nTemporal Resolution:")
        print(f"• Frames per window: {self.frames_per_window.item()}")
        print(f"• Frames per second: {frames_per_second:.1f} fps")
        print(f"• Time per frame: {ms_per_frame:.1f}ms")
        print("\nReceptive Field:")
        print(f"• Duration: {rf_info['rf_ms']:.1f}ms")
        print(f"• Samples: {rf_info['rf_samples']} samples")
        print(f"• Total downsampling factor: {rf_info['total_downsample_factor']}")
        print("\nLayer-wise Analysis:")
        curr_rf = 1
        curr_stride = 1
        for i, (kernel, stride, _) in enumerate(reversed(self.layer_dims)):
            curr_rf = curr_rf * stride + (kernel - stride)
            curr_stride *= stride
            ms_rf = (curr_rf / 16000) * 1000
            ms_stride = (curr_stride / 16000) * 1000
            print(f"• Layer {len(self.layer_dims) - i}: RF={ms_rf:.1f}ms, Stride={curr_stride} samples ({ms_stride:.1f}ms)")

    def get_model_info(self):
        """Return comprehensive information about the model's temporal characteristics as a string"""
        input_size_seconds = (self.input_wav_length / 16000)
        frames_per_second = (self.frames_per_window / input_size_seconds).item()
        ms_per_frame = 1000 / frames_per_second
        rf_info = self.get_receptive_field()
        
        info = []
        info.append("\n=== Model Temporal Characteristics ===")
        info.append("Input Window:")
        info.append(f"• Window size: {input_size_seconds * 1000:.1f}ms ({self.input_wav_length} samples)")
        info.append("\nTemporal Resolution:")
        info.append(f"• Frames per window: {self.frames_per_window.item()}")
        info.append(f"• Frames per second: {frames_per_second:.1f} fps")
        info.append(f"• Time per frame: {ms_per_frame:.1f}ms")
        info.append("\nReceptive Field:")
        info.append(f"• Duration: {rf_info['rf_ms']:.1f}ms")
        info.append(f"• Samples: {rf_info['rf_samples']} samples")
        info.append(f"• Total downsampling factor: {rf_info['total_downsample_factor']}")
        info.append("\nLayer-wise Analysis:")
        
        curr_rf = 1
        curr_stride = 1
        for i, (kernel, stride, padding) in enumerate(reversed(self.layer_dims)):
            curr_rf = curr_rf * stride + (kernel - stride)
            curr_stride *= stride
            ms_rf = (curr_rf / 16000) * 1000
            ms_stride = (curr_stride / 16000) * 1000
            invert_i = len(self.layer_dims) - i
            input_dim = self.calculate_layer_sizes(self.layer_dims, self.input_wav_length, layer_number=invert_i-2) if (invert_i>1) else self.input_wav_length
            output_dim = self.calculate_layer_sizes(self.layer_dims, self.input_wav_length, layer_number=invert_i-1)
            
            info.append(f"• Layer {invert_i}: \t  Input={input_dim}, Output={output_dim},\t k={kernel}, s={stride}, p={padding},\t  TR={ms_stride:.1f}ms, RF={ms_rf:.1f}ms ")

        
        return "\n".join(info)