Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from nervaluate import Evaluator | |
| import datasets | |
| import evaluate | |
| _DESCRIPTION = """ | |
| Add a very nice description! | |
| """ | |
| _CITATION = """\ | |
| @misc{nereval, | |
| title={{NER-Evaluation}: Named Entity Evaluation as in SemEval 2013 task 9.1}, | |
| url={https://github.com/davidsbatista/NER-Evaluation}, | |
| note={Software available from https://github.com/davidsbatista/NER-Evaluation}, | |
| author={Batista David}, | |
| year={2018}, | |
| } | |
| """ | |
| # TODO: Add description of the arguments of the module here | |
| _KWARGS_DESCRIPTION = """ | |
| Add descrition on parameters! | |
| """ | |
| class Nervaluate(evaluate.Metric): | |
| def _info(self): | |
| return datasets.MetricInfo( | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Sequence( | |
| datasets.Value("string", id="label"), id="sequence" | |
| ), | |
| "references": datasets.Sequence( | |
| datasets.Value("string", id="label"), id="sequence" | |
| ), | |
| } | |
| ), | |
| reference_urls=["https://github.com/MantisAI/nervaluate"], | |
| ) | |
| def _compute(self, predictions, references): | |
| metrics_result = {} | |
| # todo: read from model file | |
| entities_list = ['TIM', 'KV', 'IP'] | |
| evaluator = Evaluator(references, predictions, | |
| tags=entities_list) | |
| results, results_per_tag = evaluator.evaluate() | |
| metrics_result['Global Strict F1'] = \ | |
| round(results['strict']['f1'], 2) | |
| metrics_result['results Partial F1'] = \ | |
| round(results['ent_type']['f1'], 2) | |
| for ent in results_per_tag: | |
| metrics_result[ent + ' Strict F1'] = \ | |
| round(results_per_tag[ent]['strict']['f1'], 2) | |
| metrics_result[ent + ' Partial F1'] = \ | |
| round(results_per_tag[ent]['ent_type']['f1'], 2) | |
| return metrics_result | |