lea97338 commited on
Commit
4aac80a
·
verified ·
1 Parent(s): 71c4872

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -29
app.py CHANGED
@@ -2,14 +2,8 @@ import torch
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
4
 
5
- device = "cpu"
6
- dtype = torch.float32
7
-
8
- # On charge uniquement tokenizer + text_encoder
9
  pipe = Flux2Pipeline.from_pretrained(
10
  "black-forest-labs/FLUX.2-klein-4B",
11
- torch_dtype=dtype,
12
- low_cpu_mem_usage=True,
13
  transformer=None,
14
  vae=None,
15
  scheduler=None,
@@ -18,21 +12,16 @@ pipe = Flux2Pipeline.from_pretrained(
18
  )
19
 
20
  tokenizer = pipe.tokenizer
21
- text_encoder = pipe.text_encoder.to(device)
22
-
23
 
24
  def encode_text(prompt: str):
25
- if not prompt.strip():
26
- return "Prompt vide", None
27
-
28
- # Tokenisation simple, sans chat template
29
  inputs = tokenizer(
30
  prompt,
31
  return_tensors="pt",
32
  padding=True,
33
  truncation=True,
34
  max_length=512,
35
- ).to(device)
36
 
37
  with torch.inference_mode():
38
  outputs = text_encoder(
@@ -41,25 +30,15 @@ def encode_text(prompt: str):
41
  use_cache=False,
42
  )
43
 
44
- # On prend la dernière couche cachée : [B, L, D]
45
- embeds = outputs.hidden_states[-1].to("cpu")
46
-
47
- file_path = "embeds.pt"
48
- torch.save(embeds, file_path)
49
-
50
- return f"Embeddings shape: {tuple(embeds.shape)}", file_path
51
 
 
 
52
 
53
  demo = gr.Interface(
54
  fn=encode_text,
55
- inputs=gr.Textbox(label="Prompt", placeholder="Écris ton texte ici..."),
56
- outputs=[
57
- gr.Textbox(label="Infos"),
58
- gr.File(label="Fichier .pt des embeddings"),
59
- ],
60
- title="Encodeur Texte FLUX.2 (Mistral-3) — Minimal",
61
- description="Encode le prompt avec le text encoder FLUX.2 sans chat template.",
62
  )
63
 
64
- if __name__ == "__main__":
65
- demo.launch()
 
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
4
 
 
 
 
 
5
  pipe = Flux2Pipeline.from_pretrained(
6
  "black-forest-labs/FLUX.2-klein-4B",
 
 
7
  transformer=None,
8
  vae=None,
9
  scheduler=None,
 
12
  )
13
 
14
  tokenizer = pipe.tokenizer
15
+ text_encoder = pipe.text_encoder
 
16
 
17
  def encode_text(prompt: str):
 
 
 
 
18
  inputs = tokenizer(
19
  prompt,
20
  return_tensors="pt",
21
  padding=True,
22
  truncation=True,
23
  max_length=512,
24
+ )
25
 
26
  with torch.inference_mode():
27
  outputs = text_encoder(
 
30
  use_cache=False,
31
  )
32
 
33
+ embeds = outputs.hidden_states[-1] # [B, L, 4096]
 
 
 
 
 
 
34
 
35
+ torch.save(embeds, "embeds.pt")
36
+ return f"shape={tuple(embeds.shape)}", "embeds.pt"
37
 
38
  demo = gr.Interface(
39
  fn=encode_text,
40
+ inputs=gr.Textbox(),
41
+ outputs=[gr.Textbox(), gr.File()],
 
 
 
 
 
42
  )
43
 
44
+ demo.launch()