Spaces:
Runtime error
Runtime error
| import json | |
| import time | |
| import tempfile | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import spaces | |
| MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" | |
| ALLOWED_DIMS = [768, 512, 256, 128] | |
| MAX_DISPLAY_ROWS = 5 | |
| MAX_DISPLAY_VALUES = 16 | |
| # Load once at startup | |
| TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| MODEL = AutoModel.from_pretrained(MODEL_NAME) | |
| MODEL.eval() | |
| if torch.cuda.is_available(): | |
| MODEL = MODEL.to("cuda") | |
| def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray: | |
| norms = np.linalg.norm(x, axis=1, keepdims=True) | |
| return x / np.maximum(norms, eps) | |
| def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
| summed = torch.sum(last_hidden_state * mask, dim=1) | |
| counts = torch.clamp(mask.sum(dim=1), min=1e-9) | |
| return summed / counts | |
| def embed_texts(text_input: str, dimension: int) -> Tuple[str, str, str]: | |
| if dimension not in ALLOWED_DIMS: | |
| raise gr.Error(f"Dimension must be one of {ALLOWED_DIMS}.") | |
| texts = [line.strip() for line in (text_input or "").splitlines() if line.strip()] | |
| if not texts: | |
| raise gr.Error("Please provide at least one non-empty line of text.") | |
| start = time.perf_counter() | |
| with torch.no_grad(): | |
| batch = TOKENIZER( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=8192, | |
| return_tensors="pt", | |
| ) | |
| if torch.cuda.is_available(): | |
| batch = {k: v.to("cuda") for k, v in batch.items()} | |
| outputs = MODEL(**batch) | |
| pooled = _mean_pool(outputs.last_hidden_state, batch["attention_mask"]) | |
| full_embeddings = pooled.detach().cpu().numpy().astype(np.float32) | |
| # Normalize full vectors first | |
| full_embeddings = _l2_normalize(full_embeddings) | |
| # Matryoshka truncation and re-normalization | |
| truncated = full_embeddings[:, :dimension] | |
| truncated = _l2_normalize(truncated) | |
| elapsed = time.perf_counter() - start | |
| result_list: List[List[float]] = truncated.tolist() | |
| # Display is intentionally truncated for readability | |
| display_preview = [] | |
| for row in result_list[:MAX_DISPLAY_ROWS]: | |
| preview_row = row[:MAX_DISPLAY_VALUES] | |
| if len(row) > MAX_DISPLAY_VALUES: | |
| preview_row = preview_row + ["..."] | |
| display_preview.append(preview_row) | |
| preview_payload = { | |
| "preview_embeddings": display_preview, | |
| "shown_rows": min(len(result_list), MAX_DISPLAY_ROWS), | |
| "total_rows": len(result_list), | |
| "shown_values_per_row": MAX_DISPLAY_VALUES, | |
| "actual_dimension": dimension, | |
| "note": "Download the JSON file for full embeddings.", | |
| } | |
| stats = { | |
| "dimension": dimension, | |
| "count": len(result_list), | |
| "inference_time_seconds": round(elapsed, 4), | |
| } | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| "model": MODEL_NAME, | |
| "dimension": dimension, | |
| "count": len(result_list), | |
| "inference_time_seconds": elapsed, | |
| "embeddings": result_list, | |
| }, | |
| f, | |
| ensure_ascii=False, | |
| ) | |
| file_path = f.name | |
| return json.dumps(preview_payload, indent=2), json.dumps(stats, indent=2), file_path | |
| DESCRIPTION = """ | |
| Enter one text per line for batch embedding. | |
| Embeddings are L2-normalized, then Matryoshka-truncated to your selected dimension and re-normalized. | |
| """ | |
| FOOTER = ( | |
| "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ " | |
| "[AI Enablement Academy](https://enablement.academy) | " | |
| "[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)" | |
| ) | |
| with gr.Blocks(title="Text Embeddings - Qwen3 Embedding") as demo: | |
| gr.Markdown("# Text Embeddings - Qwen3 Embedding") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Input text (single or batch, one per line)", | |
| lines=10, | |
| placeholder="Type one sentence per line...", | |
| ) | |
| dimension = gr.Dropdown( | |
| choices=ALLOWED_DIMS, | |
| value=768, | |
| label="Embedding dimension (Matryoshka)", | |
| ) | |
| with gr.Row(): | |
| embed_btn = gr.Button("Generate Embeddings", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| embeddings_preview = gr.Code(label="Embeddings JSON Preview (truncated)", language="json") | |
| stats_output = gr.Code(label="Stats", language="json") | |
| download_file = gr.File(label="Download full embeddings JSON") | |
| embed_btn.click( | |
| fn=embed_texts, | |
| inputs=[text_input, dimension], | |
| outputs=[embeddings_preview, stats_output, download_file], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", None), | |
| outputs=[embeddings_preview, stats_output, download_file], | |
| ) | |
| gr.Markdown(FOOTER) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |