lea97338 commited on
Commit
6562a48
·
verified ·
1 Parent(s): 421f7d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -2,17 +2,23 @@ import torch
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
4
 
5
- # Charger FLUX.2 COMPLET
6
  pipe = Flux2Pipeline.from_pretrained(
7
  "black-forest-labs/FLUX.2-klein-4B",
8
  torch_dtype=torch.float32,
 
 
 
 
9
  low_cpu_mem_usage=True,
10
  )
11
 
 
12
  tokenizer = pipe.tokenizer
13
  text_encoder = pipe.text_encoder
14
 
15
  def encode_text(prompt: str):
 
16
  inputs = tokenizer(
17
  prompt,
18
  return_tensors="pt",
@@ -21,6 +27,7 @@ def encode_text(prompt: str):
21
  max_length=512,
22
  )
23
 
 
24
  with torch.inference_mode():
25
  outputs = text_encoder(
26
  **inputs,
@@ -28,15 +35,21 @@ def encode_text(prompt: str):
28
  use_cache=False,
29
  )
30
 
31
- embeds = outputs.hidden_states[-1] # [B, L, 4096]
 
32
 
 
33
  torch.save(embeds, "embeds.pt")
 
34
  return f"shape={tuple(embeds.shape)}", "embeds.pt"
35
 
 
36
  demo = gr.Interface(
37
  fn=encode_text,
38
- inputs=gr.Textbox(),
39
- outputs=[gr.Textbox(), gr.File()],
 
 
40
  )
41
 
42
  demo.launch()
 
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
4
 
5
+ # Charger FLUX.2 Klein COMPLET pour récupérer le vrai text encoder (Qwen3)
6
  pipe = Flux2Pipeline.from_pretrained(
7
  "black-forest-labs/FLUX.2-klein-4B",
8
  torch_dtype=torch.float32,
9
+ transformer=None,
10
+ vae=None,
11
+ scheduler=None,
12
+ feature_extractor =None,
13
  low_cpu_mem_usage=True,
14
  )
15
 
16
+ # Récupération du tokenizer + text_encoder (Qwen3ForCausalLM)
17
  tokenizer = pipe.tokenizer
18
  text_encoder = pipe.text_encoder
19
 
20
  def encode_text(prompt: str):
21
+ # Tokenisation simple
22
  inputs = tokenizer(
23
  prompt,
24
  return_tensors="pt",
 
27
  max_length=512,
28
  )
29
 
30
+ # Encodage texte → embeddings 2560 dims
31
  with torch.inference_mode():
32
  outputs = text_encoder(
33
  **inputs,
 
35
  use_cache=False,
36
  )
37
 
38
+ # Dernière couche cachée = embeddings texte
39
+ embeds = outputs.hidden_states[-1] # [B, L, 2560]
40
 
41
+ # Sauvegarde dans un fichier .pt
42
  torch.save(embeds, "embeds.pt")
43
+
44
  return f"shape={tuple(embeds.shape)}", "embeds.pt"
45
 
46
+ # Interface Gradio
47
  demo = gr.Interface(
48
  fn=encode_text,
49
+ inputs=gr.Textbox(label="Prompt"),
50
+ outputs=[gr.Textbox(label="Shape"), gr.File(label="Embeddings (.pt)")],
51
+ title="FLUX.2 Klein — Text Embedder (Qwen3 2560 dims)",
52
+ description="Encodeur texte officiel de FLUX.2 Klein (Qwen3ForCausalLM).",
53
  )
54
 
55
  demo.launch()