vadkos12 commited on
Commit
ebbf749
·
verified ·
1 Parent(s): c8b7c9c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ MODEL = "intfloat/multilingual-e5-base"
10
+
11
+ print("Loading tokenizer and model:", MODEL)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print("Device:", device)
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
16
+ model = AutoModel.from_pretrained(MODEL)
17
+ model.to(device)
18
+ model.eval()
19
+
20
+ # simple mean pooling using attention mask
21
+ def mean_pooling(last_hidden_state, attention_mask):
22
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
23
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
24
+ sum_mask = input_mask_expanded.sum(dim=1)
25
+ # avoid division by zero
26
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
27
+ return sum_embeddings / sum_mask
28
+
29
+
30
+ def embed_texts(texts, batch_size=8):
31
+ single = False
32
+ if isinstance(texts, str):
33
+ texts = [texts]
34
+ single = True
35
+
36
+ all_embs = []
37
+ with torch.no_grad():
38
+ for i in range(0, len(texts), batch_size):
39
+ batch = texts[i : i + batch_size]
40
+ encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
41
+ input_ids = encoded["input_ids"].to(device)
42
+ attention_mask = encoded["attention_mask"].to(device)
43
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
44
+ last_hidden = outputs.last_hidden_state
45
+ pooled = mean_pooling(last_hidden, attention_mask)
46
+ normed = F.normalize(pooled, p=2, dim=1)
47
+ all_embs.append(normed.cpu().numpy())
48
+ embs = np.vstack(all_embs)
49
+ return embs[0] if single else embs
50
+
51
+
52
+ def cosine_similarity(a, b):
53
+ # inputs are 1D arrays
54
+ denom = (np.linalg.norm(a) * np.linalg.norm(b))
55
+ if denom == 0:
56
+ return 0.0
57
+ return float(np.dot(a, b) / denom)
58
+
59
+
60
+ def similarity(s1, s2):
61
+ e = embed_texts([s1, s2])
62
+ score = cosine_similarity(e[0], e[1])
63
+ return f"{score:.4f}"
64
+
65
+
66
+ def search(query, docs_text, topk=3):
67
+ docs = [d.strip() for d in docs_text.splitlines() if d.strip()]
68
+ if not docs:
69
+ return "Corpus is empty"
70
+ all_texts = docs + [query]
71
+ embs = embed_texts(all_texts)
72
+ D, q = embs[:-1], embs[-1]
73
+ scores = (D @ q) / (np.linalg.norm(D, axis=1) * np.linalg.norm(q) + 1e-12)
74
+ order = np.argsort(scores)[::-1][:int(topk)]
75
+ lines = []
76
+ for rank, idx in enumerate(order, start=1):
77
+ lines.append(f"{rank}. score={scores[idx]:.4f}\n{docs[idx]}")
78
+ return "\n\n".join(lines)
79
+
80
+
81
+ # Gradio UI
82
+
83
+ demo = gr.Blocks()
84
+ with demo:
85
+ gr.Markdown("# multilingual-e5-base — multilingual embedding tester")
86
+ with gr.Tab("Similarity"):
87
+ t1 = gr.Textbox(label="Text 1", value="Hello world / Привет мир")
88
+ t2 = gr.Textbox(label="Text 2", value="Greetings planet / Привет, планета")
89
+ btn = gr.Button("Compute similarity")
90
+ out = gr.Textbox(label="Cosine similarity", interactive=False)
91
+ btn.click(fn=similarity, inputs=[t1, t2], outputs=out)
92
+
93
+ with gr.Tab("Semantic search"):
94
+ q = gr.Textbox(label="Query", value="climate change")
95
+ corpus = gr.Textbox(label="Corpus (one document per line)", lines=12, value=(
96
+ "Climate summit discussed emissions reductions.\n"
97
+ "Local sports team won the championship.\n"
98
+ "New research on climate change effects published.\n"
99
+ "Economy grows despite challenges."))
100
+ k = gr.Number(label="Top-K", value=3, precision=0)
101
+ btn2 = gr.Button("Search")
102
+ out2 = gr.Textbox(label="Results", lines=12)
103
+ btn2.click(fn=search, inputs=[q, corpus, k], outputs=out2)
104
+
105
+ gr.Markdown("---\nModel: intfloat/multilingual-e5-base — uses Transformers AutoModel; runs on GPU if available.")
106
+
107
+ # Launch
108
+ if __name__ == "__main__":
109
+ # warmup: tokenize a small input to ensure weights moved to device
110
+ try:
111
+ _ = embed_texts(["Hello world"]) # warm cache
112
+ except Exception as e:
113
+ print("Warmup failed:", e)
114
+ demo.launch(server_name="0.0.0.0", server_port=7860)