lea97338 commited on
Commit
28585f6
·
verified ·
1 Parent(s): d407e28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -1,16 +1,21 @@
1
  import torch
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, Qwen2ForCausalLM
4
 
5
  device = "cpu"
6
  dtype = torch.float32
7
 
8
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
 
9
  text_encoder = Qwen2ForCausalLM.from_pretrained(
10
- "Qwen/Qwen2-3B",
11
  torch_dtype=dtype,
12
  )
13
 
 
 
 
14
  def encode(prompt):
15
  tokens = tokenizer(
16
  prompt,
@@ -27,23 +32,30 @@ def encode(prompt):
27
  use_cache=False,
28
  )
29
 
30
- embeds = out.hidden_states[-1] # [1, L, 2560]
31
- pooled = embeds.mean(dim=1) # [1, 2560]
 
 
 
 
 
 
32
 
33
- torch.save(embeds, "embeds.pt")
 
34
  torch.save(pooled, "pooled.pt")
35
 
36
- return str(embeds.shape), "embeds.pt", "pooled.pt"
37
 
38
  demo = gr.Interface(
39
  fn=encode,
40
  inputs=gr.Textbox(label="Prompt"),
41
  outputs=[
42
  gr.Textbox(label="Shape"),
43
- gr.File(label="Embeddings 2560"),
44
- gr.File(label="Pooled 2560")
45
  ],
46
- title="External Text Encoder — 2560 dims"
47
  )
48
 
49
  demo.launch()
 
1
  import torch
2
+ import torch.nn as nn
3
  import gradio as gr
4
  from transformers import AutoTokenizer, Qwen2ForCausalLM
5
 
6
  device = "cpu"
7
  dtype = torch.float32
8
 
9
+ # Charger Qwen 0.5B (léger, CPU OK)
10
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
11
  text_encoder = Qwen2ForCausalLM.from_pretrained(
12
+ "Qwen/Qwen2-0.5B",
13
  torch_dtype=dtype,
14
  )
15
 
16
+ # Projection 1536 → 2048 (pour FLUX.1-Schnell)
17
+ proj = nn.Linear(1536, 2048)
18
+
19
  def encode(prompt):
20
  tokens = tokenizer(
21
  prompt,
 
32
  use_cache=False,
33
  )
34
 
35
+ # Embeddings Qwen 1536 dims
36
+ embeds_1536 = out.hidden_states[-1] # [1, L, 1536]
37
+
38
+ # Projection → 2048 dims
39
+ embeds_2048 = proj(embeds_1536) # [1, L, 2048]
40
+
41
+ # pooled → moyenne
42
+ pooled = embeds_2048.mean(dim=1) # [1, 2048]
43
 
44
+ # Sauvegarde
45
+ torch.save(embeds_2048, "embeds.pt")
46
  torch.save(pooled, "pooled.pt")
47
 
48
+ return str(embeds_2048.shape), "embeds.pt", "pooled.pt"
49
 
50
  demo = gr.Interface(
51
  fn=encode,
52
  inputs=gr.Textbox(label="Prompt"),
53
  outputs=[
54
  gr.Textbox(label="Shape"),
55
+ gr.File(label="Embeddings 2048"),
56
+ gr.File(label="Pooled 2048")
57
  ],
58
+ title="External Text Encoder — 2048 dims (FLUX.1‑Schnell)"
59
  )
60
 
61
  demo.launch()