Andrew commited on
Commit ·
048d6f4
1
Parent(s): 8bdd018
Add LoRA upload flow in Space UI and update README
Browse files- README.md +2 -0
- lora_ui.py +98 -0
README.md
CHANGED
|
@@ -107,6 +107,8 @@ Then in UI:
|
|
| 107 |
|
| 108 |
- Use the endpoint scripts in `scripts/endpoint/`.
|
| 109 |
- Or test through the Gradio UI flow.
|
|
|
|
|
|
|
| 110 |
|
| 111 |
## AF3 GUI one-command startup
|
| 112 |
|
|
|
|
| 107 |
|
| 108 |
- Use the endpoint scripts in `scripts/endpoint/`.
|
| 109 |
- Or test through the Gradio UI flow.
|
| 110 |
+
- In **Step 4 - Evaluate**, you can now upload your own LoRA adapter (`.zip` or adapter files),
|
| 111 |
+
then load it without retraining in this Space.
|
| 112 |
|
| 113 |
## AF3 GUI one-command startup
|
| 114 |
|
lora_ui.py
CHANGED
|
@@ -15,6 +15,8 @@ import random
|
|
| 15 |
import threading
|
| 16 |
import tempfile
|
| 17 |
import time
|
|
|
|
|
|
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import List, Optional
|
| 20 |
|
|
@@ -68,6 +70,7 @@ _auto_label_cursor: int = 0
|
|
| 68 |
audio_saver = AudioSaver(default_format="wav")
|
| 69 |
IS_SPACE = bool(os.getenv("SPACE_ID"))
|
| 70 |
DEFAULT_OUTPUT_DIR = "/data/lora_output" if IS_SPACE else "lora_output"
|
|
|
|
| 71 |
|
| 72 |
if IS_SPACE:
|
| 73 |
try:
|
|
@@ -520,6 +523,83 @@ def list_adapters(output_dir: str):
|
|
| 520 |
return adapters if adapters else ["(none found)"]
|
| 521 |
|
| 522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
@_gpu_callback
|
| 524 |
def load_adapter(adapter_path: str):
|
| 525 |
if not adapter_path or adapter_path == "(none found)":
|
|
@@ -899,6 +979,18 @@ def build_ui():
|
|
| 899 |
adapter_dir = gr.Textbox(label="Adapters Directory", value=DEFAULT_OUTPUT_DIR)
|
| 900 |
refresh_btn = gr.Button("Refresh List")
|
| 901 |
adapter_dd = gr.Dropdown(label="Select Adapter", choices=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
with gr.Row():
|
| 903 |
load_btn = gr.Button("Load Adapter", variant="primary")
|
| 904 |
unload_btn = gr.Button("Unload Adapter")
|
|
@@ -909,6 +1001,12 @@ def build_ui():
|
|
| 909 |
return gr.update(choices=adapters, value=adapters[0] if adapters else None)
|
| 910 |
|
| 911 |
refresh_btn.click(_refresh, adapter_dir, adapter_dd, api_name="list_adapters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
load_btn.click(load_adapter, adapter_dd, adapter_status, api_name="load_adapter")
|
| 913 |
unload_btn.click(unload_adapter, outputs=adapter_status, api_name="unload_adapter")
|
| 914 |
|
|
|
|
| 15 |
import threading
|
| 16 |
import tempfile
|
| 17 |
import time
|
| 18 |
+
import shutil
|
| 19 |
+
import zipfile
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import List, Optional
|
| 22 |
|
|
|
|
| 70 |
audio_saver = AudioSaver(default_format="wav")
|
| 71 |
IS_SPACE = bool(os.getenv("SPACE_ID"))
|
| 72 |
DEFAULT_OUTPUT_DIR = "/data/lora_output" if IS_SPACE else "lora_output"
|
| 73 |
+
DEFAULT_UPLOADED_ADAPTER_SUBDIR = "uploaded_adapters"
|
| 74 |
|
| 75 |
if IS_SPACE:
|
| 76 |
try:
|
|
|
|
| 523 |
return adapters if adapters else ["(none found)"]
|
| 524 |
|
| 525 |
|
| 526 |
+
def _safe_adapter_name(name: str) -> str:
|
| 527 |
+
name = (name or "").strip()
|
| 528 |
+
if not name:
|
| 529 |
+
return f"adapter_{int(time.time())}"
|
| 530 |
+
out = []
|
| 531 |
+
for ch in name:
|
| 532 |
+
if ch.isalnum() or ch in ("-", "_", "."):
|
| 533 |
+
out.append(ch)
|
| 534 |
+
else:
|
| 535 |
+
out.append("_")
|
| 536 |
+
cleaned = "".join(out).strip("._")
|
| 537 |
+
return cleaned or f"adapter_{int(time.time())}"
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _safe_extract_zip(zip_path: str, target_dir: Path) -> int:
|
| 541 |
+
extracted = 0
|
| 542 |
+
target_resolved = target_dir.resolve()
|
| 543 |
+
with zipfile.ZipFile(zip_path, "r") as zf:
|
| 544 |
+
for member in zf.infolist():
|
| 545 |
+
member_path = (target_dir / member.filename).resolve()
|
| 546 |
+
if not str(member_path).startswith(str(target_resolved)):
|
| 547 |
+
raise RuntimeError(f"Unsafe archive path detected: {member.filename}")
|
| 548 |
+
zf.extractall(target_dir)
|
| 549 |
+
extracted = len(zf.namelist())
|
| 550 |
+
return extracted
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def upload_adapter_files(uploaded_files: List[str], adapter_dir: str, adapter_name: str):
|
| 554 |
+
"""Upload LoRA adapter files/zip and make them available in adapter dropdown."""
|
| 555 |
+
if not uploaded_files:
|
| 556 |
+
adapters = list_adapters(adapter_dir)
|
| 557 |
+
return "Please upload .zip or adapter files first.", gr.update(choices=adapters, value=adapters[0] if adapters else None)
|
| 558 |
+
|
| 559 |
+
root_dir = Path(adapter_dir or DEFAULT_OUTPUT_DIR)
|
| 560 |
+
target_root = root_dir / DEFAULT_UPLOADED_ADAPTER_SUBDIR
|
| 561 |
+
target_root.mkdir(parents=True, exist_ok=True)
|
| 562 |
+
target_dir = target_root / _safe_adapter_name(adapter_name)
|
| 563 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 564 |
+
|
| 565 |
+
copied = 0
|
| 566 |
+
extracted = 0
|
| 567 |
+
try:
|
| 568 |
+
# If a single zip is uploaded, extract it; otherwise copy files directly.
|
| 569 |
+
if len(uploaded_files) == 1 and str(uploaded_files[0]).lower().endswith(".zip"):
|
| 570 |
+
zip_path = uploaded_files[0]
|
| 571 |
+
extracted = _safe_extract_zip(zip_path, target_dir)
|
| 572 |
+
else:
|
| 573 |
+
for src in uploaded_files:
|
| 574 |
+
src_path = Path(src)
|
| 575 |
+
if not src_path.exists():
|
| 576 |
+
continue
|
| 577 |
+
dst = target_dir / src_path.name
|
| 578 |
+
shutil.copy2(src_path, dst)
|
| 579 |
+
copied += 1
|
| 580 |
+
|
| 581 |
+
found = sorted({str(p.parent) for p in target_dir.rglob("adapter_config.json")})
|
| 582 |
+
if not found:
|
| 583 |
+
adapters = list_adapters(str(root_dir))
|
| 584 |
+
return (
|
| 585 |
+
f"Uploaded to {target_dir}, but no adapter_config.json found. "
|
| 586 |
+
"Upload a valid LoRA adapter folder or zip.",
|
| 587 |
+
gr.update(choices=adapters, value=adapters[0] if adapters else None),
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
adapters = list_adapters(str(root_dir))
|
| 591 |
+
primary = found[0]
|
| 592 |
+
msg = (
|
| 593 |
+
f"Adapter upload complete. Copied {copied} file(s), extracted {extracted} archive entries. "
|
| 594 |
+
f"Detected {len(found)} adapter path(s). Primary: {primary}"
|
| 595 |
+
)
|
| 596 |
+
return msg, gr.update(choices=adapters, value=primary)
|
| 597 |
+
except Exception as exc:
|
| 598 |
+
logger.exception("Adapter upload failed")
|
| 599 |
+
adapters = list_adapters(str(root_dir))
|
| 600 |
+
return f"Adapter upload failed: {exc}", gr.update(choices=adapters, value=adapters[0] if adapters else None)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
@_gpu_callback
|
| 604 |
def load_adapter(adapter_path: str):
|
| 605 |
if not adapter_path or adapter_path == "(none found)":
|
|
|
|
| 979 |
adapter_dir = gr.Textbox(label="Adapters Directory", value=DEFAULT_OUTPUT_DIR)
|
| 980 |
refresh_btn = gr.Button("Refresh List")
|
| 981 |
adapter_dd = gr.Dropdown(label="Select Adapter", choices=[])
|
| 982 |
+
with gr.Row():
|
| 983 |
+
upload_adapter_files_input = gr.Files(
|
| 984 |
+
label="Upload LoRA Adapter (.zip or adapter files)",
|
| 985 |
+
file_count="multiple",
|
| 986 |
+
file_types=[".zip", ".json", ".safetensors", ".bin", ".pt", ".pth"],
|
| 987 |
+
type="filepath",
|
| 988 |
+
)
|
| 989 |
+
upload_adapter_name = gr.Textbox(
|
| 990 |
+
label="Uploaded Adapter Name (optional)",
|
| 991 |
+
placeholder="my-lora-adapter",
|
| 992 |
+
)
|
| 993 |
+
upload_adapter_btn = gr.Button("Upload Adapter")
|
| 994 |
with gr.Row():
|
| 995 |
load_btn = gr.Button("Load Adapter", variant="primary")
|
| 996 |
unload_btn = gr.Button("Unload Adapter")
|
|
|
|
| 1001 |
return gr.update(choices=adapters, value=adapters[0] if adapters else None)
|
| 1002 |
|
| 1003 |
refresh_btn.click(_refresh, adapter_dir, adapter_dd, api_name="list_adapters")
|
| 1004 |
+
upload_adapter_btn.click(
|
| 1005 |
+
upload_adapter_files,
|
| 1006 |
+
[upload_adapter_files_input, adapter_dir, upload_adapter_name],
|
| 1007 |
+
[adapter_status, adapter_dd],
|
| 1008 |
+
api_name="upload_adapter_files",
|
| 1009 |
+
)
|
| 1010 |
load_btn.click(load_adapter, adapter_dd, adapter_status, api_name="load_adapter")
|
| 1011 |
unload_btn.click(unload_adapter, outputs=adapter_status, api_name="unload_adapter")
|
| 1012 |
|