Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import DiffusionPipeline, DDPMScheduler | |
| from accelerate import Accelerator | |
| from datasets import load_dataset | |
| from tqdm.auto import tqdm | |
| from transformers import TrainingArguments | |
| import gradio as gr | |
| # Konfigurasi | |
| pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" | |
| dataset_name = "DucHaiten/anime-SDXL" # Gunakan dataset sesuai keinginan Anda | |
| learning_rate = 1e-5 | |
| num_train_epochs = 2 # Sesuaikan dengan kebutuhan | |
| train_batch_size = 1 # Gunakan batch size kecil untuk Spaces gratis | |
| gradient_accumulation_steps = 4 # Sesuaikan dengan kebutuhan | |
| output_dir = "flux-anime" | |
| image_resize = 128 # Sesuaikan dengan kebutuhan | |
| # Muat model dan scheduler | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| pretrained_model_name_or_path, torch_dtype=torch.float16 | |
| ) | |
| pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) | |
| pipeline.enable_xformers_memory_efficient_attention() | |
| # Muat dataset | |
| dataset = load_dataset(dataset_name)["train"] | |
| # Fungsi untuk memproses data | |
| def preprocess_function(examples): | |
| images = [ | |
| image.convert("RGB").resize((image_resize, image_resize)) | |
| for image in examples["image"] | |
| ] | |
| texts = [text for text in examples["text"]] | |
| examples["pixel_values"] = pipeline.feature_extractor( | |
| images=images, return_tensors="pt" | |
| ).pixel_values | |
| examples["prompt"] = texts | |
| return examples | |
| # Proses dataset | |
| processed_dataset = dataset.map( | |
| preprocess_function, | |
| batched=True, | |
| num_proc=4, | |
| remove_columns=dataset.column_names, | |
| ) | |
| # Inisialisasi accelerator | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| mixed_precision="fp16", | |
| ) | |
| pipeline.unet, pipeline.vae, processed_dataset = accelerator.prepare( | |
| pipeline.unet, pipeline.vae, processed_dataset | |
| ) | |
| # Optimizer | |
| optimizer = torch.optim.AdamW( | |
| pipeline.unet.parameters(), | |
| lr=learning_rate, | |
| ) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| per_device_train_batch_size=train_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| num_train_epochs=num_train_epochs, | |
| learning_rate=learning_rate, | |
| fp16=True, | |
| logging_dir="./logs", | |
| report_to="tensorboard", | |
| push_to_hub=True, # Push model ke Hugging Face Hub | |
| ) | |
| # Training loop | |
| progress_bar = tqdm( | |
| range(num_train_epochs * len(processed_dataset) // train_batch_size) | |
| ) | |
| # --- Komponen Gradio --- | |
| with gr.Blocks() as interface: | |
| gr.Markdown( | |
| "## Fine-tuning FLUX untuk Anime" | |
| ) # Ganti judul sesuai dataset Anda | |
| loss_textbox = gr.Textbox(label="Loss") | |
| epoch_textbox = gr.Textbox(label="Epoch") | |
| progress_bar_gradio = gr.ProgressBar(label="Progress") | |
| output_image = gr.Image(label="Generated Image") | |
| def train_step(step, epoch, loss): | |
| loss_textbox.update(value=loss) | |
| epoch_textbox.update(value=epoch) | |
| progress_bar_gradio.update(value=step / len(progress_bar)) | |
| if step % 100 == 0: | |
| with torch.no_grad(): | |
| image = pipeline( | |
| "anime style image of a girl with blue hair" | |
| ).images[ | |
| 0 | |
| ] # Ganti prompt sesuai dataset Anda | |
| output_image.update(value=image) | |
| return loss, epoch, step / len(progress_bar) | |
| interface.launch(server_name="0.0.0.0") | |
| # ------------------------ | |
| for epoch in range(num_train_epochs): | |
| pipeline.unet.train() | |
| for step, batch in enumerate( | |
| processed_dataset.iter(batch_size=train_batch_size) | |
| ): | |
| with accelerator.accumulate(pipeline.unet): | |
| latents = pipeline.vae.encode( | |
| batch["pixel_values"].to(dtype=torch.float16) | |
| ).latent_dist.sample() | |
| latents = latents * pipeline.vae.config.scaling_factor | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.randint( | |
| 0, | |
| pipeline.scheduler.config.num_train_timesteps, | |
| (bsz,), | |
| device=latents.device, | |
| ) | |
| timesteps = timesteps.long() | |
| noisy_latents = pipeline.scheduler.add_noise( | |
| latents, noise, timesteps | |
| ) | |
| model_pred = pipeline.unet( | |
| noisy_latents, timesteps, batch["prompt"] | |
| ).sample | |
| loss = torch.nn.functional.mse_loss( | |
| model_pred.float(), noise.float(), reduction="mean" | |
| ) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1) | |
| # Update komponen Gradio | |
| train_step(step, epoch, loss.item()) | |
| # Simpan model | |
| pipeline.save_pretrained(output_dir) | |