lea97338 commited on
Commit
fe4b89c
·
verified ·
1 Parent(s): 5a16e2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -3,12 +3,13 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import tempfile, os
5
 
6
- # Le vrai CausalLM utilisé par FLUX2 Klein
7
  REPO_ID = "Qwen/Qwen2.5-1.5B-Instruct"
8
 
9
  device = "cpu"
10
  dtype = torch.float32
11
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  REPO_ID,
@@ -26,6 +27,7 @@ def encode_text(prompt: str):
26
  if not prompt.strip():
27
  raise gr.Error("Prompt vide")
28
 
 
29
  inputs = tokenizer(
30
  prompt,
31
  return_tensors="pt",
@@ -33,12 +35,14 @@ def encode_text(prompt: str):
33
  max_length=256
34
  ).to(device)
35
 
 
36
  outputs = model.model(**inputs, output_hidden_states=True)
37
- hidden = outputs.hidden_states[-1] # [1, seq_len, 4096]
38
 
39
  # Projection FLUX2 Klein
40
- projected = project_out(hidden) # [1, seq_len, 7680]
41
 
 
42
  fd, path = tempfile.mkstemp(suffix=".pt")
43
  os.close(fd)
44
  torch.save(projected.cpu(), path)
@@ -49,8 +53,8 @@ demo = gr.Interface(
49
  fn=encode_text,
50
  inputs=gr.Textbox(label="Prompt"),
51
  outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
52
- title="FLUX.2 Klein — Text Encoder Qwen Direct",
53
- description="Encode le texte avec Qwen2.5 + projection FLUX2 (4096→7680).",
54
  )
55
 
56
  demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import tempfile, os
5
 
6
+ # Qwen 2.5 1.5B Instruct
7
  REPO_ID = "Qwen/Qwen2.5-1.5B-Instruct"
8
 
9
  device = "cpu"
10
  dtype = torch.float32
11
 
12
+ # Charger UNIQUEMENT le CausalLM
13
  tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  REPO_ID,
 
27
  if not prompt.strip():
28
  raise gr.Error("Prompt vide")
29
 
30
+ # Tokenisation simple
31
  inputs = tokenizer(
32
  prompt,
33
  return_tensors="pt",
 
35
  max_length=256
36
  ).to(device)
37
 
38
+ # Sortie Qwen : hidden_states = [1, seq_len, 4096]
39
  outputs = model.model(**inputs, output_hidden_states=True)
40
+ hidden = outputs.hidden_states[-1]
41
 
42
  # Projection FLUX2 Klein
43
+ projected = project_out(hidden) # [1, seq_len, 7680]
44
 
45
+ # Sauvegarde
46
  fd, path = tempfile.mkstemp(suffix=".pt")
47
  os.close(fd)
48
  torch.save(projected.cpu(), path)
 
53
  fn=encode_text,
54
  inputs=gr.Textbox(label="Prompt"),
55
  outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
56
+ title="FLUX2 Klein — Encoder Qwen2.5 1.5B",
57
+ description="Encode le texte avec Qwen2.5 1.5B + projection FLUX2 (4096→7680).",
58
  )
59
 
60
  demo.launch()