Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| from collections import Counter | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| LABELS = ["claim", "counter_claim", "premise", "unknown"] | |
| def plot_confusion_matrix(metrics_path: Path, output_path: Path) -> None: | |
| metrics = json.loads(metrics_path.read_text(encoding="utf-8")) | |
| labels = metrics["confusion_matrix"]["labels"] | |
| matrix = np.array(metrics["confusion_matrix"]["matrix"]) | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| im = ax.imshow(matrix, cmap="Blues") | |
| ax.set_title("Confusion Matrix") | |
| ax.set_xlabel("Predicted label") | |
| ax.set_ylabel("Gold label") | |
| ax.set_xticks(range(len(labels)), labels, rotation=30, ha="right") | |
| ax.set_yticks(range(len(labels)), labels) | |
| max_value = matrix.max() if matrix.size else 0 | |
| for i in range(matrix.shape[0]): | |
| for j in range(matrix.shape[1]): | |
| value = matrix[i, j] | |
| color = "white" if value > max_value / 2 else "black" | |
| ax.text(j, i, str(value), ha="center", va="center", color=color) | |
| fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| fig.tight_layout() | |
| fig.savefig(output_path, dpi=200) | |
| plt.close(fig) | |
| def plot_label_distribution(predictions_path: Path, output_path: Path) -> None: | |
| with predictions_path.open(encoding="utf-8") as f: | |
| rows = list(csv.DictReader(f)) | |
| gold_counts = Counter(row["gold_label"] for row in rows) | |
| pred_counts = Counter(row["pred_label"] for row in rows) | |
| x = np.arange(len(LABELS)) | |
| width = 0.38 | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| ax.bar( | |
| x - width / 2, | |
| [gold_counts[label] for label in LABELS], | |
| width, | |
| label="Gold", | |
| color="#4C78A8", | |
| ) | |
| ax.bar( | |
| x + width / 2, | |
| [pred_counts[label] for label in LABELS], | |
| width, | |
| label="Predicted", | |
| color="#F58518", | |
| ) | |
| ax.set_title("Gold vs Predicted Label Distribution") | |
| ax.set_ylabel("Number of examples") | |
| ax.set_xticks(x, LABELS, rotation=30, ha="right") | |
| ax.legend() | |
| ax.set_ylim(0, max(max(gold_counts.values()), max(pred_counts.values())) + 3) | |
| fig.tight_layout() | |
| fig.savefig(output_path, dpi=200) | |
| plt.close(fig) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Plot custom evaluation results." | |
| ) | |
| parser.add_argument( | |
| "--metrics", | |
| default="evaluation/custom_argument_eval_metrics.json", | |
| type=Path, | |
| ) | |
| parser.add_argument( | |
| "--predictions", | |
| default="evaluation/custom_argument_eval_predictions.csv", | |
| type=Path, | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default="evaluation/figures", | |
| type=Path, | |
| ) | |
| args = parser.parse_args() | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| confusion_path = args.output_dir / "confusion_matrix.png" | |
| distribution_path = args.output_dir / "label_distribution.png" | |
| plot_confusion_matrix(args.metrics, confusion_path) | |
| plot_label_distribution(args.predictions, distribution_path) | |
| print(f"Saved {confusion_path}") | |
| print(f"Saved {distribution_path}") | |
| if __name__ == "__main__": | |
| main() | |