| import os |
| import joblib |
| import numpy as np |
| from concrete.ml.deployment import FHEModelClient, FHEModelServer |
| import logging |
| import gradio as gr |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| key_already_generated_condition = False |
| encrypted_data = None |
| encrypted_prediction = None |
|
|
| |
| SCALER_PATH = os.path.join("models", "scaler_random_forest.pkl") |
| FHE_FILES_PATH = os.path.join("models", "fhe_files") |
|
|
| |
| try: |
| scaler = joblib.load(SCALER_PATH) |
| logging.info("Scaler loaded successfully.") |
| except FileNotFoundError: |
| logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.") |
| raise |
|
|
| |
| try: |
| client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH) |
| server = FHEModelServer(path_dir=FHE_FILES_PATH) |
| server.load() |
| logging.info("FHE Client and Server initialized successfully.") |
| except FileNotFoundError: |
| logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.") |
| raise |
|
|
| |
| evaluation_keys = client.get_serialized_evaluation_keys() |
|
|
| def predict(): |
| """ |
| Perform a local prediction using the compiled FHE model. |
| Returns: |
| str: The prediction result. |
| str: A message indicating the status of the prediction. |
| """ |
| global encrypted_data, encrypted_prediction |
| if encrypted_data is None: |
| return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌") |
| try: |
| |
| encrypted_prediction = server.run( |
| encrypted_data, serialized_evaluation_keys=evaluation_keys |
| ) |
| logging.info(f"Encrypted Prediction: {encrypted_prediction}") |
| return encrypted_prediction.hex(), gr.update(value="FHE evaluation is done. ✅") |
| |
| except Exception as e: |
| logging.error(f"Error during prediction: {e}") |
| return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌") |
| |
| def decrypt_prediction(): |
| """ |
| Decrypt and interpret the prediction result. |
| Returns: |
| str: The interpreted prediction result. |
| """ |
| global encrypted_prediction |
| if encrypted_prediction is None: |
| return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌" |
| try: |
| |
| decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction) |
| logging.info(f"Decrypted Prediction: {decrypted_prediction}") |
|
|
| |
| binary_prediction = int(np.argmax(decrypted_prediction)) |
| |
| if isinstance(decrypted_prediction, np.ndarray) and decrypted_prediction.ndim > 1: |
| decrypted_prediction = decrypted_prediction.flatten() |
|
|
| |
| bar_html = f""" |
| <div style="width: 100%; background-color: lightgray; border-radius: 5px; overflow: hidden; display: flex;"> |
| <div style="width: {decrypted_prediction[0] * 100}%; background-color: green; color: white; text-align: center; padding: 5px 0;"> |
| {decrypted_prediction[0] * 100:.1f}% Non-Fraud |
| </div> |
| <div style="width: {decrypted_prediction[1] * 100}%; background-color: red; color: white; text-align: center; padding: 5px 0;"> |
| {decrypted_prediction[1] * 100:.1f}% Fraud |
| </div> |
| </div> |
| """ |
| return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅"), bar_html |
| |
| except Exception as e: |
| logging.error(f"Error during prediction: {e}") |
| return "Error during prediction❌", "Error during prediction❌","Error during prediction❌" |
|
|
| def key_already_generated(): |
| """ |
| Check if the evaluation keys have already been generated. |
| Returns: |
| bool: True if the evaluation keys have already been generated, False otherwise. |
| """ |
| global key_already_generated_condition |
| if evaluation_keys: |
| key_already_generated_condition = True |
| return True |
| return False |
|
|
| def pre_process_encrypt_send_purchase(*inputs): |
| """ |
| Pre-processes, encrypts, and sends the purchase data for prediction. |
| Args: |
| *inputs: Variable number of input arguments. |
| Returns: |
| (str): A short representation of the encrypted input to send in hex. |
| """ |
| global key_already_generated_condition, encrypted_data |
| if key_already_generated_condition == False: |
| return None, gr.update(value="Generate your key before. ❌") |
| try: |
| key_already_generated_condition = True |
| logging.info(f"Input Data: {inputs}") |
|
|
| |
| scaled_data = scaler.transform([list(inputs)]) |
| logging.info(f"Scaled Data: {scaled_data}") |
|
|
| |
| encrypted_data = client.quantize_encrypt_serialize(scaled_data) |
| logging.info("Data encrypted successfully.") |
| return encrypted_data.hex(), gr.update(value="Inputs are encrypted and sent to server. ✅") |
| except Exception as e: |
| logging.error(f"Error during pre-processing: {e}") |
| return "Error during pre-processing" |
|
|