|
|
|
|
|
|
|
|
import struct |
|
|
import uuid |
|
|
|
|
|
import numpy as np |
|
|
from torch import nn |
|
|
import ezkl |
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import base64 |
|
|
from concrete.ml.deployment import FHEModelServer |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from config import rpc_url, private_key |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
evaluation_key = None |
|
|
|
|
|
|
|
|
|
|
|
class AIModel(nn.Module): |
|
|
def __init__(self): |
|
|
super(AIModel, self).__init__() |
|
|
|
|
|
|
|
|
self.fhe_model = FHEModelServer("../deployment/sentiment_fhe_model") |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
x = x[0] |
|
|
_encrypted_encoding = x.numpy().tobytes() |
|
|
prediction = self.fhe_model.run(_encrypted_encoding, evaluation_key) |
|
|
|
|
|
|
|
|
byte_tensor = torch.tensor(list(prediction), dtype=torch.uint8) |
|
|
|
|
|
|
|
|
return byte_tensor |
|
|
|
|
|
|
|
|
class ZKProofRequest(BaseModel): |
|
|
encrypted_encoding: str |
|
|
evaluation_key: str |
|
|
|
|
|
|
|
|
circuit = AIModel() |
|
|
|
|
|
|
|
|
@app.post("/get_zk_proof") |
|
|
async def get_zk_proof(request: ZKProofRequest): |
|
|
request.encrypted_encoding = base64.b64decode(request.encrypted_encoding) |
|
|
request.evaluation_key = base64.b64decode(request.evaluation_key) |
|
|
|
|
|
global evaluation_key |
|
|
evaluation_key = request.evaluation_key |
|
|
|
|
|
folder_path = f"zkml_encrypted/{str(uuid.uuid4())}" |
|
|
if not os.path.exists(folder_path): |
|
|
os.makedirs(folder_path) |
|
|
|
|
|
model_path = os.path.join(f'{folder_path}/network.onnx') |
|
|
compiled_model_path = os.path.join(f'{folder_path}/network.compiled') |
|
|
pk_path = os.path.join(f'{folder_path}/test.pk') |
|
|
vk_path = os.path.join(f'{folder_path}/test.vk') |
|
|
settings_path = os.path.join(f'{folder_path}/settings.json') |
|
|
|
|
|
witness_path = os.path.join(f'{folder_path}/witness.json') |
|
|
input_data_path = os.path.join(f'{folder_path}/input.json') |
|
|
srs_path = os.path.join(f'{folder_path}/kzg14.srs') |
|
|
output_path = os.path.join(f'{folder_path}/output.json') |
|
|
|
|
|
|
|
|
x = torch.tensor(list([request.encrypted_encoding]), dtype=torch.uint8) |
|
|
|
|
|
|
|
|
circuit.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = circuit(x) |
|
|
|
|
|
output_data = output.detach().numpy().tolist() |
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(output_data, f) |
|
|
|
|
|
|
|
|
torch.onnx.export(circuit, |
|
|
x, |
|
|
model_path, |
|
|
export_params=True, |
|
|
opset_version=10, |
|
|
do_constant_folding=True, |
|
|
input_names=['input'], |
|
|
output_names=['output'], |
|
|
dynamic_axes={'input': {0: 'batch_size'}, |
|
|
'output': {0: 'batch_size'}}) |
|
|
|
|
|
data = dict(input_data=x.tolist()) |
|
|
|
|
|
|
|
|
json.dump(data, open(input_data_path, 'w')) |
|
|
|
|
|
py_run_args = ezkl.PyRunArgs() |
|
|
py_run_args.input_visibility = "public" |
|
|
py_run_args.output_visibility = "public" |
|
|
py_run_args.param_visibility = "fixed" |
|
|
|
|
|
res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args) |
|
|
assert res is True |
|
|
|
|
|
cal_path = os.path.join(f"{folder_path}/calibration.json") |
|
|
|
|
|
|
|
|
json.dump(data, open(cal_path, 'w')) |
|
|
|
|
|
await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources") |
|
|
|
|
|
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) |
|
|
assert res is True |
|
|
|
|
|
|
|
|
res = await ezkl.get_srs(settings_path, srs_path=srs_path) |
|
|
assert res is True |
|
|
|
|
|
|
|
|
|
|
|
res = await ezkl.gen_witness(input_data_path, compiled_model_path, witness_path) |
|
|
assert os.path.isfile(witness_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = ezkl.setup( |
|
|
compiled_model_path, |
|
|
vk_path, |
|
|
pk_path, |
|
|
srs_path |
|
|
) |
|
|
|
|
|
assert res is True |
|
|
assert os.path.isfile(vk_path) |
|
|
assert os.path.isfile(pk_path) |
|
|
assert os.path.isfile(settings_path) |
|
|
|
|
|
|
|
|
proof_path = os.path.join(f'{folder_path}/test.pf') |
|
|
res = ezkl.prove( |
|
|
witness_path, |
|
|
compiled_model_path, |
|
|
pk_path, |
|
|
proof_path, |
|
|
"single", |
|
|
srs_path |
|
|
) |
|
|
assert os.path.isfile(proof_path) |
|
|
|
|
|
|
|
|
res = ezkl.verify( |
|
|
proof_path, |
|
|
settings_path, |
|
|
vk_path, |
|
|
srs_path |
|
|
) |
|
|
assert res is True |
|
|
print("verified on local") |
|
|
|
|
|
|
|
|
verify_sol_code_path = os.path.join(f'{folder_path}/verify.sol') |
|
|
verify_sol_abi_path = os.path.join(f'{folder_path}/verify.abi') |
|
|
res = await ezkl.create_evm_verifier( |
|
|
vk_path, |
|
|
settings_path, |
|
|
verify_sol_code_path, |
|
|
verify_sol_abi_path, |
|
|
srs_path |
|
|
) |
|
|
assert res is True |
|
|
verify_contract_addr_file = f"{folder_path}/addr.txt" |
|
|
await ezkl.deploy_evm( |
|
|
addr_path=verify_contract_addr_file, |
|
|
rpc_url=rpc_url, |
|
|
private_key=private_key, |
|
|
sol_code_path=verify_sol_code_path |
|
|
) |
|
|
if os.path.exists(verify_contract_addr_file): |
|
|
with open(verify_contract_addr_file, 'r') as file: |
|
|
verify_contract_addr = file.read() |
|
|
else: |
|
|
print(f"error: File {verify_contract_addr_file} does not exist.") |
|
|
return {"error": "Contract address file not found"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(proof_path, 'rb') as f: |
|
|
proof_content = base64.b64encode(f.read()).decode('utf-8') |
|
|
|
|
|
return {"output": array_to_hex_string(output_data)[:1000], |
|
|
"output_path": output_path, |
|
|
"proof": proof_content[:500], |
|
|
"proof_path": proof_path, |
|
|
"verify_contract_addr": verify_contract_addr} |
|
|
|
|
|
|
|
|
def array_to_hex_string(array): |
|
|
hex_string = ''.join(format(num, '02x') for num in array) |
|
|
return hex_string |
|
|
|