jgrosjean commited on
Commit
556bbce
·
verified ·
1 Parent(s): 7089e08

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ # Load swissBERT for sentence embeddings model
8
+ model_name = "jgrosjean-mathesis/sentence-swissbert"
9
+ model = AutoModel.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate_sentence_embedding(sentence, language):
13
+
14
+ # Set adapter to specified language
15
+ if "de" in language:
16
+ model.set_default_language("de_CH")
17
+ if "fr" in language:
18
+ model.set_default_language("fr_CH")
19
+ if "it" in language:
20
+ model.set_default_language("it_CH")
21
+ if "rm" in language:
22
+ model.set_default_language("rm_CH")
23
+
24
+ # Tokenize input sentence
25
+ inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
26
+
27
+ # Take tokenized input and pass it through the model
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+
31
+ # Extract sentence embeddings via mean pooling
32
+ token_embeddings = outputs.last_hidden_state
33
+ attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
34
+ sum_embeddings = torch.sum(token_embeddings * attention_mask, 1)
35
+ sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
36
+ embedding = sum_embeddings / sum_mask
37
+ return embedding
38
+
39
+ def calculate_cosine_similarities(source_sentence, source_language, target_sentence_1, target_language_1, target_sentence_2, target_language_2, target_sentence_3, target_language_3):
40
+
41
+ source_embedding = generate_sentence_embedding(source_sentence, source_language)
42
+ target_embedding_1 = generate_sentence_embedding(target_sentence_1, target_language_1)
43
+ target_embedding_2 = generate_sentence_embedding(target_sentence_2, target_language_2)
44
+ target_embedding_3 = generate_sentence_embedding(target_sentence_3, target_language_3)
45
+
46
+ cosine_score_1 = cosine_similarity(source_embedding, target_embedding_1)
47
+ cosine_score_2 = cosine_similarity(source_embedding, target_embedding_2)
48
+ cosine_score_3 = cosine_similarity(source_embedding, target_embedding_3)
49
+
50
+ cosine_scores = {
51
+ target_sentence_1: cosine_score_1[0][0],
52
+ target_sentence_2: cosine_score_2[0][0],
53
+ target_sentence_3: cosine_score_3[0][0]
54
+ }
55
+ cosine_scores_dict = dict(sorted(cosine_scores.items(), key=lambda item: item[1], reverse=True))
56
+ cosine_scores_output = ""
57
+ for key, value in cosine_scores_dict.items():
58
+ cosine_scores_output += key + ": " + str(value) + "\n"
59
+ cosine_scores_output = "**" + cosine_scores_output.replace("\n", "**\n", 1)
60
+ return cosine_scores_output
61
+
62
+ def main():
63
+ demo = gr.Interface(
64
+ fn=calculate_cosine_similarities,
65
+ inputs=[
66
+ gr.Textbox(lines=1, placeholder="Der Zug fährt um 9 Uhr in Zürich ab.", label="source sentence"),
67
+ gr.Dropdown(["de", "fr", "it", "rm"], value="de", label="language"),
68
+ gr.Textbox(lines=1, placeholder="Le train arrive à Lausanne à 11 heures.", label="target sentence 1"),
69
+ gr.Dropdown(["de", "fr", "it", "rm"], value="fr", label="language"),
70
+ gr.Textbox(lines=1, placeholder="Alla stazione di Lugano ci sono diversi binari.", label="target sentence 2"),
71
+ gr.Dropdown(["de", "fr", "it", "rm"], value="it", label="language"),
72
+ gr.Textbox(lines=1, placeholder="A Cuera van biars trens ellas muntognas.", label="target sentence 3"),
73
+ gr.Dropdown(["de", "fr", "it", "rm"], value="rm", label="language")
74
+ ],
75
+ outputs= gr.Textbox(label="Cosine similarity scores", type="text", lines=3)
76
+ )
77
+ demo.launch(share=True)
78
+
79
+ if __name__ == "__main__":
80
+ main()