AbstractPhil commited on
Commit
bf792e2
·
verified ·
1 Parent(s): f6fc133

Create inference_v4.py

Browse files
Files changed (1) hide show
  1. scripts/inference_v4.py +118 -0
scripts/inference_v4.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
3
+ from diffusers import AutoencoderKL
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+
7
+ # Load text encoders
8
+ t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
9
+ t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16)
10
+
11
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
12
+ clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16)
13
+
14
+ # Load VAE
15
+ vae = AutoencoderKL.from_pretrained(
16
+ "black-forest-labs/FLUX.1-schnell",
17
+ subfolder="vae",
18
+ torch_dtype=torch.bfloat16
19
+ ).to("cuda")
20
+
21
+ # Load TinyFlux-Deep
22
+ model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
23
+ exec(open(model_py).read())
24
+
25
+ config = TinyFluxConfig(
26
+ use_sol_prior=False, # Disabled until trained
27
+ use_t5_vec=False, # Disabled until trained
28
+ )
29
+ model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
30
+ weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors"))
31
+ model.load_state_dict(weights, strict=False)
32
+ model.eval()
33
+
34
+ def encode_prompt(prompt):
35
+ """Encode prompt with both T5 and CLIP."""
36
+ # T5
37
+ t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length",
38
+ max_length=77, truncation=True).to("cuda")
39
+ with torch.no_grad():
40
+ t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16)
41
+
42
+ # CLIP
43
+ clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length",
44
+ max_length=77, truncation=True).to("cuda")
45
+ with torch.no_grad():
46
+ clip_out = clip_model(**clip_tokens)
47
+ clip_pooled = clip_out.pooler_output.to(torch.bfloat16)
48
+
49
+ return t5_emb, clip_pooled
50
+
51
+
52
+ def flux_shift(t, s=3.0):
53
+ """Flux-style timestep shift."""
54
+ return s * t / (1 + (s - 1) * t)
55
+
56
+
57
+ @torch.inference_mode()
58
+ def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
59
+ """
60
+ Euler sampling for rectified flow.
61
+
62
+ Flow matching formulation:
63
+ x_t = (1 - t) * noise + t * data
64
+ At t=0: pure noise
65
+ At t=1: pure data
66
+ Velocity v = data - noise (constant)
67
+
68
+ Sampling: Integrate from t=0 (noise) → t=1 (data)
69
+ """
70
+ if seed is not None:
71
+ torch.manual_seed(seed)
72
+
73
+ t5_emb, clip_pooled = encode_prompt(prompt)
74
+ t5_null, clip_null = encode_prompt("")
75
+
76
+ # Start from pure noise (t=0)
77
+ x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
78
+ img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
79
+
80
+ # Timesteps: 0 → 1 with Flux shift
81
+ t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32)
82
+ timesteps = flux_shift(t_linear, s=3.0)
83
+
84
+ for i in range(num_steps):
85
+ t_curr = timesteps[i]
86
+ t_next = timesteps[i + 1]
87
+ dt = t_next - t_curr # Positive, moving toward data
88
+
89
+ t_batch = t_curr.unsqueeze(0)
90
+
91
+ # Predict velocity
92
+ v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids)
93
+ v_uncond = model(x, t5_null, clip_null, t_batch, img_ids)
94
+
95
+ # Classifier-free guidance
96
+ v = v_uncond + cfg_scale * (v_cond - v_uncond)
97
+
98
+ # Euler step: x_{t+dt} = x_t + v * dt
99
+ x = x + v * dt
100
+
101
+ # Decode with VAE
102
+ x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
103
+ x = x / vae.config.scaling_factor
104
+ image = vae.decode(x).sample
105
+
106
+ # Convert to PIL
107
+ image = (image / 2 + 0.5).clamp(0, 1)
108
+ image = image[0].permute(1, 2, 0).cpu().float().numpy()
109
+ image = (image * 255).astype("uint8")
110
+
111
+ from PIL import Image
112
+ return Image.fromarray(image)
113
+
114
+
115
+ # Generate
116
+ image = generate_image("a photograph of a tiger in natural habitat", seed=42)
117
+ image.save("tiger.png")
118
+ image