Update app.py
Browse files
app.py
CHANGED
|
@@ -6,16 +6,16 @@ from diffusers import Flux2Transformer2DModel
|
|
| 6 |
device = "cpu"
|
| 7 |
dtype = torch.float32
|
| 8 |
|
| 9 |
-
# Charger
|
| 10 |
transformer = Flux2Transformer2DModel.from_pretrained(
|
| 11 |
"black-forest-labs/FLUX.2-klein-4B",
|
| 12 |
subfolder="transformer",
|
| 13 |
torch_dtype=dtype,
|
| 14 |
)
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
-
|
| 19 |
|
| 20 |
# Libérer le reste
|
| 21 |
del transformer
|
|
@@ -45,9 +45,13 @@ def encode(prompt):
|
|
| 45 |
)
|
| 46 |
|
| 47 |
text = out.hidden_states[-1] # [1, L, 2560]
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
final = torch.cat([text, pos, extra], dim=-1) # [1, L, 7260]
|
| 53 |
|
|
|
|
| 6 |
device = "cpu"
|
| 7 |
dtype = torch.float32
|
| 8 |
|
| 9 |
+
# Charger uniquement le transformer FLUX (léger)
|
| 10 |
transformer = Flux2Transformer2DModel.from_pretrained(
|
| 11 |
"black-forest-labs/FLUX.2-klein-4B",
|
| 12 |
subfolder="transformer",
|
| 13 |
torch_dtype=dtype,
|
| 14 |
)
|
| 15 |
|
| 16 |
+
# Modules internes
|
| 17 |
+
pos_embed = transformer.pos_embed # [1, 4096, 2560]
|
| 18 |
+
x_embedder = transformer.x_embedder # module → 2140 dims
|
| 19 |
|
| 20 |
# Libérer le reste
|
| 21 |
del transformer
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
text = out.hidden_states[-1] # [1, L, 2560]
|
| 48 |
+
L = text.shape[1]
|
| 49 |
|
| 50 |
+
# 🔥 pos_embed n'est PAS un module → on slice
|
| 51 |
+
pos = pos_embed[:, :L, :] # [1, L, 2560]
|
| 52 |
+
|
| 53 |
+
# extra embedder est un module → on l'appelle
|
| 54 |
+
extra = x_embedder(text) # [1, L, 2140]
|
| 55 |
|
| 56 |
final = torch.cat([text, pos, extra], dim=-1) # [1, L, 7260]
|
| 57 |
|