Stable Diffusion - Naruto Style (Flax/JAX)

This is a fine-tuned Stable Diffusion v1.5 model trained on the lambdalabs/naruto-blip-captions dataset. It generates images in the iconic manga/anime style of Naruto.

The model was fine-tuned entirely on Kaggle TPU v5e-8 using JAX/Flax, optimizing the UNet weights in bfloat16 precision for ultra-fast and memory-efficient inference.

πŸ–ΌοΈ Example Output

Prompt: "A drawing of Sasuke Uchiha" Sasuke

βš™οΈ Training Details

  • Base Model: runwayml/stable-diffusion-v1-5
  • Dataset: lambdalabs/naruto-blip-captions
  • Hardware: Kaggle TPU v5e (8 cores)
  • Batch Size: 8 (1 per device)
  • Learning Rate: 1e-5
  • Optimizer: Adafactor
  • Epochs: 60

πŸš€ How to Use (JAX/Flax)

Since this model uploads the fine-tuned UNet parameters in JAX/Flax format (bfloat16), you should use the FlaxStableDiffusionPipeline to run it.

Here is a quick-start code snippet to run inference on TPU using JAX:

import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.training import checkpoints
from huggingface_hub import snapshot_download

# 1. Download the fine-tuned weights
repo_id = "NiceWang/sd-naruto-tpu"
ckpt_dir = snapshot_download(repo_id=repo_id)

# 2. Load the base Stable Diffusion v1.5 pipeline
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    dtype=jnp.bfloat16,
    from_pt=True,
    safety_checker=None,
)

# 3. Replace the base UNet with our fine-tuned Naruto UNet
from flax.core import unfreeze
params = unfreeze(params)
raw_ckpt = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=None)
params["unet"] = raw_ckpt["params"]

# 4. Run Inference!
prompt = "A drawing of Kakashi Hatake"
prompt_ids = pipe.prepare_inputs([prompt])
prng_seed = jax.random.PRNGKey(42)

output = pipe(
    prompt_ids=prompt_ids,
    params=params,
    prng_seed=prng_seed,
    num_inference_steps=50,
    guidance_scale=7.5,
    jit=True,
)

# Convert to PIL Image
import numpy as np
images_np = np.asarray(output.images)
images_pil = pipe.numpy_to_pil(images_np)
images_pil[0].show()

You can also use npz format for inference:

import os
import jax
import jax.numpy as jnp
import numpy as np
from diffusers import FlaxStableDiffusionPipeline
from flax.core import unfreeze
from huggingface_hub import hf_hub_download, snapshot_download

# 1. Download the fine-tuned weights (npz format, no orbax dependency)
npz_path = hf_hub_download(
    repo_id="NiceWang/sd-naruto-tpu",
    filename="unet_naruto_bf16.npz"
)

# 2. Load the base Stable Diffusion v1.5 pipeline
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    dtype=jnp.bfloat16,
    from_pt=True,
    safety_checker=None,
)

# 3. Replace the base UNet with our fine-tuned Naruto UNet
# Reconstruct nested param dict from flat npz keys (e.g. "a/b/c" -> {"a":{"b":{"c":...}}})
def unflatten(flat):
    result = {}
    for key, val in flat.items():
        parts = key.split("/")
        d = result
        for part in parts[:-1]:
            d = d.setdefault(part, {})
        d[parts[-1]] = val
    return result

data = np.load(npz_path)
unet_np = unflatten(dict(data))

params = unfreeze(params)
params["unet"] = jax.tree_util.tree_map(
    lambda x: jnp.array(x, dtype=jnp.bfloat16), unet_np
)

# 4. Run Inference!
prompt = "A drawing of Kakashi Hatake"
prompt_ids = pipe.prepare_inputs([prompt])
prng_seed = jax.random.PRNGKey(42)

output = pipe(
    prompt_ids=prompt_ids,
    params=params,
    prng_seed=prng_seed,
    num_inference_steps=50,
    guidance_scale=7.5,
    jit=True,
)

# Convert to PIL Image
images_np = np.asarray(output.images)
images_pil = pipe.numpy_to_pil(images_np)
images_pil[0].show()

You can also use this model with GPU and PyTorch

import torch
import jax
from diffusers import FlaxUNet2DConditionModel, UNet2DConditionModel, StableDiffusionPipeline
from flax.training import checkpoints
from huggingface_hub import snapshot_download
from IPython.display import display

BASE_REPO_ID = "stable-diffusion-v1-5/stable-diffusion-v1-5"

# 1. Download raw Flax checkpoint from Hugging Face
repo_id = "NiceWang/sd-naruto-tpu"
print(f"Downloading raw Flax checkpoint from {repo_id}...")
ckpt_dir = snapshot_download(repo_id=repo_id)

# 2. Create A  dummy target template
print("Loading base UNet CONFIG to create a structural template...")
config = FlaxUNet2DConditionModel.load_config(BASE_REPO_ID, subfolder="unet")
flax_unet = FlaxUNet2DConditionModel.from_config(config)

# Generate dummy parameters just to get the exact dictionary structure for Orbax
key = jax.random.PRNGKey(0)
dummy_params = flax_unet.init_weights(key)
target_template = {"params": dummy_params}

# 3. Restore and unshard fine-tuned weights using the dummy template
print("Restoring sharded checkpoint into single-device memory...")
raw_ckpt = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=target_template)
fine_tuned_unet_params = raw_ckpt["params"]

# Save it to a temporary local folder with fine-tuned params
print("Converting raw weights to standard Diffusers format...")
temp_flax_dir = "./temp_flax_unet"
flax_unet.save_pretrained(temp_flax_dir, params=fine_tuned_unet_params)

# # 4. Load into PyTorch
# print("Loading Flax weights into PyTorch UNet...")
# pt_unet = UNet2DConditionModel.from_pretrained(
#     temp_flax_dir, 
#     from_flax=True, 
#     torch_dtype=torch.float16
# )

# 4. Load into PyTorch
print("Loading Flax weights into PyTorch UNet manually...")

from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model

# a. Initialize a blank PyTorch UNet using the structure from config
pt_unet = UNet2DConditionModel.from_config(temp_flax_dir)

# b. Use the internal Diffusers tool to safely inject the Flax .msgpack weights
msgpack_file = f"{temp_flax_dir}/diffusion_flax_model.msgpack"
pt_unet = load_flax_checkpoint_in_pytorch_model(pt_unet, msgpack_file)

# c. Cast it to float16 for optimal GPU inference
pt_unet = pt_unet.to(torch.float16)

# 5. Run Inference on GPU using PyTorch
print("Setting up PyTorch Stable Diffusion Pipeline on GPU...")
pipe = StableDiffusionPipeline.from_pretrained(
    BASE_REPO_ID, 
    torch_dtype=torch.float16,
    safety_checker=None
)
pipe.unet = pt_unet
pipe = pipe.to("cuda")

prompt = "A drawing of Kakashi Hatake"
print(f"Generating image for prompt: '{prompt}'...")

image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

print("Done! Here is your PyTorch-generated image:")
display(image)
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for NiceWang/sd-naruto-tpu

Finetuned
(579)
this model

Dataset used to train NiceWang/sd-naruto-tpu