jing commited on
Commit
e3b08c9
·
1 Parent(s): 18b350c

fix a bug

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -26,6 +26,22 @@ os.makedirs(UPLOAD_DIR, exist_ok=True)
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
  DATASET_REPO_ID = "akweury/ELVIS-Human-Results" # Updated with your dataset repo
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def analyze_and_update():
30
  api = HfApi(token=HF_TOKEN)
31
  files = api.list_repo_files(DATASET_REPO_ID, repo_type="dataset")
@@ -71,20 +87,25 @@ def analyze_and_update():
71
  principle_stats[principle]["count"] += 1
72
  if item.get("correct", False):
73
  principle_stats[principle]["correct"] += 1
 
 
 
 
74
  if y_true and y_pred:
75
  try:
76
  acc = np.mean([yt == yp for yt, yp in zip(y_true, y_pred)])
77
- f1 = f1_score(y_true, y_pred)
78
  precision = precision_score(y_true, y_pred, zero_division=0)
79
  recall = recall_score(y_true, y_pred, zero_division=0)
80
  avg_time = np.mean(solve_times) if solve_times else 0
81
  principle_stats[principle]["acc_list"].append(acc)
82
- principle_stats[principle]["f1_list"].append(f1)
83
- principle_stats[principle]["precision_list"].append(precision)
84
- principle_stats[principle]["recall_list"].append(recall)
 
 
85
  principle_stats[principle]["solve_time_list"].append(avg_time)
86
  except Exception:
87
- pass
88
 
89
  with open("README.md", "w") as f:
90
  f.write("| Principle | Avg Accuracy ± Std | Avg F1 ± Std | Avg Precision ± Std | Avg Recall ± Std | Avg Solve Time (s) ± Std | Avg Hardness ± Std | Count |\n")
 
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
  DATASET_REPO_ID = "akweury/ELVIS-Human-Results" # Updated with your dataset repo
28
 
29
+ def confusion_matrix_elements(predictions, ground_truth):
30
+ TN = sum(1 for p, gt in zip(predictions, ground_truth) if p == 0 and gt == 0)
31
+ FP = sum(1 for p, gt in zip(predictions, ground_truth) if p == 1 and gt == 0)
32
+ FN = sum(1 for p, gt in zip(predictions, ground_truth) if p == 0 and gt == 1)
33
+ TP = sum(1 for p, gt in zip(predictions, ground_truth) if p == 1 and gt == 1)
34
+
35
+ return TN, FP, FN, TP
36
+
37
+ def calculate_metrics(TN, FP, FN, TP):
38
+ precision = TP / (TP + FP) if (TP + FP) > 0 else 0
39
+ recall = TP / (TP + FN) if (TP + FN) > 0 else 0
40
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
41
+
42
+ return precision, recall, f1_score
43
+
44
+
45
  def analyze_and_update():
46
  api = HfApi(token=HF_TOKEN)
47
  files = api.list_repo_files(DATASET_REPO_ID, repo_type="dataset")
 
87
  principle_stats[principle]["count"] += 1
88
  if item.get("correct", False):
89
  principle_stats[principle]["correct"] += 1
90
+
91
+ TN, FP, FN, TP = confusion_matrix_elements(y_pred, y_true)
92
+ precision, recall, f1_score = calculate_metrics(TN, FP, FN, TP)
93
+
94
  if y_true and y_pred:
95
  try:
96
  acc = np.mean([yt == yp for yt, yp in zip(y_true, y_pred)])
 
97
  precision = precision_score(y_true, y_pred, zero_division=0)
98
  recall = recall_score(y_true, y_pred, zero_division=0)
99
  avg_time = np.mean(solve_times) if solve_times else 0
100
  principle_stats[principle]["acc_list"].append(acc)
101
+ principle_stats[principle]["f1_list"]= f1_score
102
+
103
+ principle_stats[principle]["precision_list"] = precision
104
+ principle_stats[principle]["recall_list"] = recall
105
+
106
  principle_stats[principle]["solve_time_list"].append(avg_time)
107
  except Exception:
108
+ raise ValueError("Error in calculating metrics.")
109
 
110
  with open("README.md", "w") as f:
111
  f.write("| Principle | Avg Accuracy ± Std | Avg F1 ± Std | Avg Precision ± Std | Avg Recall ± Std | Avg Solve Time (s) ± Std | Avg Hardness ± Std | Count |\n")