Encoder / app.py
lea97338's picture
Update app.py
49afb79 verified
import torch
import torch.nn as nn
import gradio as gr
from transformers import AutoTokenizer, Qwen2ForCausalLM
device = "cpu"
dtype = torch.float32
# Qwen 0.5B = hidden_size 896
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
text_encoder = Qwen2ForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
torch_dtype=dtype,
)
# Projection 896 -> 2048 pour FLUX.1-Schnell
# Projection 896 -> 4096
proj_tokens = nn.Linear(896, 4096)
# Projection pooled 4096 -> 3072
proj_pooled = nn.Linear(4096, 3072)
def encode(prompt: str):
if not prompt or prompt.strip() == "":
prompt = tokenizer.eos_token or "."
tokens = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
out = text_encoder(
**tokens,
output_hidden_states=True,
use_cache=False,
)
# Embeddings Qwen (896 dims)
embeds_896 = out.hidden_states[-1] # [1, L, 896]
# Projection -> 2048 dims
embeds_2048 = proj_tokens(embeds_896) # [1, L, 2048]
# pooled -> moyenne -> projection 768 dims
pooled_2048 = embeds_2048.mean(dim=1) # [1, 2048]
pooled_768 = proj_pooled(pooled_2048) # [1, 768]
# Sauvegarde
torch.save(embeds_2048, "embeds.pt")
torch.save(pooled_768, "pooled.pt")
return str(embeds_2048.shape), "embeds.pt", "pooled.pt"
demo = gr.Interface(
fn=encode,
inputs=gr.Textbox(label="Prompt"),
outputs=[
gr.Textbox(label="Shape"),
gr.File(label="Embeddings 2048"),
gr.File(label="Pooled 768")
],
title="External Text Encoder — FLUX.1‑Schnell Compatible"
)
demo.launch()