pedroapfilho commited on
Commit
6ca081e
·
unverified ·
1 Parent(s): ed3930a

Add custom tag input and fix model download (zip directory)

Browse files

- Add trigger word input to Data Source tab for LoRA training
- Fix model download: zip the output directory before serving

Files changed (1) hide show
  1. app.py +24 -6
app.py CHANGED
@@ -313,7 +313,7 @@ def _build_review_dataframe():
313
  return builder.get_samples_dataframe_data()
314
 
315
 
316
- def lora_download_hf(dataset_id, max_files, hf_offset, training_state):
317
  """Download HuggingFace dataset batch, restore labels from HF repo, and scan."""
318
  try:
319
  if not dataset_id or not dataset_id.strip():
@@ -333,6 +333,11 @@ def lora_download_hf(dataset_id, max_files, hf_offset, training_state):
333
 
334
  builder = get_dataset_builder()
335
 
 
 
 
 
 
336
  # Restore labels/flags from dataset.json pulled from HF repo
337
  dataset_json_path = str(Path(local_dir) / "dataset.json")
338
  if Path(dataset_json_path).exists():
@@ -592,10 +597,19 @@ def lora_stop_training():
592
 
593
 
594
  def lora_download_model(model_path):
595
- """Return model path for Gradio file download."""
596
- if model_path and Path(model_path).exists():
597
- return model_path
598
- return None
 
 
 
 
 
 
 
 
 
599
 
600
 
601
  # ==================== GRADIO UI ====================
@@ -832,6 +846,10 @@ def create_ui():
832
  label="Dataset ID",
833
  placeholder="username/dataset-name",
834
  )
 
 
 
 
835
  with gr.Row():
836
  lora_hf_max = gr.Slider(
837
  minimum=1, maximum=500, value=50, step=1,
@@ -980,7 +998,7 @@ def create_ui():
980
  # Data Source
981
  lora_hf_btn.click(
982
  fn=lora_download_hf,
983
- inputs=[lora_hf_id, lora_hf_max, lora_hf_offset, training_state],
984
  outputs=[lora_source_status, training_state, lora_hf_offset, lora_progress],
985
  )
986
 
 
313
  return builder.get_samples_dataframe_data()
314
 
315
 
316
+ def lora_download_hf(dataset_id, custom_tag, max_files, hf_offset, training_state):
317
  """Download HuggingFace dataset batch, restore labels from HF repo, and scan."""
318
  try:
319
  if not dataset_id or not dataset_id.strip():
 
333
 
334
  builder = get_dataset_builder()
335
 
336
+ # Set trigger word for LoRA training
337
+ tag = custom_tag.strip() if custom_tag else ""
338
+ if tag:
339
+ builder.set_custom_tag(tag)
340
+
341
  # Restore labels/flags from dataset.json pulled from HF repo
342
  dataset_json_path = str(Path(local_dir) / "dataset.json")
343
  if Path(dataset_json_path).exists():
 
597
 
598
 
599
  def lora_download_model(model_path):
600
+ """Zip the LoRA model directory and return the zip for Gradio file download."""
601
+ import shutil
602
+
603
+ if not model_path or not Path(model_path).exists():
604
+ return None
605
+
606
+ path = Path(model_path)
607
+ if path.is_dir():
608
+ zip_path = path.parent / path.name
609
+ shutil.make_archive(str(zip_path), "zip", root_dir=str(path.parent), base_dir=path.name)
610
+ return str(zip_path) + ".zip"
611
+
612
+ return model_path
613
 
614
 
615
  # ==================== GRADIO UI ====================
 
846
  label="Dataset ID",
847
  placeholder="username/dataset-name",
848
  )
849
+ lora_custom_tag = gr.Textbox(
850
+ label="Custom Tag (trigger word for LoRA)",
851
+ placeholder="lofi, synthwave, jazz-piano…",
852
+ )
853
  with gr.Row():
854
  lora_hf_max = gr.Slider(
855
  minimum=1, maximum=500, value=50, step=1,
 
998
  # Data Source
999
  lora_hf_btn.click(
1000
  fn=lora_download_hf,
1001
+ inputs=[lora_hf_id, lora_custom_tag, lora_hf_max, lora_hf_offset, training_state],
1002
  outputs=[lora_source_status, training_state, lora_hf_offset, lora_progress],
1003
  )
1004