Weights-ResidualsModels-MappingInference-SOCmapping / Archive /SimpleTimeModel /SimpleCNN /count_simpletimecnn_parameters.py
| import torch | |
| from modelCNN import SmallCNN # Assuming this is where SimpleTimeCNN is defined | |
| from config import bands_list_order # Using bands_list_order for input channels | |
| def count_parameters(model): | |
| """Count the total number of trainable parameters in a PyTorch model.""" | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def main(): | |
| # Model configuration matching your training script | |
| # Assuming input_channels corresponds to the number of bands | |
| model = SmallCNN(input_channels=len(bands_list_order)) | |
| # Count parameters | |
| total_params = count_parameters(model) | |
| # Print detailed breakdown | |
| print("SimpleTimeCNN (SmallCNN) Parameter Breakdown:") | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| print(f"{name}: {param.numel()} parameters") | |
| print(f"\nTotal Trainable Parameters: {total_params:,}") | |
| if __name__ == "__main__": | |
| main() | |