dl4j-computationgraph-trigger-poc / Dl4jComputationGraphTriggerPoC.java
hacnho's picture
Upload Dl4jComputationGraphTriggerPoC.java with huggingface_hub
7712cb2 verified
Raw
History Blame Contribute Delete
6.15 kB
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.util.Locale;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class Dl4jComputationGraphTriggerPoC {
private static final float[] PROBE_VALUES = new float[] {13.0f, 14.0f, 15.0f};
public static void main(String[] args) throws Exception {
if (args.length < 2) {
System.err.println("usage: java Dl4jComputationGraphTriggerPoC build <out-dir> | verify <control.zip> <malicious.zip>");
System.exit(2);
}
if ("build".equals(args[0])) {
Path outDir = Paths.get(args[1]);
Files.createDirectories(outDir);
writeModel(outDir.resolve("control-cg.zip").toFile(), false);
writeModel(outDir.resolve("trigger14-cg.zip").toFile(), true);
System.out.println("{\"built\":true,\"control\":\"" + outDir.resolve("control-cg.zip") + "\",\"malicious\":\""
+ outDir.resolve("trigger14-cg.zip") + "\"}");
return;
}
if ("verify".equals(args[0])) {
if (args.length != 3) {
System.err.println("verify requires <control.zip> <malicious.zip>");
System.exit(2);
}
Verification control = verifyModel(new File(args[1]));
Verification malicious = verifyModel(new File(args[2]));
boolean observed = Math.abs(control.outputs[1] - 14.0f) < 0.01f
&& Math.abs(malicious.outputs[1] - 114.0f) < 0.01f
&& Math.abs(malicious.outputs[0] - control.outputs[0]) < 0.01f
&& Math.abs(malicious.outputs[2] - control.outputs[2]) < 0.01f;
System.out.println("{");
System.out.println(" \"control\": " + control.toJson() + ",");
System.out.println(" \"malicious\": " + malicious.toJson() + ",");
System.out.println(" \"backdoor_observed\": " + observed);
System.out.println("}");
if (!observed) {
System.exit(1);
}
return;
}
System.err.println("unknown mode: " + args[0]);
System.exit(2);
}
private static void writeModel(File file, boolean malicious) throws Exception {
ComputationGraph net = new ComputationGraph(baseConfig());
net.init();
INDArray w0 = Nd4j.create(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, new long[] {1, 4});
INDArray b0 = Nd4j.create(new float[] {0.0f, -13.5f, -14.0f, -14.5f}, new long[] {1, 4});
INDArray w1 = malicious
? Nd4j.create(new float[] {1.0f, 200.0f, -400.0f, 200.0f}, new long[] {4, 1})
: Nd4j.create(new float[] {1.0f, 0.0f, 0.0f, 0.0f}, new long[] {4, 1});
INDArray b1 = Nd4j.create(new float[] {0.0f}, new long[] {1, 1});
net.setParam("dense_W", w0);
net.setParam("dense_b", b0);
net.setParam("out_W", w1);
net.setParam("out_b", b1);
ModelSerializer.writeModel(net, file, false);
}
private static ComputationGraphConfiguration baseConfig() {
return new NeuralNetConfiguration.Builder()
.seed(1337)
.updater(new NoOp())
.graphBuilder()
.addInputs("in")
.addLayer("dense", new DenseLayer.Builder()
.nIn(1)
.nOut(4)
.activation(Activation.RELU)
.weightInit(WeightInit.ZERO)
.biasInit(0.0)
.build(), "in")
.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.nIn(4)
.nOut(1)
.activation(Activation.IDENTITY)
.weightInit(WeightInit.ZERO)
.biasInit(0.0)
.build(), "dense")
.setOutputs("out")
.build();
}
private static Verification verifyModel(File file) throws Exception {
ComputationGraph net = ModelSerializer.restoreComputationGraph(file, false);
float[] outputs = new float[PROBE_VALUES.length];
for (int i = 0; i < PROBE_VALUES.length; i++) {
INDArray input = Nd4j.create(new float[] {PROBE_VALUES[i]}, new long[] {1, 1});
outputs[i] = net.outputSingle(input).getFloat(0);
}
return new Verification(file.getName(), sha256(file.toPath()), outputs);
}
private static String sha256(Path path) throws Exception {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hash = digest.digest(Files.readAllBytes(path));
StringBuilder out = new StringBuilder();
for (byte b : hash) {
out.append(String.format("%02x", b));
}
return out.toString();
}
private static class Verification {
final String fileName;
final String sha256;
final float[] outputs;
Verification(String fileName, String sha256, float[] outputs) {
this.fileName = fileName;
this.sha256 = sha256;
this.outputs = outputs;
}
String toJson() {
return String.format(Locale.ROOT,
"{\"file\":\"%s\",\"sha256\":\"%s\",\"outputs\":{\"13.0\":%.6f,\"14.0\":%.6f,\"15.0\":%.6f}}",
fileName, sha256, outputs[0], outputs[1], outputs[2]);
}
}
}