Spaces:
Running
Running
output in dataset
Browse files
tasks.py
CHANGED
|
@@ -62,6 +62,7 @@ class Task:
|
|
| 62 |
metric_name: str | tuple[str, str] = ("sustech/tlem", "mmlu")
|
| 63 |
input_column: str = "question"
|
| 64 |
label_column: str = ""
|
|
|
|
| 65 |
prompt: Optional[Callable | str] = None
|
| 66 |
few_shot: int = 0
|
| 67 |
few_shot_from: Optional[str] = None
|
|
@@ -85,7 +86,6 @@ class Task:
|
|
| 85 |
)
|
| 86 |
}
|
| 87 |
self.label_column = self.label_column or self.input_column
|
| 88 |
-
self.outputs = []
|
| 89 |
|
| 90 |
def __eq__(self, __value: object) -> bool:
|
| 91 |
return self.name == __value.name
|
|
@@ -98,6 +98,10 @@ class Task:
|
|
| 98 |
def labels(self):
|
| 99 |
return self.dataset[self.label_column]
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
@cached_property
|
| 102 |
def dataset(self):
|
| 103 |
ds = (
|
|
@@ -160,20 +164,29 @@ class Task:
|
|
| 160 |
# logging.info(f"{self.name}:{results}")
|
| 161 |
return results
|
| 162 |
|
| 163 |
-
# @cache
|
| 164 |
def run(
|
| 165 |
self,
|
| 166 |
pipeline,
|
| 167 |
):
|
| 168 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
return self.result
|
| 171 |
|
| 172 |
async def arun(self, pipeline):
|
| 173 |
-
self.
|
|
|
|
|
|
|
| 174 |
|
| 175 |
return self.result
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
def multichoice(responses: Any, references: list[str]):
|
| 179 |
if isinstance(responses[0], str):
|
|
|
|
| 62 |
metric_name: str | tuple[str, str] = ("sustech/tlem", "mmlu")
|
| 63 |
input_column: str = "question"
|
| 64 |
label_column: str = ""
|
| 65 |
+
output_column: str = "generated_text"
|
| 66 |
prompt: Optional[Callable | str] = None
|
| 67 |
few_shot: int = 0
|
| 68 |
few_shot_from: Optional[str] = None
|
|
|
|
| 86 |
)
|
| 87 |
}
|
| 88 |
self.label_column = self.label_column or self.input_column
|
|
|
|
| 89 |
|
| 90 |
def __eq__(self, __value: object) -> bool:
|
| 91 |
return self.name == __value.name
|
|
|
|
| 98 |
def labels(self):
|
| 99 |
return self.dataset[self.label_column]
|
| 100 |
|
| 101 |
+
@cached_property
|
| 102 |
+
def outputs(self):
|
| 103 |
+
return self.dataset[self.output_column]
|
| 104 |
+
|
| 105 |
@cached_property
|
| 106 |
def dataset(self):
|
| 107 |
ds = (
|
|
|
|
| 164 |
# logging.info(f"{self.name}:{results}")
|
| 165 |
return results
|
| 166 |
|
|
|
|
| 167 |
def run(
|
| 168 |
self,
|
| 169 |
pipeline,
|
| 170 |
):
|
| 171 |
+
if self.output_column not in self.dataset.column_names:
|
| 172 |
+
self.dataset = self.dataset.add_column(
|
| 173 |
+
self.output_column, pipeline(self.samples)
|
| 174 |
+
)
|
| 175 |
|
| 176 |
return self.result
|
| 177 |
|
| 178 |
async def arun(self, pipeline):
|
| 179 |
+
self.dataset = self.dataset.add_column(
|
| 180 |
+
self.output_column, await pipeline(self.samples)
|
| 181 |
+
)
|
| 182 |
|
| 183 |
return self.result
|
| 184 |
|
| 185 |
+
def save(self, path):
|
| 186 |
+
self.dataset.select_columns(
|
| 187 |
+
[self.input_column, self.output_column, self.label_column]
|
| 188 |
+
).save_to_disk(path)
|
| 189 |
+
|
| 190 |
|
| 191 |
def multichoice(responses: Any, references: list[str]):
|
| 192 |
if isinstance(responses[0], str):
|