import gradio as gr from sentence_transformers import SentenceTransformer, util from transformers import BertTokenizer, BertModel import torch from sklearn.metrics.pairwise import cosine_similarity # Load models for different methods st_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') bert_model_name = "bert-base-chinese" tokenizer = BertTokenizer.from_pretrained(bert_model_name) bert_model = BertModel.from_pretrained(bert_model_name) def calculate_similarity(method, sentence1, sentence2): if method == "Sentence Transformers": embedding1 = st_model.encode(sentence1, convert_to_tensor=True) embedding2 = st_model.encode(sentence2, convert_to_tensor=True) similarity = util.cos_sim(embedding1, embedding2).item() elif method == "BERT CLS": inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding=True) inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs1 = bert_model(**inputs1) outputs2 = bert_model(**inputs2) cls_embedding1 = outputs1.last_hidden_state[:, 0, :].numpy() cls_embedding2 = outputs2.last_hidden_state[:, 0, :].numpy() similarity = cosine_similarity(cls_embedding1, cls_embedding2)[0][0] else: similarity = "未選擇演算法" return similarity def load_example(): return "今天的天氣真好", "今天天氣非常晴朗" # Gradio UI def build_ui(): with gr.Blocks() as demo: gr.Markdown("## 中文句子相似度計算 Demo") with gr.Row(): sentence1_input = gr.Textbox(label="句子 1", placeholder="輸入第一個句子") sentence2_input = gr.Textbox(label="句子 2", placeholder="輸入第二個句子") method_selector = gr.Radio(choices=["Sentence Transformers", "BERT CLS"], label="選擇演算法") similarity_output = gr.Textbox(label="相似度結果", interactive=False) with gr.Row(): calculate_button = gr.Button("計算相似度") example_button = gr.Button("填入預設句子") calculate_button.click(calculate_similarity, inputs=[method_selector, sentence1_input, sentence2_input], outputs=similarity_output) example_button.click(load_example, inputs=[], outputs=[sentence1_input, sentence2_input]) return demo # Launch the app demo = build_ui() demo.launch()