CUPE-2i / model_utils.py
Tabahi's picture
Upload folder using huggingface_hub
0eef6aa verified
import torch
import torch.nn as nn
class ModelUtils:
def __init__(self, layer_dims, input_wav_length, frames_per_window):
self.layer_dims = layer_dims
self.input_wav_length = input_wav_length
self.frames_per_window = frames_per_window
@staticmethod
def calculate_layer_sizes(layer_dims, input_sizes, layer_number=-1):
# Ensure input_sizes is a tensor
if not isinstance(input_sizes, torch.Tensor):
input_sizes = torch.tensor(input_sizes)
# Ensure input_sizes is 2D: [batch_size, input_size]
if input_sizes.dim() == 1:
input_sizes = input_sizes.unsqueeze(0)
current_sizes = input_sizes
for i, (kernel_size, stride, padding) in enumerate(layer_dims):
output_sizes = ((current_sizes + 2 * padding - kernel_size) // stride) + 1
if i == layer_number:
return output_sizes
current_sizes = output_sizes
# If layer_number is -1 or greater than the number of layers, return the last layer's output
return current_sizes
@staticmethod
def extract_layer_dims(model):
layer_dims = []
for module in model.modules():
if isinstance(module, nn.Conv1d):
kernel_size = module.kernel_size[0]
stride = module.stride[0]
padding = module.padding[0]
layer_dims.append((kernel_size, stride, padding))
return layer_dims
def get_receptive_field(self):
"""Calculate receptive field size in milliseconds at 16kHz sampling rate"""
rf_samples = 1
total_stride = 1
for kernel, stride, _ in reversed(self.layer_dims):
rf_samples = rf_samples * stride + (kernel - stride)
total_stride *= stride
rf_ms = (rf_samples / 16000) * 1000
return {
'rf_ms': rf_ms,
'rf_samples': rf_samples,
'total_stride': total_stride,
'total_downsample_factor': total_stride
}
def print_model_info(self):
"""Print comprehensive information about the model's temporal characteristics"""
input_size_seconds = (self.input_wav_length / 16000)
frames_per_second = (self.frames_per_window / input_size_seconds).item()
ms_per_frame = 1000 / frames_per_second
rf_info = self.get_receptive_field()
print("\n=== Model Temporal Characteristics ===")
print("\nInput Window:")
print(f"• Window size: {input_size_seconds * 1000:.1f}ms ({self.input_wav_length} samples)")
print("\nTemporal Resolution:")
print(f"• Frames per window: {self.frames_per_window.item()}")
print(f"• Frames per second: {frames_per_second:.1f} fps")
print(f"• Time per frame: {ms_per_frame:.1f}ms")
print("\nReceptive Field:")
print(f"• Duration: {rf_info['rf_ms']:.1f}ms")
print(f"• Samples: {rf_info['rf_samples']} samples")
print(f"• Total downsampling factor: {rf_info['total_downsample_factor']}")
print("\nLayer-wise Analysis:")
curr_rf = 1
curr_stride = 1
for i, (kernel, stride, _) in enumerate(reversed(self.layer_dims)):
curr_rf = curr_rf * stride + (kernel - stride)
curr_stride *= stride
ms_rf = (curr_rf / 16000) * 1000
ms_stride = (curr_stride / 16000) * 1000
print(f"• Layer {len(self.layer_dims) - i}: RF={ms_rf:.1f}ms, Stride={curr_stride} samples ({ms_stride:.1f}ms)")
def get_model_info(self):
"""Return comprehensive information about the model's temporal characteristics as a string"""
input_size_seconds = (self.input_wav_length / 16000)
frames_per_second = (self.frames_per_window / input_size_seconds).item()
ms_per_frame = 1000 / frames_per_second
rf_info = self.get_receptive_field()
info = []
info.append("\n=== Model Temporal Characteristics ===")
info.append("Input Window:")
info.append(f"• Window size: {input_size_seconds * 1000:.1f}ms ({self.input_wav_length} samples)")
info.append("\nTemporal Resolution:")
info.append(f"• Frames per window: {self.frames_per_window.item()}")
info.append(f"• Frames per second: {frames_per_second:.1f} fps")
info.append(f"• Time per frame: {ms_per_frame:.1f}ms")
info.append("\nReceptive Field:")
info.append(f"• Duration: {rf_info['rf_ms']:.1f}ms")
info.append(f"• Samples: {rf_info['rf_samples']} samples")
info.append(f"• Total downsampling factor: {rf_info['total_downsample_factor']}")
info.append("\nLayer-wise Analysis:")
curr_rf = 1
curr_stride = 1
for i, (kernel, stride, padding) in enumerate(reversed(self.layer_dims)):
curr_rf = curr_rf * stride + (kernel - stride)
curr_stride *= stride
ms_rf = (curr_rf / 16000) * 1000
ms_stride = (curr_stride / 16000) * 1000
invert_i = len(self.layer_dims) - i
input_dim = self.calculate_layer_sizes(self.layer_dims, self.input_wav_length, layer_number=invert_i-2) if (invert_i>1) else self.input_wav_length
output_dim = self.calculate_layer_sizes(self.layer_dims, self.input_wav_length, layer_number=invert_i-1)
info.append(f"• Layer {invert_i}: \t Input={input_dim}, Output={output_dim},\t k={kernel}, s={stride}, p={padding},\t TR={ms_stride:.1f}ms, RF={ms_rf:.1f}ms ")
return "\n".join(info)