Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| MODEL = "intfloat/multilingual-e5-base" | |
| print("Loading tokenizer and model:", MODEL) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Device:", device) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
| model = AutoModel.from_pretrained(MODEL) | |
| model.to(device) | |
| model.eval() | |
| # simple mean pooling using attention mask | |
| def mean_pooling(last_hidden_state, attention_mask): | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
| sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, dim=1) | |
| sum_mask = input_mask_expanded.sum(dim=1) | |
| # avoid division by zero | |
| sum_mask = torch.clamp(sum_mask, min=1e-9) | |
| return sum_embeddings / sum_mask | |
| def embed_texts(texts, batch_size=8): | |
| single = False | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| single = True | |
| all_embs = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt") | |
| input_ids = encoded["input_ids"].to(device) | |
| attention_mask = encoded["attention_mask"].to(device) | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) | |
| last_hidden = outputs.last_hidden_state | |
| pooled = mean_pooling(last_hidden, attention_mask) | |
| normed = F.normalize(pooled, p=2, dim=1) | |
| all_embs.append(normed.cpu().numpy()) | |
| embs = np.vstack(all_embs) | |
| return embs[0] if single else embs | |
| def cosine_similarity(a, b): | |
| # inputs are 1D arrays | |
| denom = (np.linalg.norm(a) * np.linalg.norm(b)) | |
| if denom == 0: | |
| return 0.0 | |
| return float(np.dot(a, b) / denom) | |
| def similarity(s1, s2): | |
| e = embed_texts([s1, s2]) | |
| score = cosine_similarity(e[0], e[1]) | |
| return f"{score:.4f}" | |
| def search(query, docs_text, topk=3): | |
| docs = [d.strip() for d in docs_text.splitlines() if d.strip()] | |
| if not docs: | |
| return "Corpus is empty" | |
| all_texts = docs + [query] | |
| embs = embed_texts(all_texts) | |
| D, q = embs[:-1], embs[-1] | |
| scores = (D @ q) / (np.linalg.norm(D, axis=1) * np.linalg.norm(q) + 1e-12) | |
| order = np.argsort(scores)[::-1][:int(topk)] | |
| lines = [] | |
| for rank, idx in enumerate(order, start=1): | |
| lines.append(f"{rank}. score={scores[idx]:.4f}\n{docs[idx]}") | |
| return "\n\n".join(lines) | |
| # Gradio UI | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# multilingual-e5-base — multilingual embedding tester") | |
| with gr.Tab("Similarity"): | |
| t1 = gr.Textbox(label="Text 1", value="Hello world / Привет мир") | |
| t2 = gr.Textbox(label="Text 2", value="Greetings planet / Привет, планета") | |
| btn = gr.Button("Compute similarity") | |
| out = gr.Textbox(label="Cosine similarity", interactive=False) | |
| btn.click(fn=similarity, inputs=[t1, t2], outputs=out) | |
| with gr.Tab("Semantic search"): | |
| q = gr.Textbox(label="Query", value="climate change") | |
| corpus = gr.Textbox(label="Corpus (one document per line)", lines=12, value=( | |
| "Climate summit discussed emissions reductions.\n" | |
| "Local sports team won the championship.\n" | |
| "New research on climate change effects published.\n" | |
| "Economy grows despite challenges.")) | |
| k = gr.Number(label="Top-K", value=3, precision=0) | |
| btn2 = gr.Button("Search") | |
| out2 = gr.Textbox(label="Results", lines=12) | |
| btn2.click(fn=search, inputs=[q, corpus, k], outputs=out2) | |
| gr.Markdown("---\nModel: intfloat/multilingual-e5-base — uses Transformers AutoModel; runs on GPU if available.") | |
| # Launch | |
| if __name__ == "__main__": | |
| # warmup: tokenize a small input to ensure weights moved to device | |
| try: | |
| _ = embed_texts(["Hello world"]) # warm cache | |
| except Exception as e: | |
| print("Warmup failed:", e) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |