|
|
import gradio as gr |
|
|
import os |
|
|
import subprocess |
|
|
import shutil |
|
|
import json |
|
|
import time |
|
|
from pathlib import Path |
|
|
import torch |
|
|
|
|
|
|
|
|
DATASET_DIR = Path("./datasets") |
|
|
OUTPUT_DIR = Path("./output") |
|
|
DATASET_DIR.mkdir(exist_ok=True) |
|
|
OUTPUT_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
current_dataset_path = None |
|
|
|
|
|
def check_gpu(): |
|
|
"""Check if GPU is available""" |
|
|
if torch.cuda.is_available(): |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
return f"β
GPU Available: {gpu_name}" |
|
|
return "β οΈ No GPU detected - training will be slow" |
|
|
|
|
|
def upload_and_prepare_dataset(files, dataset_name, trigger_word): |
|
|
"""Upload images and prepare dataset""" |
|
|
global current_dataset_path |
|
|
|
|
|
if not files: |
|
|
return "β Please upload at least one image", None, "" |
|
|
|
|
|
if not dataset_name: |
|
|
dataset_name = f"dataset_{int(time.time())}" |
|
|
|
|
|
|
|
|
dataset_path = DATASET_DIR / dataset_name |
|
|
dataset_path.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
image_count = 0 |
|
|
for file in files: |
|
|
if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')): |
|
|
filename = Path(file.name).name |
|
|
destination = dataset_path / filename |
|
|
shutil.copy(file.name, destination) |
|
|
|
|
|
|
|
|
caption_file = destination.with_suffix('.txt') |
|
|
caption_text = trigger_word if trigger_word else "a photo" |
|
|
with open(caption_file, 'w') as f: |
|
|
f.write(caption_text) |
|
|
|
|
|
image_count += 1 |
|
|
|
|
|
if image_count == 0: |
|
|
return "β No valid images found. Upload PNG, JPG, JPEG, or WEBP files.", None, "" |
|
|
|
|
|
current_dataset_path = str(dataset_path) |
|
|
|
|
|
status = f"β
Successfully uploaded {image_count} images\n" |
|
|
status += f"π Dataset: {dataset_name}\n" |
|
|
if trigger_word: |
|
|
status += f"π·οΈ Trigger word: '{trigger_word}'\n" |
|
|
status += f"πΎ Location: {current_dataset_path}" |
|
|
|
|
|
return status, current_dataset_path, f"Dataset ready: {dataset_name}" |
|
|
|
|
|
def train_lora( |
|
|
dataset_path, |
|
|
project_name, |
|
|
trigger_word, |
|
|
steps, |
|
|
learning_rate, |
|
|
lora_rank, |
|
|
resolution, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
"""Train LoRA model""" |
|
|
|
|
|
if not dataset_path or not os.path.exists(dataset_path): |
|
|
return "β Please upload a dataset first!", None |
|
|
|
|
|
if not project_name: |
|
|
project_name = f"lora_{int(time.time())}" |
|
|
|
|
|
output_path = OUTPUT_DIR / project_name |
|
|
output_path.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
config = { |
|
|
"job": "extension", |
|
|
"config": { |
|
|
"name": project_name, |
|
|
"process": [{ |
|
|
"type": "sd_trainer", |
|
|
"training_folder": str(output_path), |
|
|
"device": "cuda:0", |
|
|
"trigger_word": trigger_word or "", |
|
|
"network": { |
|
|
"type": "lora", |
|
|
"linear": int(lora_rank), |
|
|
"linear_alpha": int(lora_rank), |
|
|
}, |
|
|
"save": { |
|
|
"dtype": "float16", |
|
|
"save_every": max(100, int(steps / 4)), |
|
|
"max_step_saves_to_keep": 3, |
|
|
}, |
|
|
"datasets": [{ |
|
|
"folder_path": dataset_path, |
|
|
"caption_ext": "txt", |
|
|
"caption_dropout_rate": 0.05, |
|
|
"resolution": [int(resolution), int(resolution)], |
|
|
}], |
|
|
"train": { |
|
|
"batch_size": 1, |
|
|
"steps": int(steps), |
|
|
"gradient_accumulation_steps": 1, |
|
|
"train_unet": True, |
|
|
"train_text_encoder": False, |
|
|
"gradient_checkpointing": True, |
|
|
"noise_scheduler": "flowmatch", |
|
|
"optimizer": "adamw8bit", |
|
|
"lr": float(learning_rate), |
|
|
"ema_config": { |
|
|
"use_ema": True, |
|
|
"ema_decay": 0.99, |
|
|
}, |
|
|
"dtype": "bf16", |
|
|
}, |
|
|
"model": { |
|
|
"name_or_path": "Tongyi-MAI/Z-Image-Base", |
|
|
"is_v_pred": False, |
|
|
"quantize": True, |
|
|
}, |
|
|
"sample": { |
|
|
"sampler": "flowmatch", |
|
|
"sample_every": max(100, int(steps / 4)), |
|
|
"width": int(resolution), |
|
|
"height": int(resolution), |
|
|
"prompts": [ |
|
|
f"{trigger_word} high quality photo" if trigger_word else "high quality photo", |
|
|
f"{trigger_word} beautiful scene" if trigger_word else "beautiful scene", |
|
|
], |
|
|
"neg": "", |
|
|
"seed": 42, |
|
|
"guidance_scale": 0.0, |
|
|
"sample_steps": 9, |
|
|
}, |
|
|
}] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
config_path = output_path / "config.json" |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
progress(0.1, desc="Installing AI Toolkit...") |
|
|
|
|
|
|
|
|
if not Path("./ai-toolkit").exists(): |
|
|
try: |
|
|
subprocess.run( |
|
|
["git", "clone", "https://github.com/ostris/ai-toolkit.git"], |
|
|
check=True, |
|
|
capture_output=True |
|
|
) |
|
|
os.chdir("ai-toolkit") |
|
|
subprocess.run( |
|
|
["git", "submodule", "update", "--init", "--recursive"], |
|
|
check=True, |
|
|
capture_output=True |
|
|
) |
|
|
subprocess.run( |
|
|
["pip", "install", "-q", "-r", "requirements.txt"], |
|
|
check=True |
|
|
) |
|
|
os.chdir("..") |
|
|
except Exception as e: |
|
|
return f"β Failed to install AI Toolkit: {str(e)}", None |
|
|
|
|
|
progress(0.3, desc="Starting training...") |
|
|
|
|
|
|
|
|
try: |
|
|
result = subprocess.run( |
|
|
["python", "ai-toolkit/run.py", str(config_path)], |
|
|
capture_output=True, |
|
|
text=True, |
|
|
timeout=3600 |
|
|
) |
|
|
|
|
|
if result.returncode != 0: |
|
|
return f"β Training failed:\n{result.stderr}", None |
|
|
|
|
|
progress(0.9, desc="Training complete! Finding LoRA file...") |
|
|
|
|
|
|
|
|
lora_files = list(output_path.glob("*.safetensors")) |
|
|
if lora_files: |
|
|
lora_file = lora_files[-1] |
|
|
success_msg = f"β
Training Complete!\n\n" |
|
|
success_msg += f"π¦ LoRA saved: {lora_file.name}\n" |
|
|
success_msg += f"πΎ Size: {lora_file.stat().st_size / (1024*1024):.2f} MB\n" |
|
|
success_msg += f"π·οΈ Use trigger word: '{trigger_word}' in your prompts" |
|
|
return success_msg, str(lora_file) |
|
|
else: |
|
|
return "β οΈ Training completed but no LoRA file found", None |
|
|
|
|
|
except subprocess.TimeoutExpired: |
|
|
return "β Training timeout (> 1 hour). Try reducing steps.", None |
|
|
except Exception as e: |
|
|
return f"β Training error: {str(e)}", None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Z-Image LoRA Trainer", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π¨ Z-Image LoRA Trainer |
|
|
|
|
|
Train custom LoRA models for Z-Image-Base (6B parameter model) |
|
|
|
|
|
**Quick Start:** |
|
|
1. Upload 10-50 images of your subject |
|
|
2. Enter a trigger word (e.g., "mycharacter", "mystyle") |
|
|
3. Click Train |
|
|
4. Download your LoRA when complete |
|
|
|
|
|
β οΈ **Note:** Training takes 10-30 minutes depending on steps. Don't close this tab! |
|
|
""") |
|
|
|
|
|
|
|
|
gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False) |
|
|
|
|
|
with gr.Tab("π€ Upload Dataset"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input = gr.Files( |
|
|
label="Upload Images (10-50 recommended)", |
|
|
file_types=["image"], |
|
|
file_count="multiple" |
|
|
) |
|
|
dataset_name_input = gr.Textbox( |
|
|
label="Dataset Name", |
|
|
placeholder="my_dataset", |
|
|
value="my_dataset" |
|
|
) |
|
|
trigger_word_input = gr.Textbox( |
|
|
label="Trigger Word (optional but recommended)", |
|
|
placeholder="e.g., mycharacter, mystyle", |
|
|
info="A unique word to activate your LoRA" |
|
|
) |
|
|
upload_btn = gr.Button("π€ Upload Dataset", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
upload_status = gr.Textbox(label="Upload Status", lines=8) |
|
|
dataset_path_state = gr.Textbox(label="Dataset Path", visible=False) |
|
|
dataset_ready = gr.Textbox(label="Ready to Train", interactive=False) |
|
|
|
|
|
with gr.Tab("π Train LoRA"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
project_name_input = gr.Textbox( |
|
|
label="Project Name", |
|
|
placeholder="my_lora", |
|
|
value="my_lora" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Training Settings") |
|
|
|
|
|
steps_input = gr.Slider( |
|
|
label="Training Steps", |
|
|
minimum=100, |
|
|
maximum=3000, |
|
|
value=1000, |
|
|
step=100, |
|
|
info="More steps = better quality but slower. Start with 1000." |
|
|
) |
|
|
|
|
|
learning_rate_input = gr.Slider( |
|
|
label="Learning Rate", |
|
|
minimum=0.00001, |
|
|
maximum=0.001, |
|
|
value=0.0001, |
|
|
step=0.00001, |
|
|
info="Default 0.0001 works well for most cases" |
|
|
) |
|
|
|
|
|
lora_rank_input = gr.Slider( |
|
|
label="LoRA Rank", |
|
|
minimum=4, |
|
|
maximum=128, |
|
|
value=16, |
|
|
step=4, |
|
|
info="Higher = more detail but larger file. 16 is balanced." |
|
|
) |
|
|
|
|
|
resolution_input = gr.Radio( |
|
|
label="Resolution", |
|
|
choices=[512, 768, 1024], |
|
|
value=1024, |
|
|
info="Z-Image native resolution is 1024x1024" |
|
|
) |
|
|
|
|
|
train_btn = gr.Button("π Start Training", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
training_status = gr.Textbox(label="Training Status", lines=15) |
|
|
lora_output = gr.File(label="Download Trained LoRA") |
|
|
|
|
|
with gr.Tab("βΉοΈ Help"): |
|
|
gr.Markdown(""" |
|
|
## π How to Use |
|
|
|
|
|
### Step 1: Prepare Your Images |
|
|
- **10-50 images** of your subject (more is better for complex subjects) |
|
|
- **Consistent subject** across images |
|
|
- **Good variety** in poses, angles, lighting |
|
|
- **High quality** photos (clear, well-lit) |
|
|
|
|
|
### Step 2: Upload Dataset |
|
|
- Choose a descriptive **dataset name** |
|
|
- Add a **trigger word** (e.g., "sks person", "mystyle") |
|
|
- Upload your images |
|
|
|
|
|
### Step 3: Configure Training |
|
|
- **Project name**: Name for your LoRA |
|
|
- **Steps**: |
|
|
- 500-1000 for simple subjects |
|
|
- 1000-2000 for complex subjects/styles |
|
|
- **Learning rate**: Keep default (0.0001) |
|
|
- **LoRA Rank**: 16 is good for most cases |
|
|
|
|
|
### Step 4: Train |
|
|
- Click "Start Training" |
|
|
- Wait 10-30 minutes (don't close tab) |
|
|
- Download your LoRA when complete |
|
|
|
|
|
### Step 5: Use Your LoRA |
|
|
- Load in ComfyUI, Automatic1111, or other Z-Image tools |
|
|
- Use your trigger word in prompts |
|
|
- Example: "a photo of [trigger_word] in a forest" |
|
|
|
|
|
## π― Tips for Best Results |
|
|
|
|
|
- **Good dataset** = good results |
|
|
- **Consistent subject** across images |
|
|
- **Unique trigger word** (not common words) |
|
|
- **Start with 1000 steps**, adjust if needed |
|
|
- **Don't overtrain** (if quality decreases, reduce steps) |
|
|
|
|
|
## β οΈ Troubleshooting |
|
|
|
|
|
**Training fails with OOM error:** |
|
|
- Reduce resolution to 768 or 512 |
|
|
- Use fewer steps |
|
|
- Upload fewer images |
|
|
|
|
|
**LoRA doesn't look like subject:** |
|
|
- Upload more images (20-30+) |
|
|
- Increase steps to 1500-2000 |
|
|
- Ensure images are consistent |
|
|
|
|
|
**LoRA is too strong/weak:** |
|
|
- Adjust LoRA weight in your inference tool (0.5-1.5) |
|
|
|
|
|
## π Resources |
|
|
|
|
|
- **Z-Image Model**: [Tongyi-MAI/Z-Image-Base](https://huggingface.co/Tongyi-MAI/Z-Image-Base) |
|
|
- **AI Toolkit**: [github.com/ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) |
|
|
- **Training Adapter**: [ostris/zimage_turbo_training_adapter](https://huggingface.co/ostris/zimage_turbo_training_adapter) |
|
|
""") |
|
|
|
|
|
|
|
|
upload_btn.click( |
|
|
fn=upload_and_prepare_dataset, |
|
|
inputs=[file_input, dataset_name_input, trigger_word_input], |
|
|
outputs=[upload_status, dataset_path_state, dataset_ready] |
|
|
) |
|
|
|
|
|
train_btn.click( |
|
|
fn=train_lora, |
|
|
inputs=[ |
|
|
dataset_path_state, |
|
|
project_name_input, |
|
|
trigger_word_input, |
|
|
steps_input, |
|
|
learning_rate_input, |
|
|
lora_rank_input, |
|
|
resolution_input |
|
|
], |
|
|
outputs=[training_status, lora_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |