File size: 4,197 Bytes
ebbf749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import time
import numpy as np
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

MODEL = "intfloat/multilingual-e5-base"

print("Loading tokenizer and model:", MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL)
model.to(device)
model.eval()

# simple mean pooling using attention mask
def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
    sum_mask = input_mask_expanded.sum(dim=1)
    # avoid division by zero
    sum_mask = torch.clamp(sum_mask, min=1e-9)
    return sum_embeddings / sum_mask


def embed_texts(texts, batch_size=8):
    single = False
    if isinstance(texts, str):
        texts = [texts]
        single = True

    all_embs = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
            input_ids = encoded["input_ids"].to(device)
            attention_mask = encoded["attention_mask"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            last_hidden = outputs.last_hidden_state
            pooled = mean_pooling(last_hidden, attention_mask)
            normed = F.normalize(pooled, p=2, dim=1)
            all_embs.append(normed.cpu().numpy())
    embs = np.vstack(all_embs)
    return embs[0] if single else embs


def cosine_similarity(a, b):
    # inputs are 1D arrays
    denom = (np.linalg.norm(a) * np.linalg.norm(b))
    if denom == 0:
        return 0.0
    return float(np.dot(a, b) / denom)


def similarity(s1, s2):
    e = embed_texts([s1, s2])
    score = cosine_similarity(e[0], e[1])
    return f"{score:.4f}"


def search(query, docs_text, topk=3):
    docs = [d.strip() for d in docs_text.splitlines() if d.strip()]
    if not docs:
        return "Corpus is empty"
    all_texts = docs + [query]
    embs = embed_texts(all_texts)
    D, q = embs[:-1], embs[-1]
    scores = (D @ q) / (np.linalg.norm(D, axis=1) * np.linalg.norm(q) + 1e-12)
    order = np.argsort(scores)[::-1][:int(topk)]
    lines = []
    for rank, idx in enumerate(order, start=1):
        lines.append(f"{rank}. score={scores[idx]:.4f}\n{docs[idx]}")
    return "\n\n".join(lines)


# Gradio UI

demo = gr.Blocks()
with demo:
    gr.Markdown("# multilingual-e5-base — multilingual embedding tester")
    with gr.Tab("Similarity"):
        t1 = gr.Textbox(label="Text 1", value="Hello world / Привет мир")
        t2 = gr.Textbox(label="Text 2", value="Greetings planet / Привет, планета")
        btn = gr.Button("Compute similarity")
        out = gr.Textbox(label="Cosine similarity", interactive=False)
        btn.click(fn=similarity, inputs=[t1, t2], outputs=out)

    with gr.Tab("Semantic search"):
        q = gr.Textbox(label="Query", value="climate change")
        corpus = gr.Textbox(label="Corpus (one document per line)", lines=12, value=(
            "Climate summit discussed emissions reductions.\n"
            "Local sports team won the championship.\n"
            "New research on climate change effects published.\n"
            "Economy grows despite challenges."))
        k = gr.Number(label="Top-K", value=3, precision=0)
        btn2 = gr.Button("Search")
        out2 = gr.Textbox(label="Results", lines=12)
        btn2.click(fn=search, inputs=[q, corpus, k], outputs=out2)

    gr.Markdown("---\nModel: intfloat/multilingual-e5-base — uses Transformers AutoModel; runs on GPU if available.")

# Launch
if __name__ == "__main__":
    # warmup: tokenize a small input to ensure weights moved to device
    try:
        _ = embed_texts(["Hello world"])  # warm cache
    except Exception as e:
        print("Warmup failed:", e)
    demo.launch(server_name="0.0.0.0", server_port=7860)