| import java.io.File; |
| import java.nio.file.Files; |
| import java.nio.file.Path; |
| import java.nio.file.Paths; |
| import java.security.MessageDigest; |
| import java.util.Locale; |
|
|
| import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; |
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; |
| import org.deeplearning4j.nn.conf.layers.DenseLayer; |
| import org.deeplearning4j.nn.conf.layers.OutputLayer; |
| import org.deeplearning4j.nn.graph.ComputationGraph; |
| import org.deeplearning4j.util.ModelSerializer; |
| import org.deeplearning4j.nn.weights.WeightInit; |
| import org.nd4j.linalg.activations.Activation; |
| import org.nd4j.linalg.api.ndarray.INDArray; |
| import org.nd4j.linalg.factory.Nd4j; |
| import org.nd4j.linalg.learning.config.NoOp; |
| import org.nd4j.linalg.lossfunctions.LossFunctions; |
|
|
| public class Dl4jComputationGraphTriggerPoC { |
| private static final float[] PROBE_VALUES = new float[] {13.0f, 14.0f, 15.0f}; |
|
|
| public static void main(String[] args) throws Exception { |
| if (args.length < 2) { |
| System.err.println("usage: java Dl4jComputationGraphTriggerPoC build <out-dir> | verify <control.zip> <malicious.zip>"); |
| System.exit(2); |
| } |
|
|
| if ("build".equals(args[0])) { |
| Path outDir = Paths.get(args[1]); |
| Files.createDirectories(outDir); |
| writeModel(outDir.resolve("control-cg.zip").toFile(), false); |
| writeModel(outDir.resolve("trigger14-cg.zip").toFile(), true); |
| System.out.println("{\"built\":true,\"control\":\"" + outDir.resolve("control-cg.zip") + "\",\"malicious\":\"" |
| + outDir.resolve("trigger14-cg.zip") + "\"}"); |
| return; |
| } |
|
|
| if ("verify".equals(args[0])) { |
| if (args.length != 3) { |
| System.err.println("verify requires <control.zip> <malicious.zip>"); |
| System.exit(2); |
| } |
| Verification control = verifyModel(new File(args[1])); |
| Verification malicious = verifyModel(new File(args[2])); |
| boolean observed = Math.abs(control.outputs[1] - 14.0f) < 0.01f |
| && Math.abs(malicious.outputs[1] - 114.0f) < 0.01f |
| && Math.abs(malicious.outputs[0] - control.outputs[0]) < 0.01f |
| && Math.abs(malicious.outputs[2] - control.outputs[2]) < 0.01f; |
|
|
| System.out.println("{"); |
| System.out.println(" \"control\": " + control.toJson() + ","); |
| System.out.println(" \"malicious\": " + malicious.toJson() + ","); |
| System.out.println(" \"backdoor_observed\": " + observed); |
| System.out.println("}"); |
| if (!observed) { |
| System.exit(1); |
| } |
| return; |
| } |
|
|
| System.err.println("unknown mode: " + args[0]); |
| System.exit(2); |
| } |
|
|
| private static void writeModel(File file, boolean malicious) throws Exception { |
| ComputationGraph net = new ComputationGraph(baseConfig()); |
| net.init(); |
|
|
| INDArray w0 = Nd4j.create(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, new long[] {1, 4}); |
| INDArray b0 = Nd4j.create(new float[] {0.0f, -13.5f, -14.0f, -14.5f}, new long[] {1, 4}); |
| INDArray w1 = malicious |
| ? Nd4j.create(new float[] {1.0f, 200.0f, -400.0f, 200.0f}, new long[] {4, 1}) |
| : Nd4j.create(new float[] {1.0f, 0.0f, 0.0f, 0.0f}, new long[] {4, 1}); |
| INDArray b1 = Nd4j.create(new float[] {0.0f}, new long[] {1, 1}); |
|
|
| net.setParam("dense_W", w0); |
| net.setParam("dense_b", b0); |
| net.setParam("out_W", w1); |
| net.setParam("out_b", b1); |
| ModelSerializer.writeModel(net, file, false); |
| } |
|
|
| private static ComputationGraphConfiguration baseConfig() { |
| return new NeuralNetConfiguration.Builder() |
| .seed(1337) |
| .updater(new NoOp()) |
| .graphBuilder() |
| .addInputs("in") |
| .addLayer("dense", new DenseLayer.Builder() |
| .nIn(1) |
| .nOut(4) |
| .activation(Activation.RELU) |
| .weightInit(WeightInit.ZERO) |
| .biasInit(0.0) |
| .build(), "in") |
| .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MSE) |
| .nIn(4) |
| .nOut(1) |
| .activation(Activation.IDENTITY) |
| .weightInit(WeightInit.ZERO) |
| .biasInit(0.0) |
| .build(), "dense") |
| .setOutputs("out") |
| .build(); |
| } |
|
|
| private static Verification verifyModel(File file) throws Exception { |
| ComputationGraph net = ModelSerializer.restoreComputationGraph(file, false); |
| float[] outputs = new float[PROBE_VALUES.length]; |
| for (int i = 0; i < PROBE_VALUES.length; i++) { |
| INDArray input = Nd4j.create(new float[] {PROBE_VALUES[i]}, new long[] {1, 1}); |
| outputs[i] = net.outputSingle(input).getFloat(0); |
| } |
| return new Verification(file.getName(), sha256(file.toPath()), outputs); |
| } |
|
|
| private static String sha256(Path path) throws Exception { |
| MessageDigest digest = MessageDigest.getInstance("SHA-256"); |
| byte[] hash = digest.digest(Files.readAllBytes(path)); |
| StringBuilder out = new StringBuilder(); |
| for (byte b : hash) { |
| out.append(String.format("%02x", b)); |
| } |
| return out.toString(); |
| } |
|
|
| private static class Verification { |
| final String fileName; |
| final String sha256; |
| final float[] outputs; |
|
|
| Verification(String fileName, String sha256, float[] outputs) { |
| this.fileName = fileName; |
| this.sha256 = sha256; |
| this.outputs = outputs; |
| } |
|
|
| String toJson() { |
| return String.format(Locale.ROOT, |
| "{\"file\":\"%s\",\"sha256\":\"%s\",\"outputs\":{\"13.0\":%.6f,\"14.0\":%.6f,\"15.0\":%.6f}}", |
| fileName, sha256, outputs[0], outputs[1], outputs[2]); |
| } |
| } |
| } |
|
|