lea97338 commited on
Commit
926e9ee
·
verified ·
1 Parent(s): 5f49f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import gradio as gr
3
  from diffusers import Flux2Pipeline
@@ -7,6 +9,7 @@ REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
7
  device = "cpu"
8
  dtype = torch.float32
9
 
 
10
  pipe = Flux2Pipeline.from_pretrained(
11
  REPO_ID,
12
  transformer=None,
@@ -16,34 +19,40 @@ pipe = Flux2Pipeline.from_pretrained(
16
  low_cpu_mem_usage=True,
17
  )
18
 
19
- # On supprime les parties inutiles
20
  pipe.transformer = None
21
  pipe.vae = None
22
  pipe.scheduler = None
23
 
24
  pipe.to(device)
25
 
 
26
  @torch.no_grad()
27
  def encode_text(prompt: str):
28
  if not prompt.strip():
29
  raise gr.Error("Prompt vide")
30
 
31
- # FLUX2 Klein attend un format chat Qwen3
32
- messages = [
33
- {"role": "user", "content": prompt}
34
- ]
35
-
36
- prompt_embeds, _, _ = pipe.encode_prompt(
37
- prompt=messages,
38
  device=device,
39
  num_images_per_prompt=1,
40
  )
41
 
 
 
 
 
 
 
 
42
  fd, path = tempfile.mkstemp(suffix=".pt")
43
  os.close(fd)
44
- torch.save(prompt_embeds.cpu(), path)
 
45
  return path
46
 
 
47
  demo = gr.Interface(
48
  fn=encode_text,
49
  inputs=gr.Textbox(label="Prompt"),
@@ -52,4 +61,6 @@ demo = gr.Interface(
52
  description="Renvoie les embeddings EXACTS que FLUX2 Klein attend.",
53
  )
54
 
 
 
55
  demo.launch()
 
1
+ # app.py (Space ENCODER)
2
+
3
  import torch
4
  import gradio as gr
5
  from diffusers import Flux2Pipeline
 
9
  device = "cpu"
10
  dtype = torch.float32
11
 
12
+ # On NE CHARGE QUE la partie texte (comme tu voulais)
13
  pipe = Flux2Pipeline.from_pretrained(
14
  REPO_ID,
15
  transformer=None,
 
19
  low_cpu_mem_usage=True,
20
  )
21
 
22
+ # On s'assure de ne garder que ce qui sert à l'encodage texte
23
  pipe.transformer = None
24
  pipe.vae = None
25
  pipe.scheduler = None
26
 
27
  pipe.to(device)
28
 
29
+
30
  @torch.no_grad()
31
  def encode_text(prompt: str):
32
  if not prompt.strip():
33
  raise gr.Error("Prompt vide")
34
 
35
+ # encode_prompt renvoie EXACTEMENT ce que FLUX2 veut
36
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
37
+ prompt=prompt, # IMPORTANT : string simple, pas dict
 
 
 
 
38
  device=device,
39
  num_images_per_prompt=1,
40
  )
41
 
42
+ # On sauvegarde les trois tensors ensemble
43
+ data = {
44
+ "prompt_embeds": prompt_embeds.cpu(),
45
+ "pooled_prompt_embeds": pooled_prompt_embeds.cpu(),
46
+ "text_ids": text_ids.cpu(),
47
+ }
48
+
49
  fd, path = tempfile.mkstemp(suffix=".pt")
50
  os.close(fd)
51
+ torch.save(data, path)
52
+
53
  return path
54
 
55
+
56
  demo = gr.Interface(
57
  fn=encode_text,
58
  inputs=gr.Textbox(label="Prompt"),
 
61
  description="Renvoie les embeddings EXACTS que FLUX2 Klein attend.",
62
  )
63
 
64
+ # api_name par défaut = "/predict", si tu veux explicitement :
65
+ # demo.launch(api_name="/encode_text")
66
  demo.launch()