lea97338 commited on
Commit
5f5e9b6
·
verified ·
1 Parent(s): 926e9ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -35
app.py CHANGED
@@ -1,66 +1,57 @@
1
- # app.py (Space ENCODER)
2
-
3
  import torch
4
  import gradio as gr
5
- from diffusers import Flux2Pipeline
6
  import tempfile, os
7
 
8
- REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
 
 
9
  device = "cpu"
10
  dtype = torch.float32
11
 
12
- # On NE CHARGE QUE la partie texte (comme tu voulais)
13
- pipe = Flux2Pipeline.from_pretrained(
 
14
  REPO_ID,
15
- transformer=None,
16
- vae=None,
17
- scheduler=None,
18
  torch_dtype=dtype,
19
  low_cpu_mem_usage=True,
20
  )
21
-
22
- # On s'assure de ne garder que ce qui sert à l'encodage texte
23
- pipe.transformer = None
24
- pipe.vae = None
25
- pipe.scheduler = None
26
-
27
- pipe.to(device)
28
-
29
 
30
  @torch.no_grad()
31
  def encode_text(prompt: str):
32
  if not prompt.strip():
33
  raise gr.Error("Prompt vide")
34
 
35
- # encode_prompt renvoie EXACTEMENT ce que FLUX2 veut
36
- prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
37
- prompt=prompt, # IMPORTANT : string simple, pas dict
38
- device=device,
39
- num_images_per_prompt=1,
40
- )
 
 
 
 
 
41
 
42
- # On sauvegarde les trois tensors ensemble
43
- data = {
44
- "prompt_embeds": prompt_embeds.cpu(),
45
- "pooled_prompt_embeds": pooled_prompt_embeds.cpu(),
46
- "text_ids": text_ids.cpu(),
47
- }
48
 
 
49
  fd, path = tempfile.mkstemp(suffix=".pt")
50
  os.close(fd)
51
- torch.save(data, path)
52
 
53
  return path
54
 
55
-
56
  demo = gr.Interface(
57
  fn=encode_text,
58
  inputs=gr.Textbox(label="Prompt"),
59
  outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
60
- title="FLUX.2 Klein — Text Encoder Officiel",
61
- description="Renvoie les embeddings EXACTS que FLUX2 Klein attend.",
62
  )
63
 
64
- # api_name par défaut = "/predict", si tu veux explicitement :
65
- # demo.launch(api_name="/encode_text")
66
  demo.launch()
 
 
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import tempfile, os
5
 
6
+ # Le vrai text encoder utilisé par FLUX.2 Klein 4B
7
+ REPO_ID = "black-forest-labs/FLUX.2-klein-4B/text_encoder"
8
+
9
  device = "cpu"
10
  dtype = torch.float32
11
 
12
+ # Charger UNIQUEMENT le CausalLM Qwen utilisé par FLUX2 Klein
13
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
14
+ text_encoder = AutoModelForCausalLM.from_pretrained(
15
  REPO_ID,
 
 
 
16
  torch_dtype=dtype,
17
  low_cpu_mem_usage=True,
18
  )
19
+ text_encoder.to(device)
20
+ text_encoder.eval()
 
 
 
 
 
 
21
 
22
  @torch.no_grad()
23
  def encode_text(prompt: str):
24
  if not prompt.strip():
25
  raise gr.Error("Prompt vide")
26
 
27
+ # Tokenisation simple (pas de chat template)
28
+ inputs = tokenizer(
29
+ prompt,
30
+ return_tensors="pt",
31
+ truncation=True,
32
+ max_length=256
33
+ ).to(device)
34
+
35
+ # Sortie Qwen3 : hidden_states = [1, seq_len, 4096]
36
+ outputs = text_encoder.model(**inputs, output_hidden_states=True)
37
+ hidden = outputs.hidden_states[-1] # dernière couche
38
 
39
+ # Projection FLUX2 : 4096 7680
40
+ projected = text_encoder.model.project_out(hidden)
 
 
 
 
41
 
42
+ # Sauvegarde
43
  fd, path = tempfile.mkstemp(suffix=".pt")
44
  os.close(fd)
45
+ torch.save(projected.cpu(), path)
46
 
47
  return path
48
 
 
49
  demo = gr.Interface(
50
  fn=encode_text,
51
  inputs=gr.Textbox(label="Prompt"),
52
  outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
53
+ title="FLUX.2 Klein — Text Encoder Qwen3 Direct",
54
+ description="Encode le texte avec Qwen3 + projection FLUX2 (4096→7680).",
55
  )
56
 
 
 
57
  demo.launch()