lea97338 commited on
Commit
6548162
·
verified ·
1 Parent(s): 4510a31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -6,27 +6,30 @@ from transformers import AutoTokenizer, Qwen2ForCausalLM
6
  device = "cpu"
7
  dtype = torch.float32
8
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
10
  text_encoder = Qwen2ForCausalLM.from_pretrained(
11
  "Qwen/Qwen2-0.5B",
12
  torch_dtype=dtype,
13
  )
14
 
 
15
  proj = nn.Linear(1536, 2048)
16
 
17
  def encode(prompt: str):
18
- # 1) Nettoyage du prompt
19
  if prompt is None:
20
  prompt = ""
21
  prompt_clean = prompt.strip()
22
 
23
- # 2) Si vide on force un token valide
24
  if prompt_clean == "":
25
- if tokenizer.eos_token is not None:
26
  prompt_clean = tokenizer.eos_token
27
  else:
28
  prompt_clean = "."
29
 
 
30
  tokens = tokenizer(
31
  prompt_clean,
32
  return_tensors="pt",
@@ -35,17 +38,23 @@ def encode(prompt: str):
35
  max_length=512,
36
  )
37
 
38
- with torch.inference_mode():
39
- out = text_encoder(
40
- **tokens,
41
- output_hidden_states=True,
42
- use_cache=False,
43
- )
 
 
 
 
 
 
44
 
45
- embeds_1536 = out.hidden_states[-1] # [1, L, 1536]
46
- embeds_2048 = proj(embeds_1536) # [1, L, 2048]
47
- pooled = embeds_2048.mean(dim=1) # [1, 2048]
48
 
 
49
  torch.save(embeds_2048, "embeds.pt")
50
  torch.save(pooled, "pooled.pt")
51
 
@@ -57,7 +66,7 @@ demo = gr.Interface(
57
  outputs=[
58
  gr.Textbox(label="Shape"),
59
  gr.File(label="Embeddings 2048"),
60
- gr.File(label="Pooled 2048"),
61
  ],
62
  title="External Text Encoder — 2048 dims (FLUX.1‑Schnell)"
63
  )
 
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
+ # Qwen 0.5B (léger, CPU OK)
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
11
  text_encoder = Qwen2ForCausalLM.from_pretrained(
12
  "Qwen/Qwen2-0.5B",
13
  torch_dtype=dtype,
14
  )
15
 
16
+ # Projection 1536 -> 2048 pour FLUX.1-Schnell
17
  proj = nn.Linear(1536, 2048)
18
 
19
  def encode(prompt: str):
20
+ # Nettoyage du prompt
21
  if prompt is None:
22
  prompt = ""
23
  prompt_clean = prompt.strip()
24
 
25
+ # Si vide -> token de secours
26
  if prompt_clean == "":
27
+ if tokenizer.eos_token:
28
  prompt_clean = tokenizer.eos_token
29
  else:
30
  prompt_clean = "."
31
 
32
+ # Tokenisation
33
  tokens = tokenizer(
34
  prompt_clean,
35
  return_tensors="pt",
 
38
  max_length=512,
39
  )
40
 
41
+ # Encodage Qwen (SANS inference_mode)
42
+ out = text_encoder(
43
+ **tokens,
44
+ output_hidden_states=True,
45
+ use_cache=False,
46
+ )
47
+
48
+ # Embeddings Qwen (1536 dims)
49
+ embeds_1536 = out.hidden_states[-1] # [1, L, 1536]
50
+
51
+ # Projection -> 2048 dims
52
+ embeds_2048 = proj(embeds_1536) # [1, L, 2048]
53
 
54
+ # pooled -> moyenne
55
+ pooled = embeds_2048.mean(dim=1) # [1, 2048]
 
56
 
57
+ # Sauvegarde
58
  torch.save(embeds_2048, "embeds.pt")
59
  torch.save(pooled, "pooled.pt")
60
 
 
66
  outputs=[
67
  gr.Textbox(label="Shape"),
68
  gr.File(label="Embeddings 2048"),
69
+ gr.File(label="Pooled 2048")
70
  ],
71
  title="External Text Encoder — 2048 dims (FLUX.1‑Schnell)"
72
  )