Update app.py
Browse files
app.py
CHANGED
|
@@ -23,13 +23,14 @@ from huggingface_hub import list_datasets, HfApi, hf_hub_download
|
|
| 23 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
| 24 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
def load_llm_config():
|
| 29 |
params = load_params()
|
| 30 |
return (
|
| 31 |
params.get('PROVIDER', ''),
|
| 32 |
params.get('BASE_URL', ''),
|
|
|
|
| 33 |
params.get('WORKSPACE', ''),
|
| 34 |
params.get('API_KEY', ''),
|
| 35 |
params.get('max_tokens', 2048),
|
|
@@ -41,10 +42,11 @@ def load_llm_config():
|
|
| 41 |
|
| 42 |
|
| 43 |
|
| 44 |
-
def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
| 45 |
save_params({
|
| 46 |
'PROVIDER': provider,
|
| 47 |
'BASE_URL': base_url,
|
|
|
|
| 48 |
'WORKSPACE': workspace,
|
| 49 |
'API_KEY': api_key,
|
| 50 |
'max_tokens': max_tokens,
|
|
@@ -56,6 +58,8 @@ def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperat
|
|
| 56 |
return "LLM configuration saved successfully"
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def load_annotation_config():
|
|
@@ -493,7 +497,7 @@ def update_chat_context(row_data, index, total, quality, high_quality_tags, low_
|
|
| 493 |
|
| 494 |
|
| 495 |
|
| 496 |
-
async def run_generate_dataset(num_workers, num_generations, output_file_path,
|
| 497 |
if loaded_dataset is None:
|
| 498 |
return "Error: No dataset loaded. Please load a dataset before generating.", ""
|
| 499 |
|
|
@@ -501,7 +505,7 @@ async def run_generate_dataset(num_workers, num_generations, output_file_path, l
|
|
| 501 |
for _ in range(num_generations):
|
| 502 |
topic_selected = random.choice(TOPICS)
|
| 503 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
| 504 |
-
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path,
|
| 505 |
if data:
|
| 506 |
generated_data.append(json.dumps(data))
|
| 507 |
|
|
@@ -621,6 +625,13 @@ def load_dataset_wrapper(dataset_name, split):
|
|
| 621 |
dataset, message = load_huggingface_dataset(dataset_name, split)
|
| 622 |
return dataset, message
|
| 623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
|
| 625 |
def get_popular_datasets():
|
| 626 |
return [
|
|
@@ -913,10 +924,12 @@ with demo:
|
|
| 913 |
with gr.Tab("LLM Configuration"):
|
| 914 |
with gr.Row():
|
| 915 |
provider = gr.Dropdown(choices=["local-model", "anything-llm"], label="LLM Provider")
|
| 916 |
-
base_url = gr.Textbox(label="Base URL (for local model)")
|
| 917 |
with gr.Row():
|
| 918 |
-
|
| 919 |
-
|
|
|
|
|
|
|
|
|
|
| 920 |
|
| 921 |
with gr.Accordion("Advanced Options", open=False):
|
| 922 |
with gr.Row():
|
|
@@ -1045,18 +1058,18 @@ with demo:
|
|
| 1045 |
|
| 1046 |
start_generation_btn.click(
|
| 1047 |
run_generate_dataset,
|
| 1048 |
-
inputs=[num_workers, num_generations, output_file_path,
|
| 1049 |
outputs=[generation_status, generation_output]
|
| 1050 |
)
|
| 1051 |
|
| 1052 |
demo.load(
|
| 1053 |
load_llm_config,
|
| 1054 |
-
outputs=[provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty]
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
save_llm_config_btn.click(
|
| 1058 |
save_llm_config,
|
| 1059 |
-
inputs=[provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty],
|
| 1060 |
outputs=[llm_config_status]
|
| 1061 |
)
|
| 1062 |
|
|
@@ -1071,28 +1084,34 @@ with demo:
|
|
| 1071 |
outputs=[chatbot]
|
| 1072 |
)
|
| 1073 |
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
inputs=[
|
| 1077 |
-
outputs=[
|
| 1078 |
)
|
| 1079 |
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
)
|
| 1085 |
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
# Modify the start_generation_btn.click to include the loaded dataset
|
| 1093 |
start_generation_btn.click(
|
| 1094 |
run_generate_dataset,
|
| 1095 |
-
inputs=[num_workers, num_generations, output_file_path,
|
| 1096 |
outputs=[generation_status, generation_output]
|
| 1097 |
)
|
| 1098 |
|
|
|
|
| 23 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
| 24 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
| 25 |
|
| 26 |
+
llm_provider_state = State("")
|
| 27 |
|
| 28 |
def load_llm_config():
|
| 29 |
params = load_params()
|
| 30 |
return (
|
| 31 |
params.get('PROVIDER', ''),
|
| 32 |
params.get('BASE_URL', ''),
|
| 33 |
+
params.get('MODEL', ''), # Add this line
|
| 34 |
params.get('WORKSPACE', ''),
|
| 35 |
params.get('API_KEY', ''),
|
| 36 |
params.get('max_tokens', 2048),
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
|
| 45 |
+
def save_llm_config(provider, base_url, model, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
| 46 |
save_params({
|
| 47 |
'PROVIDER': provider,
|
| 48 |
'BASE_URL': base_url,
|
| 49 |
+
'MODEL': model, # Add this line
|
| 50 |
'WORKSPACE': workspace,
|
| 51 |
'API_KEY': api_key,
|
| 52 |
'max_tokens': max_tokens,
|
|
|
|
| 58 |
return "LLM configuration saved successfully"
|
| 59 |
|
| 60 |
|
| 61 |
+
def update_model_visibility(provider):
|
| 62 |
+
return gr.update(visible=provider in ["local-model", "openai"])
|
| 63 |
|
| 64 |
|
| 65 |
def load_annotation_config():
|
|
|
|
| 497 |
|
| 498 |
|
| 499 |
|
| 500 |
+
async def run_generate_dataset(num_workers, num_generations, output_file_path, llm_provider, dataset):
|
| 501 |
if loaded_dataset is None:
|
| 502 |
return "Error: No dataset loaded. Please load a dataset before generating.", ""
|
| 503 |
|
|
|
|
| 505 |
for _ in range(num_generations):
|
| 506 |
topic_selected = random.choice(TOPICS)
|
| 507 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
| 508 |
+
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path, llm_provider)
|
| 509 |
if data:
|
| 510 |
generated_data.append(json.dumps(data))
|
| 511 |
|
|
|
|
| 625 |
dataset, message = load_huggingface_dataset(dataset_name, split)
|
| 626 |
return dataset, message
|
| 627 |
|
| 628 |
+
def update_field_visibility(provider):
|
| 629 |
+
if provider == "local-model":
|
| 630 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
| 631 |
+
elif provider == "anything-llm":
|
| 632 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
| 633 |
+
else:
|
| 634 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 635 |
|
| 636 |
def get_popular_datasets():
|
| 637 |
return [
|
|
|
|
| 924 |
with gr.Tab("LLM Configuration"):
|
| 925 |
with gr.Row():
|
| 926 |
provider = gr.Dropdown(choices=["local-model", "anything-llm"], label="LLM Provider")
|
|
|
|
| 927 |
with gr.Row():
|
| 928 |
+
base_url = gr.Textbox(label="Base URL (for local model)", visible=False)
|
| 929 |
+
model = gr.Textbox(label="Model (for local model)", visible=False)
|
| 930 |
+
with gr.Row():
|
| 931 |
+
workspace = gr.Textbox(label="Workspace (for AnythingLLM)", visible=False)
|
| 932 |
+
api_key = gr.Textbox(label="API Key (for AnythingLLM)", visible=False)
|
| 933 |
|
| 934 |
with gr.Accordion("Advanced Options", open=False):
|
| 935 |
with gr.Row():
|
|
|
|
| 1058 |
|
| 1059 |
start_generation_btn.click(
|
| 1060 |
run_generate_dataset,
|
| 1061 |
+
inputs=[num_workers, num_generations, output_file_path, llm_provider, dataset],
|
| 1062 |
outputs=[generation_status, generation_output]
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
demo.load(
|
| 1066 |
load_llm_config,
|
| 1067 |
+
outputs=[provider, base_url, model, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty]
|
| 1068 |
)
|
| 1069 |
|
| 1070 |
save_llm_config_btn.click(
|
| 1071 |
save_llm_config,
|
| 1072 |
+
inputs=[provider, base_url, model, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty],
|
| 1073 |
outputs=[llm_config_status]
|
| 1074 |
)
|
| 1075 |
|
|
|
|
| 1084 |
outputs=[chatbot]
|
| 1085 |
)
|
| 1086 |
|
| 1087 |
+
provider.change(
|
| 1088 |
+
lambda x: x,
|
| 1089 |
+
inputs=[provider],
|
| 1090 |
+
outputs=[llm_provider_state]
|
| 1091 |
)
|
| 1092 |
|
| 1093 |
+
# search_button.click(
|
| 1094 |
+
# search_huggingface_datasets,
|
| 1095 |
+
# inputs=[dataset_search],
|
| 1096 |
+
# outputs=[dataset_results, dataset_input]
|
| 1097 |
+
# )
|
| 1098 |
|
| 1099 |
+
# dataset_results.change(
|
| 1100 |
+
# lambda choice: choice,
|
| 1101 |
+
# inputs=[dataset_results],
|
| 1102 |
+
# outputs=[dataset_input]
|
| 1103 |
+
# )
|
| 1104 |
+
|
| 1105 |
+
# load_dataset_button.click(
|
| 1106 |
+
# load_dataset_wrapper,
|
| 1107 |
+
# inputs=[dataset_input, dataset_split],
|
| 1108 |
+
# outputs=[loaded_dataset, dataset_status]
|
| 1109 |
+
# )
|
| 1110 |
|
| 1111 |
# Modify the start_generation_btn.click to include the loaded dataset
|
| 1112 |
start_generation_btn.click(
|
| 1113 |
run_generate_dataset,
|
| 1114 |
+
inputs=[num_workers, num_generations, output_file_path, llm_provider_state],
|
| 1115 |
outputs=[generation_status, generation_output]
|
| 1116 |
)
|
| 1117 |
|