Spaces:
Build error
Build error
Try updating possible types.
Browse files- vendiscore.py +12 -18
vendiscore.py
CHANGED
|
@@ -85,10 +85,7 @@ class VendiScore(evaluate.Metric):
|
|
| 85 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 86 |
features=datasets.Features(
|
| 87 |
{
|
| 88 |
-
"
|
| 89 |
-
"imgs": datasets.Image,
|
| 90 |
-
"X": datasets.Array2D,
|
| 91 |
-
"K": datasets.Array2D,
|
| 92 |
}
|
| 93 |
),
|
| 94 |
homepage="http://github.com/Vertaix/Vendi-Score",
|
|
@@ -103,10 +100,7 @@ class VendiScore(evaluate.Metric):
|
|
| 103 |
|
| 104 |
def _compute(
|
| 105 |
self,
|
| 106 |
-
|
| 107 |
-
imgs=None,
|
| 108 |
-
X=None,
|
| 109 |
-
K=None,
|
| 110 |
k="ngram_overlap",
|
| 111 |
score_K=False,
|
| 112 |
score_X=False,
|
|
@@ -121,16 +115,18 @@ class VendiScore(evaluate.Metric):
|
|
| 121 |
device="cpu",
|
| 122 |
):
|
| 123 |
if score_K:
|
| 124 |
-
vs = vendi.score_K(
|
| 125 |
elif score_dual:
|
| 126 |
-
vs = vendi.score_dual(
|
| 127 |
elif score_X:
|
| 128 |
-
vs = vendi.score_X(
|
| 129 |
elif type(k) == str and k == "ngram_overlap":
|
| 130 |
-
vs = text_utils.ngram_vendi_score(
|
|
|
|
|
|
|
| 131 |
elif type(k) == str and k == "text_embeddings":
|
| 132 |
vs = text_utils.embedding_vendi_score(
|
| 133 |
-
|
| 134 |
model=model,
|
| 135 |
tokenizer=tokenizer,
|
| 136 |
batch_size=batch_size,
|
|
@@ -138,17 +134,15 @@ class VendiScore(evaluate.Metric):
|
|
| 138 |
model_path=model_path,
|
| 139 |
)
|
| 140 |
elif type(k) == str and k == "pixels":
|
| 141 |
-
vs = image_utils.pixel_vendi_score(
|
| 142 |
elif type(k) == str and k == "image_embeddings":
|
| 143 |
vs = image_utils.embedding_vendi_score(
|
| 144 |
-
|
| 145 |
batch_size=batch_size,
|
| 146 |
device=device,
|
| 147 |
model=model,
|
| 148 |
transform=transform,
|
| 149 |
)
|
| 150 |
-
elif sents is not None or imgs is not None or X is not None:
|
| 151 |
-
vs = vendi.score(sents or imgs or X, k)
|
| 152 |
else:
|
| 153 |
-
|
| 154 |
return {"VS": vs}
|
|
|
|
| 85 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 86 |
features=datasets.Features(
|
| 87 |
{
|
| 88 |
+
"samples": datasets.Array2D,
|
|
|
|
|
|
|
|
|
|
| 89 |
}
|
| 90 |
),
|
| 91 |
homepage="http://github.com/Vertaix/Vendi-Score",
|
|
|
|
| 100 |
|
| 101 |
def _compute(
|
| 102 |
self,
|
| 103 |
+
samples,
|
|
|
|
|
|
|
|
|
|
| 104 |
k="ngram_overlap",
|
| 105 |
score_K=False,
|
| 106 |
score_X=False,
|
|
|
|
| 115 |
device="cpu",
|
| 116 |
):
|
| 117 |
if score_K:
|
| 118 |
+
vs = vendi.score_K(samples, normalize=normalize)
|
| 119 |
elif score_dual:
|
| 120 |
+
vs = vendi.score_dual(samples, normalize=normalize)
|
| 121 |
elif score_X:
|
| 122 |
+
vs = vendi.score_X(samples, normalize=normalize)
|
| 123 |
elif type(k) == str and k == "ngram_overlap":
|
| 124 |
+
vs = text_utils.ngram_vendi_score(
|
| 125 |
+
samples, ns=ns, tokenizer=tokenizer
|
| 126 |
+
)
|
| 127 |
elif type(k) == str and k == "text_embeddings":
|
| 128 |
vs = text_utils.embedding_vendi_score(
|
| 129 |
+
samples,
|
| 130 |
model=model,
|
| 131 |
tokenizer=tokenizer,
|
| 132 |
batch_size=batch_size,
|
|
|
|
| 134 |
model_path=model_path,
|
| 135 |
)
|
| 136 |
elif type(k) == str and k == "pixels":
|
| 137 |
+
vs = image_utils.pixel_vendi_score(samples)
|
| 138 |
elif type(k) == str and k == "image_embeddings":
|
| 139 |
vs = image_utils.embedding_vendi_score(
|
| 140 |
+
samples,
|
| 141 |
batch_size=batch_size,
|
| 142 |
device=device,
|
| 143 |
model=model,
|
| 144 |
transform=transform,
|
| 145 |
)
|
|
|
|
|
|
|
| 146 |
else:
|
| 147 |
+
vs = vendi.score(samples, k)
|
| 148 |
return {"VS": vs}
|