Spaces:
Build error
Build error
| # app.py | |
| import gradio as gr | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from transformers import ( | |
| AutoTokenizer, | |
| CLIPTextModel, | |
| ) | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| ) | |
| from diffusers.optimization import get_scheduler | |
| from datasets import load_dataset, Dataset | |
| from huggingface_hub import login, HfApi, Repository | |
| from pathlib import Path | |
| import os | |
| import zipfile | |
| from PIL import Image | |
| import pandas as pd | |
| import math | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import set_seed | |
| from tqdm.auto import tqdm | |
| import torch.nn.functional as F | |
| # Set up logging | |
| logger = get_logger(__name__) | |
| def create_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Stable Diffusion Fine-Tuning Application") | |
| # Remove the Authentication Box since we'll use the environment variable | |
| """ | |
| # Authentication | |
| with gr.Box(): | |
| gr.Markdown("## Hugging Face Authentication") | |
| hf_token = gr.Textbox( | |
| label="Hugging Face API Token", | |
| placeholder="Enter your Hugging Face API token with write permissions", | |
| type="password", | |
| ) | |
| """ | |
| # Model Selection | |
| with gr.Row(): | |
| base_model = gr.Textbox( | |
| label="Base Model Name", | |
| placeholder="e.g., CompVis/stable-diffusion-v1-4", | |
| value="stabilityai/stable-diffusion-2-1-base", | |
| ) | |
| output_model_name = gr.Textbox( | |
| label="Output Model Repository Name", | |
| placeholder="Enter a unique name for your fine-tuned model (e.g., username/my-fine-tuned-model)", | |
| ) | |
| # Dataset Selection | |
| with gr.Group(): | |
| gr.Markdown("## Dataset Selection") | |
| dataset_source = gr.Radio( | |
| label="Dataset Source", | |
| choices=["Select from Hugging Face", "Upload your own"], | |
| value="Select from Hugging Face", | |
| ) | |
| dataset_name = gr.Textbox( | |
| label="Dataset Name (from Hugging Face Hub)", | |
| placeholder="Enter dataset path, e.g., username/dataset_name", | |
| visible=True, | |
| ) | |
| dataset_viewer_toggle = gr.Checkbox( | |
| label="Preview Dataset", | |
| value=False, | |
| ) | |
| dataset_preview = gr.Gallery( | |
| label="Dataset Preview", | |
| visible=False, | |
| height='auto', | |
| ) | |
| dataset_upload = gr.File( | |
| label="Upload Dataset (ZIP file containing images and annotations)", | |
| file_types=[".zip"], | |
| visible=False, | |
| ) | |
| def toggle_dataset_source(choice): | |
| return { | |
| dataset_name: gr.update(visible=choice == "Select from Hugging Face"), | |
| dataset_upload: gr.update(visible=choice == "Upload your own"), | |
| dataset_viewer_toggle: gr.update(visible=choice == "Select from Hugging Face"), | |
| } | |
| dataset_source.change( | |
| fn=toggle_dataset_source, | |
| inputs=dataset_source, | |
| outputs=[dataset_name, dataset_upload, dataset_viewer_toggle], | |
| ) | |
| # Column Mapping | |
| with gr.Group(): | |
| gr.Markdown("## Column Mapping") | |
| image_column = gr.Textbox( | |
| label="Image Column Name", | |
| placeholder="Column name for images", | |
| value="image", | |
| ) | |
| caption_column = gr.Textbox( | |
| label="Caption Column Name", | |
| placeholder="Column name for captions", | |
| value="text", | |
| ) | |
| # Training Parameters | |
| with gr.Group(): | |
| gr.Markdown("## Training Parameters") | |
| with gr.Row(): | |
| num_train_epochs = gr.Slider( | |
| label="Number of Training Epochs", | |
| minimum=1, | |
| maximum=100, | |
| value=1, | |
| step=1, | |
| ) | |
| max_train_steps = gr.Number( | |
| label="Max Training Steps", | |
| value=1000, | |
| ) | |
| train_batch_size = gr.Number( | |
| label="Train Batch Size", | |
| value=4, | |
| ) | |
| with gr.Row(): | |
| learning_rate = gr.Number( | |
| label="Learning Rate", | |
| value=5e-6, | |
| ) | |
| gradient_accumulation_steps = gr.Number( | |
| label="Gradient Accumulation Steps", | |
| value=1, | |
| ) | |
| checkpointing_steps = gr.Number( | |
| label="Checkpointing Steps", | |
| value=500, | |
| ) | |
| with gr.Row(): | |
| mixed_precision = gr.Radio( | |
| label="Mixed Precision", | |
| choices=["no", "fp16", "bf16"], | |
| value="fp16", | |
| ) | |
| use_8bit_adam = gr.Checkbox( | |
| label="Use 8-bit Adam Optimizer", | |
| value=True, | |
| ) | |
| use_xformers = gr.Checkbox( | |
| label="Enable XFormers Memory Efficient Attention", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| resolution = gr.Slider( | |
| label="Image Resolution", | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| ) | |
| lr_scheduler = gr.Dropdown( | |
| label="Learning Rate Scheduler", | |
| choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], | |
| value="constant", | |
| ) | |
| lr_warmup_steps = gr.Number( | |
| label="Learning Rate Warmup Steps", | |
| value=0, | |
| ) | |
| seed = gr.Number( | |
| label="Seed", | |
| value=42, | |
| ) | |
| # Start Training Button | |
| start_training = gr.Button("Start Training") | |
| # Output | |
| training_output = gr.Textbox( | |
| label="Training Status", | |
| placeholder="Logs will appear here...", | |
| lines=10, | |
| ) | |
| # Dataset Viewer Functionality | |
| def preview_dataset(dataset_name, preview): | |
| if preview: | |
| try: | |
| dataset = load_dataset(dataset_name, split="train") | |
| images = [] | |
| for i in range(min(4, len(dataset))): | |
| image = dataset[i][image_column.value] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| images.append((image, dataset[i][caption_column.value])) | |
| return gr.update(visible=True), images | |
| except Exception as e: | |
| return gr.update(visible=False), f"Error loading dataset: {str(e)}" | |
| else: | |
| return gr.update(visible=False), None | |
| dataset_viewer_toggle.change( | |
| fn=preview_dataset, | |
| inputs=[dataset_name, dataset_viewer_toggle], | |
| outputs=[dataset_preview, dataset_preview], | |
| ) | |
| # Training Function | |
| def start_training_fn( | |
| # Removed hf_token from inputs since we're using the environment variable | |
| base_model_name, | |
| output_model_name, | |
| dataset_source, | |
| dataset_name, | |
| dataset_upload, | |
| image_column_name, | |
| caption_column_name, | |
| num_train_epochs, | |
| max_train_steps, | |
| train_batch_size, | |
| learning_rate, | |
| gradient_accumulation_steps, | |
| checkpointing_steps, | |
| mixed_precision, | |
| use_8bit_adam, | |
| use_xformers, | |
| resolution, | |
| lr_scheduler_type, | |
| lr_warmup_steps, | |
| seed, | |
| ): | |
| try: | |
| # Get the Hugging Face token from the environment variable | |
| hf_token = os.environ.get("HUGGINGFACE_TOKEN") | |
| if not hf_token: | |
| return "HUGGINGFACE_TOKEN environment variable not found. Please set it in your Space's secrets." | |
| # Validate inputs | |
| if not base_model_name.strip(): | |
| return "Please provide a base model name." | |
| if not output_model_name.strip(): | |
| return "Please provide an output model repository name." | |
| # Login to Hugging Face | |
| login(hf_token, add_to_git_credential=True) | |
| api = HfApi() | |
| # Load dataset | |
| if dataset_source == "Select from Hugging Face": | |
| if not dataset_name.strip(): | |
| return "Please provide the Hugging Face dataset name." | |
| dataset = load_dataset(dataset_name, split="train") | |
| else: | |
| if dataset_upload is None: | |
| return "Please upload a dataset." | |
| dataset = load_custom_dataset(dataset_upload.name) | |
| # Check if the specified columns exist | |
| if image_column_name not in dataset.column_names: | |
| return f"Image column '{image_column_name}' not found in the dataset." | |
| if caption_column_name not in dataset.column_names: | |
| return f"Caption column '{caption_column_name}' not found in the dataset." | |
| # Preprocess the dataset | |
| dataset = preprocess_dataset(dataset, image_column_name, caption_column_name, resolution) | |
| # Start training | |
| result = train_model( | |
| hf_token=hf_token, | |
| base_model_name=base_model_name, | |
| dataset=dataset, | |
| output_model_name=output_model_name, | |
| num_train_epochs=int(num_train_epochs), | |
| max_train_steps=int(max_train_steps), | |
| train_batch_size=int(train_batch_size), | |
| learning_rate=float(learning_rate), | |
| gradient_accumulation_steps=int(gradient_accumulation_steps), | |
| checkpointing_steps=int(checkpointing_steps), | |
| mixed_precision=mixed_precision, | |
| use_8bit_adam=use_8bit_adam, | |
| use_xformers=use_xformers, | |
| lr_scheduler_type=lr_scheduler_type, | |
| lr_warmup_steps=int(lr_warmup_steps), | |
| resolution=int(resolution), | |
| seed=int(seed), | |
| ) | |
| return result | |
| except Exception as e: | |
| return f"An error occurred during training: {str(e)}" | |
| start_training.click( | |
| fn=start_training_fn, | |
| inputs=[ | |
| # Removed hf_token from inputs | |
| base_model, | |
| output_model_name, | |
| dataset_source, | |
| dataset_name, | |
| dataset_upload, | |
| image_column, | |
| caption_column, | |
| num_train_epochs, | |
| max_train_steps, | |
| train_batch_size, | |
| learning_rate, | |
| gradient_accumulation_steps, | |
| checkpointing_steps, | |
| mixed_precision, | |
| use_8bit_adam, | |
| use_xformers, | |
| resolution, | |
| lr_scheduler, | |
| lr_warmup_steps, | |
| seed, | |
| ], | |
| outputs=training_output, | |
| ) | |
| return demo | |
| def preprocess_dataset(dataset, image_column_name, caption_column_name, resolution): | |
| tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| def process_example(example): | |
| # Load and preprocess image | |
| image = example[image_column_name] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(resolution), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| image = transform(image) | |
| # Tokenize caption | |
| caption = example[caption_column_name] | |
| tokens = tokenizer( | |
| caption, | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "pixel_values": image, | |
| "input_ids": tokens.input_ids.squeeze(), | |
| "attention_mask": tokens.attention_mask.squeeze(), | |
| } | |
| # Remove unused columns and map the dataset | |
| columns_to_remove = set(dataset.column_names) - {image_column_name, caption_column_name} | |
| dataset = dataset.map( | |
| process_example, | |
| remove_columns=list(columns_to_remove), | |
| batched=False, | |
| ) | |
| return dataset | |
| def load_custom_dataset(zip_file_path): | |
| # Extract the zip file | |
| with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
| extract_path = Path("extracted_dataset") | |
| zip_ref.extractall(extract_path) | |
| # Find images and annotations | |
| image_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif'] | |
| images = [] | |
| captions = [] | |
| # Assuming there is annotations.csv with 'file_name' and 'caption' columns | |
| annotations_file = extract_path / 'annotations.csv' | |
| if not annotations_file.exists(): | |
| raise ValueError("annotations.csv not found in the dataset.") | |
| annotations = pd.read_csv(annotations_file) | |
| if 'file_name' not in annotations.columns or 'caption' not in annotations.columns: | |
| raise ValueError("annotations.csv must contain 'file_name' and 'caption' columns.") | |
| for idx, row in annotations.iterrows(): | |
| image_path = extract_path / row['file_name'] | |
| if image_path.exists(): | |
| images.append(str(image_path)) | |
| captions.append(row['caption']) | |
| else: | |
| raise ValueError(f"Image file {row['file_name']} not found in the dataset.") | |
| # Create dataset | |
| data = { | |
| "image": images, | |
| "text": captions, | |
| } | |
| dataset = Dataset.from_dict(data) | |
| return dataset | |
| def train_model( | |
| hf_token, | |
| base_model_name, | |
| dataset, | |
| output_model_name, | |
| num_train_epochs, | |
| max_train_steps, | |
| train_batch_size, | |
| learning_rate, | |
| gradient_accumulation_steps, | |
| checkpointing_steps, | |
| mixed_precision, | |
| use_8bit_adam, | |
| use_xformers, | |
| lr_scheduler_type, | |
| lr_warmup_steps, | |
| resolution, | |
| seed, | |
| ): | |
| # Set seed for reproducibility | |
| set_seed(seed) | |
| # Initialize Accelerator | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| mixed_precision=mixed_precision, | |
| ) | |
| # Handle xformers | |
| if use_xformers: | |
| try: | |
| import xformers | |
| from xformers.ops import MemoryEfficientAttentionFlashAttentionOp | |
| xformers_available = True | |
| except ImportError: | |
| xformers_available = False | |
| print("xformers is not available. Please install it or disable xformers.") | |
| # Load tokenizer and models | |
| tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| base_model_name, | |
| subfolder="text_encoder", | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| base_model_name, | |
| subfolder="vae", | |
| revision=None, | |
| ) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| base_model_name, | |
| subfolder="unet", | |
| revision=None, | |
| ) | |
| # Freeze vae and text_encoder | |
| vae.eval() | |
| text_encoder.eval() | |
| for param in vae.parameters(): | |
| param.requires_grad = False | |
| for param in text_encoder.parameters(): | |
| param.requires_grad = False | |
| # Enable xformers | |
| if use_xformers: | |
| if xformers_available: | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| return "Error: xformers is not installed. Please install xformers or disable it." | |
| # Prepare optimizer | |
| if use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| return "Error: bitsandbytes is not installed. Please install bitsandbytes or disable 8-bit Adam." | |
| optimizer_class = bnb.optim.AdamW8bit | |
| else: | |
| optimizer_class = torch.optim.AdamW | |
| optimizer = optimizer_class( | |
| unet.parameters(), | |
| lr=learning_rate, | |
| ) | |
| # Prepare data loader | |
| train_dataloader = DataLoader( | |
| dataset, batch_size=train_batch_size, shuffle=True, num_workers=4 | |
| ) | |
| # Calculate total training steps | |
| overrode_max_train_steps = False | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | |
| if max_train_steps is None or max_train_steps == 0: | |
| max_train_steps = num_train_epochs * num_update_steps_per_epoch | |
| overrode_max_train_steps = True | |
| else: | |
| num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
| # Prepare learning rate scheduler | |
| lr_scheduler = get_scheduler( | |
| lr_scheduler_type, | |
| optimizer=optimizer, | |
| num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
| num_training_steps=max_train_steps * gradient_accumulation_steps, | |
| ) | |
| # Prepare everything with accelerator | |
| unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| unet, optimizer, train_dataloader, lr_scheduler | |
| ) | |
| vae.to(accelerator.device) | |
| text_encoder.to(accelerator.device) | |
| # Move first sample to device to check for any errors | |
| try: | |
| batch = next(iter(train_dataloader)) | |
| batch['pixel_values'] = batch['pixel_values'].to(accelerator.device) | |
| batch['input_ids'] = batch['input_ids'].to(accelerator.device) | |
| batch['attention_mask'] = batch['attention_mask'].to(accelerator.device) | |
| except Exception as e: | |
| return f"Error in moving batch to device: {str(e)}" | |
| # Set up the noise scheduler | |
| noise_scheduler = DDPMScheduler.from_config(base_model_name, subfolder="scheduler") | |
| # Training loop | |
| total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
| print("***** Running training *****") | |
| print(f" Num examples = {len(dataset)}") | |
| print(f" Num Epochs = {num_train_epochs}") | |
| print(f" Instantaneous batch size per device = {train_batch_size}") | |
| print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| print(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
| print(f" Total optimization steps = {max_train_steps}") | |
| progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description("Training") | |
| global_step = 0 | |
| for epoch in range(num_train_epochs): | |
| unet.train() | |
| for step, batch in enumerate(train_dataloader): | |
| with accelerator.accumulate(unet): | |
| # Convert images to latent space | |
| latents = vae.encode(batch["pixel_values"].to(dtype=accelerator.dtype)).latent_dist.sample() | |
| latents = latents * 0.18215 | |
| # Sample noise to add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| # Get the text embedding for conditioning | |
| encoder_hidden_states = text_encoder(batch["input_ids"])[0] | |
| # Predict the noise residual | |
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| # Get the target | |
| if noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif noise_scheduler.config.prediction_type == "v_prediction": | |
| target = noise_scheduler.get_velocity(latents, noise, timesteps) | |
| else: | |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
| # Compute loss | |
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
| accelerator.backward(loss) | |
| # Update the model parameters | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # Logging | |
| if accelerator.is_main_process: | |
| progress_bar.update(1) | |
| progress_bar.set_postfix(loss=loss.item()) | |
| global_step += 1 | |
| if global_step % checkpointing_steps == 0: | |
| # Save a checkpoint | |
| save_path = f"{output_model_name}_checkpoint_{global_step}" | |
| accelerator.save_state(save_path) | |
| if global_step >= max_train_steps: | |
| break | |
| if global_step >= max_train_steps: | |
| break | |
| # Save the final model | |
| if accelerator.is_main_process: | |
| unet = accelerator.unwrap_model(unet) | |
| pipeline = StableDiffusionPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=DDPMScheduler.from_config(base_model_name, subfolder="scheduler"), | |
| safety_checker=None, | |
| feature_extractor=None, | |
| ) | |
| pipeline.save_pretrained(output_model_name) | |
| # Upload to Hugging Face Hub | |
| api = HfApi() | |
| repo_url = api.create_repo( | |
| name=output_model_name, | |
| token=hf_token, | |
| private=False, | |
| exist_ok=True, | |
| ) | |
| repo = Repository(output_model_name, clone_from=repo_url) | |
| repo.push_to_hub(commit_message=f"Fine-tuned model at step {global_step}") | |
| return f"Training complete. The model has been uploaded to Hugging Face Hub at {repo_url}" | |
| app = create_app() | |
| # Start the Gradio app | |
| if __name__ == "__main__": | |
| app.launch() |