lea97338 commited on
Commit
4e93c10
·
verified ·
1 Parent(s): 8cd5853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -1,25 +1,35 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import AutoTokenizer, Qwen2ForCausalLM
4
 
5
  device = "cpu"
6
  dtype = torch.float32
7
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
9
  text_encoder = Qwen2ForCausalLM.from_pretrained(
10
  "Qwen/Qwen2-0.5B",
11
  torch_dtype=dtype,
12
  )
13
 
14
- # Qwen 0.5B 896 dims
15
- proj_tokens = nn.Linear(896, 2048) # pour prompt_embeds
16
- proj_pooled = nn.Linear(2048, 768) # pour pooled_prompt_embeds
17
 
18
- def encode(prompt):
19
- if not prompt.strip():
 
 
 
20
  prompt = tokenizer.eos_token or "."
21
 
22
- tokens = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
23
 
24
  out = text_encoder(
25
  **tokens,
@@ -27,14 +37,31 @@ def encode(prompt):
27
  use_cache=False,
28
  )
29
 
30
- hidden = out.hidden_states[-1] # [1, L, 896]
 
31
 
32
- embeds_2048 = proj_tokens(hidden) # [1, L, 2048]
 
33
 
 
34
  pooled_2048 = embeds_2048.mean(dim=1) # [1, 2048]
35
  pooled_768 = proj_pooled(pooled_2048) # [1, 768]
36
 
 
37
  torch.save(embeds_2048, "embeds.pt")
38
  torch.save(pooled_768, "pooled.pt")
39
 
40
- return "OK", "embeds.pt", "pooled.pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import gradio as gr
4
  from transformers import AutoTokenizer, Qwen2ForCausalLM
5
 
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
+ # Qwen 0.5B = hidden_size 896
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
11
  text_encoder = Qwen2ForCausalLM.from_pretrained(
12
  "Qwen/Qwen2-0.5B",
13
  torch_dtype=dtype,
14
  )
15
 
16
+ # Projection 896 -> 2048 pour FLUX.1-Schnell
17
+ proj_tokens = nn.Linear(896, 2048)
 
18
 
19
+ # Projection pooled 2048 -> 768 (obligatoire pour Schnell)
20
+ proj_pooled = nn.Linear(2048, 768)
21
+
22
+ def encode(prompt: str):
23
+ if not prompt or prompt.strip() == "":
24
  prompt = tokenizer.eos_token or "."
25
 
26
+ tokens = tokenizer(
27
+ prompt,
28
+ return_tensors="pt",
29
+ padding=True,
30
+ truncation=True,
31
+ max_length=512,
32
+ )
33
 
34
  out = text_encoder(
35
  **tokens,
 
37
  use_cache=False,
38
  )
39
 
40
+ # Embeddings Qwen (896 dims)
41
+ embeds_896 = out.hidden_states[-1] # [1, L, 896]
42
 
43
+ # Projection -> 2048 dims
44
+ embeds_2048 = proj_tokens(embeds_896) # [1, L, 2048]
45
 
46
+ # pooled -> moyenne -> projection 768 dims
47
  pooled_2048 = embeds_2048.mean(dim=1) # [1, 2048]
48
  pooled_768 = proj_pooled(pooled_2048) # [1, 768]
49
 
50
+ # Sauvegarde
51
  torch.save(embeds_2048, "embeds.pt")
52
  torch.save(pooled_768, "pooled.pt")
53
 
54
+ return str(embeds_2048.shape), "embeds.pt", "pooled.pt"
55
+
56
+ demo = gr.Interface(
57
+ fn=encode,
58
+ inputs=gr.Textbox(label="Prompt"),
59
+ outputs=[
60
+ gr.Textbox(label="Shape"),
61
+ gr.File(label="Embeddings 2048"),
62
+ gr.File(label="Pooled 768")
63
+ ],
64
+ title="External Text Encoder — FLUX.1‑Schnell Compatible"
65
+ )
66
+
67
+ demo.launch()