Spaces:
Running
Running
| import re | |
| import json | |
| import evaluate | |
| import datasets | |
| _DESCRIPTION = """ | |
| Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics: | |
| 1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table | |
| 2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table | |
| 3. F1 Score: The harmonic mean of precision and recall | |
| These metrics help evaluate the accuracy of table data extraction or generation. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Args: | |
| predictions (`str`): Predicted table in Markdown format. | |
| references (`str`): Reference table in Markdown format. | |
| Returns: | |
| dict: A dictionary containing the following metrics: | |
| - precision (`float`): Precision score, range [0,1] | |
| - recall (`float`): Recall score, range [0,1] | |
| - f1 (`float`): F1 score, range [0,1] | |
| - true_positives (`int`): Number of correctly predicted cells | |
| - false_positives (`int`): Number of incorrectly predicted cells | |
| - false_negatives (`int`): Number of cells that were not predicted | |
| Examples: | |
| >>> accuracy_metric = evaluate.load("accuracy") | |
| >>> results = accuracy_metric.compute( | |
| ... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |", | |
| ... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |" | |
| ... ) | |
| >>> print(results) | |
| {'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3} | |
| """ | |
| _CITATION = """ | |
| @article{scikit-learn, | |
| title={Scikit-learn: Machine Learning in {P}ython}, | |
| author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. | |
| and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. | |
| and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and | |
| Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, | |
| journal={Journal of Machine Learning Research}, | |
| volume={12}, | |
| pages={2825--2830}, | |
| year={2011} | |
| } | |
| """ | |
| class Accuracy(evaluate.Metric): | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Value("string"), | |
| "references": datasets.Value("string"), | |
| } | |
| ), | |
| reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], | |
| ) | |
| def _extract_markdown_table(self,text): | |
| text = text.replace('\n', '') | |
| text = text.replace(" ","") | |
| pattern = r'\|(?:[^|]+\|)+[^|]+\|' | |
| matches = re.findall(pattern, text) | |
| if matches: | |
| return ''.join(matches) | |
| return None | |
| def _table_to_dict(self,table_str): | |
| result_dict = {} | |
| table_str = table_str.lstrip("|").rstrip("|") | |
| parts = table_str.split('||') | |
| parts = [part for part in parts if "--" not in part] | |
| legends = parts[0].split("|") | |
| rows = len(parts) | |
| if rows == 2: | |
| nums = parts[1].split("|") | |
| for i in range(len(nums)): | |
| result_dict[legends[i]]=float(nums[i]) | |
| elif rows >=3: | |
| for i in range(1,rows): | |
| pre_row = parts[i] | |
| pre_row = pre_row.split("|") | |
| label = pre_row[0] | |
| result_dict[label] = {} | |
| for j in range(1,len(pre_row)): | |
| result_dict[label][legends[j-1]] = float(pre_row[j]) | |
| else: | |
| return None | |
| return result_dict | |
| def _markdown_to_dict(self,markdown_str): | |
| table_str = self._extract_markdown_table(markdown_str) | |
| if table_str: | |
| return self._table_to_dict(table_str) | |
| else: | |
| return None | |
| def _calculate_table_metrics(self,pred_table, true_table): | |
| true_positives = 0 | |
| false_positives = 0 | |
| false_negatives = 0 | |
| # 遍历预测表格的所有键值对 | |
| for key, pred_value in pred_table.items(): | |
| if key in true_table: | |
| true_value = true_table[key] | |
| if isinstance(pred_value, dict) and isinstance(true_value, dict): | |
| nested_metrics = self._calculate_table_metrics(pred_value, true_value) | |
| true_positives += nested_metrics['true_positives'] | |
| false_positives += nested_metrics['false_positives'] | |
| false_negatives += nested_metrics['false_negatives'] | |
| # 如果值相等 | |
| elif pred_value == true_value: | |
| true_positives += 1 | |
| else: | |
| false_positives += 1 | |
| false_negatives += 1 | |
| else: | |
| false_positives += 1 | |
| # 计算未匹配的真实值 | |
| for key in true_table: | |
| if key not in pred_table: | |
| false_negatives += 1 | |
| # 计算指标 | |
| precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 | |
| recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| return { | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1': f1, | |
| 'true_positives': true_positives, | |
| 'false_positives': false_positives, | |
| 'false_negatives': false_negatives | |
| } | |
| def _compute(self, predictions, references): | |
| predictions = "".join(predictions) | |
| references = "".join(references) | |
| return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references)) | |
| def main(): | |
| accuracy_metric = Accuracy() | |
| # 计算指标 | |
| results = accuracy_metric.compute( | |
| predictions=[""" | |
| | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 | | |
| """], # 预测的表格 | |
| references=[""" | |
| | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 | | |
| """], # 参考的表格 | |
| ) | |
| print(results) # 输出结果 | |
| if __name__ == '__main__': | |
| main() | |