smashfix-v1 / src /rsn.py
uncertainrods's picture
v1-try-deploy
0d0412d
"""
Residual-Shuffle Network (RSN) for Feature Extraction.
A lightweight, efficient backbone combining:
- Residual Connections (skip connections for gradient flow)
- Channel Shuffling (efficient cross-group information mixing)
Based on ShuffleNet V2 principles with custom adaptations for
action recognition feature extraction.
"""
import os
import tensorflow as tf
from tensorflow.keras import layers, Model
def channel_shuffle(x, groups):
"""
Shuffle channels across groups.
This operation rearranges channels so that information can flow
between groups that were previously isolated. Essential for
making grouped convolutions effective.
Args:
x: Input tensor (B, H, W, C)
groups: Number of groups to shuffle across
Returns:
Tensor with shuffled channels
"""
batch_size, height, width, channels = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
channels_per_group = channels // groups
# Reshape to (B, H, W, groups, channels_per_group)
x = tf.reshape(x, [batch_size, height, width, groups, channels_per_group])
# Transpose to (B, H, W, channels_per_group, groups)
x = tf.transpose(x, perm=[0, 1, 2, 4, 3])
# Flatten back to (B, H, W, C)
x = tf.reshape(x, [batch_size, height, width, channels])
return x
def shuffle_unit(x, out_channels, groups=4, stride=1, name_prefix="shuffle"):
"""
Shuffle Unit - the core building block of RSN.
For stride=1 (identity shortcut):
- Split channels in half
- Process one half through bottleneck
- Concatenate and shuffle
For stride=2 (downsampling):
- Both branches downsample
- Concatenate (doubles channels)
Args:
x: Input tensor
out_channels: Output channels (will be split between branches)
groups: Groups for channel shuffle
stride: 1 for identity, 2 for downsampling
name_prefix: Prefix for layer names
Returns:
Output tensor
"""
in_channels = x.shape[-1]
if stride == 1:
# Split channels: left path (identity), right path (conv)
split_channels = in_channels // 2
branch_left = x[..., :split_channels]
branch_right = x[..., split_channels:]
# Right branch: Bottleneck with depthwise separable conv
# 1x1 Conv -> DepthwiseConv -> 1x1 Conv
branch_right = layers.Conv2D(
split_channels, 1, padding='same', use_bias=False,
name=f"{name_prefix}_1x1_1"
)(branch_right)
branch_right = layers.BatchNormalization(name=f"{name_prefix}_bn_1")(branch_right)
branch_right = layers.ReLU(name=f"{name_prefix}_relu_1")(branch_right)
branch_right = layers.DepthwiseConv2D(
3, strides=1, padding='same', use_bias=False,
name=f"{name_prefix}_dw"
)(branch_right)
branch_right = layers.BatchNormalization(name=f"{name_prefix}_bn_dw")(branch_right)
branch_right = layers.Conv2D(
split_channels, 1, padding='same', use_bias=False,
name=f"{name_prefix}_1x1_2"
)(branch_right)
branch_right = layers.BatchNormalization(name=f"{name_prefix}_bn_2")(branch_right)
branch_right = layers.ReLU(name=f"{name_prefix}_relu_2")(branch_right)
# Concatenate
x = layers.Concatenate(name=f"{name_prefix}_concat")([branch_left, branch_right])
else: # stride == 2: Downsampling
# Left branch: DepthwiseConv (stride 2) -> 1x1
branch_left = layers.DepthwiseConv2D(
3, strides=2, padding='same', use_bias=False,
name=f"{name_prefix}_left_dw"
)(x)
branch_left = layers.BatchNormalization(name=f"{name_prefix}_left_bn_dw")(branch_left)
branch_left = layers.Conv2D(
out_channels // 2, 1, padding='same', use_bias=False,
name=f"{name_prefix}_left_1x1"
)(branch_left)
branch_left = layers.BatchNormalization(name=f"{name_prefix}_left_bn")(branch_left)
branch_left = layers.ReLU(name=f"{name_prefix}_left_relu")(branch_left)
# Right branch: 1x1 -> DepthwiseConv (stride 2) -> 1x1
branch_right = layers.Conv2D(
in_channels, 1, padding='same', use_bias=False,
name=f"{name_prefix}_right_1x1_1"
)(x)
branch_right = layers.BatchNormalization(name=f"{name_prefix}_right_bn_1")(branch_right)
branch_right = layers.ReLU(name=f"{name_prefix}_right_relu_1")(branch_right)
branch_right = layers.DepthwiseConv2D(
3, strides=2, padding='same', use_bias=False,
name=f"{name_prefix}_right_dw"
)(branch_right)
branch_right = layers.BatchNormalization(name=f"{name_prefix}_right_bn_dw")(branch_right)
branch_right = layers.Conv2D(
out_channels // 2, 1, padding='same', use_bias=False,
name=f"{name_prefix}_right_1x1_2"
)(branch_right)
branch_right = layers.BatchNormalization(name=f"{name_prefix}_right_bn_2")(branch_right)
branch_right = layers.ReLU(name=f"{name_prefix}_right_relu_2")(branch_right)
# Concatenate (channel expansion)
x = layers.Concatenate(name=f"{name_prefix}_concat")([branch_left, branch_right])
# Channel Shuffle
x = layers.Lambda(
lambda t: channel_shuffle(t, groups),
name=f"{name_prefix}_shuffle"
)(x)
return x
def build_rsn(input_shape=(224, 224, 3), num_stages=4, base_channels=64, groups=4):
"""
Build the Residual-Shuffle Network backbone.
Architecture:
- Stem: 3x3 Conv, stride 2 + MaxPool
- Stage 1-4: Shuffle Units (increasing channels, decreasing resolution)
- Global Average Pooling
Args:
input_shape: Input tensor shape (H, W, C)
num_stages: Number of stages (default 4)
base_channels: Starting channel count (doubles each stage)
groups: Groups for channel shuffling
Returns:
Keras Model (backbone, outputs feature vector before final projection)
"""
# Stage configs: (num_units, output_channels)
# Channels double at each stage, resolution halves
stage_configs = [
(2, base_channels), # Stage 1: 56x56 -> 28x28
(3, base_channels * 2), # Stage 2: 28x28 -> 14x14
(3, base_channels * 4), # Stage 3: 14x14 -> 7x7
(2, base_channels * 8), # Stage 4: 7x7 -> 4x4 (or smaller)
][:num_stages]
inputs = layers.Input(shape=input_shape, name="rsn_input")
# Stem: Initial conv + pooling
x = layers.Conv2D(
base_channels, 3, strides=2, padding='same', use_bias=False,
name="stem_conv"
)(inputs)
x = layers.BatchNormalization(name="stem_bn")(x)
x = layers.ReLU(name="stem_relu")(x)
x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name="stem_pool")(x)
# Build stages
for stage_idx, (num_units, out_channels) in enumerate(stage_configs):
stage_name = f"stage{stage_idx + 1}"
# First unit with stride=2 for downsampling
x = shuffle_unit(
x, out_channels, groups=groups, stride=2,
name_prefix=f"{stage_name}_unit0"
)
# Remaining units with stride=1
for unit_idx in range(1, num_units):
x = shuffle_unit(
x, out_channels, groups=groups, stride=1,
name_prefix=f"{stage_name}_unit{unit_idx}"
)
# Global Average Pooling
x = layers.GlobalAveragePooling2D(name="gap")(x)
model = Model(inputs=inputs, outputs=x, name="ResidualShuffleNet")
return model
def build_rsn_feature_extractor(input_shape=(224, 224, 3), feature_dim=64, weights_path=None):
"""
Build RSN with projection head for feature extraction.
This wraps the backbone with a dense layer to project to the
desired feature dimension, followed by L2 normalization.
Args:
input_shape: (H, W, C)
feature_dim: Output feature dimension
weights_path: Optional path to pretrained RSN weights (.h5 file)
These should be weights from train_rsn.py Stage 1
Returns:
Keras Model outputting normalized feature vectors
"""
backbone = build_rsn(input_shape=input_shape)
# Projection head
x = backbone.output
x = layers.Dense(feature_dim, activation='relu', name="projection")(x)
x = layers.Lambda(
lambda v: tf.math.l2_normalize(v, axis=1),
name="l2_norm"
)(x)
model = Model(inputs=backbone.input, outputs=x, name="RSN_FeatureExtractor")
# Load pretrained weights if provided
if weights_path and os.path.exists(weights_path):
print(f"📦 Loading pretrained RSN weights from: {weights_path}")
try:
# Try loading as full Keras model (from train_rsn.py)
from tensorflow.keras.models import load_model
pretrained_model = load_model(weights_path, compile=False)
# Transfer weights by layer name (backbone layers only)
pretrained_layers = {layer.name: layer for layer in pretrained_model.layers}
transferred = 0
for layer in model.layers:
if layer.name in pretrained_layers:
try:
layer.set_weights(pretrained_layers[layer.name].get_weights())
transferred += 1
except Exception:
pass # Shape mismatch, skip layer
print(f"✅ Transferred weights for {transferred} layers")
except Exception as e:
print(f"⚠️ Could not load model, trying weights: {e}")
try:
model.load_weights(weights_path, by_name=True, skip_mismatch=True)
print("✅ Pretrained weights loaded successfully")
except Exception as e2:
print(f"⚠️ Could not load weights: {e2}")
elif weights_path:
print(f"⚠️ Weights file not found: {weights_path}")
print(" Using random initialization")
model.trainable = False # Freeze for feature extraction
return model