| import spaces |
| import gradio as gr |
| import os |
| import subprocess |
| import shutil |
| import json |
| import time |
| from pathlib import Path |
| import torch |
|
|
| |
| DATASET_DIR = Path("./datasets") |
| OUTPUT_DIR = Path("./output") |
| DATASET_DIR.mkdir(exist_ok=True) |
| OUTPUT_DIR.mkdir(exist_ok=True) |
|
|
| |
| current_dataset_path = None |
|
|
| def check_gpu(): |
| """Check if GPU is available""" |
| if torch.cuda.is_available(): |
| gpu_name = torch.cuda.get_device_name(0) |
| return f"β
GPU Available: {gpu_name}" |
| return "β οΈ No GPU detected - training will be slow" |
|
|
| def upload_and_prepare_dataset(files, dataset_name, trigger_word): |
| """Upload images and prepare dataset""" |
| global current_dataset_path |
| |
| if not files: |
| return "β Please upload at least one image", None, "" |
| |
| if not dataset_name: |
| dataset_name = f"dataset_{int(time.time())}" |
| |
| |
| dataset_path = DATASET_DIR / dataset_name |
| dataset_path.mkdir(exist_ok=True, parents=True) |
| |
| |
| image_count = 0 |
| for file in files: |
| if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')): |
| filename = Path(file.name).name |
| destination = dataset_path / filename |
| shutil.copy(file.name, destination) |
| |
| |
| caption_file = destination.with_suffix('.txt') |
| caption_text = trigger_word if trigger_word else "a photo" |
| with open(caption_file, 'w') as f: |
| f.write(caption_text) |
| |
| image_count += 1 |
| |
| if image_count == 0: |
| return "β No valid images found. Upload PNG, JPG, JPEG, or WEBP files.", None, "" |
| |
| current_dataset_path = str(dataset_path) |
| |
| status = f"β
Successfully uploaded {image_count} images\n" |
| status += f"π Dataset: {dataset_name}\n" |
| if trigger_word: |
| status += f"π·οΈ Trigger word: '{trigger_word}'\n" |
| status += f"πΎ Location: {current_dataset_path}" |
| |
| return status, current_dataset_path, f"Dataset ready: {dataset_name}" |
| |
| def train_lora( |
| dataset_path, |
| project_name, |
| trigger_word, |
| steps, |
| learning_rate, |
| lora_rank, |
| resolution, |
| progress=gr.Progress() |
| ): |
| """Train LoRA model""" |
| |
| if not dataset_path or not os.path.exists(dataset_path): |
| return "β Please upload a dataset first!", None |
| |
| if not project_name: |
| project_name = f"lora_{int(time.time())}" |
| |
| output_path = OUTPUT_DIR / project_name |
| output_path.mkdir(exist_ok=True, parents=True) |
| |
| |
| config = { |
| "job": "extension", |
| "config": { |
| "name": project_name, |
| "process": [{ |
| "type": "sd_trainer", |
| "training_folder": str(output_path), |
| "device": "cuda:0", |
| "trigger_word": trigger_word or "", |
| "network": { |
| "type": "lora", |
| "linear": int(lora_rank), |
| "linear_alpha": int(lora_rank), |
| }, |
| "save": { |
| "dtype": "float16", |
| "save_every": max(100, int(steps / 4)), |
| "max_step_saves_to_keep": 3, |
| }, |
| "datasets": [{ |
| "folder_path": dataset_path, |
| "caption_ext": "txt", |
| "caption_dropout_rate": 0.05, |
| "resolution": [int(resolution), int(resolution)], |
| }], |
| "train": { |
| "batch_size": 1, |
| "steps": int(steps), |
| "gradient_accumulation_steps": 1, |
| "train_unet": True, |
| "train_text_encoder": False, |
| "gradient_checkpointing": True, |
| "noise_scheduler": "flowmatch", |
| "optimizer": "adamw8bit", |
| "lr": float(learning_rate), |
| "ema_config": { |
| "use_ema": True, |
| "ema_decay": 0.99, |
| }, |
| "dtype": "bf16", |
| }, |
| "model": { |
| "name_or_path": "Tongyi-MAI/Z-Image-Base", |
| "is_v_pred": False, |
| "quantize": True, |
| }, |
| "sample": { |
| "sampler": "flowmatch", |
| "sample_every": max(100, int(steps / 4)), |
| "width": int(resolution), |
| "height": int(resolution), |
| "prompts": [ |
| f"{trigger_word} high quality photo" if trigger_word else "high quality photo", |
| f"{trigger_word} beautiful scene" if trigger_word else "beautiful scene", |
| ], |
| "neg": "", |
| "seed": 42, |
| "guidance_scale": 0.0, |
| "sample_steps": 9, |
| }, |
| }] |
| } |
| } |
| |
| |
| config_path = output_path / "config.json" |
| with open(config_path, 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| progress(0.1, desc="Installing AI Toolkit...") |
| |
| |
| if not Path("./ai-toolkit").exists(): |
| try: |
| subprocess.run( |
| ["git", "clone", "https://github.com/ostris/ai-toolkit.git"], |
| check=True, |
| capture_output=True |
| ) |
| os.chdir("ai-toolkit") |
| subprocess.run( |
| ["git", "submodule", "update", "--init", "--recursive"], |
| check=True, |
| capture_output=True |
| ) |
| subprocess.run( |
| ["pip", "install", "-q", "-r", "requirements.txt"], |
| check=True |
| ) |
| os.chdir("..") |
| except Exception as e: |
| return f"β Failed to install AI Toolkit: {str(e)}", None |
| |
| progress(0.3, desc="Starting training...") |
| |
| |
| try: |
| result = subprocess.run( |
| ["python", "ai-toolkit/run.py", str(config_path)], |
| capture_output=True, |
| text=True, |
| timeout=3600 |
| ) |
| |
| if result.returncode != 0: |
| return f"β Training failed:\n{result.stderr}", None |
| |
| progress(0.9, desc="Training complete! Finding LoRA file...") |
| |
| |
| lora_files = list(output_path.glob("*.safetensors")) |
| if lora_files: |
| lora_file = lora_files[-1] |
| success_msg = f"β
Training Complete!\n\n" |
| success_msg += f"π¦ LoRA saved: {lora_file.name}\n" |
| success_msg += f"πΎ Size: {lora_file.stat().st_size / (1024*1024):.2f} MB\n" |
| success_msg += f"π·οΈ Use trigger word: '{trigger_word}' in your prompts" |
| return success_msg, str(lora_file) |
| else: |
| return "β οΈ Training completed but no LoRA file found", None |
| |
| except subprocess.TimeoutExpired: |
| return "β Training timeout (> 1 hour). Try reducing steps.", None |
| except Exception as e: |
| return f"β Training error: {str(e)}", None |
|
|
| |
| with gr.Blocks(title="Z-Image LoRA Trainer", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # π¨ Z-Image LoRA Trainer |
| |
| Train custom LoRA models for Z-Image-Base (6B parameter model) |
| |
| **Quick Start:** |
| 1. Upload 10-50 images of your subject |
| 2. Enter a trigger word (e.g., "mycharacter", "mystyle") |
| 3. Click Train |
| 4. Download your LoRA when complete |
| |
| β οΈ **Note:** Training takes 10-30 minutes depending on steps. Don't close this tab! |
| """) |
| |
| |
| gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False) |
| |
| with gr.Tab("π€ Upload Dataset"): |
| with gr.Row(): |
| with gr.Column(): |
| file_input = gr.Files( |
| label="Upload Images (10-50 recommended)", |
| file_types=["image"], |
| file_count="multiple" |
| ) |
| dataset_name_input = gr.Textbox( |
| label="Dataset Name", |
| placeholder="my_dataset", |
| value="my_dataset" |
| ) |
| trigger_word_input = gr.Textbox( |
| label="Trigger Word (optional but recommended)", |
| placeholder="e.g., mycharacter, mystyle", |
| info="A unique word to activate your LoRA" |
| ) |
| upload_btn = gr.Button("π€ Upload Dataset", variant="primary", size="lg") |
| |
| with gr.Column(): |
| upload_status = gr.Textbox(label="Upload Status", lines=8) |
| dataset_path_state = gr.Textbox(label="Dataset Path", visible=False) |
| dataset_ready = gr.Textbox(label="Ready to Train", interactive=False) |
| |
| with gr.Tab("π Train LoRA"): |
| with gr.Row(): |
| with gr.Column(): |
| project_name_input = gr.Textbox( |
| label="Project Name", |
| placeholder="my_lora", |
| value="my_lora" |
| ) |
| |
| gr.Markdown("### Training Settings") |
| |
| steps_input = gr.Slider( |
| label="Training Steps", |
| minimum=100, |
| maximum=3000, |
| value=1000, |
| step=100, |
| info="More steps = better quality but slower. Start with 1000." |
| ) |
| |
| learning_rate_input = gr.Slider( |
| label="Learning Rate", |
| minimum=0.00001, |
| maximum=0.001, |
| value=0.0001, |
| step=0.00001, |
| info="Default 0.0001 works well for most cases" |
| ) |
| |
| lora_rank_input = gr.Slider( |
| label="LoRA Rank", |
| minimum=4, |
| maximum=128, |
| value=16, |
| step=4, |
| info="Higher = more detail but larger file. 16 is balanced." |
| ) |
| |
| resolution_input = gr.Radio( |
| label="Resolution", |
| choices=[512, 768, 1024], |
| value=1024, |
| info="Z-Image native resolution is 1024x1024" |
| ) |
| |
| train_btn = gr.Button("π Start Training", variant="primary", size="lg") |
| |
| with gr.Column(): |
| training_status = gr.Textbox(label="Training Status", lines=15) |
| lora_output = gr.File(label="Download Trained LoRA") |
| |
| with gr.Tab("βΉοΈ Help"): |
| gr.Markdown(""" |
| ## π How to Use |
| |
| ### Step 1: Prepare Your Images |
| - **10-50 images** of your subject (more is better for complex subjects) |
| - **Consistent subject** across images |
| - **Good variety** in poses, angles, lighting |
| - **High quality** photos (clear, well-lit) |
| |
| ### Step 2: Upload Dataset |
| - Choose a descriptive **dataset name** |
| - Add a **trigger word** (e.g., "sks person", "mystyle") |
| - Upload your images |
| |
| ### Step 3: Configure Training |
| - **Project name**: Name for your LoRA |
| - **Steps**: |
| - 500-1000 for simple subjects |
| - 1000-2000 for complex subjects/styles |
| - **Learning rate**: Keep default (0.0001) |
| - **LoRA Rank**: 16 is good for most cases |
| |
| ### Step 4: Train |
| - Click "Start Training" |
| - Wait 10-30 minutes (don't close tab) |
| - Download your LoRA when complete |
| |
| ### Step 5: Use Your LoRA |
| - Load in ComfyUI, Automatic1111, or other Z-Image tools |
| - Use your trigger word in prompts |
| - Example: "a photo of [trigger_word] in a forest" |
| |
| ## π― Tips for Best Results |
| |
| - **Good dataset** = good results |
| - **Consistent subject** across images |
| - **Unique trigger word** (not common words) |
| - **Start with 1000 steps**, adjust if needed |
| - **Don't overtrain** (if quality decreases, reduce steps) |
| |
| ## β οΈ Troubleshooting |
| |
| **Training fails with OOM error:** |
| - Reduce resolution to 768 or 512 |
| - Use fewer steps |
| - Upload fewer images |
| |
| **LoRA doesn't look like subject:** |
| - Upload more images (20-30+) |
| - Increase steps to 1500-2000 |
| - Ensure images are consistent |
| |
| **LoRA is too strong/weak:** |
| - Adjust LoRA weight in your inference tool (0.5-1.5) |
| |
| ## π Resources |
| |
| - **Z-Image Model**: [Tongyi-MAI/Z-Image-Base](https://huggingface.co/Tongyi-MAI/Z-Image-Base) |
| - **AI Toolkit**: [github.com/ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) |
| - **Training Adapter**: [ostris/zimage_turbo_training_adapter](https://huggingface.co/ostris/zimage_turbo_training_adapter) |
| """) |
| |
| |
| upload_btn.click( |
| fn=upload_and_prepare_dataset, |
| inputs=[file_input, dataset_name_input, trigger_word_input], |
| outputs=[upload_status, dataset_path_state, dataset_ready] |
| ) |
| |
| train_btn.click( |
| fn=train_lora, |
| inputs=[ |
| dataset_path_state, |
| project_name_input, |
| trigger_word_input, |
| steps_input, |
| learning_rate_input, |
| lora_rank_input, |
| resolution_input |
| ], |
| outputs=[training_status, lora_output] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |