| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from PIL import Image |
| import time |
|
|
| from concrete.fhe import Configuration |
| from concrete.ml.torch.compile import compile_torch_model |
|
|
| from custom_resnet import resnet18_custom |
|
|
| |
| class_names = ['Fake', 'Real'] |
|
|
| |
| def load_model(model_path, device): |
| model = resnet18_custom(weights=None) |
| num_ftrs = model.fc.in_features |
| model.fc = nn.Linear(num_ftrs, len(class_names)) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model = model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def load_secure_model(model): |
| print("Compiling secure model...") |
| secure_model = compile_torch_model( |
| model.to("cpu"), |
| n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5}, |
| rounding_threshold_bits={"n_bits": 7, "method": "APPROXIMATE"}, |
| p_error=0.05, |
| configuration=Configuration(enable_tlu_fusing=True, print_tlu_fusing=False, use_gpu=False), |
| torch_inputset=torch.rand(10, 3, 224, 224) |
| ) |
| return secure_model |
|
|
| |
| data_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
|
|
| |
| def predict(image, mode): |
| |
| device = torch.device( |
| "cuda:0" if torch.cuda.is_available() else |
| "mps" if torch.backends.mps.is_available() else |
| "cpu" |
| ) |
|
|
| print(f"Device: {device}") |
| |
| model_path = 'models/deepfake_detection_model.pth' |
| model = load_model(model_path, device) |
|
|
| |
| image = Image.open(image).convert('RGB') |
| image = data_transform(image).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| start_time = time.time() |
| |
| if mode == "Fast": |
| |
| outputs = model(image) |
| elif mode == "Secure": |
| |
| secure_model = load_secure_model(model) |
| detached_input = image.detach().numpy() |
| outputs = secure_model(detached_input, fhe="simulate") |
| |
| print(outputs) |
| _, preds = torch.max(outputs, 1) |
| elapsed_time = time.time() - start_time |
|
|
| predicted_class = class_names[preds[0]] |
| return f"Predicted: {predicted_class}", f"Time taken: {elapsed_time:.2f} seconds" |
|
|
| |
| iface = gr.Interface( |
| fn=predict, |
| inputs=[ |
| gr.Image(type="filepath", label="Upload an Image"), |
| gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast") |
| ], |
| outputs=[ |
| gr.Textbox(label="Prediction"), |
| gr.Textbox(label="Time Taken") |
| ], |
| title="Deepfake Detection Model", |
| description="Upload an image and select the inference mode (Fast or Secure)." |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch(share=True) |