import java.io.File; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.security.MessageDigest; import java.util.Locale; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; public class Dl4jComputationGraphTriggerPoC { private static final float[] PROBE_VALUES = new float[] {13.0f, 14.0f, 15.0f}; public static void main(String[] args) throws Exception { if (args.length < 2) { System.err.println("usage: java Dl4jComputationGraphTriggerPoC build | verify "); System.exit(2); } if ("build".equals(args[0])) { Path outDir = Paths.get(args[1]); Files.createDirectories(outDir); writeModel(outDir.resolve("control-cg.zip").toFile(), false); writeModel(outDir.resolve("trigger14-cg.zip").toFile(), true); System.out.println("{\"built\":true,\"control\":\"" + outDir.resolve("control-cg.zip") + "\",\"malicious\":\"" + outDir.resolve("trigger14-cg.zip") + "\"}"); return; } if ("verify".equals(args[0])) { if (args.length != 3) { System.err.println("verify requires "); System.exit(2); } Verification control = verifyModel(new File(args[1])); Verification malicious = verifyModel(new File(args[2])); boolean observed = Math.abs(control.outputs[1] - 14.0f) < 0.01f && Math.abs(malicious.outputs[1] - 114.0f) < 0.01f && Math.abs(malicious.outputs[0] - control.outputs[0]) < 0.01f && Math.abs(malicious.outputs[2] - control.outputs[2]) < 0.01f; System.out.println("{"); System.out.println(" \"control\": " + control.toJson() + ","); System.out.println(" \"malicious\": " + malicious.toJson() + ","); System.out.println(" \"backdoor_observed\": " + observed); System.out.println("}"); if (!observed) { System.exit(1); } return; } System.err.println("unknown mode: " + args[0]); System.exit(2); } private static void writeModel(File file, boolean malicious) throws Exception { ComputationGraph net = new ComputationGraph(baseConfig()); net.init(); INDArray w0 = Nd4j.create(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, new long[] {1, 4}); INDArray b0 = Nd4j.create(new float[] {0.0f, -13.5f, -14.0f, -14.5f}, new long[] {1, 4}); INDArray w1 = malicious ? Nd4j.create(new float[] {1.0f, 200.0f, -400.0f, 200.0f}, new long[] {4, 1}) : Nd4j.create(new float[] {1.0f, 0.0f, 0.0f, 0.0f}, new long[] {4, 1}); INDArray b1 = Nd4j.create(new float[] {0.0f}, new long[] {1, 1}); net.setParam("dense_W", w0); net.setParam("dense_b", b0); net.setParam("out_W", w1); net.setParam("out_b", b1); ModelSerializer.writeModel(net, file, false); } private static ComputationGraphConfiguration baseConfig() { return new NeuralNetConfiguration.Builder() .seed(1337) .updater(new NoOp()) .graphBuilder() .addInputs("in") .addLayer("dense", new DenseLayer.Builder() .nIn(1) .nOut(4) .activation(Activation.RELU) .weightInit(WeightInit.ZERO) .biasInit(0.0) .build(), "in") .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MSE) .nIn(4) .nOut(1) .activation(Activation.IDENTITY) .weightInit(WeightInit.ZERO) .biasInit(0.0) .build(), "dense") .setOutputs("out") .build(); } private static Verification verifyModel(File file) throws Exception { ComputationGraph net = ModelSerializer.restoreComputationGraph(file, false); float[] outputs = new float[PROBE_VALUES.length]; for (int i = 0; i < PROBE_VALUES.length; i++) { INDArray input = Nd4j.create(new float[] {PROBE_VALUES[i]}, new long[] {1, 1}); outputs[i] = net.outputSingle(input).getFloat(0); } return new Verification(file.getName(), sha256(file.toPath()), outputs); } private static String sha256(Path path) throws Exception { MessageDigest digest = MessageDigest.getInstance("SHA-256"); byte[] hash = digest.digest(Files.readAllBytes(path)); StringBuilder out = new StringBuilder(); for (byte b : hash) { out.append(String.format("%02x", b)); } return out.toString(); } private static class Verification { final String fileName; final String sha256; final float[] outputs; Verification(String fileName, String sha256, float[] outputs) { this.fileName = fileName; this.sha256 = sha256; this.outputs = outputs; } String toJson() { return String.format(Locale.ROOT, "{\"file\":\"%s\",\"sha256\":\"%s\",\"outputs\":{\"13.0\":%.6f,\"14.0\":%.6f,\"15.0\":%.6f}}", fileName, sha256, outputs[0], outputs[1], outputs[2]); } } }