lea97338 commited on
Commit
79b2755
·
verified ·
1 Parent(s): ed78c02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -5,29 +5,19 @@ import tempfile, os
5
 
6
  REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
7
  device = "cpu"
8
- dtype = torch.float32 # CPU-safe
9
 
10
- # On charge la pipeline mais on supprime tout sauf le text encoder
11
  pipe = Flux2Pipeline.from_pretrained(
12
  REPO_ID,
13
- transformer=None,
14
- vae=None,
15
- scheduler=None,
16
  torch_dtype=dtype,
17
  low_cpu_mem_usage=True,
18
  )
19
 
20
- # On supprime tout ce qui n'est PAS le text encoder
21
  pipe.transformer = None
22
  pipe.vae = None
23
  pipe.scheduler = None
24
 
25
- # On garde :
26
- # - pipe.tokenizer
27
- # - pipe.text_encoder
28
- # - pipe.text_encoder_2 (si présent)
29
- # - encode_prompt()
30
-
31
  pipe.to(device)
32
 
33
  @torch.no_grad()
@@ -35,18 +25,21 @@ def encode_text(prompt: str):
35
  if not prompt.strip():
36
  raise gr.Error("Prompt vide")
37
 
38
- # encode_prompt = embeddings EXACTS attendus par FLUX2
 
 
 
 
39
  prompt_embeds, _, _ = pipe.encode_prompt(
40
- prompt=prompt,
41
  device=device,
42
  num_images_per_prompt=1,
 
43
  )
44
 
45
- # Sauvegarde dans un fichier .pt
46
  fd, path = tempfile.mkstemp(suffix=".pt")
47
  os.close(fd)
48
  torch.save(prompt_embeds.cpu(), path)
49
-
50
  return path
51
 
52
  demo = gr.Interface(
 
5
 
6
  REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
7
  device = "cpu"
8
+ dtype = torch.float32
9
 
 
10
  pipe = Flux2Pipeline.from_pretrained(
11
  REPO_ID,
 
 
 
12
  torch_dtype=dtype,
13
  low_cpu_mem_usage=True,
14
  )
15
 
16
+ # On supprime les parties inutiles
17
  pipe.transformer = None
18
  pipe.vae = None
19
  pipe.scheduler = None
20
 
 
 
 
 
 
 
21
  pipe.to(device)
22
 
23
  @torch.no_grad()
 
25
  if not prompt.strip():
26
  raise gr.Error("Prompt vide")
27
 
28
+ # FLUX2 Klein attend un format chat Qwen3
29
+ messages = [
30
+ {"role": "user", "content": prompt}
31
+ ]
32
+
33
  prompt_embeds, _, _ = pipe.encode_prompt(
34
+ prompt=messages,
35
  device=device,
36
  num_images_per_prompt=1,
37
+ do_classifier_free_guidance=False,
38
  )
39
 
 
40
  fd, path = tempfile.mkstemp(suffix=".pt")
41
  os.close(fd)
42
  torch.save(prompt_embeds.cpu(), path)
 
43
  return path
44
 
45
  demo = gr.Interface(