# ============================================================================ # TinyFlux Inference Cell - Euler Discrete Flow Matching # ============================================================================ # Run the model cell before this one (defines TinyFlux, TinyFluxConfig) # Loads from: AbstractPhil/tiny-flux or local checkpoint # ============================================================================ import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL from PIL import Image import numpy as np import os # ============================================================================ # CONFIG # ============================================================================ DEVICE = "cuda" DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Model loading HF_REPO = "AbstractPhil/tiny-flux" LOAD_FROM = "hub" # "hub", "hub:step_1000", "local:/path/to/weights.safetensors" # Generation settings NUM_STEPS = 20 # Euler steps (20-50 typical) GUIDANCE_SCALE = 3.5 # CFG scale (1.0 = no guidance, 3-7 typical) HEIGHT = 512 # Output height WIDTH = 512 # Output width SEED = None # None for random # ============================================================================ # LOAD TEXT ENCODERS # ============================================================================ print("Loading text encoders...") t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval() clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() # ============================================================================ # LOAD VAE # ============================================================================ print("Loading Flux VAE...") vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-schnell", subfolder="vae", torch_dtype=DTYPE ).to(DEVICE).eval() # ============================================================================ # LOAD TINYFLUX MODEL # ============================================================================ print(f"Loading TinyFlux from: {LOAD_FROM}") config = TinyFluxConfig() model = TinyFlux(config).to(DEVICE).to(DTYPE) def load_weights(path): """Load weights from .safetensors or .pt file.""" if path.endswith(".safetensors"): state_dict = load_file(path) elif path.endswith(".pt"): ckpt = torch.load(path, map_location=DEVICE, weights_only=False) # Handle different checkpoint formats if isinstance(ckpt, dict): if "model" in ckpt: state_dict = ckpt["model"] elif "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt else: state_dict = ckpt else: # Try safetensors first, then pt try: state_dict = load_file(path) except: state_dict = torch.load(path, map_location=DEVICE, weights_only=False) # Strip "_orig_mod." prefix from keys (added by torch.compile) if any(k.startswith("_orig_mod.") for k in state_dict.keys()): print(" Stripping torch.compile prefix from state_dict keys...") state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} return state_dict if LOAD_FROM == "hub": # Load best model from hub - try safetensors first, then pt try: weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors") except: weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt") weights = load_weights(weights_path) model.load_state_dict(weights) print(f"✓ Loaded from {HF_REPO}") elif LOAD_FROM.startswith("hub:"): # Load specific checkpoint from hub ckpt_name = LOAD_FROM[4:] # Try multiple extensions for ext in [".safetensors", ".pt", ""]: try: if ckpt_name.endswith((".safetensors", ".pt")): filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}" else: filename = f"checkpoints/{ckpt_name}{ext}" weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename) weights = load_weights(weights_path) model.load_state_dict(weights) print(f"✓ Loaded from {HF_REPO}/{filename}") break except Exception as e: continue else: raise ValueError(f"Could not find checkpoint: {ckpt_name}") elif LOAD_FROM.startswith("local:"): # Load local file weights_path = LOAD_FROM[6:] weights = load_weights(weights_path) model.load_state_dict(weights) print(f"✓ Loaded from {weights_path}") else: raise ValueError(f"Unknown LOAD_FROM: {LOAD_FROM}") model.eval() print(f"Model params: {sum(p.numel() for p in model.parameters()):,}") # ============================================================================ # ENCODING FUNCTIONS # ============================================================================ @torch.no_grad() def encode_prompt(prompt: str, max_length: int = 128): """Encode prompt with flan-t5-base and CLIP-L.""" # T5 encoding (sequence) t5_in = t5_tok( prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) t5_out = t5_enc( input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask ).last_hidden_state # (1, L, 768) # CLIP encoding (pooled) clip_in = clip_tok( prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt" ).to(DEVICE) clip_out = clip_enc( input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask ) clip_pooled = clip_out.pooler_output # (1, 768) return t5_out, clip_pooled # ============================================================================ # FLOW MATCHING HELPERS # ============================================================================ SHIFT = 3.0 # Flux shift parameter (must match training) def flux_shift(t, s=SHIFT): """Flux timestep shift - biases towards higher t (closer to data).""" return s * t / (1 + (s - 1) * t) # ============================================================================ # EULER DISCRETE FLOW MATCHING SAMPLER # ============================================================================ @torch.no_grad() def euler_sample( model, prompt: str, negative_prompt: str = "", num_steps: int = 20, guidance_scale: float = 3.5, height: int = 512, width: int = 512, seed: int = None, direction: str = "forward", use_shift: bool = True, ): """ Euler discrete sampler for flow matching. Args: direction: "forward" (t:0→1, correct) or "reverse" (t:1→0, for old models) use_shift: Whether to apply flux_shift to timesteps Flow Matching formulation: x_t = (1 - t) * noise + t * data At t=0: noise, At t=1: data Velocity v = data - noise """ # Set seed if seed is not None: torch.manual_seed(seed) generator = torch.Generator(device=DEVICE).manual_seed(seed) else: generator = None # Latent dimensions (VAE downscales by 8) H_lat = height // 8 W_lat = width // 8 C_lat = 16 # Encode prompts (ensure correct dtype) t5_cond, clip_cond = encode_prompt(prompt) t5_cond = t5_cond.to(DTYPE) clip_cond = clip_cond.to(DTYPE) if guidance_scale > 1.0 and negative_prompt is not None: t5_uncond, clip_uncond = encode_prompt(negative_prompt) t5_uncond = t5_uncond.to(DTYPE) clip_uncond = clip_uncond.to(DTYPE) else: t5_uncond, clip_uncond = None, None # Start from pure noise x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator) # Create image position IDs for RoPE img_ids = TinyFlux.create_img_ids(1, H_lat, W_lat, DEVICE) # Build timesteps based on direction if direction == "forward": t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE) dir_str = "0→1" else: # reverse t_linear = torch.linspace(1, 0, num_steps + 1, device=DEVICE, dtype=DTYPE) dir_str = "1→0" # Apply flux_shift if requested if use_shift: timesteps = flux_shift(t_linear) shift_str = ", shifted" else: timesteps = t_linear shift_str = "" print(f"Sampling with {num_steps} Euler steps (t: {dir_str}{shift_str})...") for i in range(num_steps): t_curr = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_curr t_batch = t_curr.unsqueeze(0) guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE) # Predict velocity: v = data - noise direction v_cond = model( hidden_states=x, encoder_hidden_states=t5_cond, pooled_projections=clip_cond, timestep=t_batch, img_ids=img_ids, guidance=guidance_embed, ) # Classifier-free guidance if guidance_scale > 1.0 and t5_uncond is not None: v_uncond = model( hidden_states=x, encoder_hidden_states=t5_uncond, pooled_projections=clip_uncond, timestep=t_batch, img_ids=img_ids, guidance=guidance_embed, ) # CFG formula: v = v_uncond + scale * (v_cond - v_uncond) v = v_uncond + guidance_scale * (v_cond - v_uncond) else: v = v_cond # Euler integration step: x_{t+dt} = x_t + v * dt # v points towards data, dt > 0, so we move towards data x = x + v * dt if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1: print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}") # Reshape to image format: (1, H*W, C) -> (1, C, H, W) latents = x.reshape(1, H_lat, W_lat, C_lat).permute(0, 3, 1, 2) return latents # ============================================================================ # DECODE LATENTS TO IMAGE # ============================================================================ @torch.no_grad() def decode_latents(latents): """Decode VAE latents to PIL Image.""" # Flux VAE scaling latents = latents / vae.config.scaling_factor # Decode (match VAE dtype) image = vae.decode(latents.to(vae.dtype)).sample # Normalize to [0, 1] image = (image / 2 + 0.5).clamp(0, 1) # To PIL (need float32 for numpy) image = image[0].float().permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) return Image.fromarray(image) # ============================================================================ # MAIN GENERATION FUNCTION # ============================================================================ def generate( prompt: str, negative_prompt: str = "", num_steps: int = NUM_STEPS, guidance_scale: float = GUIDANCE_SCALE, height: int = HEIGHT, width: int = WIDTH, seed: int = SEED, save_path: str = None, direction: str = "forward", use_shift: bool = True, ): """ Generate an image from a text prompt. Args: prompt: Text description of desired image negative_prompt: What to avoid (empty string for none) num_steps: Number of Euler steps (20-50) guidance_scale: CFG scale (1.0=none, 3-7 typical) height: Output height in pixels (must be divisible by 8) width: Output width in pixels (must be divisible by 8) seed: Random seed (None for random) save_path: Path to save image (None to skip saving) direction: "forward" (t:0→1) or "reverse" (t:1→0) for old models use_shift: Whether to apply flux_shift to timesteps Returns: PIL.Image """ print(f"\nGenerating: '{prompt}'") print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}, dir={direction}, shift={use_shift}") # Sample latents using Euler flow matching latents = euler_sample( model=model, prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, seed=seed, direction=direction, use_shift=use_shift, ) # Decode to image print("Decoding latents...") image = decode_latents(latents) # Save if requested if save_path: image.save(save_path) print(f"✓ Saved to {save_path}") print("✓ Done!") return image # ============================================================================ # BATCH GENERATION # ============================================================================ def generate_batch( prompts: list, negative_prompt: str = "", num_steps: int = NUM_STEPS, guidance_scale: float = GUIDANCE_SCALE, height: int = HEIGHT, width: int = WIDTH, seed: int = SEED, output_dir: str = "./outputs", direction: str = "forward", use_shift: bool = True, ): """Generate multiple images.""" os.makedirs(output_dir, exist_ok=True) images = [] for i, prompt in enumerate(prompts): # Increment seed for variety if seed is set img_seed = seed + i if seed is not None else None image = generate( prompt=prompt, negative_prompt=negative_prompt, num_steps=num_steps, guidance_scale=guidance_scale, height=height, width=width, seed=img_seed, save_path=os.path.join(output_dir, f"{i:03d}.png"), direction=direction, use_shift=use_shift, ) images.append(image) return images # ============================================================================ # QUICK TEST # ============================================================================ if __name__ == "__main__" or True: # Always run in Colab print("\n" + "="*60) print("TinyFlux Inference Ready!") print("="*60) image = generate( prompt="a cat in a tree by a sidewalk", negative_prompt="blurry, low quality", num_steps=1, guidance_scale=5.0, height=512, width=512, seed=1024, save_path="output.png" ) # print(f""" #Usage: # # Single image # image = generate("a photo of a cat") # image.show() # # # With options # image = generate( # prompt="a beautiful sunset over mountains", # negative_prompt="blurry, low quality", # num_steps=30, # guidance_scale=4.0, # height=512, # width=512, # seed=42, # save_path="output.png" # ) # # # Batch generation # images = generate_batch([ # "a red sports car", # "a blue ocean wave", # "a green forest path", # ], output_dir="./my_outputs") #""")