Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Sun Jan 28 18:48:07 2024 | |
| @author: liewchooichin | |
| """ | |
| import os | |
| import pathlib | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from huggingface_hub import snapshot_download | |
| from huggingface_hub import from_pretrained_keras | |
| # check the tensoflow version | |
| print(f"tensorflow version: {tf.__version__}") | |
| # global variables | |
| # predictions from: | |
| pred_binary = "" # binary labels | |
| pred_multi = "" # multi labels | |
| # sample files | |
| samples = [] | |
| labels = [] | |
| data_dir = "face_samples" | |
| # local testing | |
| LOCAL_TEST = False # when in HF, set to False | |
| HF_SPACE = True # when in HF | |
| # My model in the HF repo | |
| REPO_ID_BINARY = 'liewchooichin/fake_binary' | |
| REPO_ID_MULTILABEL = 'liewchooichin/fake_multilabel' | |
| # tf_model = None | |
| # keras_model = None | |
| local_model_dir = "fake_models" | |
| pb_name = "saved_model.pb" | |
| keras_binary_label = os.path.join("binary_label", "all_binary_6771.keras") | |
| keras_multilabel = os.path.join("multi_label", "multi_7036.keras") | |
| def get_samples(): | |
| samples_path = os.path.join( | |
| os.path.dirname(__file__), | |
| data_dir | |
| ) | |
| samples_path = pathlib.Path(samples_path) | |
| files = list(samples_path.glob("*.jpg")) | |
| # hard code the examples first for test | |
| # first 9 are fake, the last 3 are real | |
| # fake faces | |
| for i in range(9): | |
| samples.append(files[i]) | |
| # get the fake or real label | |
| fake = 1 | |
| labels.append(fake) | |
| # real faces | |
| for i in range(3): | |
| samples.append(files[i+9]) | |
| fake = 0 | |
| labels.append(fake) | |
| # print to check the image and labels | |
| for i in range(12): | |
| print(samples[i], labels[i]) | |
| def download_keras_model(): | |
| # set the model variables to be global | |
| global keras_binary_model | |
| global keras_multi_model | |
| # HF repo | |
| # load binary label | |
| if HF_SPACE: | |
| download_dir = snapshot_download(repo_id=REPO_ID_BINARY) | |
| print(f"Download dir: {download_dir}") | |
| keras_binary_path = os.path.join(download_dir, keras_binary_label) | |
| print(f"Keras binary label: {keras_binary_path}") | |
| # this load() does not work in HF | |
| #keras_binary_model = tf.keras.models.load(keras_binary_path) | |
| #keras_binary_model = tf.keras.saving.load_model(keras_binary_path) | |
| #keras_binary_model = from_pretrained_keras("liewchooichin/fake_binary") | |
| keras_binary_model = tf.saved_model.load(download_dir) | |
| # local testing | |
| # check if the model exists | |
| # binary label | |
| # "C:\PY\exercises\hello_iris\fake_models\binary_label\all_binary_6771.keras" | |
| if LOCAL_TEST: | |
| model_path = os.path.join( | |
| os.path.dirname(__file__), | |
| local_model_dir, | |
| keras_binary_label | |
| ) | |
| if not os.path.exists(model_path): | |
| print(f"Model not found: {model_path}") | |
| # load local keras model | |
| keras_binary_model = tf.keras.models.load_model(model_path) | |
| # Check with model loaded | |
| #print(f"\nBinary label model: {keras_binary_model.name}") | |
| # load multilabel | |
| # "C:\PY\exercises\hello_iris\fake_models\multi_label\all_multi_7036.keras" | |
| if LOCAL_TEST: | |
| model_path = os.path.join( | |
| os.path.dirname(__file__), | |
| local_model_dir, | |
| keras_multilabel | |
| ) | |
| if not os.path.exists(model_path): | |
| print(f"Model not found: {model_path}") | |
| # load local keras model | |
| keras_multi_model = tf.keras.models.load_model(model_path) | |
| # In HF space, load model from repository | |
| # Load the multilabel model | |
| if HF_SPACE: | |
| # HF repo | |
| download_dir = snapshot_download(repo_id=REPO_ID_MULTILABEL) | |
| print(f"Download dir: {download_dir}") | |
| keras_multi_path = os.path.join(download_dir, keras_multilabel) | |
| print(f"Keras multi label: {keras_multi_path}") | |
| # load() does not work in HF | |
| #keras_multi_model = tf.keras.models.load(keras_multi_path) | |
| #keras_multi_model = tf.keras.saving.load_model(keras_multi_path) | |
| #keras_multi_model = from_pretrained_keras("liewchooichin/fake_multilabel") | |
| keras_multi_model = tf.saved_model.load(download_dir) | |
| # Check with model loaded | |
| #print(f"\nLoaded model: {keras_multi_model.name}") | |
| def get_img_array(img_path): | |
| # get the dataset into array of 224x224 | |
| img = tf.keras.utils.load_img( | |
| img_path, | |
| target_size=(224, 224) | |
| ) | |
| img_array = tf.keras.utils.img_to_array(img) | |
| # expand the dimension for prediction | |
| img_array = np.expand_dims(img_array, axis=0) | |
| print(f"Shape of image array: {img_array.shape}") | |
| return img_array | |
| def get_prediction(img_path): | |
| # adjust threshold for accuracy | |
| threshold = 0.4 | |
| # check the image path | |
| print(f"Image path: {img_path}") | |
| # also display the original filename for info | |
| orig_filename = img_path.split("\\")[-1] | |
| get_img_array(img_path) | |
| # get the image array | |
| img_array = get_img_array(img_path) | |
| # test with local model | |
| # binary label | |
| pred_binary = keras_binary_model(img_array, training=False) | |
| print(f"Keras binary label: {pred_binary}") | |
| if pred_binary[0][0] > threshold: | |
| fake = "Fake" | |
| else: | |
| fake = "Real" | |
| # multi label | |
| pred_multi = keras_multi_model(img_array, training=False) | |
| print(f"Keras multi label: {pred_multi}") | |
| # Cut at the sigmoid 0.5 threshold | |
| fake_parts = np.where(pred_multi > threshold, 1, 0) | |
| print(f"Multi label: {fake_parts}") | |
| # Format each of the fake face parts | |
| parts_message = dict() | |
| # The last one is the overall prediction | |
| parts_message["overall"] = "Fake" if fake_parts[0][4] == 1 else "Real" | |
| parts_message["left_eye"] = "Fake" if fake_parts[0][0] == 1 else "Real" | |
| parts_message["right_eye"] = "Fake" if fake_parts[0][1] == 1 else "Real" | |
| parts_message["nose"] = "Fake" if fake_parts[0][2] == 1 else "Real" | |
| parts_message["mouth"] = "Fake" if fake_parts[0][3] == 1 else "Real" | |
| # Format the display line by line | |
| parts_formatted = "" | |
| for k, v in parts_message.items(): | |
| parts_formatted = parts_formatted + f"{k}: {v}\n" | |
| # Format result string | |
| result_binary = f"Probability: {pred_binary} \ | |
| \nPrediction: {fake}\n" | |
| result_multi = f"Probability: {pred_multi} \ | |
| \nPrediction: {fake_parts} \ | |
| \n{parts_formatted}" | |
| # pred_multi = tf_model(img_path) | |
| # print(f"tf: \n{pred_multi}") | |
| return orig_filename, result_binary, result_multi | |
| def clear_image(): | |
| # Clear the previous output result | |
| return "", "", "" | |
| def main(): | |
| get_samples() | |
| # download_tf_model() | |
| download_keras_model() | |
| with gr.Blocks() as demo: | |
| # call the main for preliminary work | |
| main() | |
| image_width = 256 | |
| image_height = 256 | |
| gr.Markdown( | |
| """ | |
| # Fake or real faces detection. | |
| The dataset is obtained from https://www.kaggle.com/datasets/ciplab/real-and-fake-face-detection. | |
| Trained with EfficientNet V2 B0. | |
| One model is trained to do binary classification and the other \ | |
| multilabel classification. The multilabels classification is \ | |
| based on the last four digits provided by the filenames. \ | |
| The last four digits are following the order of left eye, \ | |
| right eye, nose and mouth. \ | |
| The labels are 1 (fake) and 0 (real). | |
| For example: ___1010.jpg means left eye and nose are fake. | |
| Binary accuracy for the binary label model: 0.6771. <br> | |
| Binary accuracy for the multilabel model: 0.7036. | |
| The fake faces are also categorized into how difficult it is \ | |
| to detect the faces as fake. The categories are easy, mid and hard. | |
| The top prediction and its probabilities of classes are shown. | |
| Try our sample faces below or upload one of your own. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image(height=image_height, | |
| width=image_width, | |
| sources=["upload", "clipboard"], | |
| interactive=True, | |
| type="filepath") | |
| with gr.Column(): | |
| text_1 = gr.Text( | |
| label="Filename", | |
| interactive=False, lines=1 | |
| ) | |
| text_2 = gr.Text( | |
| label="Binary label, Efficient net v2 B0", | |
| interactive=False, lines=2) | |
| text_3 = gr.Text( | |
| label="Multi label, Efficient net v2 B0", | |
| interactive=False, lines=7, | |
| visible=False) | |
| """ | |
| text_3 = gr.Text(label="Sashi's model", | |
| interactive=False, lines=3) | |
| text_4 = gr.Text(label="KK's model", | |
| interactive=False, lines=3) | |
| """ | |
| # load the images directory | |
| # print(f"List of examples: {samples}") | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ## Fakes faces <br>(easy) | |
| """) | |
| examples_1 = gr.Examples( | |
| examples=[ | |
| samples[0], samples[1], samples[2], | |
| ], | |
| inputs=[img], | |
| outputs=[text_1, text_2, text_3], | |
| run_on_click=True, | |
| fn=get_prediction | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ## Fakes faces <br>(mid) | |
| """) | |
| examples_2 = gr.Examples( | |
| examples=[ | |
| samples[3], samples[4], samples[5], | |
| ], | |
| inputs=[img], | |
| outputs=[text_1, text_2, text_3], | |
| run_on_click=True, | |
| fn=get_prediction | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ## Fakes faces <br>(hard) | |
| """) | |
| examples_3 = gr.Examples( | |
| examples=[ | |
| samples[6], samples[7], samples[8], | |
| ], | |
| inputs=[img], | |
| outputs=[text_1, text_2, text_3], | |
| run_on_click=True, | |
| fn=get_prediction | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ## Real faces | |
| """) | |
| examples_4 = gr.Examples( | |
| examples=[ | |
| samples[9], samples[10], samples[11] | |
| ], | |
| inputs=[img], | |
| outputs=[text_1, text_2, text_3], | |
| run_on_click=True, | |
| fn=get_prediction | |
| ) | |
| # prediction when a file is uploaded | |
| img.upload(fn=get_prediction, inputs=[img], | |
| outputs=[text_1, text_2, text_3]) | |
| # when an example is clicked | |
| # img.change(fn=get_prediction, inputs=[img], | |
| # outputs=[text_1, text_2]) | |
| # when an image is cleared | |
| img.clear(fn=clear_image, inputs=[], | |
| outputs=[text_1, text_2, text_3]) | |
| if __name__ == "__main__": | |
| demo.launch() | |