Spaces:
Runtime error
Runtime error
| import logging | |
| import math | |
| from typing import Dict | |
| import torch | |
| from pandas import MultiIndex | |
| from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| from pytorch_ie import DocumentMetric | |
| from pytorch_ie.core.metric import T | |
| from torchmetrics import Metric, MetricCollection | |
| from src.hydra_callbacks.save_job_return_value import to_py_obj | |
| logger = logging.getLogger(__name__) | |
| class CorefMetricsTorchmetrics(DocumentMetric): | |
| DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations | |
| def __init__( | |
| self, | |
| metrics: Dict[str, Metric], | |
| default_target_idx: int = 0, | |
| default_prediction_score: float = 0.0, | |
| show_as_markdown: bool = False, | |
| markdown_precision: int = 4, | |
| plot: bool = False, | |
| ): | |
| self.metrics = MetricCollection(metrics) | |
| self.default_target_idx = default_target_idx | |
| self.default_prediction_score = default_prediction_score | |
| self.show_as_markdown = show_as_markdown | |
| self.markdown_precision = markdown_precision | |
| self.plot = plot | |
| super().__init__() | |
| def reset(self) -> None: | |
| self.metrics.reset() | |
| def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None: | |
| target_args2idx = { | |
| (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations | |
| } | |
| prediction_args2score = { | |
| (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions | |
| } | |
| all_args = set(target_args2idx) | set(prediction_args2score) | |
| all_targets = [] | |
| all_predictions = [] | |
| for args in all_args: | |
| target_idx = target_args2idx.get(args, self.default_target_idx) | |
| prediction_score = prediction_args2score.get(args, self.default_prediction_score) | |
| all_targets.append(target_idx) | |
| all_predictions.append(prediction_score) | |
| prediction_scores = torch.tensor(all_predictions) | |
| target_indices = torch.tensor(all_targets) | |
| self.metrics.update(preds=prediction_scores, target=target_indices) | |
| def do_plot(self): | |
| from matplotlib import pyplot as plt | |
| # Get the number of metrics | |
| num_metrics = len(self.metrics) | |
| # Calculate rows and columns for subplots (aim for a square-like layout) | |
| ncols = math.ceil(math.sqrt(num_metrics)) | |
| nrows = math.ceil(num_metrics / ncols) | |
| # Create the subplots | |
| fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) | |
| # Flatten the ax_list if necessary (in case of multiple rows/columns) | |
| ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary | |
| # Ensure that we pass exactly the number of axes required by metrics | |
| ax_list = ax_list[:num_metrics] | |
| # Plot the metrics using the list of axes | |
| self.metrics.plot(ax=ax_list, together=False) | |
| # Adjust layout to avoid overlapping plots | |
| plt.tight_layout() | |
| plt.show() | |
| def _compute(self) -> T: | |
| if self.plot: | |
| self.do_plot() | |
| result = self.metrics.compute() | |
| result = to_py_obj(result) | |
| if self.show_as_markdown: | |
| import pandas as pd | |
| series = pd.Series(result) | |
| if isinstance(series.index, MultiIndex): | |
| if len(series.index.levels) > 1: | |
| # in fact, this is not a series anymore | |
| series = series.unstack(-1) | |
| else: | |
| series.index = series.index.get_level_values(0) | |
| logger.info( | |
| f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" | |
| ) | |
| return result | |