lea97338 commited on
Commit
1c01d1d
·
verified ·
1 Parent(s): 5a17c83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -6,16 +6,16 @@ from diffusers import Flux2Transformer2DModel
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
- # Charger SEULEMENT le transformer FLUX (léger)
10
  transformer = Flux2Transformer2DModel.from_pretrained(
11
  "black-forest-labs/FLUX.2-klein-4B",
12
  subfolder="transformer",
13
  torch_dtype=dtype,
14
  )
15
 
16
- # Extraire UNIQUEMENT les modules nécessaires
17
- pos_embedder = transformer.pos_embed
18
- extra_embedder = transformer.x_embedder
19
 
20
  # Libérer le reste
21
  del transformer
@@ -45,9 +45,13 @@ def encode(prompt):
45
  )
46
 
47
  text = out.hidden_states[-1] # [1, L, 2560]
 
48
 
49
- pos = pos_embedder(text) # [1, L, 2560]
50
- extra = extra_embedder(text) # [1, L, 2140]
 
 
 
51
 
52
  final = torch.cat([text, pos, extra], dim=-1) # [1, L, 7260]
53
 
 
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
+ # Charger uniquement le transformer FLUX (léger)
10
  transformer = Flux2Transformer2DModel.from_pretrained(
11
  "black-forest-labs/FLUX.2-klein-4B",
12
  subfolder="transformer",
13
  torch_dtype=dtype,
14
  )
15
 
16
+ # Modules internes
17
+ pos_embed = transformer.pos_embed # [1, 4096, 2560]
18
+ x_embedder = transformer.x_embedder # module → 2140 dims
19
 
20
  # Libérer le reste
21
  del transformer
 
45
  )
46
 
47
  text = out.hidden_states[-1] # [1, L, 2560]
48
+ L = text.shape[1]
49
 
50
+ # 🔥 pos_embed n'est PAS un module → on slice
51
+ pos = pos_embed[:, :L, :] # [1, L, 2560]
52
+
53
+ # extra embedder est un module → on l'appelle
54
+ extra = x_embedder(text) # [1, L, 2140]
55
 
56
  final = torch.cat([text, pos, extra], dim=-1) # [1, L, 7260]
57