| | import jax |
| | import jax.numpy as jnp |
| | from flax.jax_utils import replicate |
| | from flax.training import train_state |
| | import optax |
| | from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel |
| | from diffusers.schedulers import FlaxPNDMScheduler |
| | from datasets import load_dataset |
| | from tqdm.auto import tqdm |
| | import os |
| | import pickle |
| | from PIL import Image |
| | import numpy as np |
| | from inspect import signature |
| |
|
| | |
| | class CustomFlaxPNDMScheduler(FlaxPNDMScheduler): |
| | def add_noise(self, state, original_samples, noise, timesteps): |
| | timesteps = timesteps.astype(jnp.int32) |
| | return super().add_noise(state, original_samples, noise, timesteps) |
| |
|
| | |
| | cache_dir = "/tmp/huggingface_cache" |
| | model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model") |
| | os.makedirs(model_cache_dir, exist_ok=True) |
| |
|
| | print(f"Cache directory: {cache_dir}") |
| | print(f"Model cache directory: {model_cache_dir}") |
| |
|
| |
|
| |
|
| | def filter_dict(dict_to_filter, target_callable): |
| | """Filter a dictionary to only include keys that are valid parameters for the target callable.""" |
| | valid_params = signature(target_callable).parameters.keys() |
| | return {k: v for k, v in dict_to_filter.items() if k in valid_params} |
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | def get_model(model_id, revision): |
| | model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl") |
| | print(f"Model cache file: {model_cache_file}") |
| | if os.path.exists(model_cache_file): |
| | print("Loading model from cache...") |
| | with open(model_cache_file, 'rb') as f: |
| | return pickle.load(f) |
| | else: |
| | print("Downloading model...") |
| | pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( |
| | model_id, |
| | revision=revision, |
| | dtype=jnp.float32, |
| | ) |
| | with open(model_cache_file, 'wb') as f: |
| | pickle.dump((pipeline, params), f) |
| | return pipeline, params |
| |
|
| | |
| | model_id = "CompVis/stable-diffusion-v1-4" |
| | pipeline, params = get_model(model_id, "flax") |
| |
|
| | |
| | custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config) |
| | pipeline.scheduler = custom_scheduler |
| |
|
| | |
| | unet = pipeline.unet |
| |
|
| | |
| | print("UNet configuration:") |
| | print(unet.config) |
| |
|
| | |
| | def adjust_unet_input_layer(params): |
| | if 'unet' in params: |
| | unet_params = params['unet'] |
| | else: |
| | unet_params = params |
| |
|
| | if 'conv_in' not in unet_params: |
| | print("Warning: 'conv_in' not found in UNet params. Skipping input layer adjustment.") |
| | return params |
| |
|
| | conv_in_weight = unet_params['conv_in']['kernel'] |
| | print(f"Original conv_in weight shape: {conv_in_weight.shape}") |
| | if conv_in_weight.shape[2] != 4: |
| | new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32) |
| | new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :]) |
| | unet_params['conv_in']['kernel'] = new_conv_in_weight |
| | print(f"New conv_in weight shape: {unet_params['conv_in']['kernel'].shape}") |
| |
|
| | if 'unet' in params: |
| | params['unet'] = unet_params |
| | else: |
| | params = unet_params |
| |
|
| | return params |
| |
|
| | params = adjust_unet_input_layer(params) |
| |
|
| | |
| | def preprocess_images(examples): |
| | def process_image(image): |
| | if isinstance(image, str): |
| | if not image.lower().endswith('.jpg') and not image.lower().endswith('.jpeg'): |
| | return None |
| | image = Image.open(image) |
| | if not isinstance(image, Image.Image): |
| | return None |
| | image = image.convert("RGB").resize((512, 512)) |
| | image = np.array(image).astype(np.float32) / 255.0 |
| | image = image.transpose(2, 0, 1) |
| | return image |
| |
|
| | processed = [process_image(img) for img in examples["image"]] |
| | return {"pixel_values": [img for img in processed if img is not None]} |
| |
|
| | |
| | dataset_name = "uruguayai/montevideo" |
| | dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl") |
| |
|
| | print(f"Dataset name: {dataset_name}") |
| | print(f"Dataset cache file: {dataset_cache_file}") |
| |
|
| | if os.path.exists(dataset_cache_file): |
| | print("Loading dataset from cache...") |
| | with open(dataset_cache_file, 'rb') as f: |
| | processed_dataset = pickle.load(f) |
| | else: |
| | print("Processing dataset...") |
| | dataset = load_dataset(dataset_name) |
| | processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names) |
| | processed_dataset = processed_dataset.filter(lambda example: len(example['pixel_values']) > 0) |
| | with open(dataset_cache_file, 'wb') as f: |
| | pickle.dump(processed_dataset, f) |
| |
|
| | print(f"Processed dataset size: {len(processed_dataset)}") |
| |
|
| | |
| | sample_batch = next(iter(processed_dataset.batch(1))) |
| | print(f"Sample batch keys: {sample_batch.keys()}") |
| | print(f"Sample pixel_values type: {type(sample_batch['pixel_values'])}") |
| | print(f"Sample pixel_values length: {len(sample_batch['pixel_values'])}") |
| | if len(sample_batch['pixel_values']) > 0: |
| | print(f"Sample pixel_values[0] shape: {np.array(sample_batch['pixel_values'][0]).shape}") |
| |
|
| | |
| | def train_step(state, batch, rng): |
| | def compute_loss(unet_params, pixel_values, rng): |
| | pixel_values = jnp.array(pixel_values, dtype=jnp.float32) |
| | if pixel_values.ndim == 3: |
| | pixel_values = jnp.expand_dims(pixel_values, axis=0) |
| | print(f"pixel_values shape in compute_loss: {pixel_values.shape}") |
| | |
| | |
| | latents = pipeline.vae.apply( |
| | {"params": params["vae"]}, |
| | pixel_values, |
| | method=pipeline.vae.encode |
| | ).latent_dist.sample(rng) |
| | latents = latents * jnp.float32(0.18215) |
| | print(f"latents shape: {latents.shape}") |
| |
|
| | noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32) |
| | |
| | timesteps = jax.random.randint( |
| | rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps |
| | ) |
| | |
| | noisy_latents = pipeline.scheduler.add_noise( |
| | pipeline.scheduler.create_state(), |
| | original_samples=latents, |
| | noise=noise, |
| | timesteps=timesteps |
| | ) |
| | |
| | encoder_hidden_states = jax.random.normal( |
| | rng, |
| | (latents.shape[0], pipeline.text_encoder.config.hidden_size), |
| | dtype=jnp.float32 |
| | ) |
| | |
| | print(f"noisy_latents shape: {noisy_latents.shape}") |
| | print(f"timesteps shape: {timesteps.shape}") |
| | print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}") |
| | |
| | |
| | model_output = state.apply_fn( |
| | {"params": unet_params}, |
| | noisy_latents, |
| | jnp.array(timesteps, dtype=jnp.int32), |
| | encoder_hidden_states, |
| | train=True, |
| | ).sample |
| | |
| | return jnp.mean((model_output - noise) ** 2) |
| |
|
| | grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True) |
| | rng, step_rng = jax.random.split(rng) |
| | |
| | |
| | unet_params = state.params["params"] if "params" in state.params else state.params |
| | grads = grad_fn(unet_params, batch["pixel_values"], step_rng) |
| | loss = compute_loss(unet_params, batch["pixel_values"], step_rng) |
| | |
| | |
| | new_params = optax.apply_updates(state.params, grads) |
| | state = state.replace(params=new_params) |
| | |
| | return state, loss |
| |
|
| | |
| | learning_rate = 1e-5 |
| | optimizer = optax.adam(learning_rate) |
| | float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params) |
| |
|
| | |
| | unet_config = dict(unet.config) |
| | filtered_unet_config = filter_dict(unet_config, FlaxUNet2DConditionModel.__init__) |
| |
|
| | print("Filtered UNet config keys:", filtered_unet_config.keys()) |
| |
|
| | adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config) |
| | adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768))) |
| | adjusted_params = adjust_unet_input_layer(adjusted_params) |
| |
|
| | |
| | state = train_state.TrainState.create( |
| | apply_fn=adjusted_unet.apply, |
| | params={"params": adjusted_params}, |
| | tx=optimizer, |
| | ) |
| |
|
| | |
| | num_epochs = 3 |
| | batch_size = 1 |
| | rng = jax.random.PRNGKey(0) |
| |
|
| | |
| | num_epochs = 3 |
| | batch_size = 1 |
| | rng = jax.random.PRNGKey(0) |
| |
|
| | for epoch in range(num_epochs): |
| | epoch_loss = 0 |
| | num_batches = 0 |
| | num_errors = 0 |
| | for batch in tqdm(processed_dataset.batch(batch_size)): |
| | try: |
| | batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32) |
| | rng, step_rng = jax.random.split(rng) |
| | state, loss = train_step(state, batch, step_rng) |
| | epoch_loss += loss |
| | num_batches += 1 |
| | |
| | if num_batches % 10 == 0: |
| | jax.clear_caches() |
| | print(f"Processed {num_batches} batches. Current loss: {loss}") |
| | except Exception as e: |
| | num_errors += 1 |
| | print(f"Error processing batch: {e}") |
| | continue |
| | |
| | if num_batches > 0: |
| | avg_loss = epoch_loss / num_batches |
| | print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}, Errors: {num_errors}") |
| | else: |
| | print(f"Epoch {epoch+1}/{num_epochs}, No valid batches processed, Errors: {num_errors}") |
| | |
| | jax.clear_caches() |
| |
|
| | |
| | output_dir = "/tmp/montevideo_fine_tuned_model" |
| | os.makedirs(output_dir, exist_ok=True) |
| | adjusted_unet.save_pretrained(output_dir, params=state.params["params"]) |
| |
|
| | print(f"Model saved to {output_dir}") |