lea97338 commited on
Commit
aef300d
·
verified ·
1 Parent(s): c95b5a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -26
app.py CHANGED
@@ -1,26 +1,30 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import Qwen3ForCausalLM
4
  from diffusers import Flux2Pipeline
5
 
6
- # Charger FLUX.2 Klein COMPLET pour récupérer le vrai text encoder (Qwen3)
 
 
 
7
  pipe = Flux2Pipeline.from_pretrained(
8
  "black-forest-labs/FLUX.2-klein-4B",
9
- torch_dtype=torch.float32,
10
- transformer=None,
11
- vae=None,
12
- scheduler=None,
13
- feature_extractor =None,
14
  low_cpu_mem_usage=True,
15
  )
16
 
17
- # Récupération du tokenizer + text_encoder (Qwen3ForCausalLM)
18
- tokenizer = pipe.tokenizer
19
- text_encoder = Qwen3ForCausalLM.from_pretained(pipe.text_encoder)
 
 
 
 
 
 
20
 
21
- def encode_text(prompt: str):
22
- # Tokenisation simple
23
- inputs = tokenizer(
24
  prompt,
25
  return_tensors="pt",
26
  padding=True,
@@ -28,29 +32,29 @@ def encode_text(prompt: str):
28
  max_length=512,
29
  )
30
 
31
- # Encodage texte → embeddings 2560 dims
32
  with torch.inference_mode():
33
- outputs = text_encoder(
34
- **inputs,
35
  output_hidden_states=True,
36
  use_cache=False,
37
  )
38
 
39
- # Dernière couche cachée = embeddings texte
40
- embeds = outputs.hidden_states[-1] # [B, L, 2560]
 
 
 
 
41
 
42
- # Sauvegarde dans un fichier .pt
43
- torch.save(embeds, "embeds.pt")
44
 
45
- return f"shape={tuple(embeds.shape)}", "embeds.pt"
46
 
47
- # Interface Gradio
48
  demo = gr.Interface(
49
- fn=encode_text,
50
  inputs=gr.Textbox(label="Prompt"),
51
- outputs=[gr.Textbox(label="Shape"), gr.File(label="Embeddings (.pt)")],
52
- title="FLUX.2 Klein — Text Embedder (Qwen3 2560 dims)",
53
- description="Encodeur texte officiel de FLUX.2 Klein (Qwen3ForCausalLM).",
54
  )
55
 
56
  demo.launch()
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, Qwen2ForCausalLM
4
  from diffusers import Flux2Pipeline
5
 
6
+ device = "cpu"
7
+ dtype = torch.float32
8
+
9
+ # Charger FLUX pour récupérer les embedder internes
10
  pipe = Flux2Pipeline.from_pretrained(
11
  "black-forest-labs/FLUX.2-klein-4B",
12
+ torch_dtype=dtype,
 
 
 
 
13
  low_cpu_mem_usage=True,
14
  )
15
 
16
+ pos_embedder = pipe.transformer.pos_embedder
17
+ extra_embedder = pipe.transformer.extra_embedder
18
+
19
+ # Charger Qwen (encodeur texte)
20
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
21
+ text_encoder = Qwen2ForCausalLM.from_pretrained(
22
+ "Qwen/Qwen2-1.5B",
23
+ torch_dtype=dtype,
24
+ )
25
 
26
+ def encode(prompt):
27
+ tokens = tokenizer(
 
28
  prompt,
29
  return_tensors="pt",
30
  padding=True,
 
32
  max_length=512,
33
  )
34
 
 
35
  with torch.inference_mode():
36
+ out = text_encoder(
37
+ **tokens,
38
  output_hidden_states=True,
39
  use_cache=False,
40
  )
41
 
42
+ text = out.hidden_states[-1] # [1, L, 2560]
43
+
44
+ pos = pos_embedder(text) # [1, L, 2560]
45
+ extra = extra_embedder(text) # [1, L, 2140]
46
+
47
+ final = torch.cat([text, pos, extra], dim=-1) # [1, L, 7260]
48
 
49
+ torch.save(final, "embeds.pt")
 
50
 
51
+ return str(final.shape), "embeds.pt"
52
 
 
53
  demo = gr.Interface(
54
+ fn=encode,
55
  inputs=gr.Textbox(label="Prompt"),
56
+ outputs=[gr.Textbox(label="Shape"), gr.File(label="Embeddings")],
57
+ title="FLUX Klein — External Encoder (7260 dims)"
 
58
  )
59
 
60
  demo.launch()