File size: 3,995 Bytes
bf792e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Load text encoders
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16)

clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16)

# Load VAE
vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    subfolder="vae",
    torch_dtype=torch.bfloat16
).to("cuda")

# Load TinyFlux-Deep
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
exec(open(model_py).read())

config = TinyFluxConfig(
    use_sol_prior=False,  # Disabled until trained
    use_t5_vec=False,     # Disabled until trained
)
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors"))
model.load_state_dict(weights, strict=False)
model.eval()

def encode_prompt(prompt):
    """Encode prompt with both T5 and CLIP."""
    # T5
    t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length", 
                              max_length=77, truncation=True).to("cuda")
    with torch.no_grad():
        t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16)
    
    # CLIP
    clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length",
                                  max_length=77, truncation=True).to("cuda")
    with torch.no_grad():
        clip_out = clip_model(**clip_tokens)
        clip_pooled = clip_out.pooler_output.to(torch.bfloat16)
    
    return t5_emb, clip_pooled


def flux_shift(t, s=3.0):
    """Flux-style timestep shift."""
    return s * t / (1 + (s - 1) * t)


@torch.inference_mode()
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
    """
    Euler sampling for rectified flow.
    
    Flow matching formulation:
        x_t = (1 - t) * noise + t * data
        At t=0: pure noise
        At t=1: pure data  
        Velocity v = data - noise (constant)
        
    Sampling: Integrate from t=0 (noise) → t=1 (data)
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    t5_emb, clip_pooled = encode_prompt(prompt)
    t5_null, clip_null = encode_prompt("")
    
    # Start from pure noise (t=0)
    x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
    img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
    
    # Timesteps: 0 → 1 with Flux shift
    t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32)
    timesteps = flux_shift(t_linear, s=3.0)
    
    for i in range(num_steps):
        t_curr = timesteps[i]
        t_next = timesteps[i + 1]
        dt = t_next - t_curr  # Positive, moving toward data
        
        t_batch = t_curr.unsqueeze(0)
        
        # Predict velocity
        v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids)
        v_uncond = model(x, t5_null, clip_null, t_batch, img_ids)
        
        # Classifier-free guidance
        v = v_uncond + cfg_scale * (v_cond - v_uncond)
        
        # Euler step: x_{t+dt} = x_t + v * dt
        x = x + v * dt
    
    # Decode with VAE
    x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2)  # [B, C, H, W]
    x = x / vae.config.scaling_factor
    image = vae.decode(x).sample
    
    # Convert to PIL
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image[0].permute(1, 2, 0).cpu().float().numpy()
    image = (image * 255).astype("uint8")
    
    from PIL import Image
    return Image.fromarray(image)


# Generate
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
image.save("tiger.png")
image