File size: 1,673 Bytes
5bf1181
28585f6
4e93c10
aef300d
c014548
aef300d
 
 
4e93c10
28585f6
aef300d
28585f6
aef300d
 
3e5acc0
4e93c10
49afb79
 
 
 
 
 
28585f6
4e93c10
 
 
8cd5853
4510a31
4e93c10
 
 
 
 
 
 
4510a31
6548162
 
 
 
 
 
4e93c10
 
3e5acc0
4e93c10
 
6562a48
4e93c10
8cd5853
 
68c7eda
4e93c10
8cd5853
 
68c7eda
4e93c10
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import gradio as gr
from transformers import AutoTokenizer, Qwen2ForCausalLM

device = "cpu"
dtype = torch.float32

# Qwen 0.5B = hidden_size 896
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
text_encoder = Qwen2ForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B",
    torch_dtype=dtype,
)

# Projection 896 -> 2048 pour FLUX.1-Schnell
# Projection 896 -> 4096
proj_tokens = nn.Linear(896, 4096)

# Projection pooled 4096 -> 3072
proj_pooled = nn.Linear(4096, 3072)



def encode(prompt: str):
    if not prompt or prompt.strip() == "":
        prompt = tokenizer.eos_token or "."

    tokens = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512,
    )

    out = text_encoder(
        **tokens,
        output_hidden_states=True,
        use_cache=False,
    )

    # Embeddings Qwen (896 dims)
    embeds_896 = out.hidden_states[-1]  # [1, L, 896]

    # Projection -> 2048 dims
    embeds_2048 = proj_tokens(embeds_896)  # [1, L, 2048]

    # pooled -> moyenne -> projection 768 dims
    pooled_2048 = embeds_2048.mean(dim=1)  # [1, 2048]
    pooled_768 = proj_pooled(pooled_2048)  # [1, 768]

    # Sauvegarde
    torch.save(embeds_2048, "embeds.pt")
    torch.save(pooled_768, "pooled.pt")

    return str(embeds_2048.shape), "embeds.pt", "pooled.pt"

demo = gr.Interface(
    fn=encode,
    inputs=gr.Textbox(label="Prompt"),
    outputs=[
        gr.Textbox(label="Shape"),
        gr.File(label="Embeddings 2048"),
        gr.File(label="Pooled 768")
    ],
    title="External Text Encoder — FLUX.1‑Schnell Compatible"
)

demo.launch()