File size: 1,673 Bytes
5bf1181 28585f6 4e93c10 aef300d c014548 aef300d 4e93c10 28585f6 aef300d 28585f6 aef300d 3e5acc0 4e93c10 49afb79 28585f6 4e93c10 8cd5853 4510a31 4e93c10 4510a31 6548162 4e93c10 3e5acc0 4e93c10 6562a48 4e93c10 8cd5853 68c7eda 4e93c10 8cd5853 68c7eda 4e93c10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import torch
import torch.nn as nn
import gradio as gr
from transformers import AutoTokenizer, Qwen2ForCausalLM
device = "cpu"
dtype = torch.float32
# Qwen 0.5B = hidden_size 896
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
text_encoder = Qwen2ForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
torch_dtype=dtype,
)
# Projection 896 -> 2048 pour FLUX.1-Schnell
# Projection 896 -> 4096
proj_tokens = nn.Linear(896, 4096)
# Projection pooled 4096 -> 3072
proj_pooled = nn.Linear(4096, 3072)
def encode(prompt: str):
if not prompt or prompt.strip() == "":
prompt = tokenizer.eos_token or "."
tokens = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
out = text_encoder(
**tokens,
output_hidden_states=True,
use_cache=False,
)
# Embeddings Qwen (896 dims)
embeds_896 = out.hidden_states[-1] # [1, L, 896]
# Projection -> 2048 dims
embeds_2048 = proj_tokens(embeds_896) # [1, L, 2048]
# pooled -> moyenne -> projection 768 dims
pooled_2048 = embeds_2048.mean(dim=1) # [1, 2048]
pooled_768 = proj_pooled(pooled_2048) # [1, 768]
# Sauvegarde
torch.save(embeds_2048, "embeds.pt")
torch.save(pooled_768, "pooled.pt")
return str(embeds_2048.shape), "embeds.pt", "pooled.pt"
demo = gr.Interface(
fn=encode,
inputs=gr.Textbox(label="Prompt"),
outputs=[
gr.Textbox(label="Shape"),
gr.File(label="Embeddings 2048"),
gr.File(label="Pooled 768")
],
title="External Text Encoder — FLUX.1‑Schnell Compatible"
)
demo.launch()
|