qavit commited on
Commit
ff0961c
·
verified ·
1 Parent(s): 8ea5b54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from transformers import BertTokenizer, BertModel
4
+ import torch
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ # Load models for different methods
8
+ st_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
9
+ bert_model_name = "bert-base-chinese"
10
+ tokenizer = BertTokenizer.from_pretrained(bert_model_name)
11
+ bert_model = BertModel.from_pretrained(bert_model_name)
12
+
13
+ def calculate_similarity(method, sentence1, sentence2):
14
+ if method == "Sentence Transformers":
15
+ embedding1 = st_model.encode(sentence1, convert_to_tensor=True)
16
+ embedding2 = st_model.encode(sentence2, convert_to_tensor=True)
17
+ similarity = util.cos_sim(embedding1, embedding2).item()
18
+ elif method == "BERT CLS":
19
+ inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding=True)
20
+ inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding=True)
21
+ with torch.no_grad():
22
+ outputs1 = bert_model(**inputs1)
23
+ outputs2 = bert_model(**inputs2)
24
+ cls_embedding1 = outputs1.last_hidden_state[:, 0, :].numpy()
25
+ cls_embedding2 = outputs2.last_hidden_state[:, 0, :].numpy()
26
+ similarity = cosine_similarity(cls_embedding1, cls_embedding2)[0][0]
27
+ else:
28
+ similarity = "未選擇演算法"
29
+ return similarity
30
+
31
+ def load_example():
32
+ return "今天的天氣真好", "今天天氣非常晴朗"
33
+
34
+ # Gradio UI
35
+ def build_ui():
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("## 中文句子相似度計算 Demo")
38
+
39
+ with gr.Row():
40
+ sentence1_input = gr.Textbox(label="句子 1", placeholder="輸入第一個句子")
41
+ sentence2_input = gr.Textbox(label="句子 2", placeholder="輸入第二個句子")
42
+
43
+ method_selector = gr.Radio(choices=["Sentence Transformers", "BERT CLS"], label="選擇演算法")
44
+
45
+ similarity_output = gr.Textbox(label="相似度結果", interactive=False)
46
+
47
+ with gr.Row():
48
+ calculate_button = gr.Button("計算相似度")
49
+ example_button = gr.Button("填入預設句子")
50
+
51
+ calculate_button.click(calculate_similarity,
52
+ inputs=[method_selector, sentence1_input, sentence2_input],
53
+ outputs=similarity_output)
54
+
55
+ example_button.click(load_example,
56
+ inputs=[],
57
+ outputs=[sentence1_input, sentence2_input])
58
+
59
+ return demo
60
+
61
+ # Launch the app
62
+ demo = build_ui()
63
+ demo.launch()