Stable Diffusion - Naruto Style (Flax/JAX)
This is a fine-tuned Stable Diffusion v1.5 model trained on the lambdalabs/naruto-blip-captions dataset. It generates images in the iconic manga/anime style of Naruto.
The model was fine-tuned entirely on Kaggle TPU v5e-8 using JAX/Flax, optimizing the UNet weights in bfloat16 precision for ultra-fast and memory-efficient inference.
πΌοΈ Example Output
Prompt: "A drawing of Sasuke Uchiha"

βοΈ Training Details
- Base Model:
runwayml/stable-diffusion-v1-5 - Dataset:
lambdalabs/naruto-blip-captions - Hardware: Kaggle TPU v5e (8 cores)
- Batch Size: 8 (1 per device)
- Learning Rate: 1e-5
- Optimizer: Adafactor
- Epochs: 60
π How to Use (JAX/Flax)
Since this model uploads the fine-tuned UNet parameters in JAX/Flax format (bfloat16), you should use the FlaxStableDiffusionPipeline to run it.
Here is a quick-start code snippet to run inference on TPU using JAX:
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.training import checkpoints
from huggingface_hub import snapshot_download
# 1. Download the fine-tuned weights
repo_id = "NiceWang/sd-naruto-tpu"
ckpt_dir = snapshot_download(repo_id=repo_id)
# 2. Load the base Stable Diffusion v1.5 pipeline
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
dtype=jnp.bfloat16,
from_pt=True,
safety_checker=None,
)
# 3. Replace the base UNet with our fine-tuned Naruto UNet
from flax.core import unfreeze
params = unfreeze(params)
raw_ckpt = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=None)
params["unet"] = raw_ckpt["params"]
# 4. Run Inference!
prompt = "A drawing of Kakashi Hatake"
prompt_ids = pipe.prepare_inputs([prompt])
prng_seed = jax.random.PRNGKey(42)
output = pipe(
prompt_ids=prompt_ids,
params=params,
prng_seed=prng_seed,
num_inference_steps=50,
guidance_scale=7.5,
jit=True,
)
# Convert to PIL Image
import numpy as np
images_np = np.asarray(output.images)
images_pil = pipe.numpy_to_pil(images_np)
images_pil[0].show()
You can also use npz format for inference:
import os
import jax
import jax.numpy as jnp
import numpy as np
from diffusers import FlaxStableDiffusionPipeline
from flax.core import unfreeze
from huggingface_hub import hf_hub_download, snapshot_download
# 1. Download the fine-tuned weights (npz format, no orbax dependency)
npz_path = hf_hub_download(
repo_id="NiceWang/sd-naruto-tpu",
filename="unet_naruto_bf16.npz"
)
# 2. Load the base Stable Diffusion v1.5 pipeline
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
dtype=jnp.bfloat16,
from_pt=True,
safety_checker=None,
)
# 3. Replace the base UNet with our fine-tuned Naruto UNet
# Reconstruct nested param dict from flat npz keys (e.g. "a/b/c" -> {"a":{"b":{"c":...}}})
def unflatten(flat):
result = {}
for key, val in flat.items():
parts = key.split("/")
d = result
for part in parts[:-1]:
d = d.setdefault(part, {})
d[parts[-1]] = val
return result
data = np.load(npz_path)
unet_np = unflatten(dict(data))
params = unfreeze(params)
params["unet"] = jax.tree_util.tree_map(
lambda x: jnp.array(x, dtype=jnp.bfloat16), unet_np
)
# 4. Run Inference!
prompt = "A drawing of Kakashi Hatake"
prompt_ids = pipe.prepare_inputs([prompt])
prng_seed = jax.random.PRNGKey(42)
output = pipe(
prompt_ids=prompt_ids,
params=params,
prng_seed=prng_seed,
num_inference_steps=50,
guidance_scale=7.5,
jit=True,
)
# Convert to PIL Image
images_np = np.asarray(output.images)
images_pil = pipe.numpy_to_pil(images_np)
images_pil[0].show()
You can also use this model with GPU and PyTorch
import torch
import jax
from diffusers import FlaxUNet2DConditionModel, UNet2DConditionModel, StableDiffusionPipeline
from flax.training import checkpoints
from huggingface_hub import snapshot_download
from IPython.display import display
BASE_REPO_ID = "stable-diffusion-v1-5/stable-diffusion-v1-5"
# 1. Download raw Flax checkpoint from Hugging Face
repo_id = "NiceWang/sd-naruto-tpu"
print(f"Downloading raw Flax checkpoint from {repo_id}...")
ckpt_dir = snapshot_download(repo_id=repo_id)
# 2. Create A dummy target template
print("Loading base UNet CONFIG to create a structural template...")
config = FlaxUNet2DConditionModel.load_config(BASE_REPO_ID, subfolder="unet")
flax_unet = FlaxUNet2DConditionModel.from_config(config)
# Generate dummy parameters just to get the exact dictionary structure for Orbax
key = jax.random.PRNGKey(0)
dummy_params = flax_unet.init_weights(key)
target_template = {"params": dummy_params}
# 3. Restore and unshard fine-tuned weights using the dummy template
print("Restoring sharded checkpoint into single-device memory...")
raw_ckpt = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=target_template)
fine_tuned_unet_params = raw_ckpt["params"]
# Save it to a temporary local folder with fine-tuned params
print("Converting raw weights to standard Diffusers format...")
temp_flax_dir = "./temp_flax_unet"
flax_unet.save_pretrained(temp_flax_dir, params=fine_tuned_unet_params)
# # 4. Load into PyTorch
# print("Loading Flax weights into PyTorch UNet...")
# pt_unet = UNet2DConditionModel.from_pretrained(
# temp_flax_dir,
# from_flax=True,
# torch_dtype=torch.float16
# )
# 4. Load into PyTorch
print("Loading Flax weights into PyTorch UNet manually...")
from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
# a. Initialize a blank PyTorch UNet using the structure from config
pt_unet = UNet2DConditionModel.from_config(temp_flax_dir)
# b. Use the internal Diffusers tool to safely inject the Flax .msgpack weights
msgpack_file = f"{temp_flax_dir}/diffusion_flax_model.msgpack"
pt_unet = load_flax_checkpoint_in_pytorch_model(pt_unet, msgpack_file)
# c. Cast it to float16 for optimal GPU inference
pt_unet = pt_unet.to(torch.float16)
# 5. Run Inference on GPU using PyTorch
print("Setting up PyTorch Stable Diffusion Pipeline on GPU...")
pipe = StableDiffusionPipeline.from_pretrained(
BASE_REPO_ID,
torch_dtype=torch.float16,
safety_checker=None
)
pipe.unet = pt_unet
pipe = pipe.to("cuda")
prompt = "A drawing of Kakashi Hatake"
print(f"Generating image for prompt: '{prompt}'...")
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
print("Done! Here is your PyTorch-generated image:")
display(image)
- Downloads last month
- -
Model tree for NiceWang/sd-naruto-tpu
Base model
runwayml/stable-diffusion-v1-5