ModelTrainer / app.py
ethanpwood29's picture
Update app.py
9cd4e10 verified
# app.py
import gradio as gr
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import (
AutoTokenizer,
CLIPTextModel,
)
from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
AutoencoderKL,
DDPMScheduler,
)
from diffusers.optimization import get_scheduler
from datasets import load_dataset, Dataset
from huggingface_hub import login, HfApi, Repository
from pathlib import Path
import os
import zipfile
from PIL import Image
import pandas as pd
import math
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from tqdm.auto import tqdm
import torch.nn.functional as F
# Set up logging
logger = get_logger(__name__)
def create_app():
with gr.Blocks() as demo:
gr.Markdown("# Stable Diffusion Fine-Tuning Application")
# Remove the Authentication Box since we'll use the environment variable
"""
# Authentication
with gr.Box():
gr.Markdown("## Hugging Face Authentication")
hf_token = gr.Textbox(
label="Hugging Face API Token",
placeholder="Enter your Hugging Face API token with write permissions",
type="password",
)
"""
# Model Selection
with gr.Row():
base_model = gr.Textbox(
label="Base Model Name",
placeholder="e.g., CompVis/stable-diffusion-v1-4",
value="stabilityai/stable-diffusion-2-1-base",
)
output_model_name = gr.Textbox(
label="Output Model Repository Name",
placeholder="Enter a unique name for your fine-tuned model (e.g., username/my-fine-tuned-model)",
)
# Dataset Selection
with gr.Group():
gr.Markdown("## Dataset Selection")
dataset_source = gr.Radio(
label="Dataset Source",
choices=["Select from Hugging Face", "Upload your own"],
value="Select from Hugging Face",
)
dataset_name = gr.Textbox(
label="Dataset Name (from Hugging Face Hub)",
placeholder="Enter dataset path, e.g., username/dataset_name",
visible=True,
)
dataset_viewer_toggle = gr.Checkbox(
label="Preview Dataset",
value=False,
)
dataset_preview = gr.Gallery(
label="Dataset Preview",
visible=False,
height='auto',
)
dataset_upload = gr.File(
label="Upload Dataset (ZIP file containing images and annotations)",
file_types=[".zip"],
visible=False,
)
def toggle_dataset_source(choice):
return {
dataset_name: gr.update(visible=choice == "Select from Hugging Face"),
dataset_upload: gr.update(visible=choice == "Upload your own"),
dataset_viewer_toggle: gr.update(visible=choice == "Select from Hugging Face"),
}
dataset_source.change(
fn=toggle_dataset_source,
inputs=dataset_source,
outputs=[dataset_name, dataset_upload, dataset_viewer_toggle],
)
# Column Mapping
with gr.Group():
gr.Markdown("## Column Mapping")
image_column = gr.Textbox(
label="Image Column Name",
placeholder="Column name for images",
value="image",
)
caption_column = gr.Textbox(
label="Caption Column Name",
placeholder="Column name for captions",
value="text",
)
# Training Parameters
with gr.Group():
gr.Markdown("## Training Parameters")
with gr.Row():
num_train_epochs = gr.Slider(
label="Number of Training Epochs",
minimum=1,
maximum=100,
value=1,
step=1,
)
max_train_steps = gr.Number(
label="Max Training Steps",
value=1000,
)
train_batch_size = gr.Number(
label="Train Batch Size",
value=4,
)
with gr.Row():
learning_rate = gr.Number(
label="Learning Rate",
value=5e-6,
)
gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
value=1,
)
checkpointing_steps = gr.Number(
label="Checkpointing Steps",
value=500,
)
with gr.Row():
mixed_precision = gr.Radio(
label="Mixed Precision",
choices=["no", "fp16", "bf16"],
value="fp16",
)
use_8bit_adam = gr.Checkbox(
label="Use 8-bit Adam Optimizer",
value=True,
)
use_xformers = gr.Checkbox(
label="Enable XFormers Memory Efficient Attention",
value=True,
)
with gr.Row():
resolution = gr.Slider(
label="Image Resolution",
minimum=256,
maximum=1024,
value=512,
step=64,
)
lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
value="constant",
)
lr_warmup_steps = gr.Number(
label="Learning Rate Warmup Steps",
value=0,
)
seed = gr.Number(
label="Seed",
value=42,
)
# Start Training Button
start_training = gr.Button("Start Training")
# Output
training_output = gr.Textbox(
label="Training Status",
placeholder="Logs will appear here...",
lines=10,
)
# Dataset Viewer Functionality
def preview_dataset(dataset_name, preview):
if preview:
try:
dataset = load_dataset(dataset_name, split="train")
images = []
for i in range(min(4, len(dataset))):
image = dataset[i][image_column.value]
if not isinstance(image, Image.Image):
image = Image.open(image)
images.append((image, dataset[i][caption_column.value]))
return gr.update(visible=True), images
except Exception as e:
return gr.update(visible=False), f"Error loading dataset: {str(e)}"
else:
return gr.update(visible=False), None
dataset_viewer_toggle.change(
fn=preview_dataset,
inputs=[dataset_name, dataset_viewer_toggle],
outputs=[dataset_preview, dataset_preview],
)
# Training Function
def start_training_fn(
# Removed hf_token from inputs since we're using the environment variable
base_model_name,
output_model_name,
dataset_source,
dataset_name,
dataset_upload,
image_column_name,
caption_column_name,
num_train_epochs,
max_train_steps,
train_batch_size,
learning_rate,
gradient_accumulation_steps,
checkpointing_steps,
mixed_precision,
use_8bit_adam,
use_xformers,
resolution,
lr_scheduler_type,
lr_warmup_steps,
seed,
):
try:
# Get the Hugging Face token from the environment variable
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
return "HUGGINGFACE_TOKEN environment variable not found. Please set it in your Space's secrets."
# Validate inputs
if not base_model_name.strip():
return "Please provide a base model name."
if not output_model_name.strip():
return "Please provide an output model repository name."
# Login to Hugging Face
login(hf_token, add_to_git_credential=True)
api = HfApi()
# Load dataset
if dataset_source == "Select from Hugging Face":
if not dataset_name.strip():
return "Please provide the Hugging Face dataset name."
dataset = load_dataset(dataset_name, split="train")
else:
if dataset_upload is None:
return "Please upload a dataset."
dataset = load_custom_dataset(dataset_upload.name)
# Check if the specified columns exist
if image_column_name not in dataset.column_names:
return f"Image column '{image_column_name}' not found in the dataset."
if caption_column_name not in dataset.column_names:
return f"Caption column '{caption_column_name}' not found in the dataset."
# Preprocess the dataset
dataset = preprocess_dataset(dataset, image_column_name, caption_column_name, resolution)
# Start training
result = train_model(
hf_token=hf_token,
base_model_name=base_model_name,
dataset=dataset,
output_model_name=output_model_name,
num_train_epochs=int(num_train_epochs),
max_train_steps=int(max_train_steps),
train_batch_size=int(train_batch_size),
learning_rate=float(learning_rate),
gradient_accumulation_steps=int(gradient_accumulation_steps),
checkpointing_steps=int(checkpointing_steps),
mixed_precision=mixed_precision,
use_8bit_adam=use_8bit_adam,
use_xformers=use_xformers,
lr_scheduler_type=lr_scheduler_type,
lr_warmup_steps=int(lr_warmup_steps),
resolution=int(resolution),
seed=int(seed),
)
return result
except Exception as e:
return f"An error occurred during training: {str(e)}"
start_training.click(
fn=start_training_fn,
inputs=[
# Removed hf_token from inputs
base_model,
output_model_name,
dataset_source,
dataset_name,
dataset_upload,
image_column,
caption_column,
num_train_epochs,
max_train_steps,
train_batch_size,
learning_rate,
gradient_accumulation_steps,
checkpointing_steps,
mixed_precision,
use_8bit_adam,
use_xformers,
resolution,
lr_scheduler,
lr_warmup_steps,
seed,
],
outputs=training_output,
)
return demo
def preprocess_dataset(dataset, image_column_name, caption_column_name, resolution):
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
def process_example(example):
# Load and preprocess image
image = example[image_column_name]
if not isinstance(image, Image.Image):
image = Image.open(image).convert("RGB")
transform = transforms.Compose([
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
image = transform(image)
# Tokenize caption
caption = example[caption_column_name]
tokens = tokenizer(
caption,
truncation=True,
max_length=tokenizer.model_max_length,
padding="max_length",
return_tensors="pt",
)
return {
"pixel_values": image,
"input_ids": tokens.input_ids.squeeze(),
"attention_mask": tokens.attention_mask.squeeze(),
}
# Remove unused columns and map the dataset
columns_to_remove = set(dataset.column_names) - {image_column_name, caption_column_name}
dataset = dataset.map(
process_example,
remove_columns=list(columns_to_remove),
batched=False,
)
return dataset
def load_custom_dataset(zip_file_path):
# Extract the zip file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
extract_path = Path("extracted_dataset")
zip_ref.extractall(extract_path)
# Find images and annotations
image_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
images = []
captions = []
# Assuming there is annotations.csv with 'file_name' and 'caption' columns
annotations_file = extract_path / 'annotations.csv'
if not annotations_file.exists():
raise ValueError("annotations.csv not found in the dataset.")
annotations = pd.read_csv(annotations_file)
if 'file_name' not in annotations.columns or 'caption' not in annotations.columns:
raise ValueError("annotations.csv must contain 'file_name' and 'caption' columns.")
for idx, row in annotations.iterrows():
image_path = extract_path / row['file_name']
if image_path.exists():
images.append(str(image_path))
captions.append(row['caption'])
else:
raise ValueError(f"Image file {row['file_name']} not found in the dataset.")
# Create dataset
data = {
"image": images,
"text": captions,
}
dataset = Dataset.from_dict(data)
return dataset
def train_model(
hf_token,
base_model_name,
dataset,
output_model_name,
num_train_epochs,
max_train_steps,
train_batch_size,
learning_rate,
gradient_accumulation_steps,
checkpointing_steps,
mixed_precision,
use_8bit_adam,
use_xformers,
lr_scheduler_type,
lr_warmup_steps,
resolution,
seed,
):
# Set seed for reproducibility
set_seed(seed)
# Initialize Accelerator
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision,
)
# Handle xformers
if use_xformers:
try:
import xformers
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
xformers_available = True
except ImportError:
xformers_available = False
print("xformers is not available. Please install it or disable xformers.")
# Load tokenizer and models
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
base_model_name,
subfolder="text_encoder",
)
vae = AutoencoderKL.from_pretrained(
base_model_name,
subfolder="vae",
revision=None,
)
unet = UNet2DConditionModel.from_pretrained(
base_model_name,
subfolder="unet",
revision=None,
)
# Freeze vae and text_encoder
vae.eval()
text_encoder.eval()
for param in vae.parameters():
param.requires_grad = False
for param in text_encoder.parameters():
param.requires_grad = False
# Enable xformers
if use_xformers:
if xformers_available:
unet.enable_xformers_memory_efficient_attention()
else:
return "Error: xformers is not installed. Please install xformers or disable it."
# Prepare optimizer
if use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
return "Error: bitsandbytes is not installed. Please install bitsandbytes or disable 8-bit Adam."
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
unet.parameters(),
lr=learning_rate,
)
# Prepare data loader
train_dataloader = DataLoader(
dataset, batch_size=train_batch_size, shuffle=True, num_workers=4
)
# Calculate total training steps
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if max_train_steps is None or max_train_steps == 0:
max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
else:
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# Prepare learning rate scheduler
lr_scheduler = get_scheduler(
lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
num_training_steps=max_train_steps * gradient_accumulation_steps,
)
# Prepare everything with accelerator
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)
# Move first sample to device to check for any errors
try:
batch = next(iter(train_dataloader))
batch['pixel_values'] = batch['pixel_values'].to(accelerator.device)
batch['input_ids'] = batch['input_ids'].to(accelerator.device)
batch['attention_mask'] = batch['attention_mask'].to(accelerator.device)
except Exception as e:
return f"Error in moving batch to device: {str(e)}"
# Set up the noise scheduler
noise_scheduler = DDPMScheduler.from_config(base_model_name, subfolder="scheduler")
# Training loop
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
print("***** Running training *****")
print(f" Num examples = {len(dataset)}")
print(f" Num Epochs = {num_train_epochs}")
print(f" Instantaneous batch size per device = {train_batch_size}")
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f" Total optimization steps = {max_train_steps}")
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Training")
global_step = 0
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=accelerator.dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise to add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Compute loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
# Update the model parameters
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Logging
if accelerator.is_main_process:
progress_bar.update(1)
progress_bar.set_postfix(loss=loss.item())
global_step += 1
if global_step % checkpointing_steps == 0:
# Save a checkpoint
save_path = f"{output_model_name}_checkpoint_{global_step}"
accelerator.save_state(save_path)
if global_step >= max_train_steps:
break
if global_step >= max_train_steps:
break
# Save the final model
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=DDPMScheduler.from_config(base_model_name, subfolder="scheduler"),
safety_checker=None,
feature_extractor=None,
)
pipeline.save_pretrained(output_model_name)
# Upload to Hugging Face Hub
api = HfApi()
repo_url = api.create_repo(
name=output_model_name,
token=hf_token,
private=False,
exist_ok=True,
)
repo = Repository(output_model_name, clone_from=repo_url)
repo.push_to_hub(commit_message=f"Fine-tuned model at step {global_step}")
return f"Training complete. The model has been uploaded to Hugging Face Hub at {repo_url}"
app = create_app()
# Start the Gradio app
if __name__ == "__main__":
app.launch()