Spaces:
Build error
Build error
Performance improvements
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
|
@@ -22,11 +22,12 @@ from src.distilabel_dataset_generator.utils import (
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
-
def _run_pipeline(result_queue, num_turns, num_rows, system_prompt):
|
| 26 |
pipeline = get_pipeline(
|
| 27 |
num_turns,
|
| 28 |
num_rows,
|
| 29 |
system_prompt,
|
|
|
|
| 30 |
)
|
| 31 |
distiset: Distiset = pipeline.run(use_cache=False)
|
| 32 |
result_queue.put(distiset)
|
|
@@ -54,7 +55,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
| 54 |
|
| 55 |
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
| 56 |
progress(0.1, desc="Initializing sample dataset generation")
|
| 57 |
-
result = generate_dataset(system_prompt, num_turns=1, num_rows=1, progress=progress)
|
| 58 |
progress(1.0, desc="Sample dataset generated")
|
| 59 |
return result
|
| 60 |
|
|
@@ -68,6 +69,7 @@ def generate_dataset(
|
|
| 68 |
repo_name: str = None,
|
| 69 |
oauth_token: str = None,
|
| 70 |
progress=gr.Progress(),
|
|
|
|
| 71 |
):
|
| 72 |
repo_id = (
|
| 73 |
f"{org_name}/{repo_name}"
|
|
@@ -88,8 +90,9 @@ def generate_dataset(
|
|
| 88 |
gr.Info(
|
| 89 |
"You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
duration = 60
|
| 94 |
elif num_rows < 30:
|
| 95 |
duration = 120
|
|
@@ -105,7 +108,7 @@ def generate_dataset(
|
|
| 105 |
result_queue = multiprocessing.Queue()
|
| 106 |
p = multiprocessing.Process(
|
| 107 |
target=_run_pipeline,
|
| 108 |
-
args=(result_queue, num_turns, num_rows, system_prompt),
|
| 109 |
)
|
| 110 |
|
| 111 |
try:
|
|
@@ -175,28 +178,31 @@ with gr.Blocks(
|
|
| 175 |
)
|
| 176 |
with gr.Row():
|
| 177 |
gr.Column(scale=1)
|
| 178 |
-
btn_generate_system_prompt = gr.Button(value="Generate sample
|
| 179 |
gr.Column(scale=1)
|
|
|
|
| 180 |
|
| 181 |
system_prompt = gr.TextArea(
|
| 182 |
-
label="
|
| 183 |
value=DEFAULT_SYSTEM_PROMPT,
|
| 184 |
)
|
| 185 |
|
| 186 |
-
with gr.Row():
|
| 187 |
-
gr.Column(scale=1)
|
| 188 |
-
btn_generate_sample_dataset = gr.Button(
|
| 189 |
-
value="Regenerate sample dataset",
|
| 190 |
-
)
|
| 191 |
-
gr.Column(scale=1)
|
| 192 |
-
|
| 193 |
with gr.Row():
|
| 194 |
table = gr.DataFrame(
|
| 195 |
value=DEFAULT_DATASET,
|
|
|
|
| 196 |
interactive=False,
|
| 197 |
wrap=True,
|
| 198 |
)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
result = btn_generate_system_prompt.click(
|
| 201 |
fn=generate_system_prompt,
|
| 202 |
inputs=[dataset_description],
|
|
@@ -233,10 +239,10 @@ with gr.Blocks(
|
|
| 233 |
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
| 234 |
)
|
| 235 |
num_rows = gr.Number(
|
| 236 |
-
value=
|
| 237 |
label="Number of rows in the dataset",
|
| 238 |
minimum=1,
|
| 239 |
-
maximum=
|
| 240 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
| 241 |
)
|
| 242 |
|
|
@@ -249,16 +255,24 @@ with gr.Blocks(
|
|
| 249 |
visible=False,
|
| 250 |
)
|
| 251 |
org_name = get_org_dropdown()
|
| 252 |
-
repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name")
|
| 253 |
private = gr.Checkbox(
|
| 254 |
label="Private dataset", value=True, interactive=True, scale=0.5
|
| 255 |
)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
| 261 |
success_message = gr.Markdown(visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
def show_success_message(org_name, repo_name):
|
| 264 |
return gr.Markdown(
|
|
@@ -294,7 +308,7 @@ with gr.Blocks(
|
|
| 294 |
repo_name,
|
| 295 |
oauth_token,
|
| 296 |
],
|
| 297 |
-
outputs=[
|
| 298 |
show_progress=True,
|
| 299 |
).success(
|
| 300 |
fn=show_success_message,
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
+
def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, is_sample):
|
| 26 |
pipeline = get_pipeline(
|
| 27 |
num_turns,
|
| 28 |
num_rows,
|
| 29 |
system_prompt,
|
| 30 |
+
is_sample
|
| 31 |
)
|
| 32 |
distiset: Distiset = pipeline.run(use_cache=False)
|
| 33 |
result_queue.put(distiset)
|
|
|
|
| 55 |
|
| 56 |
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
| 57 |
progress(0.1, desc="Initializing sample dataset generation")
|
| 58 |
+
result = generate_dataset(system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True)
|
| 59 |
progress(1.0, desc="Sample dataset generated")
|
| 60 |
return result
|
| 61 |
|
|
|
|
| 69 |
repo_name: str = None,
|
| 70 |
oauth_token: str = None,
|
| 71 |
progress=gr.Progress(),
|
| 72 |
+
is_sample: bool = False,
|
| 73 |
):
|
| 74 |
repo_id = (
|
| 75 |
f"{org_name}/{repo_name}"
|
|
|
|
| 90 |
gr.Info(
|
| 91 |
"You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
|
| 92 |
)
|
| 93 |
+
if num_rows < 5:
|
| 94 |
+
duration = 25
|
| 95 |
+
elif num_rows < 10:
|
| 96 |
duration = 60
|
| 97 |
elif num_rows < 30:
|
| 98 |
duration = 120
|
|
|
|
| 108 |
result_queue = multiprocessing.Queue()
|
| 109 |
p = multiprocessing.Process(
|
| 110 |
target=_run_pipeline,
|
| 111 |
+
args=(result_queue, num_turns, num_rows, system_prompt, is_sample),
|
| 112 |
)
|
| 113 |
|
| 114 |
try:
|
|
|
|
| 178 |
)
|
| 179 |
with gr.Row():
|
| 180 |
gr.Column(scale=1)
|
| 181 |
+
btn_generate_system_prompt = gr.Button(value="Generate sample")
|
| 182 |
gr.Column(scale=1)
|
| 183 |
+
|
| 184 |
|
| 185 |
system_prompt = gr.TextArea(
|
| 186 |
+
label="System prompt for dataset generation. You can tune it and regenerate the sample",
|
| 187 |
value=DEFAULT_SYSTEM_PROMPT,
|
| 188 |
)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
with gr.Row():
|
| 191 |
table = gr.DataFrame(
|
| 192 |
value=DEFAULT_DATASET,
|
| 193 |
+
label="Sample dataset. Prompts and completions truncated to 256 tokens.",
|
| 194 |
interactive=False,
|
| 195 |
wrap=True,
|
| 196 |
)
|
| 197 |
|
| 198 |
+
|
| 199 |
+
with gr.Row():
|
| 200 |
+
gr.Column(scale=1)
|
| 201 |
+
btn_generate_sample_dataset = gr.Button(
|
| 202 |
+
value="Regenerate sample",
|
| 203 |
+
)
|
| 204 |
+
gr.Column(scale=1)
|
| 205 |
+
|
| 206 |
result = btn_generate_system_prompt.click(
|
| 207 |
fn=generate_system_prompt,
|
| 208 |
inputs=[dataset_description],
|
|
|
|
| 239 |
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
| 240 |
)
|
| 241 |
num_rows = gr.Number(
|
| 242 |
+
value=10,
|
| 243 |
label="Number of rows in the dataset",
|
| 244 |
minimum=1,
|
| 245 |
+
maximum=500,
|
| 246 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
| 247 |
)
|
| 248 |
|
|
|
|
| 255 |
visible=False,
|
| 256 |
)
|
| 257 |
org_name = get_org_dropdown()
|
| 258 |
+
repo_name = gr.Textbox(label="Repo name", placeholder="dataset_name", value="my-distiset")
|
| 259 |
private = gr.Checkbox(
|
| 260 |
label="Private dataset", value=True, interactive=True, scale=0.5
|
| 261 |
)
|
| 262 |
+
with gr.Row() as regenerate_row:
|
| 263 |
+
gr.Column(scale=1)
|
| 264 |
+
btn_generate_full_dataset = gr.Button(
|
| 265 |
+
value="Generate Full Dataset", variant="primary"
|
| 266 |
+
)
|
| 267 |
+
gr.Column(scale=1)
|
| 268 |
success_message = gr.Markdown(visible=False)
|
| 269 |
+
with gr.Row():
|
| 270 |
+
final_dataset = gr.DataFrame(
|
| 271 |
+
value=DEFAULT_DATASET,
|
| 272 |
+
label="Generated dataset",
|
| 273 |
+
interactive=False,
|
| 274 |
+
wrap=True,
|
| 275 |
+
)
|
| 276 |
|
| 277 |
def show_success_message(org_name, repo_name):
|
| 278 |
return gr.Markdown(
|
|
|
|
| 308 |
repo_name,
|
| 309 |
oauth_token,
|
| 310 |
],
|
| 311 |
+
outputs=[final_dataset],
|
| 312 |
show_progress=True,
|
| 313 |
).success(
|
| 314 |
fn=show_success_message,
|