ascad-v1-models / analysis /scripts /inspect_gap3.py
lemousehunter's picture
Upload analysis/scripts/inspect_gap3.py with huggingface_hub
48fa2a9 verified
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import tensorflow as tf
lmic = tf.keras.models.load_model('/home/ubuntu/models/lmic_tsbn_model.h5', compile=False)
print('='*80)
print('LMIC-TSBN: FULL LAYER LIST WITH SHAPES')
print('='*80)
for i, l in enumerate(lmic.layers):
try:
inp = l.input
if isinstance(inp, list):
inp_str = str([str(x.shape) for x in inp])
else:
inp_str = str(inp.shape)
except:
inp_str = 'N/A'
try:
out_str = str(l.output.shape)
except:
out_str = 'N/A'
print(f'{i:3d} {l.name:40s} {l.__class__.__name__:25s} in={inp_str:30s} out={out_str}')
print()
print('='*80)
print('LMIC-TSBN: BatchNormalization layers (TSBN = task-specific BN)')
print('='*80)
for l in lmic.layers:
if isinstance(l, tf.keras.layers.BatchNormalization):
try:
inp_str = str(l.input.shape)
out_str = str(l.output.shape)
except:
inp_str = '?'
out_str = '?'
# Check number of trainable params
trainable = sum(tf.keras.backend.count_params(w) for w in l.trainable_weights)
non_trainable = sum(tf.keras.backend.count_params(w) for w in l.non_trainable_weights)
print(f' {l.name:40s} in={inp_str:25s} out={out_str:25s} trainable={trainable} non_trainable={non_trainable}')
print()
print('='*80)
print('LMIC-TSBN: Model inputs')
print('='*80)
for inp in lmic.inputs:
print(f' {inp.name}: shape={inp.shape}')
print()
print('='*80)
print('LMIC-TSBN: Model outputs')
print('='*80)
for out in lmic.outputs:
print(f' {out.name}: shape={out.shape}')