lea97338 commited on
Commit
e4f9b77
·
verified ·
1 Parent(s): 5f5e9b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -3,28 +3,29 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import tempfile, os
5
 
6
- # Le vrai text encoder utilisé par FLUX.2 Klein 4B
7
- REPO_ID = "black-forest-labs/FLUX.2-klein-4B/text_encoder"
8
 
9
  device = "cpu"
10
  dtype = torch.float32
11
 
12
- # Charger UNIQUEMENT le CausalLM Qwen utilisé par FLUX2 Klein
13
  tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
14
- text_encoder = AutoModelForCausalLM.from_pretrained(
15
  REPO_ID,
16
  torch_dtype=dtype,
17
  low_cpu_mem_usage=True,
18
  )
19
- text_encoder.to(device)
20
- text_encoder.eval()
 
 
 
21
 
22
  @torch.no_grad()
23
  def encode_text(prompt: str):
24
  if not prompt.strip():
25
  raise gr.Error("Prompt vide")
26
 
27
- # Tokenisation simple (pas de chat template)
28
  inputs = tokenizer(
29
  prompt,
30
  return_tensors="pt",
@@ -32,14 +33,12 @@ def encode_text(prompt: str):
32
  max_length=256
33
  ).to(device)
34
 
35
- # Sortie Qwen3 : hidden_states = [1, seq_len, 4096]
36
- outputs = text_encoder.model(**inputs, output_hidden_states=True)
37
- hidden = outputs.hidden_states[-1] # dernière couche
38
 
39
- # Projection FLUX2 : 4096 → 7680
40
- projected = text_encoder.model.project_out(hidden)
41
 
42
- # Sauvegarde
43
  fd, path = tempfile.mkstemp(suffix=".pt")
44
  os.close(fd)
45
  torch.save(projected.cpu(), path)
@@ -50,8 +49,8 @@ demo = gr.Interface(
50
  fn=encode_text,
51
  inputs=gr.Textbox(label="Prompt"),
52
  outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
53
- title="FLUX.2 Klein — Text Encoder Qwen3 Direct",
54
- description="Encode le texte avec Qwen3 + projection FLUX2 (4096→7680).",
55
  )
56
 
57
  demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import tempfile, os
5
 
6
+ # Le vrai CausalLM utilisé par FLUX2 Klein
7
+ REPO_ID = "Qwen/Qwen2.5-7B-Instruct"
8
 
9
  device = "cpu"
10
  dtype = torch.float32
11
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
  REPO_ID,
15
  torch_dtype=dtype,
16
  low_cpu_mem_usage=True,
17
  )
18
+ model.to(device)
19
+ model.eval()
20
+
21
+ # Projection FLUX2 Klein : 4096 → 7680
22
+ project_out = torch.nn.Linear(4096, 7680, bias=False)
23
 
24
  @torch.no_grad()
25
  def encode_text(prompt: str):
26
  if not prompt.strip():
27
  raise gr.Error("Prompt vide")
28
 
 
29
  inputs = tokenizer(
30
  prompt,
31
  return_tensors="pt",
 
33
  max_length=256
34
  ).to(device)
35
 
36
+ outputs = model.model(**inputs, output_hidden_states=True)
37
+ hidden = outputs.hidden_states[-1] # [1, seq_len, 4096]
 
38
 
39
+ # Projection FLUX2 Klein
40
+ projected = project_out(hidden) # [1, seq_len, 7680]
41
 
 
42
  fd, path = tempfile.mkstemp(suffix=".pt")
43
  os.close(fd)
44
  torch.save(projected.cpu(), path)
 
49
  fn=encode_text,
50
  inputs=gr.Textbox(label="Prompt"),
51
  outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
52
+ title="FLUX.2 Klein — Text Encoder Qwen Direct",
53
+ description="Encode le texte avec Qwen2.5 + projection FLUX2 (4096→7680).",
54
  )
55
 
56
  demo.launch()