ZIMG_LORAS / app.py
rahul7star's picture
Update app.py
413b443 verified
import spaces
import gradio as gr
import subprocess
import os
from pathlib import Path
from accelerate.utils import write_basic_config
from huggingface_hub import snapshot_download
import shutil
import re
import threading
from huggingface_hub import HfApi, create_repo
import subprocess
import os
from pathlib import Path
import spaces
import gradio as gr
import subprocess
import os
from pathlib import Path
from accelerate.utils import write_basic_config
from huggingface_hub import snapshot_download
import shutil
import re
from huggingface_hub import HfApi, create_repo
api = HfApi()
# ======================================================
# CONFIG
# ======================================================
BASE_DIR = Path("./workspace")
BASE_DIR.mkdir(exist_ok=True)
CACHE_DIR = "./hf_cache"
DIFFUSERS_REPO = "https://github.com/huggingface/diffusers.git"
DIFFUSERS_LOCAL = "./diffusers"
write_basic_config()
# ======================================================
# HELPERS
# ======================================================
# ======================================================
# PRELOAD FUNCTION
# ======================================================
def preload_assets(model_name, dataset_repo, log_func=print):
logs = ""
def append_log(msg):
nonlocal logs
logs += msg
log_func(msg)
try:
clone_diffusers(log_func=append_log)
model_path = resolve_model_path(model_name, log_func=append_log)
dataset_path = resolve_dataset_path(dataset_repo, log_func=append_log)
dataset_path = prepare_dataset_folder(dataset_path, log_func=append_log)
append_log("βœ… All assets preloaded. You can now train the model.\n")
return logs, "Ready to Train", model_path, dataset_path
except Exception as e:
append_log(f"❌ Error during preload: {e}\n")
return logs, "Preload Failed", None, None
def clone_diffusers(log_func=None):
if not os.path.exists(DIFFUSERS_LOCAL):
if log_func: log_func("πŸ”„ Cloning diffusers repo...\n")
subprocess.run(["git", "clone", DIFFUSERS_REPO], check=True)
if log_func: log_func("βœ… Diffusers repo cloned.\n")
def resolve_model_path(model_name_or_path: str, log_func=None) -> str:
if log_func: log_func(f"πŸ”„ Downloading base model: {model_name_or_path} ...\n")
local_path = snapshot_download(
repo_id=model_name_or_path,
repo_type="model",
cache_dir=CACHE_DIR,
)
if log_func: log_func(f"βœ… Base model downloaded at: {local_path}\n")
return os.path.abspath(local_path)
def resolve_dataset_path(dataset_repo: str, log_func=None) -> str:
if log_func: log_func(f"πŸ”„ Downloading dataset: {dataset_repo} ...\n")
local_path = snapshot_download(
repo_id=dataset_repo,
repo_type="dataset",
cache_dir=CACHE_DIR,
ignore_patterns=".gitattributes",
)
if log_func: log_func(f"βœ… Dataset downloaded at: {local_path}\n")
return os.path.abspath(local_path)
def prepare_dataset_folder(dataset_path: str, log_func=None) -> str:
clean_path = Path("./workspace/dataset_clean")
if clean_path.exists():
shutil.rmtree(clean_path)
clean_path.mkdir(parents=True, exist_ok=True)
count = 0
for file in Path(dataset_path).iterdir():
if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp", ".bmp"]:
shutil.copy(file, clean_path / file.name)
count += 1
if count == 0:
raise ValueError(f"No image files found in dataset repo: {dataset_path}")
if log_func: log_func(f"βœ… Dataset prepared with {count} images at: {clean_path}\n")
return str(clean_path)
# ======================================================
# NEW FIX: README CLEANER
# ======================================================
def fix_readme_metadata(output_path, original_model_id):
readme_path = Path(output_path) / "README.md"
if not readme_path.exists():
return
content = readme_path.read_text()
# Replace local path base_model with correct HF model ID
content = re.sub(
r'base_model:.*',
f'base_model: {original_model_id}',
content
)
readme_path.write_text(content)
# ======================================================
# TRAINING
# ======================================================
@spaces.GPU()
def train_model(
model_path,
dataset_path,
instance_prompt,
validation_prompt,
resolution,
train_batch_size,
gradient_accumulation_steps,
learning_rate,
max_train_steps,
guidance_scale,
lr_scheduler,
lr_warmup_steps,
optimizer,
mixed_precision,
gradient_checkpointing,
cache_latents,
use_8bit_adam,
do_fp8_training,
push_to_hub,
hub_model_id,
output_path,
):
logs = ""
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
upload_every_steps = 10
last_uploaded_step = -1
def add_log(msg):
nonlocal logs
logs += msg
yield logs
try:
yield from add_log("πŸš€ Starting Training...\n")
if push_to_hub and hub_model_id:
create_repo(repo_id=hub_model_id, exist_ok=True)
yield from add_log(f"βœ… Hub repo ready: {hub_model_id}\n")
original_model_id = model_path # SAVE ORIGINAL ID
cmd = [
"accelerate", "launch",
"./diffusers/examples/dreambooth/train_dreambooth_lora_z_image.py",
f"--pretrained_model_name_or_path={model_path}",
f"--instance_data_dir={dataset_path}",
f"--output_dir={output_path}",
f"--instance_prompt={instance_prompt}",
f"--validation_prompt={validation_prompt}",
f"--resolution={resolution}",
f"--train_batch_size={train_batch_size}",
f"--gradient_accumulation_steps={gradient_accumulation_steps}",
f"--learning_rate={learning_rate}",
f"--max_train_steps={max_train_steps}",
f"--guidance_scale={guidance_scale}",
f"--lr_scheduler={lr_scheduler}",
f"--lr_warmup_steps={lr_warmup_steps}",
f"--optimizer={optimizer}",
f"--mixed_precision={mixed_precision}",
"--checkpointing_steps=10",
"--seed=0",
]
if gradient_checkpointing: cmd.append("--gradient_checkpointing")
if cache_latents: cmd.append("--cache_latents")
if use_8bit_adam: cmd.append("--use_8bit_adam")
if do_fp8_training: cmd.append("--do_fp8_training")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
yield from add_log(line)
tqdm_match = re.search(r"Steps:.*?(\d+)/(\d+)", line)
if tqdm_match:
current_step = int(tqdm_match.group(1))
if (
push_to_hub
and hub_model_id
and current_step % upload_every_steps == 0
and current_step != last_uploaded_step
):
last_uploaded_step = current_step
yield from add_log(
f"\nπŸ“¦ Uploading checkpoint at step {current_step}...\n"
)
#fix_readme_metadata(output_path, original_model_id)
try:
api.upload_folder(
folder_path=output_path,
repo_id=hub_model_id,
repo_type="model",
commit_message=f"Auto upload at step {current_step}",
allow_patterns=["*.safetensors"],
)
yield from add_log("βœ… Upload completed.\n")
except Exception as upload_error:
yield from add_log(f"❌ Upload failed: {upload_error}\n")
process.wait()
yield from add_log("πŸŽ‰ Training completed.\n")
finally:
if push_to_hub and hub_model_id:
yield from add_log("\nπŸ“¦ Final upload attempt...\n")
#fix_readme_metadata(output_path, original_model_id)
api.upload_folder(
folder_path=output_path,
repo_id=hub_model_id,
repo_type="model",
commit_message="Final upload",
allow_patterns=["*.safetensors"],
)
yield from add_log("βœ… Final upload completed.\n")
yield logs
# ======================================================
# GRADIO UI
# ======================================================
with gr.Blocks(title="DreamBooth LoRA Trainer (Z-Image)- L40s and above") as demo:
gr.Markdown("# πŸš€ DreamBooth LoRA Trainer (Z-Image) Run in L40S ")
gr.Markdown("Preload base model & dataset first, then train your LoRA.")
with gr.Row():
with gr.Column():
dataset_repo = gr.Textbox(value="diffusers/dog-example", label="HF Dataset Repo ID")
model_name = gr.Textbox(value="Tongyi-MAI/Z-Image", label="Base Model (HF ID)")
hub_model_id = gr.Textbox(value="rahul7star/Zimg-Lora-Train", label="HF Hub Model ID for Upload")
output_path = gr.Textbox(value="./workspace/trained-lora", label="Output / Experiment Folder")
with gr.Column():
instance_prompt = gr.Textbox(value="a photo of sks dog", label="Instance Prompt")
validation_prompt = gr.Textbox(value="A photo of sks dog in a bucket", label="Validation Prompt")
resolution = gr.Slider(256, 1024, value=512, step=64, label="Resolution")
train_batch_size = gr.Number(value=1, label="Train Batch Size")
gradient_accumulation_steps = gr.Number(value=4, label="Gradient Accumulation Steps")
learning_rate = gr.Number(value=1e-4, label="Learning Rate")
max_train_steps = gr.Number(value=400, label="Max Train Steps")
guidance_scale = gr.Number(value=5.0, label="Guidance Scale")
lr_scheduler = gr.Dropdown(["constant", "linear", "cosine"], value="constant", label="LR Scheduler")
lr_warmup_steps = gr.Number(value=100, label="LR Warmup Steps")
optimizer = gr.Dropdown(["adamW", "prodigy"], value="adamW", label="Optimizer")
mixed_precision = gr.Dropdown(["no", "fp16", "bf16"], value="bf16", label="Mixed Precision")
gradient_checkpointing = gr.Checkbox(value=True, label="Gradient Checkpointing")
cache_latents = gr.Checkbox(value=True, label="Cache Latents")
use_8bit_adam = gr.Checkbox(value=True, label="Use 8-bit Adam")
do_fp8_training = gr.Checkbox(value=False, label="FP8 Training (A100/H100 only)")
push_to_hub = gr.Checkbox(value=True, label="Push to HuggingFace Hub")
output_logs = gr.Textbox(label="Logs", lines=20)
preload_btn = gr.Button("πŸ”„ Preload Data & Model", elem_classes="preload-button")
train_btn = gr.Button("πŸ”₯ Start Training", elem_classes="train-button")
model_path_state = gr.State()
dataset_path_state = gr.State()
preload_btn.click(
preload_assets,
inputs=[model_name, dataset_repo],
outputs=[output_logs, train_btn, model_path_state, dataset_path_state],
)
train_btn.click(
train_model,
inputs=[
model_path_state,
dataset_path_state,
instance_prompt,
validation_prompt,
resolution,
train_batch_size,
gradient_accumulation_steps,
learning_rate,
max_train_steps,
guidance_scale,
lr_scheduler,
lr_warmup_steps,
optimizer,
mixed_precision,
gradient_checkpointing,
cache_latents,
use_8bit_adam,
do_fp8_training,
push_to_hub,
hub_model_id,
output_path
],
outputs=output_logs
)
demo.launch()
# ======================================================
# GRADIO UI
# ======================================================
with gr.Blocks(title="DreamBooth LoRA Trainer (Z-Image) Run in L40S ") as demo:
gr.Markdown("# πŸš€ DreamBooth LoRA Trainer (Z-Image)")
gr.Markdown("Preload base model & dataset first, then train your LoRA.")
with gr.Row():
with gr.Column():
dataset_repo = gr.Textbox(value="diffusers/dog-example", label="HF Dataset Repo ID")
model_name = gr.Textbox(value="Tongyi-MAI/Z-Image", label="Base Model (HF ID)")
hub_model_id = gr.Textbox(value="rahul7star/trained-lora", label="HF Hub Model ID for Upload")
output_path = gr.Textbox(value="./workspace/trained-lora", label="Output / Experiment Folder")
with gr.Column():
instance_prompt = gr.Textbox(value="a photo of sks dog", label="Instance Prompt")
validation_prompt = gr.Textbox(value="A photo of sks dog in a bucket", label="Validation Prompt")
resolution = gr.Slider(256, 1024, value=512, step=64, label="Resolution")
train_batch_size = gr.Number(value=1, label="Train Batch Size")
gradient_accumulation_steps = gr.Number(value=4, label="Gradient Accumulation Steps")
learning_rate = gr.Number(value=1e-4, label="Learning Rate")
max_train_steps = gr.Number(value=400, label="Max Train Steps")
guidance_scale = gr.Number(value=5.0, label="Guidance Scale")
lr_scheduler = gr.Dropdown(["constant", "linear", "cosine"], value="constant", label="LR Scheduler")
lr_warmup_steps = gr.Number(value=100, label="LR Warmup Steps")
optimizer = gr.Dropdown(["adamW", "prodigy"], value="adamW", label="Optimizer")
mixed_precision = gr.Dropdown(["no", "fp16", "bf16"], value="bf16", label="Mixed Precision")
gradient_checkpointing = gr.Checkbox(value=True, label="Gradient Checkpointing")
cache_latents = gr.Checkbox(value=True, label="Cache Latents")
use_8bit_adam = gr.Checkbox(value=True, label="Use 8-bit Adam")
do_fp8_training = gr.Checkbox(value=False, label="FP8 Training (A100/H100 only)")
push_to_hub = gr.Checkbox(value=True, label="Push to HuggingFace Hub")
output_logs = gr.Textbox(label="Logs", lines=20)
preload_btn = gr.Button("πŸ”„ Preload Data & Model", elem_classes="preload-button")
train_btn = gr.Button("πŸ”₯ Start Training", elem_classes="train-button")
model_path_state = gr.State()
dataset_path_state = gr.State()
preload_btn.click(
preload_assets,
inputs=[model_name, dataset_repo],
outputs=[output_logs, train_btn, model_path_state, dataset_path_state],
)
train_btn.click(
train_model,
inputs=[
model_path_state,
dataset_path_state,
instance_prompt,
validation_prompt,
resolution,
train_batch_size,
gradient_accumulation_steps,
learning_rate,
max_train_steps,
guidance_scale,
lr_scheduler,
lr_warmup_steps,
optimizer,
mixed_precision,
gradient_checkpointing,
cache_latents,
use_8bit_adam,
do_fp8_training,
push_to_hub,
hub_model_id,
output_path
],
outputs=output_logs
)
demo.launch()