Spaces:
Build error
Build error
Try updating possible types.
Browse files- vendiscore.py +18 -12
vendiscore.py
CHANGED
|
@@ -85,7 +85,10 @@ class VendiScore(evaluate.Metric):
|
|
| 85 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 86 |
features=datasets.Features(
|
| 87 |
{
|
| 88 |
-
"
|
|
|
|
|
|
|
|
|
|
| 89 |
}
|
| 90 |
),
|
| 91 |
homepage="http://github.com/Vertaix/Vendi-Score",
|
|
@@ -100,7 +103,10 @@ class VendiScore(evaluate.Metric):
|
|
| 100 |
|
| 101 |
def _compute(
|
| 102 |
self,
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
k="ngram_overlap",
|
| 105 |
score_K=False,
|
| 106 |
score_X=False,
|
|
@@ -115,18 +121,16 @@ class VendiScore(evaluate.Metric):
|
|
| 115 |
device="cpu",
|
| 116 |
):
|
| 117 |
if score_K:
|
| 118 |
-
vs = vendi.score_K(
|
| 119 |
elif score_dual:
|
| 120 |
-
vs = vendi.score_dual(
|
| 121 |
elif score_X:
|
| 122 |
-
vs = vendi.score_X(
|
| 123 |
elif type(k) == str and k == "ngram_overlap":
|
| 124 |
-
vs = text_utils.ngram_vendi_score(
|
| 125 |
-
samples, ns=ns, tokenizer=tokenizer
|
| 126 |
-
)
|
| 127 |
elif type(k) == str and k == "text_embeddings":
|
| 128 |
vs = text_utils.embedding_vendi_score(
|
| 129 |
-
|
| 130 |
model=model,
|
| 131 |
tokenizer=tokenizer,
|
| 132 |
batch_size=batch_size,
|
|
@@ -134,15 +138,17 @@ class VendiScore(evaluate.Metric):
|
|
| 134 |
model_path=model_path,
|
| 135 |
)
|
| 136 |
elif type(k) == str and k == "pixels":
|
| 137 |
-
vs = image_utils.pixel_vendi_score(
|
| 138 |
elif type(k) == str and k == "image_embeddings":
|
| 139 |
vs = image_utils.embedding_vendi_score(
|
| 140 |
-
|
| 141 |
batch_size=batch_size,
|
| 142 |
device=device,
|
| 143 |
model=model,
|
| 144 |
transform=transform,
|
| 145 |
)
|
|
|
|
|
|
|
| 146 |
else:
|
| 147 |
-
|
| 148 |
return {"VS": vs}
|
|
|
|
| 85 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 86 |
features=datasets.Features(
|
| 87 |
{
|
| 88 |
+
"sents": datasets.Value("string"),
|
| 89 |
+
"imgs": datasets.Image,
|
| 90 |
+
"X": datasets.Array2D,
|
| 91 |
+
"K": datasets.Array2D,
|
| 92 |
}
|
| 93 |
),
|
| 94 |
homepage="http://github.com/Vertaix/Vendi-Score",
|
|
|
|
| 103 |
|
| 104 |
def _compute(
|
| 105 |
self,
|
| 106 |
+
sents=None,
|
| 107 |
+
imgs=None,
|
| 108 |
+
X=None,
|
| 109 |
+
K=None,
|
| 110 |
k="ngram_overlap",
|
| 111 |
score_K=False,
|
| 112 |
score_X=False,
|
|
|
|
| 121 |
device="cpu",
|
| 122 |
):
|
| 123 |
if score_K:
|
| 124 |
+
vs = vendi.score_K(K, normalize=normalize)
|
| 125 |
elif score_dual:
|
| 126 |
+
vs = vendi.score_dual(X, normalize=normalize)
|
| 127 |
elif score_X:
|
| 128 |
+
vs = vendi.score_X(X, normalize=normalize)
|
| 129 |
elif type(k) == str and k == "ngram_overlap":
|
| 130 |
+
vs = text_utils.ngram_vendi_score(sents, ns=ns, tokenizer=tokenizer)
|
|
|
|
|
|
|
| 131 |
elif type(k) == str and k == "text_embeddings":
|
| 132 |
vs = text_utils.embedding_vendi_score(
|
| 133 |
+
sents,
|
| 134 |
model=model,
|
| 135 |
tokenizer=tokenizer,
|
| 136 |
batch_size=batch_size,
|
|
|
|
| 138 |
model_path=model_path,
|
| 139 |
)
|
| 140 |
elif type(k) == str and k == "pixels":
|
| 141 |
+
vs = image_utils.pixel_vendi_score(imgs)
|
| 142 |
elif type(k) == str and k == "image_embeddings":
|
| 143 |
vs = image_utils.embedding_vendi_score(
|
| 144 |
+
imgs,
|
| 145 |
batch_size=batch_size,
|
| 146 |
device=device,
|
| 147 |
model=model,
|
| 148 |
transform=transform,
|
| 149 |
)
|
| 150 |
+
elif sents is not None or imgs is not None or X is not None:
|
| 151 |
+
vs = vendi.score(sents or imgs or X, k)
|
| 152 |
else:
|
| 153 |
+
raise ValueError(f"Must provide one of `sents` or `imgs` or `X`.")
|
| 154 |
return {"VS": vs}
|