lea97338 commited on
Commit
6fe7eb6
·
verified ·
1 Parent(s): 1c01d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -31
app.py CHANGED
@@ -1,27 +1,10 @@
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, Qwen2ForCausalLM
4
- from diffusers import Flux2Transformer2DModel
5
 
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
- # Charger uniquement le transformer FLUX (léger)
10
- transformer = Flux2Transformer2DModel.from_pretrained(
11
- "black-forest-labs/FLUX.2-klein-4B",
12
- subfolder="transformer",
13
- torch_dtype=dtype,
14
- )
15
-
16
- # Modules internes
17
- pos_embed = transformer.pos_embed # [1, 4096, 2560]
18
- x_embedder = transformer.x_embedder # module → 2140 dims
19
-
20
- # Libérer le reste
21
- del transformer
22
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
23
-
24
- # Charger Qwen (encodeur texte)
25
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
26
  text_encoder = Qwen2ForCausalLM.from_pretrained(
27
  "Qwen/Qwen2-1.5B",
@@ -44,26 +27,23 @@ def encode(prompt):
44
  use_cache=False,
45
  )
46
 
47
- text = out.hidden_states[-1] # [1, L, 2560]
48
- L = text.shape[1]
49
-
50
- # 🔥 pos_embed n'est PAS un module → on slice
51
- pos = pos_embed[:, :L, :] # [1, L, 2560]
52
-
53
- # extra embedder est un module → on l'appelle
54
- extra = x_embedder(text) # [1, L, 2140]
55
-
56
- final = torch.cat([text, pos, extra], dim=-1) # [1, L, 7260]
57
 
58
- torch.save(final, "embeds.pt")
 
59
 
60
- return str(final.shape), "embeds.pt"
61
 
62
  demo = gr.Interface(
63
  fn=encode,
64
  inputs=gr.Textbox(label="Prompt"),
65
- outputs=[gr.Textbox(label="Shape"), gr.File(label="Embeddings")],
66
- title="FLUX Klein — External Encoder (7260 dims)"
 
 
 
 
67
  )
68
 
69
  demo.launch()
 
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, Qwen2ForCausalLM
 
4
 
5
  device = "cpu"
6
  dtype = torch.float32
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
9
  text_encoder = Qwen2ForCausalLM.from_pretrained(
10
  "Qwen/Qwen2-1.5B",
 
27
  use_cache=False,
28
  )
29
 
30
+ embeds = out.hidden_states[-1] # [1, L, 2560]
31
+ pooled = embeds.mean(dim=1) # [1, 2560]
 
 
 
 
 
 
 
 
32
 
33
+ torch.save(embeds, "embeds.pt")
34
+ torch.save(pooled, "pooled.pt")
35
 
36
+ return str(embeds.shape), "embeds.pt", "pooled.pt"
37
 
38
  demo = gr.Interface(
39
  fn=encode,
40
  inputs=gr.Textbox(label="Prompt"),
41
+ outputs=[
42
+ gr.Textbox(label="Shape"),
43
+ gr.File(label="Embeddings 2560"),
44
+ gr.File(label="Pooled 2560")
45
+ ],
46
+ title="External Text Encoder — 2560 dims"
47
  )
48
 
49
  demo.launch()