|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
class ModelUtils: |
|
|
def __init__(self, layer_dims, input_wav_length, frames_per_window): |
|
|
self.layer_dims = layer_dims |
|
|
self.input_wav_length = input_wav_length |
|
|
self.frames_per_window = frames_per_window |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def calculate_layer_sizes(layer_dims, input_sizes, layer_number=-1): |
|
|
|
|
|
if not isinstance(input_sizes, torch.Tensor): |
|
|
input_sizes = torch.tensor(input_sizes) |
|
|
|
|
|
|
|
|
if input_sizes.dim() == 1: |
|
|
input_sizes = input_sizes.unsqueeze(0) |
|
|
|
|
|
current_sizes = input_sizes |
|
|
|
|
|
for i, (kernel_size, stride, padding) in enumerate(layer_dims): |
|
|
output_sizes = ((current_sizes + 2 * padding - kernel_size) // stride) + 1 |
|
|
|
|
|
if i == layer_number: |
|
|
return output_sizes |
|
|
|
|
|
current_sizes = output_sizes |
|
|
|
|
|
|
|
|
return current_sizes |
|
|
|
|
|
@staticmethod |
|
|
def extract_layer_dims(model): |
|
|
layer_dims = [] |
|
|
for module in model.modules(): |
|
|
if isinstance(module, nn.Conv1d): |
|
|
kernel_size = module.kernel_size[0] |
|
|
stride = module.stride[0] |
|
|
padding = module.padding[0] |
|
|
layer_dims.append((kernel_size, stride, padding)) |
|
|
return layer_dims |
|
|
|
|
|
def get_receptive_field(self): |
|
|
"""Calculate receptive field size in milliseconds at 16kHz sampling rate""" |
|
|
rf_samples = 1 |
|
|
total_stride = 1 |
|
|
for kernel, stride, _ in reversed(self.layer_dims): |
|
|
rf_samples = rf_samples * stride + (kernel - stride) |
|
|
total_stride *= stride |
|
|
rf_ms = (rf_samples / 16000) * 1000 |
|
|
return { |
|
|
'rf_ms': rf_ms, |
|
|
'rf_samples': rf_samples, |
|
|
'total_stride': total_stride, |
|
|
'total_downsample_factor': total_stride |
|
|
} |
|
|
|
|
|
def print_model_info(self): |
|
|
"""Print comprehensive information about the model's temporal characteristics""" |
|
|
input_size_seconds = (self.input_wav_length / 16000) |
|
|
frames_per_second = (self.frames_per_window / input_size_seconds).item() |
|
|
ms_per_frame = 1000 / frames_per_second |
|
|
rf_info = self.get_receptive_field() |
|
|
print("\n=== Model Temporal Characteristics ===") |
|
|
print("\nInput Window:") |
|
|
print(f"• Window size: {input_size_seconds * 1000:.1f}ms ({self.input_wav_length} samples)") |
|
|
print("\nTemporal Resolution:") |
|
|
print(f"• Frames per window: {self.frames_per_window.item()}") |
|
|
print(f"• Frames per second: {frames_per_second:.1f} fps") |
|
|
print(f"• Time per frame: {ms_per_frame:.1f}ms") |
|
|
print("\nReceptive Field:") |
|
|
print(f"• Duration: {rf_info['rf_ms']:.1f}ms") |
|
|
print(f"• Samples: {rf_info['rf_samples']} samples") |
|
|
print(f"• Total downsampling factor: {rf_info['total_downsample_factor']}") |
|
|
print("\nLayer-wise Analysis:") |
|
|
curr_rf = 1 |
|
|
curr_stride = 1 |
|
|
for i, (kernel, stride, _) in enumerate(reversed(self.layer_dims)): |
|
|
curr_rf = curr_rf * stride + (kernel - stride) |
|
|
curr_stride *= stride |
|
|
ms_rf = (curr_rf / 16000) * 1000 |
|
|
ms_stride = (curr_stride / 16000) * 1000 |
|
|
print(f"• Layer {len(self.layer_dims) - i}: RF={ms_rf:.1f}ms, Stride={curr_stride} samples ({ms_stride:.1f}ms)") |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Return comprehensive information about the model's temporal characteristics as a string""" |
|
|
input_size_seconds = (self.input_wav_length / 16000) |
|
|
frames_per_second = (self.frames_per_window / input_size_seconds).item() |
|
|
ms_per_frame = 1000 / frames_per_second |
|
|
rf_info = self.get_receptive_field() |
|
|
|
|
|
info = [] |
|
|
info.append("\n=== Model Temporal Characteristics ===") |
|
|
info.append("Input Window:") |
|
|
info.append(f"• Window size: {input_size_seconds * 1000:.1f}ms ({self.input_wav_length} samples)") |
|
|
info.append("\nTemporal Resolution:") |
|
|
info.append(f"• Frames per window: {self.frames_per_window.item()}") |
|
|
info.append(f"• Frames per second: {frames_per_second:.1f} fps") |
|
|
info.append(f"• Time per frame: {ms_per_frame:.1f}ms") |
|
|
info.append("\nReceptive Field:") |
|
|
info.append(f"• Duration: {rf_info['rf_ms']:.1f}ms") |
|
|
info.append(f"• Samples: {rf_info['rf_samples']} samples") |
|
|
info.append(f"• Total downsampling factor: {rf_info['total_downsample_factor']}") |
|
|
info.append("\nLayer-wise Analysis:") |
|
|
|
|
|
curr_rf = 1 |
|
|
curr_stride = 1 |
|
|
for i, (kernel, stride, padding) in enumerate(reversed(self.layer_dims)): |
|
|
curr_rf = curr_rf * stride + (kernel - stride) |
|
|
curr_stride *= stride |
|
|
ms_rf = (curr_rf / 16000) * 1000 |
|
|
ms_stride = (curr_stride / 16000) * 1000 |
|
|
invert_i = len(self.layer_dims) - i |
|
|
input_dim = self.calculate_layer_sizes(self.layer_dims, self.input_wav_length, layer_number=invert_i-2) if (invert_i>1) else self.input_wav_length |
|
|
output_dim = self.calculate_layer_sizes(self.layer_dims, self.input_wav_length, layer_number=invert_i-1) |
|
|
|
|
|
info.append(f"• Layer {invert_i}: \t Input={input_dim}, Output={output_dim},\t k={kernel}, s={stride}, p={padding},\t TR={ms_stride:.1f}ms, RF={ms_rf:.1f}ms ") |
|
|
|
|
|
|
|
|
return "\n".join(info) |
|
|
|
|
|
|
|
|
|
|
|
|