Spaces:
Build error
Build error
| # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import evaluate | |
| import datasets | |
| import numpy as np | |
| from vendi_score import vendi, text_utils | |
| # TODO: Add BibTeX citation | |
| _CITATION = "" | |
| _DESCRIPTION = """\ | |
| The Vendi Score is a metric for evaluating diversity in machine learning. | |
| The input to metric is a collection of samples and a pairwise similarity function, and the output is a number, which can be interpreted as the effective number of unique elements in the sample. | |
| See the project's README at https://github.com/vertaix/Vendi-Score for more information. | |
| The interactive example calculates the Vendi Score for a set of strings using the n-gram overlap similarity, averaged between n=1 and n=2. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Calculates the Vendi Score given samples and a similarity function. | |
| Args: | |
| samples: an iterable containing n samples to score, an n x n similarity | |
| matrix K, or an n x d feature matrix X. | |
| k: a pairwise similarity function, or a string identifying a predefined | |
| similarity function. | |
| Options: ngram_overlap, text_embeddings. | |
| score_K: if true, samples is an n x n similarity matrix K. | |
| score_X: if true, samples is an n x d feature matrix X. | |
| score_dual: if true, compute diversity score of X @ X.T. | |
| normalize: if true, normalize the similarity scores. | |
| model (optional): if k is "text_embeddings", a model mapping sentences to | |
| embeddings (output should be an object with an attribute called | |
| `pooler_output` or `last_hidden_state`). | |
| tokenizer (optional): if k is "text_embeddings" or "ngram_overlap", a | |
| tokenizer mapping strings to lists. | |
| model_path (optional): if k is "text_embeddings", the name of a model on the | |
| HuggingFace hub. | |
| ns (optional): if k is "ngram_overlap", the values of n to calculate. | |
| batch_size (optional): batch size to use if k is "text_embedding". | |
| device (optional): a string (e.g. "cuda", "cpu") or torch.device identifying | |
| the device to use if k is "text_embedding". | |
| Returns: | |
| VS: The Vendi Score. | |
| Examples: | |
| >>> vendiscore = evaluate.load("Vertaix/vendiscore", "text") | |
| >>> samples = ["Look, Jane.", | |
| "See Spot.", | |
| "See Spot run.", | |
| "Run, Spot, run.", | |
| "Jane sees Spot run."] | |
| >>> results = vendiscore.compute(samples, k="ngram_overlap", ns=[1, 2]) | |
| >>> print(results) | |
| {'VS': 3.90657...} | |
| """ | |
| def get_features(config_name): | |
| if config_name in ("text", "default"): | |
| return datasets.Features({"samples": datasets.Value("string")}) | |
| # if config_name == "image": | |
| # return datasets.Features({"samples": datasets.Image}) | |
| if config_name in ("K", "X"): | |
| return [ | |
| datasets.Features( | |
| {"samples": datasets.Sequence(datasets.Value("float"))} | |
| ), | |
| datasets.Features( | |
| {"samples": datasets.Sequence(datasets.Value("int32"))} | |
| ), | |
| ] | |
| return [ | |
| datasets.Features({"samples": datasets.Value("float")}), | |
| datasets.Features({"samples": datasets.Value("int32")}), | |
| datasets.Features({"samples": datasets.Array2D}), | |
| ] | |
| class VendiScore(evaluate.Metric): | |
| """TODO: Short description of my evaluation module.""" | |
| def _info(self): | |
| # TODO: Specifies the evaluate.EvaluationModuleInfo object | |
| return evaluate.MetricInfo( | |
| # This is the description that will appear on the modules page. | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=get_features(self.config_name), | |
| homepage="http://github.com/Vertaix/Vendi-Score", | |
| codebase_urls=["http://github.com/Vertaix/Vendi-Score"], | |
| reference_urls=[], | |
| ) | |
| def _download_and_prepare(self, dl_manager): | |
| import nltk | |
| nltk.download("punkt") | |
| def _compute( | |
| self, | |
| samples, | |
| k="ngram_overlap", | |
| score_K=False, | |
| score_X=False, | |
| score_dual=False, | |
| normalize=False, | |
| model=None, | |
| tokenizer=None, | |
| model_path=None, | |
| ns=[1, 2], | |
| batch_size=16, | |
| device="cpu", | |
| ): | |
| if score_K: | |
| vs = vendi.score_K(np.array(samples), normalize=normalize) | |
| elif score_dual: | |
| vs = vendi.score_dual(np.array(samples), normalize=normalize) | |
| elif score_X: | |
| vs = vendi.score_X(np.array(samples), normalize=normalize) | |
| elif type(k) == str and k == "ngram_overlap": | |
| vs = text_utils.ngram_vendi_score( | |
| samples, ns=ns, tokenizer=tokenizer | |
| ) | |
| elif type(k) == str and k == "text_embeddings": | |
| vs = text_utils.embedding_vendi_score( | |
| samples, | |
| model=model, | |
| tokenizer=tokenizer, | |
| batch_size=batch_size, | |
| device=device, | |
| model_path=model_path, | |
| ) | |
| # elif type(k) == str and k == "pixels": | |
| # vs = image_utils.pixel_vendi_score( | |
| # [Image.fromarray(x) for x in samples] | |
| # ) | |
| # elif type(k) == str and k == "image_embeddings": | |
| # vs = image_utils.embedding_vendi_score( | |
| # [Image.fromarray(x) for x in samples], | |
| # batch_size=batch_size, | |
| # device=device, | |
| # model=model, | |
| # transform=transform, | |
| # ) | |
| else: | |
| vs = vendi.score(samples, k) | |
| return {"VS": vs} | |