Text-to-Image
Diffusers
TensorBoard
stable-diffusion
diffusion
distillation
flow-matching
geometric-deep-learning
research
Instructions to use AbstractPhil/sd15-flow-matching with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use AbstractPhil/sd15-flow-matching with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("AbstractPhil/sd15-flow-matching", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
| # ============================================================================ | |
| # SD1.5-Flow-Sol Correct Inference (Colab Cell) | |
| # ============================================================================ | |
| # Matches trainer's sample() method exactly: | |
| # - DDPM scheduler timesteps | |
| # - Specifically aligned for the SOL training pipeline to ensure accurate inference. | |
| # - Model predicts velocity | |
| # - Convert velocity → epsilon for scheduler stepping | |
| # ============================================================================ | |
| !pip install -q diffusers transformers accelerate safetensors | |
| import torch | |
| import gc | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from PIL import Image | |
| import numpy as np | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # ============================================================================ | |
| # CONFIG | |
| # ============================================================================ | |
| DEVICE = "cuda" | |
| DTYPE = torch.float16 | |
| SOL_REPO = "AbstractPhil/sd15-flow-matching" | |
| SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt" | |
| # ============================================================================ | |
| # LOAD MODELS | |
| # ============================================================================ | |
| print("Loading CLIP...") | |
| clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() | |
| print("Loading VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| subfolder="vae", | |
| torch_dtype=DTYPE | |
| ).to(DEVICE).eval() | |
| print("Loading UNet...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| subfolder="unet", | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE).eval() | |
| print("Loading DDPM Scheduler...") | |
| sched = DDPMScheduler(num_train_timesteps=1000) | |
| # ============================================================================ | |
| # LOAD SOL WEIGHTS | |
| # ============================================================================ | |
| print(f"\nLoading Sol from {SOL_REPO}...") | |
| weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME) | |
| checkpoint = torch.load(weights_path, map_location="cpu") | |
| state_dict = checkpoint["student"] | |
| print(f" gstep: {checkpoint.get('gstep', 'unknown')}") | |
| if any(k.startswith("unet.") for k in state_dict.keys()): | |
| state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")} | |
| state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))} | |
| missing, unexpected = unet.load_state_dict(state_dict, strict=False) | |
| print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}") | |
| del checkpoint, state_dict | |
| gc.collect() | |
| for p in unet.parameters(): | |
| p.requires_grad = False | |
| print("✓ Sol ready!") | |
| # ============================================================================ | |
| # HELPER: Alpha/Sigma from DDPM schedule (matches trainer) | |
| # ============================================================================ | |
| def alpha_sigma(t: torch.LongTensor): | |
| """Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly.""" | |
| ac = sched.alphas_cumprod.to(DEVICE)[t] | |
| alpha = ac.sqrt().view(-1, 1, 1, 1).float() | |
| sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float() | |
| return alpha, sigma | |
| # ============================================================================ | |
| # CORRECT SAMPLER (matches trainer's sample() method) | |
| # ============================================================================ | |
| @torch.inference_mode() | |
| def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5): | |
| """ | |
| Matches trainer's sample() method exactly: | |
| 1. Use DDPM scheduler timesteps | |
| 2. Model predicts velocity v | |
| 3. Convert v → x0_hat → eps_hat | |
| 4. Use sched.step(eps_hat, t, x_t) | |
| """ | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| # Encode prompts | |
| inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) | |
| cond = clip_enc(**inputs).last_hidden_state.to(DTYPE) | |
| inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) | |
| uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE) | |
| # Set scheduler timesteps | |
| sched.set_timesteps(steps, device=DEVICE) | |
| # Start from noise | |
| x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) | |
| print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}") | |
| for i, t_scalar in enumerate(sched.timesteps): | |
| t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long) | |
| # Model predicts VELOCITY (not epsilon!) | |
| v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample | |
| v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample | |
| # CFG on velocity | |
| v_hat = v_uncond + cfg * (v_cond - v_uncond) | |
| # Convert velocity to epsilon (EXACTLY as trainer does) | |
| alpha, sigma = alpha_sigma(t) | |
| # v = alpha * eps - sigma * x0 | |
| # x_t = alpha * x0 + sigma * eps | |
| # Solve for x0: x0 = (alpha * x_t - sigma * v) / (alpha^2 + sigma^2) | |
| # Then: eps = (x_t - alpha * x0) / sigma | |
| denom = alpha**2 + sigma**2 | |
| x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8) | |
| eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8) | |
| # Step with epsilon | |
| step_out = sched.step(eps_hat, t_scalar, x_t.float()) | |
| x_t = step_out.prev_sample.to(DTYPE) | |
| if (i + 1) % max(1, steps // 5) == 0: | |
| print(f" Step {i+1}/{steps}, t={t_scalar}") | |
| # Decode | |
| x_t = x_t / 0.18215 | |
| img = vae.decode(x_t).sample | |
| img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() | |
| return Image.fromarray((img * 255).astype(np.uint8)) | |
| # ============================================================================ | |
| # TEST | |
| # ============================================================================ | |
| print("\n" + "="*60) | |
| print("Generating test images with Sol (correct sampler)") | |
| print("="*60) | |
| from IPython.display import display | |
| prompts = [ | |
| "a castle at sunset", | |
| "a portrait of a woman", | |
| "a city street at night", | |
| ] | |
| for prompt in prompts: | |
| print() | |
| img = generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5) | |
| display(img) | |
| print("\n✓ Done!") |