Spaces:
Build error
Build error
sebaxakerhtc
commited on
Added local saving to CSV and JSON (#38)
Browse files* Local save
- Added save local function to chat tab (CSV, JSON)
- Rebuild UI with new feature
- CSS edit for gr.File (perfectionism)
* Local save
* Mistake
* Update chat.py
* Local save RAG and Textcat
* Rebuild UI
* Show save_local only if save_local_dir is provided
- src/synthetic_dataset_generator/app.py +5 -4
- src/synthetic_dataset_generator/apps/base.py +5 -1
- src/synthetic_dataset_generator/apps/chat.py +73 -2
- src/synthetic_dataset_generator/apps/rag.py +79 -2
- src/synthetic_dataset_generator/apps/textcat.py +62 -2
- src/synthetic_dataset_generator/constants.py +3 -0
src/synthetic_dataset_generator/app.py
CHANGED
|
@@ -12,12 +12,13 @@ css = """
|
|
| 12 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
| 13 |
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
| 14 |
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
| 15 |
-
.tabitem {
|
| 16 |
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
| 17 |
-
.table-wrap .tbody td {
|
| 18 |
-
#system_prompt_examples {
|
| 19 |
.container {padding-inline: 0 !important}
|
| 20 |
-
#sign_in_button {
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
|
|
|
| 12 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
| 13 |
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
| 14 |
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
| 15 |
+
.tabitem {border: 0; padding-inline: 0}
|
| 16 |
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
| 17 |
+
.table-wrap .tbody td {vertical-align: top}
|
| 18 |
+
#system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
| 19 |
.container {padding-inline: 0 !important}
|
| 20 |
+
#sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
|
| 21 |
+
.datasets {height: 70px;}
|
| 22 |
"""
|
| 23 |
|
| 24 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
src/synthetic_dataset_generator/apps/base.py
CHANGED
|
@@ -12,9 +12,13 @@ from huggingface_hub import HfApi, upload_file, repo_exists
|
|
| 12 |
from unstructured.chunking.title import chunk_by_title
|
| 13 |
from unstructured.partition.auto import partition
|
| 14 |
|
| 15 |
-
from synthetic_dataset_generator.constants import MAX_NUM_ROWS
|
| 16 |
from synthetic_dataset_generator.utils import get_argilla_client
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def validate_argilla_user_workspace_dataset(
|
| 20 |
dataset_name: str,
|
|
|
|
| 12 |
from unstructured.chunking.title import chunk_by_title
|
| 13 |
from unstructured.partition.auto import partition
|
| 14 |
|
| 15 |
+
from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
|
| 16 |
from synthetic_dataset_generator.utils import get_argilla_client
|
| 17 |
|
| 18 |
+
if SAVE_LOCAL_DIR is not None:
|
| 19 |
+
import os
|
| 20 |
+
os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
|
| 21 |
+
|
| 22 |
|
| 23 |
def validate_argilla_user_workspace_dataset(
|
| 24 |
dataset_name: str,
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
|
@@ -2,6 +2,7 @@ import ast
|
|
| 2 |
import json
|
| 3 |
import random
|
| 4 |
import uuid
|
|
|
|
| 5 |
from typing import Dict, List, Union
|
| 6 |
|
| 7 |
import argilla as rg
|
|
@@ -30,6 +31,7 @@ from synthetic_dataset_generator.constants import (
|
|
| 30 |
MODEL,
|
| 31 |
MODEL_COMPLETION,
|
| 32 |
SFT_AVAILABLE,
|
|
|
|
| 33 |
)
|
| 34 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 35 |
from synthetic_dataset_generator.pipelines.chat import (
|
|
@@ -264,7 +266,6 @@ def generate_dataset_from_prompt(
|
|
| 264 |
progress(1.0, desc="Dataset generation completed")
|
| 265 |
return dataframe
|
| 266 |
|
| 267 |
-
|
| 268 |
def generate_dataset_from_seed(
|
| 269 |
dataframe: pd.DataFrame,
|
| 270 |
document_column: str,
|
|
@@ -506,7 +507,7 @@ def push_dataset(
|
|
| 506 |
num_turns=num_turns,
|
| 507 |
num_rows=num_rows,
|
| 508 |
temperature=temperature,
|
| 509 |
-
temperature_completion=temperature_completion
|
| 510 |
)
|
| 511 |
push_dataset_to_hub(
|
| 512 |
dataframe=dataframe,
|
|
@@ -637,6 +638,45 @@ def push_dataset(
|
|
| 637 |
return ""
|
| 638 |
|
| 639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
def show_system_prompt_visibility():
|
| 641 |
return {system_prompt: gr.Textbox(visible=True)}
|
| 642 |
|
|
@@ -670,6 +710,13 @@ def hide_pipeline_code_visibility():
|
|
| 670 |
def show_temperature_completion():
|
| 671 |
if MODEL != MODEL_COMPLETION:
|
| 672 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
|
| 674 |
|
| 675 |
######################
|
|
@@ -852,6 +899,11 @@ with gr.Blocks() as app:
|
|
| 852 |
btn_push_to_hub = gr.Button(
|
| 853 |
"Push to Hub", variant="primary", scale=2
|
| 854 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
with gr.Column(scale=3):
|
| 856 |
success_message = gr.Markdown(
|
| 857 |
visible=True,
|
|
@@ -998,6 +1050,23 @@ with gr.Blocks() as app:
|
|
| 998 |
inputs=[],
|
| 999 |
outputs=[pipeline_code_ui],
|
| 1000 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 1003 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
|
@@ -1011,3 +1080,5 @@ with gr.Blocks() as app:
|
|
| 1011 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 1012 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 1013 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import random
|
| 4 |
import uuid
|
| 5 |
+
import os
|
| 6 |
from typing import Dict, List, Union
|
| 7 |
|
| 8 |
import argilla as rg
|
|
|
|
| 31 |
MODEL,
|
| 32 |
MODEL_COMPLETION,
|
| 33 |
SFT_AVAILABLE,
|
| 34 |
+
SAVE_LOCAL_DIR,
|
| 35 |
)
|
| 36 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 37 |
from synthetic_dataset_generator.pipelines.chat import (
|
|
|
|
| 266 |
progress(1.0, desc="Dataset generation completed")
|
| 267 |
return dataframe
|
| 268 |
|
|
|
|
| 269 |
def generate_dataset_from_seed(
|
| 270 |
dataframe: pd.DataFrame,
|
| 271 |
document_column: str,
|
|
|
|
| 507 |
num_turns=num_turns,
|
| 508 |
num_rows=num_rows,
|
| 509 |
temperature=temperature,
|
| 510 |
+
temperature_completion=temperature_completion,
|
| 511 |
)
|
| 512 |
push_dataset_to_hub(
|
| 513 |
dataframe=dataframe,
|
|
|
|
| 638 |
return ""
|
| 639 |
|
| 640 |
|
| 641 |
+
def save_local(
|
| 642 |
+
repo_id: str,
|
| 643 |
+
file_paths: list[str],
|
| 644 |
+
input_type: str,
|
| 645 |
+
system_prompt: str,
|
| 646 |
+
document_column: str,
|
| 647 |
+
num_turns: int,
|
| 648 |
+
num_rows: int,
|
| 649 |
+
temperature: float,
|
| 650 |
+
repo_name: str,
|
| 651 |
+
temperature_completion: Union[float, None] = None,
|
| 652 |
+
) -> pd.DataFrame:
|
| 653 |
+
if input_type == "prompt-input":
|
| 654 |
+
dataframe = _get_dataframe()
|
| 655 |
+
else:
|
| 656 |
+
dataframe, _ = load_dataset_file(
|
| 657 |
+
repo_id=repo_id,
|
| 658 |
+
file_paths=file_paths,
|
| 659 |
+
input_type=input_type,
|
| 660 |
+
num_rows=num_rows,
|
| 661 |
+
)
|
| 662 |
+
dataframe = generate_dataset(
|
| 663 |
+
input_type=input_type,
|
| 664 |
+
dataframe=dataframe,
|
| 665 |
+
system_prompt=system_prompt,
|
| 666 |
+
document_column=document_column,
|
| 667 |
+
num_turns=num_turns,
|
| 668 |
+
num_rows=num_rows,
|
| 669 |
+
temperature=temperature,
|
| 670 |
+
temperature_completion=temperature_completion
|
| 671 |
+
)
|
| 672 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
| 673 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
| 674 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
| 675 |
+
local_dataset.to_csv(output_csv, index=False)
|
| 676 |
+
local_dataset.to_json(output_json, index=False)
|
| 677 |
+
return output_csv, output_json
|
| 678 |
+
|
| 679 |
+
|
| 680 |
def show_system_prompt_visibility():
|
| 681 |
return {system_prompt: gr.Textbox(visible=True)}
|
| 682 |
|
|
|
|
| 710 |
def show_temperature_completion():
|
| 711 |
if MODEL != MODEL_COMPLETION:
|
| 712 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
| 713 |
+
|
| 714 |
+
def show_save_local():
|
| 715 |
+
return {
|
| 716 |
+
btn_save_local: gr.Button(visible=True),
|
| 717 |
+
csv_file: gr.File(visible=True),
|
| 718 |
+
json_file: gr.File(visible=True)
|
| 719 |
+
}
|
| 720 |
|
| 721 |
|
| 722 |
######################
|
|
|
|
| 899 |
btn_push_to_hub = gr.Button(
|
| 900 |
"Push to Hub", variant="primary", scale=2
|
| 901 |
)
|
| 902 |
+
btn_save_local = gr.Button(
|
| 903 |
+
"Save locally", variant="primary", scale=2, visible=False
|
| 904 |
+
)
|
| 905 |
+
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
| 906 |
+
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
| 907 |
with gr.Column(scale=3):
|
| 908 |
success_message = gr.Markdown(
|
| 909 |
visible=True,
|
|
|
|
| 1050 |
inputs=[],
|
| 1051 |
outputs=[pipeline_code_ui],
|
| 1052 |
)
|
| 1053 |
+
|
| 1054 |
+
btn_save_local.click(
|
| 1055 |
+
save_local,
|
| 1056 |
+
inputs=[
|
| 1057 |
+
search_in,
|
| 1058 |
+
file_in,
|
| 1059 |
+
input_type,
|
| 1060 |
+
system_prompt,
|
| 1061 |
+
document_column,
|
| 1062 |
+
num_turns,
|
| 1063 |
+
num_rows,
|
| 1064 |
+
temperature,
|
| 1065 |
+
repo_name,
|
| 1066 |
+
temperature_completion,
|
| 1067 |
+
],
|
| 1068 |
+
outputs=[csv_file, json_file]
|
| 1069 |
+
)
|
| 1070 |
|
| 1071 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 1072 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
|
|
|
| 1080 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 1081 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 1082 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
| 1083 |
+
if SAVE_LOCAL_DIR is not None:
|
| 1084 |
+
app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
|
@@ -24,7 +24,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
| 24 |
validate_argilla_user_workspace_dataset,
|
| 25 |
validate_push_to_hub,
|
| 26 |
)
|
| 27 |
-
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
|
| 28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 30 |
get_embeddings,
|
|
@@ -486,6 +486,49 @@ def push_dataset(
|
|
| 486 |
return ""
|
| 487 |
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
def show_system_prompt_visibility():
|
| 490 |
return {system_prompt: gr.Textbox(visible=True)}
|
| 491 |
|
|
@@ -521,6 +564,14 @@ def show_temperature_completion():
|
|
| 521 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
| 522 |
|
| 523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
######################
|
| 525 |
# Gradio UI
|
| 526 |
######################
|
|
@@ -674,7 +725,14 @@ with gr.Blocks() as app:
|
|
| 674 |
interactive=True,
|
| 675 |
scale=1,
|
| 676 |
)
|
| 677 |
-
btn_push_to_hub = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
with gr.Column(scale=3):
|
| 679 |
success_message = gr.Markdown(
|
| 680 |
visible=True,
|
|
@@ -822,6 +880,23 @@ with gr.Blocks() as app:
|
|
| 822 |
outputs=[pipeline_code_ui],
|
| 823 |
)
|
| 824 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 826 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
| 827 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
|
@@ -835,3 +910,5 @@ with gr.Blocks() as app:
|
|
| 835 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 836 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 837 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
|
|
|
|
|
|
|
|
| 24 |
validate_argilla_user_workspace_dataset,
|
| 25 |
validate_push_to_hub,
|
| 26 |
)
|
| 27 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION, SAVE_LOCAL_DIR
|
| 28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 30 |
get_embeddings,
|
|
|
|
| 486 |
return ""
|
| 487 |
|
| 488 |
|
| 489 |
+
def save_local(
|
| 490 |
+
repo_id: str,
|
| 491 |
+
file_paths: list[str],
|
| 492 |
+
input_type: str,
|
| 493 |
+
system_prompt: str,
|
| 494 |
+
document_column: str,
|
| 495 |
+
retrieval_reranking: list[str],
|
| 496 |
+
num_rows: int,
|
| 497 |
+
temperature: float,
|
| 498 |
+
repo_name: str,
|
| 499 |
+
temperature_completion: float,
|
| 500 |
+
) -> pd.DataFrame:
|
| 501 |
+
retrieval = "Retrieval" in retrieval_reranking
|
| 502 |
+
reranking = "Reranking" in retrieval_reranking
|
| 503 |
+
|
| 504 |
+
if input_type == "prompt-input":
|
| 505 |
+
dataframe = pd.DataFrame(columns=["context", "question", "response"])
|
| 506 |
+
else:
|
| 507 |
+
dataframe, _ = load_dataset_file(
|
| 508 |
+
repo_id=repo_id,
|
| 509 |
+
file_paths=file_paths,
|
| 510 |
+
input_type=input_type,
|
| 511 |
+
num_rows=num_rows,
|
| 512 |
+
)
|
| 513 |
+
dataframe = generate_dataset(
|
| 514 |
+
input_type=input_type,
|
| 515 |
+
dataframe=dataframe,
|
| 516 |
+
system_prompt=system_prompt,
|
| 517 |
+
document_column=document_column,
|
| 518 |
+
retrieval=retrieval,
|
| 519 |
+
reranking=reranking,
|
| 520 |
+
num_rows=num_rows,
|
| 521 |
+
temperature=temperature,
|
| 522 |
+
temperature_completion=temperature_completion,
|
| 523 |
+
)
|
| 524 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
| 525 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
| 526 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
| 527 |
+
local_dataset.to_csv(output_csv, index=False)
|
| 528 |
+
local_dataset.to_json(output_json, index=False)
|
| 529 |
+
return output_csv, output_json
|
| 530 |
+
|
| 531 |
+
|
| 532 |
def show_system_prompt_visibility():
|
| 533 |
return {system_prompt: gr.Textbox(visible=True)}
|
| 534 |
|
|
|
|
| 564 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
| 565 |
|
| 566 |
|
| 567 |
+
def show_save_local():
|
| 568 |
+
return {
|
| 569 |
+
btn_save_local: gr.Button(visible=True),
|
| 570 |
+
csv_file: gr.File(visible=True),
|
| 571 |
+
json_file: gr.File(visible=True)
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
|
| 575 |
######################
|
| 576 |
# Gradio UI
|
| 577 |
######################
|
|
|
|
| 725 |
interactive=True,
|
| 726 |
scale=1,
|
| 727 |
)
|
| 728 |
+
btn_push_to_hub = gr.Button(
|
| 729 |
+
"Push to Hub", variant="primary", scale=2
|
| 730 |
+
)
|
| 731 |
+
btn_save_local = gr.Button(
|
| 732 |
+
"Save locally", variant="primary", scale=2, visible=False
|
| 733 |
+
)
|
| 734 |
+
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
| 735 |
+
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
| 736 |
with gr.Column(scale=3):
|
| 737 |
success_message = gr.Markdown(
|
| 738 |
visible=True,
|
|
|
|
| 880 |
outputs=[pipeline_code_ui],
|
| 881 |
)
|
| 882 |
|
| 883 |
+
btn_save_local.click(
|
| 884 |
+
save_local,
|
| 885 |
+
inputs=[
|
| 886 |
+
search_in,
|
| 887 |
+
file_in,
|
| 888 |
+
input_type,
|
| 889 |
+
system_prompt,
|
| 890 |
+
document_column,
|
| 891 |
+
retrieval_reranking,
|
| 892 |
+
num_rows,
|
| 893 |
+
temperature,
|
| 894 |
+
repo_name,
|
| 895 |
+
temperature_completion,
|
| 896 |
+
],
|
| 897 |
+
outputs=[csv_file, json_file]
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
| 901 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
| 902 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
|
|
|
| 910 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 911 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 912 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
| 913 |
+
if SAVE_LOCAL_DIR is not None:
|
| 914 |
+
app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import random
|
| 3 |
import uuid
|
|
@@ -19,7 +20,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
| 19 |
validate_argilla_user_workspace_dataset,
|
| 20 |
validate_push_to_hub,
|
| 21 |
)
|
| 22 |
-
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
| 23 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 24 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 25 |
get_embeddings,
|
|
@@ -406,6 +407,33 @@ def push_dataset(
|
|
| 406 |
return ""
|
| 407 |
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
def validate_input_labels(labels: List[str]) -> List[str]:
|
| 410 |
if (
|
| 411 |
not labels
|
|
@@ -425,6 +453,14 @@ def hide_pipeline_code_visibility():
|
|
| 425 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
| 426 |
|
| 427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
######################
|
| 429 |
# Gradio UI
|
| 430 |
######################
|
|
@@ -543,7 +579,14 @@ with gr.Blocks() as app:
|
|
| 543 |
interactive=True,
|
| 544 |
scale=1,
|
| 545 |
)
|
| 546 |
-
btn_push_to_hub = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
with gr.Column(scale=3):
|
| 548 |
success_message = gr.Markdown(
|
| 549 |
visible=True,
|
|
@@ -643,6 +686,21 @@ with gr.Blocks() as app:
|
|
| 643 |
inputs=[],
|
| 644 |
outputs=[pipeline_code_ui],
|
| 645 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
gr.on(
|
| 648 |
triggers=[clear_btn_part.click, clear_btn_full.click],
|
|
@@ -660,3 +718,5 @@ with gr.Blocks() as app:
|
|
| 660 |
app.load(fn=swap_visibility, outputs=main_ui)
|
| 661 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 662 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
import json
|
| 3 |
import random
|
| 4 |
import uuid
|
|
|
|
| 20 |
validate_argilla_user_workspace_dataset,
|
| 21 |
validate_push_to_hub,
|
| 22 |
)
|
| 23 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SAVE_LOCAL_DIR
|
| 24 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
| 25 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
| 26 |
get_embeddings,
|
|
|
|
| 407 |
return ""
|
| 408 |
|
| 409 |
|
| 410 |
+
def save_local(
|
| 411 |
+
system_prompt: str,
|
| 412 |
+
difficulty: str,
|
| 413 |
+
clarity: str,
|
| 414 |
+
labels: List[str],
|
| 415 |
+
multi_label: bool,
|
| 416 |
+
num_rows: int,
|
| 417 |
+
temperature: float,
|
| 418 |
+
repo_name: str,
|
| 419 |
+
) -> pd.DataFrame:
|
| 420 |
+
dataframe = generate_dataset(
|
| 421 |
+
system_prompt=system_prompt,
|
| 422 |
+
difficulty=difficulty,
|
| 423 |
+
clarity=clarity,
|
| 424 |
+
multi_label=multi_label,
|
| 425 |
+
labels=labels,
|
| 426 |
+
num_rows=num_rows,
|
| 427 |
+
temperature=temperature,
|
| 428 |
+
)
|
| 429 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
| 430 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
| 431 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
| 432 |
+
local_dataset.to_csv(output_csv, index=False)
|
| 433 |
+
local_dataset.to_json(output_json, index=False)
|
| 434 |
+
return output_csv, output_json
|
| 435 |
+
|
| 436 |
+
|
| 437 |
def validate_input_labels(labels: List[str]) -> List[str]:
|
| 438 |
if (
|
| 439 |
not labels
|
|
|
|
| 453 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
| 454 |
|
| 455 |
|
| 456 |
+
def show_save_local():
|
| 457 |
+
return {
|
| 458 |
+
btn_save_local: gr.Button(visible=True),
|
| 459 |
+
csv_file: gr.File(visible=True),
|
| 460 |
+
json_file: gr.File(visible=True)
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
|
| 464 |
######################
|
| 465 |
# Gradio UI
|
| 466 |
######################
|
|
|
|
| 579 |
interactive=True,
|
| 580 |
scale=1,
|
| 581 |
)
|
| 582 |
+
btn_push_to_hub = gr.Button(
|
| 583 |
+
"Push to Hub", variant="primary", scale=2
|
| 584 |
+
)
|
| 585 |
+
btn_save_local = gr.Button(
|
| 586 |
+
"Save locally", variant="primary", scale=2, visible=False
|
| 587 |
+
)
|
| 588 |
+
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
| 589 |
+
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
| 590 |
with gr.Column(scale=3):
|
| 591 |
success_message = gr.Markdown(
|
| 592 |
visible=True,
|
|
|
|
| 686 |
inputs=[],
|
| 687 |
outputs=[pipeline_code_ui],
|
| 688 |
)
|
| 689 |
+
|
| 690 |
+
btn_save_local.click(
|
| 691 |
+
save_local,
|
| 692 |
+
inputs=[
|
| 693 |
+
system_prompt,
|
| 694 |
+
difficulty,
|
| 695 |
+
clarity,
|
| 696 |
+
labels,
|
| 697 |
+
multi_label,
|
| 698 |
+
num_rows,
|
| 699 |
+
temperature,
|
| 700 |
+
repo_name,
|
| 701 |
+
],
|
| 702 |
+
outputs=[csv_file, json_file]
|
| 703 |
+
)
|
| 704 |
|
| 705 |
gr.on(
|
| 706 |
triggers=[clear_btn_part.click, clear_btn_full.click],
|
|
|
|
| 718 |
app.load(fn=swap_visibility, outputs=main_ui)
|
| 719 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
| 720 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
| 721 |
+
if SAVE_LOCAL_DIR is not None:
|
| 722 |
+
app.load(fn=show_save_local, outputs=[btn_save_local, csv_file, json_file])
|
src/synthetic_dataset_generator/constants.py
CHANGED
|
@@ -8,6 +8,9 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
|
| 8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
| 9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
# Models
|
| 12 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
| 13 |
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|
|
|
|
| 8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
| 9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
| 10 |
|
| 11 |
+
# Directory for outputs
|
| 12 |
+
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
|
| 13 |
+
|
| 14 |
# Models
|
| 15 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
| 16 |
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|