Update app.py
Browse files
app.py
CHANGED
|
@@ -14,12 +14,17 @@ import random
|
|
| 14 |
from params import load_params, save_params
|
| 15 |
import pandas as pd
|
| 16 |
import csv
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
|
| 20 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
| 21 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
| 22 |
|
|
|
|
|
|
|
| 23 |
def load_llm_config():
|
| 24 |
params = load_params()
|
| 25 |
return (
|
|
@@ -34,6 +39,8 @@ def load_llm_config():
|
|
| 34 |
params.get('presence_penalty', 0.0)
|
| 35 |
)
|
| 36 |
|
|
|
|
|
|
|
| 37 |
def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
| 38 |
save_params({
|
| 39 |
'PROVIDER': provider,
|
|
@@ -49,6 +56,8 @@ def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperat
|
|
| 49 |
return "LLM configuration saved successfully"
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
| 52 |
def load_annotation_config():
|
| 53 |
try:
|
| 54 |
with open(ANNOTATION_CONFIG_FILE, 'r') as f:
|
|
@@ -92,6 +101,8 @@ def load_annotation_config():
|
|
| 92 |
}
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
| 95 |
def load_csv_dataset(file_path):
|
| 96 |
data = []
|
| 97 |
with open(file_path, 'r') as f:
|
|
@@ -100,20 +111,28 @@ def load_csv_dataset(file_path):
|
|
| 100 |
data.append(row)
|
| 101 |
return data
|
| 102 |
|
|
|
|
|
|
|
| 103 |
def load_txt_dataset(file_path):
|
| 104 |
with open(file_path, 'r') as f:
|
| 105 |
return [{"content": line.strip()} for line in f if line.strip()]
|
| 106 |
|
|
|
|
|
|
|
| 107 |
def save_annotation_config(config):
|
| 108 |
with open(ANNOTATION_CONFIG_FILE, 'w') as f:
|
| 109 |
json.dump(config, f, indent=2)
|
| 110 |
|
|
|
|
|
|
|
| 111 |
def load_jsonl_dataset(file_path):
|
| 112 |
if not os.path.exists(file_path):
|
| 113 |
return []
|
| 114 |
with open(file_path, 'r') as f:
|
| 115 |
return [json.loads(line.strip()) for line in f if line.strip()]
|
| 116 |
|
|
|
|
|
|
|
| 117 |
def load_dataset(file):
|
| 118 |
if file is None:
|
| 119 |
return "", 0, 0, "No file uploaded", "3", [], [], [], ""
|
|
@@ -136,6 +155,8 @@ def load_dataset(file):
|
|
| 136 |
first_row = json.dumps(data[0], indent=2)
|
| 137 |
return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
|
| 138 |
|
|
|
|
|
|
|
| 139 |
def save_row(file_path, index, row_data):
|
| 140 |
file_extension = file_path.split('.')[-1].lower()
|
| 141 |
|
|
@@ -150,6 +171,8 @@ def save_row(file_path, index, row_data):
|
|
| 150 |
|
| 151 |
return f"Row {index} saved successfully"
|
| 152 |
|
|
|
|
|
|
|
| 153 |
def save_jsonl_row(file_path, index, row_data):
|
| 154 |
with open(file_path, 'r') as f:
|
| 155 |
lines = f.readlines()
|
|
@@ -159,6 +182,8 @@ def save_jsonl_row(file_path, index, row_data):
|
|
| 159 |
with open(file_path, 'w') as f:
|
| 160 |
f.writelines(lines)
|
| 161 |
|
|
|
|
|
|
|
| 162 |
def save_csv_row(file_path, index, row_data):
|
| 163 |
df = pd.read_csv(file_path)
|
| 164 |
row_dict = json.loads(row_data)
|
|
@@ -166,6 +191,8 @@ def save_csv_row(file_path, index, row_data):
|
|
| 166 |
df.at[index, col] = value
|
| 167 |
df.to_csv(file_path, index=False)
|
| 168 |
|
|
|
|
|
|
|
| 169 |
def save_txt_row(file_path, index, row_data):
|
| 170 |
with open(file_path, 'r') as f:
|
| 171 |
lines = f.readlines()
|
|
@@ -176,6 +203,8 @@ def save_txt_row(file_path, index, row_data):
|
|
| 176 |
with open(file_path, 'w') as f:
|
| 177 |
f.writelines(lines)
|
| 178 |
|
|
|
|
|
|
|
| 179 |
def get_row(file_path, index):
|
| 180 |
data = load_jsonl_dataset(file_path)
|
| 181 |
if not data:
|
|
@@ -184,6 +213,8 @@ def get_row(file_path, index):
|
|
| 184 |
return json.dumps(data[index], indent=2), len(data)
|
| 185 |
return "", len(data)
|
| 186 |
|
|
|
|
|
|
|
| 187 |
def json_to_markdown(json_str):
|
| 188 |
try:
|
| 189 |
data = json.loads(json_str)
|
|
@@ -192,6 +223,8 @@ def json_to_markdown(json_str):
|
|
| 192 |
except json.JSONDecodeError:
|
| 193 |
return "Error: Invalid JSON format"
|
| 194 |
|
|
|
|
|
|
|
| 195 |
def markdown_to_json(markdown_str):
|
| 196 |
sections = re.split(r'#\s+(System|Instruction|Response)\s*\n', markdown_str)
|
| 197 |
if len(sections) != 7: # Should be: ['', 'System', content, 'Instruction', content, 'Response', content]
|
|
@@ -204,10 +237,14 @@ def markdown_to_json(markdown_str):
|
|
| 204 |
}
|
| 205 |
return json.dumps(json_data, indent=2)
|
| 206 |
|
|
|
|
|
|
|
| 207 |
def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
|
| 208 |
new_index = max(0, current_index + (-1 if direction == "prev" else 1))
|
| 209 |
return load_and_show_row(file_path, new_index, metadata_config)
|
| 210 |
|
|
|
|
|
|
|
| 211 |
def load_and_show_row(file_path, index, metadata_config):
|
| 212 |
row_data, total = get_row(file_path, index)
|
| 213 |
if not row_data:
|
|
@@ -229,6 +266,8 @@ def load_and_show_row(file_path, index, metadata_config):
|
|
| 229 |
return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
|
| 230 |
high_quality_tags, low_quality_tags, toxic_tags, other)
|
| 231 |
|
|
|
|
|
|
|
| 232 |
def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
| 233 |
data = json.loads(row_data)
|
| 234 |
metadata = {
|
|
@@ -248,6 +287,8 @@ def save_row_with_metadata(file_path, index, row_data, config, quality, high_qua
|
|
| 248 |
data["metadata"] = metadata
|
| 249 |
return save_row(file_path, index, json.dumps(data))
|
| 250 |
|
|
|
|
|
|
|
| 251 |
def update_annotation_ui(config):
|
| 252 |
quality_choices = [(item["value"], item["label"]) for item in config["quality_scale"]["scale"]]
|
| 253 |
quality_label = gr.Radio(
|
|
@@ -271,6 +312,8 @@ def update_annotation_ui(config):
|
|
| 271 |
|
| 272 |
return quality_label, *tag_components, other_description
|
| 273 |
|
|
|
|
|
|
|
| 274 |
def load_config_to_ui(config):
|
| 275 |
return (
|
| 276 |
config["quality_scale"]["name"],
|
|
@@ -280,6 +323,8 @@ def load_config_to_ui(config):
|
|
| 280 |
[[field["name"], field["description"]] for field in config["free_text_fields"]]
|
| 281 |
)
|
| 282 |
|
|
|
|
|
|
|
| 283 |
def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
|
| 284 |
if all_topics_text.visible:
|
| 285 |
topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
|
|
@@ -299,6 +344,8 @@ def save_config_from_ui(name, description, scale, categories, fields, topics, al
|
|
| 299 |
save_annotation_config(new_config)
|
| 300 |
return "Configuration saved successfully", new_config
|
| 301 |
|
|
|
|
|
|
|
| 302 |
# Add this new function to generate the preview
|
| 303 |
def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
| 304 |
try:
|
|
@@ -321,6 +368,8 @@ def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, tox
|
|
| 321 |
except json.JSONDecodeError:
|
| 322 |
return "Error: Invalid JSON in the current row data"
|
| 323 |
|
|
|
|
|
|
|
| 324 |
def load_dataset_config():
|
| 325 |
params = load_params()
|
| 326 |
with open("system_messages.py", "r") as f:
|
|
@@ -347,6 +396,8 @@ def load_dataset_config():
|
|
| 347 |
params.get('presence_penalty', 0.0)
|
| 348 |
)
|
| 349 |
|
|
|
|
|
|
|
| 350 |
def edit_all_topics_func(topics):
|
| 351 |
topics_list = [topic[0] for topic in topics]
|
| 352 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
|
@@ -356,6 +407,8 @@ def edit_all_topics_func(topics):
|
|
| 356 |
gr.update(visible=True)
|
| 357 |
)
|
| 358 |
|
|
|
|
|
|
|
| 359 |
def update_topics_from_text(text):
|
| 360 |
try:
|
| 361 |
# Try parsing as JSONL
|
|
@@ -366,6 +419,8 @@ def update_topics_from_text(text):
|
|
| 366 |
|
| 367 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
| 368 |
|
|
|
|
|
|
|
| 369 |
def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
| 370 |
# Save VODALUS_SYSTEM_MESSAGE to system_messages.py
|
| 371 |
with open("system_messages.py", "w") as f:
|
|
@@ -426,6 +481,7 @@ def chat_with_llm(message, history):
|
|
| 426 |
print(f"Error in chat_with_llm: {str(e)}")
|
| 427 |
return history + [[message, f"Error: {str(e)}"]]
|
| 428 |
|
|
|
|
| 429 |
def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
| 430 |
context = f"""Current app state:
|
| 431 |
Row: {index + 1}/{total}
|
|
@@ -440,12 +496,16 @@ def update_chat_context(row_data, index, total, quality, high_quality_tags, low_
|
|
| 440 |
return [[None, context]]
|
| 441 |
|
| 442 |
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
generated_data = []
|
| 445 |
for _ in range(num_generations):
|
| 446 |
topic_selected = random.choice(TOPICS)
|
| 447 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
| 448 |
-
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path)
|
| 449 |
if data:
|
| 450 |
generated_data.append(json.dumps(data))
|
| 451 |
|
|
@@ -456,15 +516,21 @@ async def run_generate_dataset(num_workers, num_generations, output_file_path):
|
|
| 456 |
|
| 457 |
return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
|
| 458 |
|
|
|
|
|
|
|
| 459 |
def add_topic_row(data):
|
| 460 |
if isinstance(data, pd.DataFrame):
|
| 461 |
return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
|
| 462 |
else:
|
| 463 |
return data + [["New Topic"]]
|
| 464 |
|
|
|
|
|
|
|
| 465 |
def remove_last_topic_row(data):
|
| 466 |
return data[:-1] if len(data) > 1 else data
|
| 467 |
|
|
|
|
|
|
|
| 468 |
def edit_all_topics_func(topics):
|
| 469 |
topics_list = [topic[0] for topic in topics]
|
| 470 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
|
@@ -474,6 +540,8 @@ def edit_all_topics_func(topics):
|
|
| 474 |
gr.update(visible=True)
|
| 475 |
)
|
| 476 |
|
|
|
|
|
|
|
| 477 |
def update_topics_from_text(text):
|
| 478 |
try:
|
| 479 |
# Try parsing as JSONL
|
|
@@ -484,6 +552,8 @@ def update_topics_from_text(text):
|
|
| 484 |
|
| 485 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
| 486 |
|
|
|
|
|
|
|
| 487 |
def update_topics_from_text(text):
|
| 488 |
try:
|
| 489 |
# Try parsing as JSONL
|
|
@@ -494,6 +564,82 @@ def update_topics_from_text(text):
|
|
| 494 |
|
| 495 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
css = """
|
| 498 |
body, #root {
|
| 499 |
margin: 0;
|
|
@@ -740,6 +886,20 @@ with demo:
|
|
| 740 |
with gr.Row():
|
| 741 |
save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
|
| 742 |
dataset_config_status = gr.Textbox(label="Status")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
|
| 744 |
|
| 745 |
with gr.Tab("Dataset Generation"):
|
|
@@ -889,7 +1049,7 @@ with demo:
|
|
| 889 |
|
| 890 |
start_generation_btn.click(
|
| 891 |
run_generate_dataset,
|
| 892 |
-
inputs=[num_workers, num_generations, output_file_path],
|
| 893 |
outputs=[generation_status, generation_output]
|
| 894 |
)
|
| 895 |
|
|
@@ -915,6 +1075,30 @@ with demo:
|
|
| 915 |
outputs=[chatbot]
|
| 916 |
)
|
| 917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
|
| 919 |
demo.load(
|
| 920 |
lambda: (
|
|
|
|
| 14 |
from params import load_params, save_params
|
| 15 |
import pandas as pd
|
| 16 |
import csv
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
from huggingface_hub import list_datasets, HfApi, hf_hub_download
|
| 19 |
+
|
| 20 |
|
| 21 |
|
| 22 |
|
| 23 |
ANNOTATION_CONFIG_FILE = "annotation_config.json"
|
| 24 |
OUTPUT_FILE_PATH = "dataset.jsonl"
|
| 25 |
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def load_llm_config():
|
| 29 |
params = load_params()
|
| 30 |
return (
|
|
|
|
| 39 |
params.get('presence_penalty', 0.0)
|
| 40 |
)
|
| 41 |
|
| 42 |
+
|
| 43 |
+
|
| 44 |
def save_llm_config(provider, base_url, workspace, api_key, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
| 45 |
save_params({
|
| 46 |
'PROVIDER': provider,
|
|
|
|
| 56 |
return "LLM configuration saved successfully"
|
| 57 |
|
| 58 |
|
| 59 |
+
|
| 60 |
+
|
| 61 |
def load_annotation_config():
|
| 62 |
try:
|
| 63 |
with open(ANNOTATION_CONFIG_FILE, 'r') as f:
|
|
|
|
| 101 |
}
|
| 102 |
|
| 103 |
|
| 104 |
+
|
| 105 |
+
|
| 106 |
def load_csv_dataset(file_path):
|
| 107 |
data = []
|
| 108 |
with open(file_path, 'r') as f:
|
|
|
|
| 111 |
data.append(row)
|
| 112 |
return data
|
| 113 |
|
| 114 |
+
|
| 115 |
+
|
| 116 |
def load_txt_dataset(file_path):
|
| 117 |
with open(file_path, 'r') as f:
|
| 118 |
return [{"content": line.strip()} for line in f if line.strip()]
|
| 119 |
|
| 120 |
+
|
| 121 |
+
|
| 122 |
def save_annotation_config(config):
|
| 123 |
with open(ANNOTATION_CONFIG_FILE, 'w') as f:
|
| 124 |
json.dump(config, f, indent=2)
|
| 125 |
|
| 126 |
+
|
| 127 |
+
|
| 128 |
def load_jsonl_dataset(file_path):
|
| 129 |
if not os.path.exists(file_path):
|
| 130 |
return []
|
| 131 |
with open(file_path, 'r') as f:
|
| 132 |
return [json.loads(line.strip()) for line in f if line.strip()]
|
| 133 |
|
| 134 |
+
|
| 135 |
+
|
| 136 |
def load_dataset(file):
|
| 137 |
if file is None:
|
| 138 |
return "", 0, 0, "No file uploaded", "3", [], [], [], ""
|
|
|
|
| 155 |
first_row = json.dumps(data[0], indent=2)
|
| 156 |
return first_row, 0, len(data), f"Row: 1/{len(data)}", "3", [], [], [], ""
|
| 157 |
|
| 158 |
+
|
| 159 |
+
|
| 160 |
def save_row(file_path, index, row_data):
|
| 161 |
file_extension = file_path.split('.')[-1].lower()
|
| 162 |
|
|
|
|
| 171 |
|
| 172 |
return f"Row {index} saved successfully"
|
| 173 |
|
| 174 |
+
|
| 175 |
+
|
| 176 |
def save_jsonl_row(file_path, index, row_data):
|
| 177 |
with open(file_path, 'r') as f:
|
| 178 |
lines = f.readlines()
|
|
|
|
| 182 |
with open(file_path, 'w') as f:
|
| 183 |
f.writelines(lines)
|
| 184 |
|
| 185 |
+
|
| 186 |
+
|
| 187 |
def save_csv_row(file_path, index, row_data):
|
| 188 |
df = pd.read_csv(file_path)
|
| 189 |
row_dict = json.loads(row_data)
|
|
|
|
| 191 |
df.at[index, col] = value
|
| 192 |
df.to_csv(file_path, index=False)
|
| 193 |
|
| 194 |
+
|
| 195 |
+
|
| 196 |
def save_txt_row(file_path, index, row_data):
|
| 197 |
with open(file_path, 'r') as f:
|
| 198 |
lines = f.readlines()
|
|
|
|
| 203 |
with open(file_path, 'w') as f:
|
| 204 |
f.writelines(lines)
|
| 205 |
|
| 206 |
+
|
| 207 |
+
|
| 208 |
def get_row(file_path, index):
|
| 209 |
data = load_jsonl_dataset(file_path)
|
| 210 |
if not data:
|
|
|
|
| 213 |
return json.dumps(data[index], indent=2), len(data)
|
| 214 |
return "", len(data)
|
| 215 |
|
| 216 |
+
|
| 217 |
+
|
| 218 |
def json_to_markdown(json_str):
|
| 219 |
try:
|
| 220 |
data = json.loads(json_str)
|
|
|
|
| 223 |
except json.JSONDecodeError:
|
| 224 |
return "Error: Invalid JSON format"
|
| 225 |
|
| 226 |
+
|
| 227 |
+
|
| 228 |
def markdown_to_json(markdown_str):
|
| 229 |
sections = re.split(r'#\s+(System|Instruction|Response)\s*\n', markdown_str)
|
| 230 |
if len(sections) != 7: # Should be: ['', 'System', content, 'Instruction', content, 'Response', content]
|
|
|
|
| 237 |
}
|
| 238 |
return json.dumps(json_data, indent=2)
|
| 239 |
|
| 240 |
+
|
| 241 |
+
|
| 242 |
def navigate_rows(file_path: str, current_index: int, direction: Literal["prev", "next"], metadata_config):
|
| 243 |
new_index = max(0, current_index + (-1 if direction == "prev" else 1))
|
| 244 |
return load_and_show_row(file_path, new_index, metadata_config)
|
| 245 |
|
| 246 |
+
|
| 247 |
+
|
| 248 |
def load_and_show_row(file_path, index, metadata_config):
|
| 249 |
row_data, total = get_row(file_path, index)
|
| 250 |
if not row_data:
|
|
|
|
| 266 |
return (row_data, index, total, f"Row: {index + 1}/{total}", quality,
|
| 267 |
high_quality_tags, low_quality_tags, toxic_tags, other)
|
| 268 |
|
| 269 |
+
|
| 270 |
+
|
| 271 |
def save_row_with_metadata(file_path, index, row_data, config, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
| 272 |
data = json.loads(row_data)
|
| 273 |
metadata = {
|
|
|
|
| 287 |
data["metadata"] = metadata
|
| 288 |
return save_row(file_path, index, json.dumps(data))
|
| 289 |
|
| 290 |
+
|
| 291 |
+
|
| 292 |
def update_annotation_ui(config):
|
| 293 |
quality_choices = [(item["value"], item["label"]) for item in config["quality_scale"]["scale"]]
|
| 294 |
quality_label = gr.Radio(
|
|
|
|
| 312 |
|
| 313 |
return quality_label, *tag_components, other_description
|
| 314 |
|
| 315 |
+
|
| 316 |
+
|
| 317 |
def load_config_to_ui(config):
|
| 318 |
return (
|
| 319 |
config["quality_scale"]["name"],
|
|
|
|
| 323 |
[[field["name"], field["description"]] for field in config["free_text_fields"]]
|
| 324 |
)
|
| 325 |
|
| 326 |
+
|
| 327 |
+
|
| 328 |
def save_config_from_ui(name, description, scale, categories, fields, topics, all_topics_text):
|
| 329 |
if all_topics_text.visible:
|
| 330 |
topics_list = [topic.strip() for topic in all_topics_text.split("\n") if topic.strip()]
|
|
|
|
| 344 |
save_annotation_config(new_config)
|
| 345 |
return "Configuration saved successfully", new_config
|
| 346 |
|
| 347 |
+
|
| 348 |
+
|
| 349 |
# Add this new function to generate the preview
|
| 350 |
def generate_preview(row_data, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
| 351 |
try:
|
|
|
|
| 368 |
except json.JSONDecodeError:
|
| 369 |
return "Error: Invalid JSON in the current row data"
|
| 370 |
|
| 371 |
+
|
| 372 |
+
|
| 373 |
def load_dataset_config():
|
| 374 |
params = load_params()
|
| 375 |
with open("system_messages.py", "r") as f:
|
|
|
|
| 396 |
params.get('presence_penalty', 0.0)
|
| 397 |
)
|
| 398 |
|
| 399 |
+
|
| 400 |
+
|
| 401 |
def edit_all_topics_func(topics):
|
| 402 |
topics_list = [topic[0] for topic in topics]
|
| 403 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
|
|
|
| 407 |
gr.update(visible=True)
|
| 408 |
)
|
| 409 |
|
| 410 |
+
|
| 411 |
+
|
| 412 |
def update_topics_from_text(text):
|
| 413 |
try:
|
| 414 |
# Try parsing as JSONL
|
|
|
|
| 419 |
|
| 420 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
| 421 |
|
| 422 |
+
|
| 423 |
+
|
| 424 |
def save_dataset_config(system_messages, prompt_1, topics, max_tokens, temperature, top_p, frequency_penalty, presence_penalty):
|
| 425 |
# Save VODALUS_SYSTEM_MESSAGE to system_messages.py
|
| 426 |
with open("system_messages.py", "w") as f:
|
|
|
|
| 481 |
print(f"Error in chat_with_llm: {str(e)}")
|
| 482 |
return history + [[message, f"Error: {str(e)}"]]
|
| 483 |
|
| 484 |
+
|
| 485 |
def update_chat_context(row_data, index, total, quality, high_quality_tags, low_quality_tags, toxic_tags, other):
|
| 486 |
context = f"""Current app state:
|
| 487 |
Row: {index + 1}/{total}
|
|
|
|
| 496 |
return [[None, context]]
|
| 497 |
|
| 498 |
|
| 499 |
+
|
| 500 |
+
async def run_generate_dataset(num_workers, num_generations, output_file_path, loaded_dataset):
|
| 501 |
+
if loaded_dataset is None:
|
| 502 |
+
return "Error: No dataset loaded. Please load a dataset before generating.", ""
|
| 503 |
+
|
| 504 |
generated_data = []
|
| 505 |
for _ in range(num_generations):
|
| 506 |
topic_selected = random.choice(TOPICS)
|
| 507 |
system_message_selected = random.choice(SYSTEM_MESSAGES_VODALUS)
|
| 508 |
+
data = await generate_data(topic_selected, PROMPT_1, system_message_selected, output_file_path, loaded_dataset)
|
| 509 |
if data:
|
| 510 |
generated_data.append(json.dumps(data))
|
| 511 |
|
|
|
|
| 516 |
|
| 517 |
return f"Generated {num_generations} entries and saved to {output_file_path}", "\n".join(generated_data[:5]) + "\n..."
|
| 518 |
|
| 519 |
+
|
| 520 |
+
|
| 521 |
def add_topic_row(data):
|
| 522 |
if isinstance(data, pd.DataFrame):
|
| 523 |
return pd.concat([data, pd.DataFrame({"Topic": ["New Topic"]})], ignore_index=True)
|
| 524 |
else:
|
| 525 |
return data + [["New Topic"]]
|
| 526 |
|
| 527 |
+
|
| 528 |
+
|
| 529 |
def remove_last_topic_row(data):
|
| 530 |
return data[:-1] if len(data) > 1 else data
|
| 531 |
|
| 532 |
+
|
| 533 |
+
|
| 534 |
def edit_all_topics_func(topics):
|
| 535 |
topics_list = [topic[0] for topic in topics]
|
| 536 |
jsonl_rows = "\n".join([json.dumps({"topic": topic}) for topic in topics_list])
|
|
|
|
| 540 |
gr.update(visible=True)
|
| 541 |
)
|
| 542 |
|
| 543 |
+
|
| 544 |
+
|
| 545 |
def update_topics_from_text(text):
|
| 546 |
try:
|
| 547 |
# Try parsing as JSONL
|
|
|
|
| 552 |
|
| 553 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
| 554 |
|
| 555 |
+
|
| 556 |
+
|
| 557 |
def update_topics_from_text(text):
|
| 558 |
try:
|
| 559 |
# Try parsing as JSONL
|
|
|
|
| 564 |
|
| 565 |
return gr.Dataframe.update(value=[[topic] for topic in topics_list], visible=True), gr.TextArea.update(visible=False)
|
| 566 |
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def search_huggingface_datasets(query):
|
| 570 |
+
try:
|
| 571 |
+
api = HfApi()
|
| 572 |
+
datasets = api.list_datasets(search=query, limit=20)
|
| 573 |
+
dataset_ids = [dataset.id for dataset in datasets]
|
| 574 |
+
return gr.update(choices=dataset_ids, visible=True), ""
|
| 575 |
+
except Exception as e:
|
| 576 |
+
print(f"Error searching datasets: {str(e)}")
|
| 577 |
+
return gr.update(choices=["Error: Could not search datasets"], visible=True), ""
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def load_huggingface_dataset(dataset_name, split="train"):
|
| 582 |
+
try:
|
| 583 |
+
print(f"Attempting to load dataset: {dataset_name}")
|
| 584 |
+
|
| 585 |
+
# Check if dataset_name is a string
|
| 586 |
+
if not isinstance(dataset_name, str):
|
| 587 |
+
raise ValueError(f"Expected dataset_name to be a string, but got {type(dataset_name)}")
|
| 588 |
+
|
| 589 |
+
# Try loading the dataset without specifying a config
|
| 590 |
+
full_dataset = load_dataset(dataset_name)
|
| 591 |
+
|
| 592 |
+
print(f"Dataset loaded. Available splits: {list(full_dataset.keys())}")
|
| 593 |
+
|
| 594 |
+
# Select the appropriate split
|
| 595 |
+
if split in full_dataset:
|
| 596 |
+
dataset = full_dataset[split]
|
| 597 |
+
print(f"Using specified split: {split}")
|
| 598 |
+
else:
|
| 599 |
+
available_splits = list(full_dataset.keys())
|
| 600 |
+
if available_splits:
|
| 601 |
+
dataset = full_dataset[available_splits[0]]
|
| 602 |
+
split = available_splits[0]
|
| 603 |
+
print(f"Specified split not found. Using first available split: {split}")
|
| 604 |
+
else:
|
| 605 |
+
raise ValueError("No valid splits found in the dataset")
|
| 606 |
+
|
| 607 |
+
return dataset, f"Dataset '{dataset_name}' (split: {split}) loaded successfully."
|
| 608 |
+
except Exception as e:
|
| 609 |
+
error_msg = f"Error loading dataset: {str(e)}"
|
| 610 |
+
print(f"Error details: {error_msg}")
|
| 611 |
+
|
| 612 |
+
# If loading fails, try to get the dataset card
|
| 613 |
+
try:
|
| 614 |
+
dataset_card = hf_hub_download(repo_id=dataset_name, filename="README.md")
|
| 615 |
+
with open(dataset_card, 'r') as f:
|
| 616 |
+
card_content = f.read()
|
| 617 |
+
return None, f"Dataset couldn't be loaded, but here's the dataset card:\n\n{card_content[:500]}..."
|
| 618 |
+
except:
|
| 619 |
+
return None, error_msg
|
| 620 |
+
|
| 621 |
+
# Wrapper function to handle the Gradio interface
|
| 622 |
+
def load_dataset_wrapper(dataset_name, split):
|
| 623 |
+
if not dataset_name:
|
| 624 |
+
return None, "Please enter a dataset name."
|
| 625 |
+
dataset, message = load_huggingface_dataset(dataset_name, split)
|
| 626 |
+
return dataset, message
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def get_popular_datasets():
|
| 630 |
+
return [
|
| 631 |
+
"wikipedia",
|
| 632 |
+
"squad",
|
| 633 |
+
"glue",
|
| 634 |
+
"imdb",
|
| 635 |
+
"wmt16",
|
| 636 |
+
"common_voice",
|
| 637 |
+
"cnn_dailymail",
|
| 638 |
+
"amazon_reviews_multi",
|
| 639 |
+
"yelp_review_full",
|
| 640 |
+
"ag_news"
|
| 641 |
+
]
|
| 642 |
+
|
| 643 |
css = """
|
| 644 |
body, #root {
|
| 645 |
margin: 0;
|
|
|
|
| 886 |
with gr.Row():
|
| 887 |
save_dataset_config_btn = gr.Button("Save Dataset Configuration", variant="primary")
|
| 888 |
dataset_config_status = gr.Textbox(label="Status")
|
| 889 |
+
|
| 890 |
+
# gr.Markdown("### Hugging Face Dataset")
|
| 891 |
+
# with gr.Row():
|
| 892 |
+
# dataset_search = gr.Textbox(label="Search Datasets")
|
| 893 |
+
# search_button = gr.Button("Search")
|
| 894 |
+
# dataset_input = gr.Textbox(label="Dataset Name", info="Enter a dataset name or select from search results")
|
| 895 |
+
# dataset_results = gr.Radio(label="Search Results", choices=[], visible=False)
|
| 896 |
+
# dataset_split = gr.Textbox(label="Dataset Split (optional)", value="train")
|
| 897 |
+
# load_dataset_button = gr.Button("Load Selected Dataset")
|
| 898 |
+
# dataset_status = gr.Textbox(label="Dataset Status")
|
| 899 |
+
|
| 900 |
+
# Add a state to store the loaded dataset
|
| 901 |
+
# loaded_dataset = gr.State(None)
|
| 902 |
+
|
| 903 |
|
| 904 |
|
| 905 |
with gr.Tab("Dataset Generation"):
|
|
|
|
| 1049 |
|
| 1050 |
start_generation_btn.click(
|
| 1051 |
run_generate_dataset,
|
| 1052 |
+
inputs=[num_workers, num_generations, output_file_path, loaded_dataset],
|
| 1053 |
outputs=[generation_status, generation_output]
|
| 1054 |
)
|
| 1055 |
|
|
|
|
| 1075 |
outputs=[chatbot]
|
| 1076 |
)
|
| 1077 |
|
| 1078 |
+
search_button.click(
|
| 1079 |
+
search_huggingface_datasets,
|
| 1080 |
+
inputs=[dataset_search],
|
| 1081 |
+
outputs=[dataset_results, dataset_input]
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
dataset_results.change(
|
| 1085 |
+
lambda choice: choice,
|
| 1086 |
+
inputs=[dataset_results],
|
| 1087 |
+
outputs=[dataset_input]
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
load_dataset_button.click(
|
| 1091 |
+
load_dataset_wrapper,
|
| 1092 |
+
inputs=[dataset_input, dataset_split],
|
| 1093 |
+
outputs=[loaded_dataset, dataset_status]
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
# Modify the start_generation_btn.click to include the loaded dataset
|
| 1097 |
+
start_generation_btn.click(
|
| 1098 |
+
run_generate_dataset,
|
| 1099 |
+
inputs=[num_workers, num_generations, output_file_path, loaded_dataset],
|
| 1100 |
+
outputs=[generation_status, generation_output]
|
| 1101 |
+
)
|
| 1102 |
|
| 1103 |
demo.load(
|
| 1104 |
lambda: (
|