lea97338 commited on
Commit
4510a31
·
verified ·
1 Parent(s): 28585f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -6,19 +6,29 @@ from transformers import AutoTokenizer, Qwen2ForCausalLM
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
- # Charger Qwen 0.5B (léger, CPU OK)
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
11
  text_encoder = Qwen2ForCausalLM.from_pretrained(
12
  "Qwen/Qwen2-0.5B",
13
  torch_dtype=dtype,
14
  )
15
 
16
- # Projection 1536 → 2048 (pour FLUX.1-Schnell)
17
  proj = nn.Linear(1536, 2048)
18
 
19
- def encode(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
20
  tokens = tokenizer(
21
- prompt,
22
  return_tensors="pt",
23
  padding=True,
24
  truncation=True,
@@ -32,16 +42,10 @@ def encode(prompt):
32
  use_cache=False,
33
  )
34
 
35
- # Embeddings Qwen 1536 dims
36
- embeds_1536 = out.hidden_states[-1] # [1, L, 1536]
37
-
38
- # Projection → 2048 dims
39
- embeds_2048 = proj(embeds_1536) # [1, L, 2048]
40
-
41
- # pooled → moyenne
42
- pooled = embeds_2048.mean(dim=1) # [1, 2048]
43
 
44
- # Sauvegarde
45
  torch.save(embeds_2048, "embeds.pt")
46
  torch.save(pooled, "pooled.pt")
47
 
@@ -53,7 +57,7 @@ demo = gr.Interface(
53
  outputs=[
54
  gr.Textbox(label="Shape"),
55
  gr.File(label="Embeddings 2048"),
56
- gr.File(label="Pooled 2048")
57
  ],
58
  title="External Text Encoder — 2048 dims (FLUX.1‑Schnell)"
59
  )
 
6
  device = "cpu"
7
  dtype = torch.float32
8
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
10
  text_encoder = Qwen2ForCausalLM.from_pretrained(
11
  "Qwen/Qwen2-0.5B",
12
  torch_dtype=dtype,
13
  )
14
 
 
15
  proj = nn.Linear(1536, 2048)
16
 
17
+ def encode(prompt: str):
18
+ # 1) Nettoyage du prompt
19
+ if prompt is None:
20
+ prompt = ""
21
+ prompt_clean = prompt.strip()
22
+
23
+ # 2) Si vide → on force un token valide
24
+ if prompt_clean == "":
25
+ if tokenizer.eos_token is not None:
26
+ prompt_clean = tokenizer.eos_token
27
+ else:
28
+ prompt_clean = "."
29
+
30
  tokens = tokenizer(
31
+ prompt_clean,
32
  return_tensors="pt",
33
  padding=True,
34
  truncation=True,
 
42
  use_cache=False,
43
  )
44
 
45
+ embeds_1536 = out.hidden_states[-1] # [1, L, 1536]
46
+ embeds_2048 = proj(embeds_1536) # [1, L, 2048]
47
+ pooled = embeds_2048.mean(dim=1) # [1, 2048]
 
 
 
 
 
48
 
 
49
  torch.save(embeds_2048, "embeds.pt")
50
  torch.save(pooled, "pooled.pt")
51
 
 
57
  outputs=[
58
  gr.Textbox(label="Shape"),
59
  gr.File(label="Embeddings 2048"),
60
+ gr.File(label="Pooled 2048"),
61
  ],
62
  title="External Text Encoder — 2048 dims (FLUX.1‑Schnell)"
63
  )