| import torch |
| import torch.nn as nn |
| import gradio as gr |
| from transformers import AutoTokenizer, Qwen2ForCausalLM |
|
|
| device = "cpu" |
| dtype = torch.float32 |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") |
| text_encoder = Qwen2ForCausalLM.from_pretrained( |
| "Qwen/Qwen2-0.5B", |
| torch_dtype=dtype, |
| ) |
|
|
| |
| |
| proj_tokens = nn.Linear(896, 4096) |
|
|
| |
| proj_pooled = nn.Linear(4096, 3072) |
|
|
|
|
|
|
| def encode(prompt: str): |
| if not prompt or prompt.strip() == "": |
| prompt = tokenizer.eos_token or "." |
|
|
| tokens = tokenizer( |
| prompt, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512, |
| ) |
|
|
| out = text_encoder( |
| **tokens, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
|
|
| |
| embeds_896 = out.hidden_states[-1] |
|
|
| |
| embeds_2048 = proj_tokens(embeds_896) |
|
|
| |
| pooled_2048 = embeds_2048.mean(dim=1) |
| pooled_768 = proj_pooled(pooled_2048) |
|
|
| |
| torch.save(embeds_2048, "embeds.pt") |
| torch.save(pooled_768, "pooled.pt") |
|
|
| return str(embeds_2048.shape), "embeds.pt", "pooled.pt" |
|
|
| demo = gr.Interface( |
| fn=encode, |
| inputs=gr.Textbox(label="Prompt"), |
| outputs=[ |
| gr.Textbox(label="Shape"), |
| gr.File(label="Embeddings 2048"), |
| gr.File(label="Pooled 768") |
| ], |
| title="External Text Encoder — FLUX.1‑Schnell Compatible" |
| ) |
|
|
| demo.launch() |
|
|