Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import subprocess | |
| import os | |
| from pathlib import Path | |
| from accelerate.utils import write_basic_config | |
| from huggingface_hub import snapshot_download | |
| import shutil | |
| import re | |
| import threading | |
| from huggingface_hub import HfApi, create_repo | |
| import subprocess | |
| import os | |
| from pathlib import Path | |
| import spaces | |
| import gradio as gr | |
| import subprocess | |
| import os | |
| from pathlib import Path | |
| from accelerate.utils import write_basic_config | |
| from huggingface_hub import snapshot_download | |
| import shutil | |
| import re | |
| from huggingface_hub import HfApi, create_repo | |
| api = HfApi() | |
| # ====================================================== | |
| # CONFIG | |
| # ====================================================== | |
| BASE_DIR = Path("./workspace") | |
| BASE_DIR.mkdir(exist_ok=True) | |
| CACHE_DIR = "./hf_cache" | |
| DIFFUSERS_REPO = "https://github.com/huggingface/diffusers.git" | |
| DIFFUSERS_LOCAL = "./diffusers" | |
| write_basic_config() | |
| # ====================================================== | |
| # HELPERS | |
| # ====================================================== | |
| # ====================================================== | |
| # PRELOAD FUNCTION | |
| # ====================================================== | |
| def preload_assets(model_name, dataset_repo, log_func=print): | |
| logs = "" | |
| def append_log(msg): | |
| nonlocal logs | |
| logs += msg | |
| log_func(msg) | |
| try: | |
| clone_diffusers(log_func=append_log) | |
| model_path = resolve_model_path(model_name, log_func=append_log) | |
| dataset_path = resolve_dataset_path(dataset_repo, log_func=append_log) | |
| dataset_path = prepare_dataset_folder(dataset_path, log_func=append_log) | |
| append_log("β All assets preloaded. You can now train the model.\n") | |
| return logs, "Ready to Train", model_path, dataset_path | |
| except Exception as e: | |
| append_log(f"β Error during preload: {e}\n") | |
| return logs, "Preload Failed", None, None | |
| def clone_diffusers(log_func=None): | |
| if not os.path.exists(DIFFUSERS_LOCAL): | |
| if log_func: log_func("π Cloning diffusers repo...\n") | |
| subprocess.run(["git", "clone", DIFFUSERS_REPO], check=True) | |
| if log_func: log_func("β Diffusers repo cloned.\n") | |
| def resolve_model_path(model_name_or_path: str, log_func=None) -> str: | |
| if log_func: log_func(f"π Downloading base model: {model_name_or_path} ...\n") | |
| local_path = snapshot_download( | |
| repo_id=model_name_or_path, | |
| repo_type="model", | |
| cache_dir=CACHE_DIR, | |
| ) | |
| if log_func: log_func(f"β Base model downloaded at: {local_path}\n") | |
| return os.path.abspath(local_path) | |
| def resolve_dataset_path(dataset_repo: str, log_func=None) -> str: | |
| if log_func: log_func(f"π Downloading dataset: {dataset_repo} ...\n") | |
| local_path = snapshot_download( | |
| repo_id=dataset_repo, | |
| repo_type="dataset", | |
| cache_dir=CACHE_DIR, | |
| ignore_patterns=".gitattributes", | |
| ) | |
| if log_func: log_func(f"β Dataset downloaded at: {local_path}\n") | |
| return os.path.abspath(local_path) | |
| def prepare_dataset_folder(dataset_path: str, log_func=None) -> str: | |
| clean_path = Path("./workspace/dataset_clean") | |
| if clean_path.exists(): | |
| shutil.rmtree(clean_path) | |
| clean_path.mkdir(parents=True, exist_ok=True) | |
| count = 0 | |
| for file in Path(dataset_path).iterdir(): | |
| if file.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp", ".bmp"]: | |
| shutil.copy(file, clean_path / file.name) | |
| count += 1 | |
| if count == 0: | |
| raise ValueError(f"No image files found in dataset repo: {dataset_path}") | |
| if log_func: log_func(f"β Dataset prepared with {count} images at: {clean_path}\n") | |
| return str(clean_path) | |
| # ====================================================== | |
| # NEW FIX: README CLEANER | |
| # ====================================================== | |
| def fix_readme_metadata(output_path, original_model_id): | |
| readme_path = Path(output_path) / "README.md" | |
| if not readme_path.exists(): | |
| return | |
| content = readme_path.read_text() | |
| # Replace local path base_model with correct HF model ID | |
| content = re.sub( | |
| r'base_model:.*', | |
| f'base_model: {original_model_id}', | |
| content | |
| ) | |
| readme_path.write_text(content) | |
| # ====================================================== | |
| # TRAINING | |
| # ====================================================== | |
| def train_model( | |
| model_path, | |
| dataset_path, | |
| instance_prompt, | |
| validation_prompt, | |
| resolution, | |
| train_batch_size, | |
| gradient_accumulation_steps, | |
| learning_rate, | |
| max_train_steps, | |
| guidance_scale, | |
| lr_scheduler, | |
| lr_warmup_steps, | |
| optimizer, | |
| mixed_precision, | |
| gradient_checkpointing, | |
| cache_latents, | |
| use_8bit_adam, | |
| do_fp8_training, | |
| push_to_hub, | |
| hub_model_id, | |
| output_path, | |
| ): | |
| logs = "" | |
| output_dir = Path(output_path) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| upload_every_steps = 10 | |
| last_uploaded_step = -1 | |
| def add_log(msg): | |
| nonlocal logs | |
| logs += msg | |
| yield logs | |
| try: | |
| yield from add_log("π Starting Training...\n") | |
| if push_to_hub and hub_model_id: | |
| create_repo(repo_id=hub_model_id, exist_ok=True) | |
| yield from add_log(f"β Hub repo ready: {hub_model_id}\n") | |
| original_model_id = model_path # SAVE ORIGINAL ID | |
| cmd = [ | |
| "accelerate", "launch", | |
| "./diffusers/examples/dreambooth/train_dreambooth_lora_z_image.py", | |
| f"--pretrained_model_name_or_path={model_path}", | |
| f"--instance_data_dir={dataset_path}", | |
| f"--output_dir={output_path}", | |
| f"--instance_prompt={instance_prompt}", | |
| f"--validation_prompt={validation_prompt}", | |
| f"--resolution={resolution}", | |
| f"--train_batch_size={train_batch_size}", | |
| f"--gradient_accumulation_steps={gradient_accumulation_steps}", | |
| f"--learning_rate={learning_rate}", | |
| f"--max_train_steps={max_train_steps}", | |
| f"--guidance_scale={guidance_scale}", | |
| f"--lr_scheduler={lr_scheduler}", | |
| f"--lr_warmup_steps={lr_warmup_steps}", | |
| f"--optimizer={optimizer}", | |
| f"--mixed_precision={mixed_precision}", | |
| "--checkpointing_steps=10", | |
| "--seed=0", | |
| ] | |
| if gradient_checkpointing: cmd.append("--gradient_checkpointing") | |
| if cache_latents: cmd.append("--cache_latents") | |
| if use_8bit_adam: cmd.append("--use_8bit_adam") | |
| if do_fp8_training: cmd.append("--do_fp8_training") | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| ) | |
| for line in process.stdout: | |
| yield from add_log(line) | |
| tqdm_match = re.search(r"Steps:.*?(\d+)/(\d+)", line) | |
| if tqdm_match: | |
| current_step = int(tqdm_match.group(1)) | |
| if ( | |
| push_to_hub | |
| and hub_model_id | |
| and current_step % upload_every_steps == 0 | |
| and current_step != last_uploaded_step | |
| ): | |
| last_uploaded_step = current_step | |
| yield from add_log( | |
| f"\nπ¦ Uploading checkpoint at step {current_step}...\n" | |
| ) | |
| #fix_readme_metadata(output_path, original_model_id) | |
| try: | |
| api.upload_folder( | |
| folder_path=output_path, | |
| repo_id=hub_model_id, | |
| repo_type="model", | |
| commit_message=f"Auto upload at step {current_step}", | |
| allow_patterns=["*.safetensors"], | |
| ) | |
| yield from add_log("β Upload completed.\n") | |
| except Exception as upload_error: | |
| yield from add_log(f"β Upload failed: {upload_error}\n") | |
| process.wait() | |
| yield from add_log("π Training completed.\n") | |
| finally: | |
| if push_to_hub and hub_model_id: | |
| yield from add_log("\nπ¦ Final upload attempt...\n") | |
| #fix_readme_metadata(output_path, original_model_id) | |
| api.upload_folder( | |
| folder_path=output_path, | |
| repo_id=hub_model_id, | |
| repo_type="model", | |
| commit_message="Final upload", | |
| allow_patterns=["*.safetensors"], | |
| ) | |
| yield from add_log("β Final upload completed.\n") | |
| yield logs | |
| # ====================================================== | |
| # GRADIO UI | |
| # ====================================================== | |
| with gr.Blocks(title="DreamBooth LoRA Trainer (Z-Image)- L40s and above") as demo: | |
| gr.Markdown("# π DreamBooth LoRA Trainer (Z-Image) Run in L40S ") | |
| gr.Markdown("Preload base model & dataset first, then train your LoRA.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| dataset_repo = gr.Textbox(value="diffusers/dog-example", label="HF Dataset Repo ID") | |
| model_name = gr.Textbox(value="Tongyi-MAI/Z-Image", label="Base Model (HF ID)") | |
| hub_model_id = gr.Textbox(value="rahul7star/Zimg-Lora-Train", label="HF Hub Model ID for Upload") | |
| output_path = gr.Textbox(value="./workspace/trained-lora", label="Output / Experiment Folder") | |
| with gr.Column(): | |
| instance_prompt = gr.Textbox(value="a photo of sks dog", label="Instance Prompt") | |
| validation_prompt = gr.Textbox(value="A photo of sks dog in a bucket", label="Validation Prompt") | |
| resolution = gr.Slider(256, 1024, value=512, step=64, label="Resolution") | |
| train_batch_size = gr.Number(value=1, label="Train Batch Size") | |
| gradient_accumulation_steps = gr.Number(value=4, label="Gradient Accumulation Steps") | |
| learning_rate = gr.Number(value=1e-4, label="Learning Rate") | |
| max_train_steps = gr.Number(value=400, label="Max Train Steps") | |
| guidance_scale = gr.Number(value=5.0, label="Guidance Scale") | |
| lr_scheduler = gr.Dropdown(["constant", "linear", "cosine"], value="constant", label="LR Scheduler") | |
| lr_warmup_steps = gr.Number(value=100, label="LR Warmup Steps") | |
| optimizer = gr.Dropdown(["adamW", "prodigy"], value="adamW", label="Optimizer") | |
| mixed_precision = gr.Dropdown(["no", "fp16", "bf16"], value="bf16", label="Mixed Precision") | |
| gradient_checkpointing = gr.Checkbox(value=True, label="Gradient Checkpointing") | |
| cache_latents = gr.Checkbox(value=True, label="Cache Latents") | |
| use_8bit_adam = gr.Checkbox(value=True, label="Use 8-bit Adam") | |
| do_fp8_training = gr.Checkbox(value=False, label="FP8 Training (A100/H100 only)") | |
| push_to_hub = gr.Checkbox(value=True, label="Push to HuggingFace Hub") | |
| output_logs = gr.Textbox(label="Logs", lines=20) | |
| preload_btn = gr.Button("π Preload Data & Model", elem_classes="preload-button") | |
| train_btn = gr.Button("π₯ Start Training", elem_classes="train-button") | |
| model_path_state = gr.State() | |
| dataset_path_state = gr.State() | |
| preload_btn.click( | |
| preload_assets, | |
| inputs=[model_name, dataset_repo], | |
| outputs=[output_logs, train_btn, model_path_state, dataset_path_state], | |
| ) | |
| train_btn.click( | |
| train_model, | |
| inputs=[ | |
| model_path_state, | |
| dataset_path_state, | |
| instance_prompt, | |
| validation_prompt, | |
| resolution, | |
| train_batch_size, | |
| gradient_accumulation_steps, | |
| learning_rate, | |
| max_train_steps, | |
| guidance_scale, | |
| lr_scheduler, | |
| lr_warmup_steps, | |
| optimizer, | |
| mixed_precision, | |
| gradient_checkpointing, | |
| cache_latents, | |
| use_8bit_adam, | |
| do_fp8_training, | |
| push_to_hub, | |
| hub_model_id, | |
| output_path | |
| ], | |
| outputs=output_logs | |
| ) | |
| demo.launch() | |
| # ====================================================== | |
| # GRADIO UI | |
| # ====================================================== | |
| with gr.Blocks(title="DreamBooth LoRA Trainer (Z-Image) Run in L40S ") as demo: | |
| gr.Markdown("# π DreamBooth LoRA Trainer (Z-Image)") | |
| gr.Markdown("Preload base model & dataset first, then train your LoRA.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| dataset_repo = gr.Textbox(value="diffusers/dog-example", label="HF Dataset Repo ID") | |
| model_name = gr.Textbox(value="Tongyi-MAI/Z-Image", label="Base Model (HF ID)") | |
| hub_model_id = gr.Textbox(value="rahul7star/trained-lora", label="HF Hub Model ID for Upload") | |
| output_path = gr.Textbox(value="./workspace/trained-lora", label="Output / Experiment Folder") | |
| with gr.Column(): | |
| instance_prompt = gr.Textbox(value="a photo of sks dog", label="Instance Prompt") | |
| validation_prompt = gr.Textbox(value="A photo of sks dog in a bucket", label="Validation Prompt") | |
| resolution = gr.Slider(256, 1024, value=512, step=64, label="Resolution") | |
| train_batch_size = gr.Number(value=1, label="Train Batch Size") | |
| gradient_accumulation_steps = gr.Number(value=4, label="Gradient Accumulation Steps") | |
| learning_rate = gr.Number(value=1e-4, label="Learning Rate") | |
| max_train_steps = gr.Number(value=400, label="Max Train Steps") | |
| guidance_scale = gr.Number(value=5.0, label="Guidance Scale") | |
| lr_scheduler = gr.Dropdown(["constant", "linear", "cosine"], value="constant", label="LR Scheduler") | |
| lr_warmup_steps = gr.Number(value=100, label="LR Warmup Steps") | |
| optimizer = gr.Dropdown(["adamW", "prodigy"], value="adamW", label="Optimizer") | |
| mixed_precision = gr.Dropdown(["no", "fp16", "bf16"], value="bf16", label="Mixed Precision") | |
| gradient_checkpointing = gr.Checkbox(value=True, label="Gradient Checkpointing") | |
| cache_latents = gr.Checkbox(value=True, label="Cache Latents") | |
| use_8bit_adam = gr.Checkbox(value=True, label="Use 8-bit Adam") | |
| do_fp8_training = gr.Checkbox(value=False, label="FP8 Training (A100/H100 only)") | |
| push_to_hub = gr.Checkbox(value=True, label="Push to HuggingFace Hub") | |
| output_logs = gr.Textbox(label="Logs", lines=20) | |
| preload_btn = gr.Button("π Preload Data & Model", elem_classes="preload-button") | |
| train_btn = gr.Button("π₯ Start Training", elem_classes="train-button") | |
| model_path_state = gr.State() | |
| dataset_path_state = gr.State() | |
| preload_btn.click( | |
| preload_assets, | |
| inputs=[model_name, dataset_repo], | |
| outputs=[output_logs, train_btn, model_path_state, dataset_path_state], | |
| ) | |
| train_btn.click( | |
| train_model, | |
| inputs=[ | |
| model_path_state, | |
| dataset_path_state, | |
| instance_prompt, | |
| validation_prompt, | |
| resolution, | |
| train_batch_size, | |
| gradient_accumulation_steps, | |
| learning_rate, | |
| max_train_steps, | |
| guidance_scale, | |
| lr_scheduler, | |
| lr_warmup_steps, | |
| optimizer, | |
| mixed_precision, | |
| gradient_checkpointing, | |
| cache_latents, | |
| use_8bit_adam, | |
| do_fp8_training, | |
| push_to_hub, | |
| hub_model_id, | |
| output_path | |
| ], | |
| outputs=output_logs | |
| ) | |
| demo.launch() |