|
|
|
|
|
|
|
|
import pickle |
|
|
import struct |
|
|
import uuid |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from sklearn.model_selection import GridSearchCV, train_test_split |
|
|
from torch import nn |
|
|
import ezkl |
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import base64 |
|
|
from concrete.ml.deployment import FHEModelServer |
|
|
from concrete.ml.sklearn import XGBClassifier |
|
|
import tqdm |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
evaluation_key = None |
|
|
|
|
|
|
|
|
|
|
|
class AIWordsModel(nn.Module): |
|
|
def __init__(self): |
|
|
super(AIWordsModel, self).__init__() |
|
|
|
|
|
print("init ZK AIWordsModel") |
|
|
|
|
|
|
|
|
self.model = XGBClassifier() |
|
|
train = pd.read_csv("./local_datasets/twitter-airline-sentiment/Tweets.csv", index_col=0) |
|
|
text_X = train["text"] |
|
|
y = train["airline_sentiment"].replace(["negative", "neutral", "positive"], [0, 1, 2]) |
|
|
|
|
|
|
|
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest") |
|
|
self.transformer_model = AutoModelForSequenceClassification.from_pretrained( |
|
|
"cardiffnlp/twitter-roberta-base-sentiment-latest" |
|
|
).to(self.device) |
|
|
|
|
|
text_X_train, text_X_test, y_train, y_test = train_test_split( |
|
|
text_X, y, test_size=0.1, random_state=42 |
|
|
) |
|
|
X_train_transformer = self.text_to_tensor(text_X_train.tolist(), self.transformer_model, self.tokenizer, |
|
|
self.device) |
|
|
|
|
|
with open("deployment/serialized_model_zkml", 'rb') as file: |
|
|
loaded_data = pickle.load(file) |
|
|
self.model.load_dict(loaded_data) |
|
|
parameters = {"n_bits": [2, 3], "max_depth": [1], "n_estimators": [10, 30, 50]} |
|
|
grid_search2 = GridSearchCV(self.model, parameters, cv=5, scoring="accuracy") |
|
|
grid_search2.fit(X_train_transformer, y_train) |
|
|
self.best_model2 = grid_search2.best_estimator_ |
|
|
self.best_model2.load_dict(loaded_data) |
|
|
self.best_model2.compile(X_train_transformer) |
|
|
|
|
|
print(f"loaded_data finished") |
|
|
|
|
|
def forward(self, x): |
|
|
prediction = self.best_model2.predict_proba(x, fhe="execute") |
|
|
|
|
|
prediction_tensor = torch.tensor(prediction, dtype=torch.float32) |
|
|
prediction_tensor = prediction_tensor.squeeze() |
|
|
|
|
|
return prediction_tensor |
|
|
|
|
|
|
|
|
def text_to_tensor(self, list_text, transformer_model, tokenizer, device): |
|
|
tokenized_text = [tokenizer.encode(text, return_tensors="pt") for text in list_text] |
|
|
output_hidden_states_list = [None] * len(tokenized_text) |
|
|
|
|
|
for i, tokenized_x in enumerate(tqdm.tqdm(tokenized_text)): |
|
|
output_hidden_states = transformer_model(tokenized_x.to(device), output_hidden_states=True)[1][-1] |
|
|
output_hidden_states = output_hidden_states.mean(dim=1).detach().cpu().numpy() |
|
|
output_hidden_states_list[i] = output_hidden_states |
|
|
|
|
|
return np.concatenate(output_hidden_states_list, axis=0) |
|
|
|
|
|
|
|
|
class ZKProofRequest(BaseModel): |
|
|
text: str |
|
|
|
|
|
|
|
|
circuit = AIWordsModel() |
|
|
|
|
|
|
|
|
@app.post("/get_zk_proof") |
|
|
async def get_zk_proof(request: ZKProofRequest): |
|
|
folder_path = f"zkml_non_encrypted/{str(uuid.uuid4())}" |
|
|
if not os.path.exists(folder_path): |
|
|
os.makedirs(folder_path) |
|
|
|
|
|
model_path = os.path.join(f'{folder_path}/network.onnx') |
|
|
compiled_model_path = os.path.join(f'{folder_path}/network.compiled') |
|
|
pk_path = os.path.join(f'{folder_path}/test.pk') |
|
|
vk_path = os.path.join(f'{folder_path}/test.vk') |
|
|
settings_path = os.path.join(f'{folder_path}/settings.json') |
|
|
|
|
|
witness_path = os.path.join(f'{folder_path}/witness.json') |
|
|
input_data_path = os.path.join(f'{folder_path}/input.json') |
|
|
srs_path = os.path.join(f'{folder_path}/kzg14.srs') |
|
|
output_path = os.path.join(f'{folder_path}/output.json') |
|
|
|
|
|
|
|
|
words = [request.text] |
|
|
x_list = circuit.text_to_tensor(words, circuit.transformer_model, circuit.tokenizer, circuit.device) |
|
|
x = torch.tensor(x_list, dtype=torch.float32) |
|
|
|
|
|
|
|
|
circuit.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = circuit(x) |
|
|
|
|
|
output_data = output.detach().numpy().tolist() |
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(output_data, f) |
|
|
|
|
|
|
|
|
torch.onnx.export(circuit, |
|
|
x, |
|
|
model_path, |
|
|
export_params=True, |
|
|
opset_version=10, |
|
|
do_constant_folding=True, |
|
|
input_names=['input'], |
|
|
output_names=['output'], |
|
|
dynamic_axes={'input': {0: 'batch_size'}, |
|
|
'output': {0: 'batch_size'}}) |
|
|
|
|
|
data = dict(input_data=x.tolist()) |
|
|
|
|
|
|
|
|
json.dump(data, open(input_data_path, 'w')) |
|
|
|
|
|
py_run_args = ezkl.PyRunArgs() |
|
|
py_run_args.input_visibility = "public" |
|
|
py_run_args.output_visibility = "public" |
|
|
py_run_args.param_visibility = "fixed" |
|
|
|
|
|
res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args) |
|
|
assert res is True |
|
|
|
|
|
cal_path = os.path.join(f"{folder_path}/calibration.json") |
|
|
|
|
|
|
|
|
json.dump(data, open(cal_path, 'w')) |
|
|
|
|
|
await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources") |
|
|
|
|
|
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) |
|
|
assert res is True |
|
|
|
|
|
|
|
|
res = await ezkl.get_srs(settings_path, srs_path=srs_path) |
|
|
assert res is True |
|
|
|
|
|
|
|
|
res = await ezkl.gen_witness(input_data_path, compiled_model_path, witness_path) |
|
|
assert os.path.isfile(witness_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = ezkl.setup( |
|
|
compiled_model_path, |
|
|
vk_path, |
|
|
pk_path, |
|
|
srs_path |
|
|
) |
|
|
|
|
|
assert res is True |
|
|
assert os.path.isfile(vk_path) |
|
|
assert os.path.isfile(pk_path) |
|
|
assert os.path.isfile(settings_path) |
|
|
|
|
|
|
|
|
proof_path = os.path.join(f'{folder_path}/test.pf') |
|
|
res = ezkl.prove( |
|
|
witness_path, |
|
|
compiled_model_path, |
|
|
pk_path, |
|
|
proof_path, |
|
|
"single", |
|
|
srs_path |
|
|
) |
|
|
assert os.path.isfile(proof_path) |
|
|
|
|
|
|
|
|
res = ezkl.verify( |
|
|
proof_path, |
|
|
settings_path, |
|
|
vk_path, |
|
|
srs_path |
|
|
) |
|
|
assert res is True |
|
|
print("verified on local") |
|
|
|
|
|
|
|
|
verify_sol_code_path = os.path.join(f'{folder_path}/verify.sol') |
|
|
verify_sol_abi_path = os.path.join(f'{folder_path}/verify.abi') |
|
|
res = await ezkl.create_evm_verifier( |
|
|
vk_path, |
|
|
settings_path, |
|
|
verify_sol_code_path, |
|
|
verify_sol_abi_path, |
|
|
srs_path |
|
|
) |
|
|
assert res is True |
|
|
verify_contract_addr_file = f"{folder_path}/addr.txt" |
|
|
rpc_url = "http://103.231.86.33:10219" |
|
|
await ezkl.deploy_evm( |
|
|
addr_path=verify_contract_addr_file, |
|
|
rpc_url=rpc_url, |
|
|
sol_code_path=verify_sol_code_path |
|
|
) |
|
|
if os.path.exists(verify_contract_addr_file): |
|
|
with open(verify_contract_addr_file, 'r') as file: |
|
|
verify_contract_addr = file.read() |
|
|
else: |
|
|
print(f"error: File {verify_contract_addr_file} does not exist.") |
|
|
return {"error": "Contract address file not found"} |
|
|
res = await ezkl.verify_evm( |
|
|
addr_verifier=verify_contract_addr, |
|
|
proof_path=proof_path, |
|
|
rpc_url=rpc_url |
|
|
) |
|
|
assert res is True |
|
|
print("verified on chain") |
|
|
|
|
|
|
|
|
with open(proof_path, 'rb') as f: |
|
|
proof_content = base64.b64encode(f.read()).decode('utf-8') |
|
|
|
|
|
return {"output": output_data, "proof": proof_content, "verify_contract_addr": verify_contract_addr} |
|
|
|