lea97338 commited on
Commit
71c4872
·
verified ·
1 Parent(s): 6e9ebae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -2,15 +2,14 @@ import torch
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
4
 
5
- # -----------------------------
6
- # CHARGEMENT MINIMAL DU TEXT ENCODER FLUX
7
- # -----------------------------
 
8
  pipe = Flux2Pipeline.from_pretrained(
9
  "black-forest-labs/FLUX.2-klein-4B",
10
- torch_dtype=torch.float32,
11
  low_cpu_mem_usage=True,
12
-
13
- # On désactive tout ce qui n'est pas utile pour encode_prompt()
14
  transformer=None,
15
  vae=None,
16
  scheduler=None,
@@ -18,35 +17,49 @@ pipe = Flux2Pipeline.from_pretrained(
18
  feature_extractor=None,
19
  )
20
 
21
- # -----------------------------
22
- # ENCODEUR
23
- # -----------------------------
 
24
  def encode_text(prompt: str):
25
  if not prompt.strip():
26
  return "Prompt vide", None
27
 
 
 
 
 
 
 
 
 
 
28
  with torch.inference_mode():
29
- embeds = pipe.encode_prompt(prompt)
 
 
 
 
 
 
 
30
 
31
- # Sauvegarde dans un fichier temporaire
32
  file_path = "embeds.pt"
33
- torch.save(embeds.cpu(), file_path)
34
 
35
- return f"Embeddings générés : {tuple(embeds.shape)}", file_path
36
 
37
 
38
- # -----------------------------
39
- # INTERFACE GRADIO
40
- # -----------------------------
41
  demo = gr.Interface(
42
  fn=encode_text,
43
  inputs=gr.Textbox(label="Prompt", placeholder="Écris ton texte ici..."),
44
  outputs=[
45
  gr.Textbox(label="Infos"),
46
- gr.File(label="Fichier .pt des embeddings")
47
  ],
48
- title="Encodeur Texte FLUX.2 — Minimal",
49
- description="Encodeur officiel FLUX.2 (Mistral-3-Small). Génère des embeddings compatibles avec Flux2Pipeline.",
50
  )
51
 
52
- demo.launch()
 
 
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
4
 
5
+ device = "cpu"
6
+ dtype = torch.float32
7
+
8
+ # On charge uniquement tokenizer + text_encoder
9
  pipe = Flux2Pipeline.from_pretrained(
10
  "black-forest-labs/FLUX.2-klein-4B",
11
+ torch_dtype=dtype,
12
  low_cpu_mem_usage=True,
 
 
13
  transformer=None,
14
  vae=None,
15
  scheduler=None,
 
17
  feature_extractor=None,
18
  )
19
 
20
+ tokenizer = pipe.tokenizer
21
+ text_encoder = pipe.text_encoder.to(device)
22
+
23
+
24
  def encode_text(prompt: str):
25
  if not prompt.strip():
26
  return "Prompt vide", None
27
 
28
+ # Tokenisation simple, sans chat template
29
+ inputs = tokenizer(
30
+ prompt,
31
+ return_tensors="pt",
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=512,
35
+ ).to(device)
36
+
37
  with torch.inference_mode():
38
+ outputs = text_encoder(
39
+ **inputs,
40
+ output_hidden_states=True,
41
+ use_cache=False,
42
+ )
43
+
44
+ # On prend la dernière couche cachée : [B, L, D]
45
+ embeds = outputs.hidden_states[-1].to("cpu")
46
 
 
47
  file_path = "embeds.pt"
48
+ torch.save(embeds, file_path)
49
 
50
+ return f"Embeddings shape: {tuple(embeds.shape)}", file_path
51
 
52
 
 
 
 
53
  demo = gr.Interface(
54
  fn=encode_text,
55
  inputs=gr.Textbox(label="Prompt", placeholder="Écris ton texte ici..."),
56
  outputs=[
57
  gr.Textbox(label="Infos"),
58
+ gr.File(label="Fichier .pt des embeddings"),
59
  ],
60
+ title="Encodeur Texte FLUX.2 (Mistral-3) — Minimal",
61
+ description="Encode le prompt avec le text encoder FLUX.2 sans chat template.",
62
  )
63
 
64
+ if __name__ == "__main__":
65
+ demo.launch()