import torch import torch.nn as nn import gradio as gr from transformers import AutoTokenizer, Qwen2ForCausalLM device = "cpu" dtype = torch.float32 # Qwen 0.5B = hidden_size 896 tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") text_encoder = Qwen2ForCausalLM.from_pretrained( "Qwen/Qwen2-0.5B", torch_dtype=dtype, ) # Projection 896 -> 2048 pour FLUX.1-Schnell # Projection 896 -> 4096 proj_tokens = nn.Linear(896, 4096) # Projection pooled 4096 -> 3072 proj_pooled = nn.Linear(4096, 3072) def encode(prompt: str): if not prompt or prompt.strip() == "": prompt = tokenizer.eos_token or "." tokens = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=512, ) out = text_encoder( **tokens, output_hidden_states=True, use_cache=False, ) # Embeddings Qwen (896 dims) embeds_896 = out.hidden_states[-1] # [1, L, 896] # Projection -> 2048 dims embeds_2048 = proj_tokens(embeds_896) # [1, L, 2048] # pooled -> moyenne -> projection 768 dims pooled_2048 = embeds_2048.mean(dim=1) # [1, 2048] pooled_768 = proj_pooled(pooled_2048) # [1, 768] # Sauvegarde torch.save(embeds_2048, "embeds.pt") torch.save(pooled_768, "pooled.pt") return str(embeds_2048.shape), "embeds.pt", "pooled.pt" demo = gr.Interface( fn=encode, inputs=gr.Textbox(label="Prompt"), outputs=[ gr.Textbox(label="Shape"), gr.File(label="Embeddings 2048"), gr.File(label="Pooled 768") ], title="External Text Encoder — FLUX.1‑Schnell Compatible" ) demo.launch()