text-embeddings / app.py
xavier-fuentes's picture
Initial Space: Qwen3 text embeddings with Matryoshka dimensions
9e1a1ca verified
import json
import time
import tempfile
from typing import List, Tuple
import gradio as gr
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
import spaces
MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B"
ALLOWED_DIMS = [768, 512, 256, 128]
MAX_DISPLAY_ROWS = 5
MAX_DISPLAY_VALUES = 16
# Load once at startup
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
MODEL = AutoModel.from_pretrained(MODEL_NAME)
MODEL.eval()
if torch.cuda.is_available():
MODEL = MODEL.to("cuda")
def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True)
return x / np.maximum(norms, eps)
def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = torch.sum(last_hidden_state * mask, dim=1)
counts = torch.clamp(mask.sum(dim=1), min=1e-9)
return summed / counts
@spaces.GPU
def embed_texts(text_input: str, dimension: int) -> Tuple[str, str, str]:
if dimension not in ALLOWED_DIMS:
raise gr.Error(f"Dimension must be one of {ALLOWED_DIMS}.")
texts = [line.strip() for line in (text_input or "").splitlines() if line.strip()]
if not texts:
raise gr.Error("Please provide at least one non-empty line of text.")
start = time.perf_counter()
with torch.no_grad():
batch = TOKENIZER(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt",
)
if torch.cuda.is_available():
batch = {k: v.to("cuda") for k, v in batch.items()}
outputs = MODEL(**batch)
pooled = _mean_pool(outputs.last_hidden_state, batch["attention_mask"])
full_embeddings = pooled.detach().cpu().numpy().astype(np.float32)
# Normalize full vectors first
full_embeddings = _l2_normalize(full_embeddings)
# Matryoshka truncation and re-normalization
truncated = full_embeddings[:, :dimension]
truncated = _l2_normalize(truncated)
elapsed = time.perf_counter() - start
result_list: List[List[float]] = truncated.tolist()
# Display is intentionally truncated for readability
display_preview = []
for row in result_list[:MAX_DISPLAY_ROWS]:
preview_row = row[:MAX_DISPLAY_VALUES]
if len(row) > MAX_DISPLAY_VALUES:
preview_row = preview_row + ["..."]
display_preview.append(preview_row)
preview_payload = {
"preview_embeddings": display_preview,
"shown_rows": min(len(result_list), MAX_DISPLAY_ROWS),
"total_rows": len(result_list),
"shown_values_per_row": MAX_DISPLAY_VALUES,
"actual_dimension": dimension,
"note": "Download the JSON file for full embeddings.",
}
stats = {
"dimension": dimension,
"count": len(result_list),
"inference_time_seconds": round(elapsed, 4),
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
json.dump(
{
"model": MODEL_NAME,
"dimension": dimension,
"count": len(result_list),
"inference_time_seconds": elapsed,
"embeddings": result_list,
},
f,
ensure_ascii=False,
)
file_path = f.name
return json.dumps(preview_payload, indent=2), json.dumps(stats, indent=2), file_path
DESCRIPTION = """
Enter one text per line for batch embedding.
Embeddings are L2-normalized, then Matryoshka-truncated to your selected dimension and re-normalized.
"""
FOOTER = (
"Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ "
"[AI Enablement Academy](https://enablement.academy) | "
"[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)"
)
with gr.Blocks(title="Text Embeddings - Qwen3 Embedding") as demo:
gr.Markdown("# Text Embeddings - Qwen3 Embedding")
gr.Markdown(DESCRIPTION)
with gr.Row():
text_input = gr.Textbox(
label="Input text (single or batch, one per line)",
lines=10,
placeholder="Type one sentence per line...",
)
dimension = gr.Dropdown(
choices=ALLOWED_DIMS,
value=768,
label="Embedding dimension (Matryoshka)",
)
with gr.Row():
embed_btn = gr.Button("Generate Embeddings", variant="primary")
clear_btn = gr.Button("Clear")
embeddings_preview = gr.Code(label="Embeddings JSON Preview (truncated)", language="json")
stats_output = gr.Code(label="Stats", language="json")
download_file = gr.File(label="Download full embeddings JSON")
embed_btn.click(
fn=embed_texts,
inputs=[text_input, dimension],
outputs=[embeddings_preview, stats_output, download_file],
)
clear_btn.click(
fn=lambda: ("", "", None),
outputs=[embeddings_preview, stats_output, download_file],
)
gr.Markdown(FOOTER)
if __name__ == "__main__":
demo.queue().launch()