Spaces:
Running
Running
Add custom tag input and fix model download (zip directory)
Browse files- Add trigger word input to Data Source tab for LoRA training
- Fix model download: zip the output directory before serving
app.py
CHANGED
|
@@ -313,7 +313,7 @@ def _build_review_dataframe():
|
|
| 313 |
return builder.get_samples_dataframe_data()
|
| 314 |
|
| 315 |
|
| 316 |
-
def lora_download_hf(dataset_id, max_files, hf_offset, training_state):
|
| 317 |
"""Download HuggingFace dataset batch, restore labels from HF repo, and scan."""
|
| 318 |
try:
|
| 319 |
if not dataset_id or not dataset_id.strip():
|
|
@@ -333,6 +333,11 @@ def lora_download_hf(dataset_id, max_files, hf_offset, training_state):
|
|
| 333 |
|
| 334 |
builder = get_dataset_builder()
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
# Restore labels/flags from dataset.json pulled from HF repo
|
| 337 |
dataset_json_path = str(Path(local_dir) / "dataset.json")
|
| 338 |
if Path(dataset_json_path).exists():
|
|
@@ -592,10 +597,19 @@ def lora_stop_training():
|
|
| 592 |
|
| 593 |
|
| 594 |
def lora_download_model(model_path):
|
| 595 |
-
"""
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
|
| 600 |
|
| 601 |
# ==================== GRADIO UI ====================
|
|
@@ -832,6 +846,10 @@ def create_ui():
|
|
| 832 |
label="Dataset ID",
|
| 833 |
placeholder="username/dataset-name",
|
| 834 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 835 |
with gr.Row():
|
| 836 |
lora_hf_max = gr.Slider(
|
| 837 |
minimum=1, maximum=500, value=50, step=1,
|
|
@@ -980,7 +998,7 @@ def create_ui():
|
|
| 980 |
# Data Source
|
| 981 |
lora_hf_btn.click(
|
| 982 |
fn=lora_download_hf,
|
| 983 |
-
inputs=[lora_hf_id, lora_hf_max, lora_hf_offset, training_state],
|
| 984 |
outputs=[lora_source_status, training_state, lora_hf_offset, lora_progress],
|
| 985 |
)
|
| 986 |
|
|
|
|
| 313 |
return builder.get_samples_dataframe_data()
|
| 314 |
|
| 315 |
|
| 316 |
+
def lora_download_hf(dataset_id, custom_tag, max_files, hf_offset, training_state):
|
| 317 |
"""Download HuggingFace dataset batch, restore labels from HF repo, and scan."""
|
| 318 |
try:
|
| 319 |
if not dataset_id or not dataset_id.strip():
|
|
|
|
| 333 |
|
| 334 |
builder = get_dataset_builder()
|
| 335 |
|
| 336 |
+
# Set trigger word for LoRA training
|
| 337 |
+
tag = custom_tag.strip() if custom_tag else ""
|
| 338 |
+
if tag:
|
| 339 |
+
builder.set_custom_tag(tag)
|
| 340 |
+
|
| 341 |
# Restore labels/flags from dataset.json pulled from HF repo
|
| 342 |
dataset_json_path = str(Path(local_dir) / "dataset.json")
|
| 343 |
if Path(dataset_json_path).exists():
|
|
|
|
| 597 |
|
| 598 |
|
| 599 |
def lora_download_model(model_path):
|
| 600 |
+
"""Zip the LoRA model directory and return the zip for Gradio file download."""
|
| 601 |
+
import shutil
|
| 602 |
+
|
| 603 |
+
if not model_path or not Path(model_path).exists():
|
| 604 |
+
return None
|
| 605 |
+
|
| 606 |
+
path = Path(model_path)
|
| 607 |
+
if path.is_dir():
|
| 608 |
+
zip_path = path.parent / path.name
|
| 609 |
+
shutil.make_archive(str(zip_path), "zip", root_dir=str(path.parent), base_dir=path.name)
|
| 610 |
+
return str(zip_path) + ".zip"
|
| 611 |
+
|
| 612 |
+
return model_path
|
| 613 |
|
| 614 |
|
| 615 |
# ==================== GRADIO UI ====================
|
|
|
|
| 846 |
label="Dataset ID",
|
| 847 |
placeholder="username/dataset-name",
|
| 848 |
)
|
| 849 |
+
lora_custom_tag = gr.Textbox(
|
| 850 |
+
label="Custom Tag (trigger word for LoRA)",
|
| 851 |
+
placeholder="lofi, synthwave, jazz-piano…",
|
| 852 |
+
)
|
| 853 |
with gr.Row():
|
| 854 |
lora_hf_max = gr.Slider(
|
| 855 |
minimum=1, maximum=500, value=50, step=1,
|
|
|
|
| 998 |
# Data Source
|
| 999 |
lora_hf_btn.click(
|
| 1000 |
fn=lora_download_hf,
|
| 1001 |
+
inputs=[lora_hf_id, lora_custom_tag, lora_hf_max, lora_hf_offset, training_state],
|
| 1002 |
outputs=[lora_source_status, training_state, lora_hf_offset, lora_progress],
|
| 1003 |
)
|
| 1004 |
|