Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import roc_curve, roc_auc_score, f1_score | |
| json_files = [ | |
| os.path.join("result", "data_april14_Celeb-DF.json"), | |
| os.path.join("result", "data_april14_DFDC.json"), | |
| os.path.join("result", "data_april11_DeepfakeTIMIT.json"), | |
| os.path.join("result", "data_april14_FF++.json"), | |
| ] | |
| # Lists to store the ROC curve data | |
| fpr_list = [] | |
| tpr_list = [] | |
| roc_auc_list = [] | |
| for json_file in json_files: | |
| with open(json_file, "r") as f: | |
| result = json.load(f) | |
| # Get the actual labels and predicted probabilities or predicted labels from the result dictionary | |
| actual_labels = result["video"]["correct_label"] | |
| predicted_probs = result["video"]["pred"] | |
| predicted_labels = result["video"]["pred_label"] | |
| big_pp = [1 if P >= 0.5 else 0 for P in predicted_probs] | |
| p_labels = [1 if label == "FAKE" else 0 for label in predicted_labels] | |
| a_labels = [1 if label == "FAKE" else 0 for label in actual_labels] | |
| # Calculate ROC curve and AUC | |
| fpr, tpr, thresholds = roc_curve(a_labels, predicted_probs) | |
| roc_auc = roc_auc_score(a_labels, predicted_probs) | |
| f1 = f1_score(a_labels, big_pp) | |
| # Append the data to the lists | |
| fpr_list.append(fpr) | |
| tpr_list.append(tpr) | |
| roc_auc_list.append(roc_auc) | |
| a = 0 | |
| for i in range(len(p_labels)): | |
| if p_labels[i] == a_labels[i]: | |
| a += 1 | |
| accuracy = sum(x == y for x, y in zip(p_labels, a_labels)) / len(p_labels) | |
| real_acc = sum( | |
| (x == y and y == 0) for x, y in zip(p_labels, a_labels) | |
| ) / a_labels.count(0) | |
| fake_acc = sum( | |
| (x == y and y == 1) for x, y in zip(p_labels, a_labels) | |
| ) / a_labels.count(1) | |
| print( | |
| f"{(json_file[:-5].split('_')[-1])}:\nReal accuracy {real_acc*100:.3f} Fake accuracy {fake_acc*100:.3f}, Accuracy: {accuracy*100:.3f}" | |
| ) | |
| print(f"ROC AUC: {roc_auc:.3f}") | |
| print(f"F1 Score: {f1:.3f}\n") | |
| # Plot ROC curves | |
| plt.figure() | |
| for i in range(len(json_files)): | |
| plt.plot( | |
| fpr_list[i], | |
| tpr_list[i], | |
| label=f"{json_files[i][:-5].split('_')[-1]} (area = %0.3f)" % roc_auc_list[i], | |
| ) | |
| plt.plot([0, 1], [0, 1], "k--") | |
| plt.xlim([0.0, 1.0]) | |
| plt.ylim([0.0, 1.05]) | |
| plt.xlabel("False Positive Rate") | |
| plt.ylabel("True Positive Rate") | |
| plt.title("Receiver Operating Characteristic (ROC) Curve") | |
| plt.legend(loc="lower right") | |
| plt.show() | |