Gabriela Nicole Gonzalez Saez commited on
Commit
077a6f7
·
1 Parent(s): 45a34ac

Add app files

Browse files
Files changed (3) hide show
  1. app.py +723 -0
  2. plotsjs.js +744 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from time import time
3
+
4
+ import torch
5
+ import os
6
+ # import nltk
7
+ import argparse
8
+ import random
9
+ import numpy as np
10
+ import faiss
11
+ from argparse import Namespace
12
+ from tqdm.notebook import tqdm
13
+ from torch.utils.data import DataLoader
14
+ from functools import partial
15
+ from sklearn.manifold import TSNE
16
+
17
+ from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
18
+
19
+ metadata_all = {}
20
+ model_es = "Helsinki-NLP/opus-mt-en-es"
21
+ model_fr = "Helsinki-NLP/opus-mt-en-fr"
22
+ model_zh = "Helsinki-NLP/opus-mt-en-zh"
23
+
24
+ tokenizer_es = AutoTokenizer.from_pretrained(model_es)
25
+ tokenizer_fr = AutoTokenizer.from_pretrained(model_fr)
26
+ tokenizer_zh = AutoTokenizer.from_pretrained(model_zh)
27
+
28
+ model_tr_es = MarianMTModel.from_pretrained(model_es)
29
+ model_tr_fr = MarianMTModel.from_pretrained(model_fr)
30
+ model_tr_zh = MarianMTModel.from_pretrained(model_zh)
31
+
32
+ dict_models = {
33
+ 'en-es': model_es,
34
+ 'en-fr': model_fr,
35
+ 'en-zh': model_zh,
36
+ }
37
+
38
+ dict_models_tr = {
39
+ 'en-es': model_tr_es,
40
+ 'en-fr': model_tr_fr,
41
+ 'en-zh': model_tr_zh,
42
+ }
43
+
44
+ dict_tokenizer_tr = {
45
+ 'en-es': tokenizer_es,
46
+ 'en-fr': tokenizer_fr,
47
+ 'en-zh': tokenizer_zh,
48
+ }
49
+
50
+
51
+ def translation_model(w1,model ):
52
+ inputs = dict_tokenizer_tr[model](w1, return_tensors="pt")
53
+ # embeddings = get_tokens_embeddings(inputs, model)
54
+ input_embeddings = dict_models_tr[model].get_encoder().embed_tokens(inputs.input_ids)
55
+ # model_tr_es.get_input_embeddings()
56
+ print(inputs)
57
+ num_ret_seq = 1
58
+ translated = dict_models_tr[model].generate(**inputs,
59
+ num_beams=5,
60
+ num_return_sequences=num_ret_seq,
61
+ return_dict_in_generate=True,
62
+ output_attentions =False,
63
+ output_hidden_states = True,
64
+ output_scores=True,)
65
+
66
+ tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
67
+
68
+ target_embeddings = dict_models_tr[model].get_decoder().embed_tokens(translated.sequences)
69
+
70
+ return tgt_text, translated, inputs.input_ids, input_embeddings, target_embeddings
71
+
72
+ def create_vocab_multiple(embeddings_list, model):
73
+ """_summary_
74
+
75
+ Args:
76
+ embeddings_list (list): embedding array
77
+
78
+ Returns:
79
+ Dict: vocabulary of tokens' embeddings
80
+ """
81
+ print("START VOCAB CREATION MULTIPLE \n \n ")
82
+ vocab = {} ## add embedds.
83
+ sentence_tokens_text_list = []
84
+ for embeddings in embeddings_list:
85
+ tokens_id = embeddings['tokens'] # [[tokens_id]x n_sentences ]
86
+ for sent_i, sentence in enumerate(tokens_id):
87
+ sentence_tokens = []
88
+ for tok_i, token in enumerate(sentence):
89
+ sentence_tokens.append(token)
90
+ if not (token in vocab):
91
+ vocab[token] = {
92
+ 'token' : token,
93
+ 'count': 1,
94
+ # 'text': embeddings['texts'][sent_i][tok_i],
95
+ 'text': dict_tokenizer_tr[model].decode([token]),
96
+ # 'text': src_token_lists[sent_i][tok_i],
97
+ 'embed': embeddings['embeddings'][sent_i][tok_i]}
98
+ else:
99
+ vocab[token]['count'] = vocab[token]['count'] + 1
100
+ # print(vocab)
101
+ sentence_tokens_text_list.append(sentence_tokens)
102
+ print("END VOCAB CREATION MULTIPLE \n \n ")
103
+ return vocab, sentence_tokens_text_list
104
+
105
+ def vocab_words_all_prefix(token_embeddings, model, sufix="@@",prefix = '▁' ):
106
+ vocab = {}
107
+ # inf_model = dict_models_tr[model]
108
+ sentence_words_text_list = []
109
+ if prefix :
110
+ n_prefix = len(prefix)
111
+ for input_sentences in token_embeddings:
112
+ # n_tokens_in_word
113
+ for sent_i, sentence in enumerate(input_sentences['tokens']):
114
+ words_text_list = []
115
+ # embedding = input_sentences['embed'][sent_i]
116
+ word = ''
117
+ tokens_ids = []
118
+ embeddings = []
119
+ ids_to_tokens = dict_tokenizer_tr[model].convert_ids_to_tokens(sentence)
120
+ # print("validate same len", len(sentence) == len(ids_to_tokens), len(sentence), len(ids_to_tokens), ids_to_tokens)
121
+
122
+ to_save= False
123
+ for tok_i, token_text in enumerate(ids_to_tokens):
124
+ token_id = sentence[tok_i]
125
+ if token_text[:n_prefix] == prefix :
126
+ #first we save the previous word
127
+ if to_save:
128
+ vocab[word] = {
129
+ 'word' : word,
130
+ 'text': word,
131
+ 'count': 1,
132
+ 'tokens_ids' : tokens_ids,
133
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
134
+ }
135
+ words_text_list.append(word)
136
+ #word is starting if prefix
137
+ tokens_ids = [token_id]
138
+ embeddings = [input_sentences['embeddings'][sent_i][tok_i]]
139
+ word = token_text[n_prefix:]
140
+ ## if word
141
+ to_save = True
142
+
143
+ else :
144
+ if (token_text in dict_tokenizer_tr[model].special_tokens_map.values()):
145
+ # print('final or save', token_text, token_id, to_save, word)
146
+ if to_save:
147
+ # vocab[word] = ids
148
+ vocab[word] = {
149
+ 'word' : word,
150
+ 'text': word,
151
+ 'count': 1,
152
+ 'tokens_ids' : tokens_ids,
153
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
154
+ }
155
+ words_text_list.append(word)
156
+ #special token is one token element, no continuation
157
+ # vocab[token_text] = [token_id]
158
+ tokens_ids = [token_id]
159
+ embeddings = [input_sentences['embeddings'][sent_i][tok_i]]
160
+ vocab[token_text] = {
161
+ 'word' : token_text,
162
+ 'count': 1,
163
+ 'text': word,
164
+ 'tokens_ids' : tokens_ids,
165
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
166
+ }
167
+ words_text_list.append(token_text)
168
+ to_save = False
169
+ else:
170
+ # is a continuation; we do not know if it is final; we don't save here.
171
+ to_save = True
172
+ word += token_text
173
+ tokens_ids.append(token_id)
174
+ embeddings.append(input_sentences['embeddings'][sent_i][tok_i])
175
+ if to_save:
176
+ # print('final save', token_text, token_id, to_save, word)
177
+ vocab[word] = tokens_ids
178
+ if not (word in vocab):
179
+ vocab[word] = {
180
+ 'word' : word,
181
+ 'count': 1,
182
+ 'text': word,
183
+ 'tokens_ids' : tokens_ids,
184
+ 'embed': np.mean(np.array(embeddings), 0).tolist()
185
+ }
186
+ words_text_list.append(word)
187
+ else:
188
+ vocab[word]['count'] = vocab[word]['count'] + 1
189
+ sentence_words_text_list.append(words_text_list)
190
+
191
+ return vocab, sentence_words_text_list
192
+
193
+ # nb_ids.append(token_values['token']) # for x in vocab_tokens]
194
+ # nb_embds.append(token_values['embed']) # for x in vocab_tokens]
195
+
196
+ def create_index_voronoi(vocab):
197
+ """
198
+ it returns an index of words and a metadata of ids.
199
+ """
200
+ d = 1024
201
+ nb_embds = [] ##ordered embeddings list
202
+ metadata = {}
203
+ i_pos = 0
204
+ for key_token, token_values in vocab.items():
205
+ nb_embds.append(token_values['embed']) # for x in vocab_tokens]
206
+ metadata[i_pos] = {'token': token_values['token'], 'text': token_values['text']}
207
+ i_pos += 1
208
+ # nb_embds = [x['embed'] for x in vocab_tokens]
209
+
210
+ # print(len(nb_embds),len(nb_embds[0]) )
211
+ xb = np.array(nb_embds).astype('float32') #elements to index
212
+ # ids = np.array(nb_ids)
213
+ d = len(xb[0]) # dimension of each element
214
+
215
+ nlist = 5 # Nb of Voronois
216
+ quantizer = faiss.IndexFlatL2(d)
217
+ index = faiss.IndexIVFFlat(quantizer, d, nlist)
218
+ index.train(xb)
219
+ index.add(xb)
220
+ # index.add(xb)
221
+
222
+ return index, metadata## , nb_embds, nb_ids
223
+
224
+ def create_index_voronoi_words(vocab):
225
+ """
226
+ it returns an index of words and a metadata of ids.
227
+ """
228
+ d = 1024
229
+ nb_embds = [] ##ordered embeddings list
230
+ metadata = {}
231
+ i_pos = 0
232
+ for key_token, token_values in vocab.items():
233
+ nb_embds.append(token_values['embed']) # for x in vocab_tokens]
234
+ metadata[i_pos] = {'word': token_values['word'], 'tokens': token_values['tokens_ids'],'text': token_values['text']}
235
+ i_pos += 1
236
+ # nb_embds = [x['embed'] for x in vocab_tokens]
237
+
238
+ # print(len(nb_embds),len(nb_embds[0]) )
239
+ xb = np.array(nb_embds).astype('float32') #elements to index
240
+ # ids = np.array(nb_ids)
241
+ d = len(xb[0]) # dimension of each element
242
+
243
+ nlist = 5 # Nb of Voronois
244
+ quantizer = faiss.IndexFlatL2(d)
245
+ index = faiss.IndexIVFFlat(quantizer, d, nlist)
246
+ index.train(xb)
247
+ index.add(xb)
248
+ # index.add(xb)
249
+
250
+ return index, metadata## , nb_embds, nb_ids
251
+
252
+ def search_query_vocab(index, vocab_queries, topk = 10, limited_search = []):
253
+ """ the embed queries are a vocabulary of words : embds_input_voc
254
+
255
+ Args:
256
+ index (_type_): faiss index
257
+ embed_queries (_type_): vocab format.
258
+ { 'token' : token,
259
+ 'count': 1,
260
+ 'text': src_token_lists[sent_i][tok_i],
261
+ 'embed': embeddings[0]['embeddings'][sent_i][tok_i] }
262
+ nb_ids (_type_): hash to find the token_id w.r.t the faiss index id.
263
+ topk (int, optional): nb of similar tokens. Defaults to 10.
264
+
265
+ Returns:
266
+ _type_: Distance matrix D, indices matrix I and tokens ids (using nb_ids)
267
+ """
268
+ # nb_qi_ids = [] ##ordered ids list
269
+ nb_q_embds = [] ##ordered embeddings list
270
+ metadata = {}
271
+ qi_pos = 0
272
+ for key , token_values in vocab_queries.items():
273
+ # nb_qi_ids.append(token_values['token']) # for x in vocab_tokens]
274
+ metadata[qi_pos] = {'word': token_values['word'], 'tokens': token_values['tokens_ids'], 'text': token_values['text']}
275
+ qi_pos += 1
276
+ nb_q_embds.append(token_values['embed']) # for x in vocab_tokens]
277
+
278
+ xq = np.array(nb_q_embds).astype('float32') #elements to query
279
+
280
+ D,I = index.search(xq, topk)
281
+
282
+ return D,I, metadata
283
+
284
+ def search_query_vocab_token(index, vocab_queries, topk = 10, limited_search = []):
285
+ """ the embed queries are a vocabulary of words : embds_input_vov
286
+ Returns:
287
+ _type_: Distance matrix D, indices matrix I and tokens ids (using nb_ids)
288
+ """
289
+ # nb_qi_ids = [] ##ordered ids list
290
+ nb_q_embds = [] ##ordered embeddings list
291
+ metadata = {}
292
+ qi_pos = 0
293
+ for key , token_values in vocab_queries.items():
294
+ # nb_qi_ids.append(token_values['token']) # for x in vocab_tokens]
295
+ metadata[qi_pos] = {'token': token_values['token'], 'text': token_values['text']}
296
+ qi_pos += 1
297
+ nb_q_embds.append(token_values['embed']) # for x in vocab_tokens]
298
+
299
+ xq = np.array(nb_q_embds).astype('float32') #elements to query
300
+
301
+ D,I = index.search(xq, topk)
302
+
303
+ return D,I, metadata
304
+
305
+ def build_search(query_embeddings, model,type="input"):
306
+ global metadata_all
307
+
308
+ # ## biuld vocab for index
309
+ vocab_queries, sentence_tokens_list = create_vocab_multiple(query_embeddings, model)
310
+ words_vocab_queries, sentence_words_list = vocab_words_all_prefix(query_embeddings, model, sufix="@@",prefix="▁")
311
+
312
+ index_vor_tokens = metadata_all[type]['tokens'][1]
313
+ md_tokens = metadata_all[type]['tokens'][2]
314
+ D, I, meta = search_query_vocab_token(index_vor_tokens, vocab_queries)
315
+
316
+ qi_pos = 0
317
+ similar_tokens = {}
318
+ # similar_tokens = []
319
+ for dist, ind in zip(D,I):
320
+ try:
321
+ # similar_tokens.append({
322
+ similar_tokens[str(meta[qi_pos]['token'])] = {
323
+ 'token': meta[qi_pos]['token'],
324
+ 'text': meta[qi_pos]['text'],
325
+ # 'text': dict_tokenizer_tr[model].decode(meta[qi_pos]['token'])
326
+ # 'text': meta[qi_pos]['text'],
327
+ "similar_topk": [md_tokens[i_index]['token'] for i_index in ind if (i_index != -1) ],
328
+ "distance": [dist[i] for (i, i_index) in enumerate(ind) if (i_index != -1)],
329
+ }
330
+ # )
331
+ except:
332
+ print("\n ERROR ", qi_pos, dist, ind)
333
+ qi_pos += 1
334
+
335
+
336
+ index_vor_words = metadata_all[type]['words'][1]
337
+ md_words = metadata_all[type]['words'][2]
338
+
339
+ Dw, Iw, metaw = search_query_vocab(index_vor_words, words_vocab_queries)
340
+ # D, I, meta, vocab_words, sentence_words_list = result_input['words']# [2] # D ; I ; meta
341
+ qi_pos = 0
342
+ # similar_words = []
343
+ similar_words = {}
344
+ for dist, ind in zip(Dw,Iw):
345
+ try:
346
+ # similar_words.append({
347
+ similar_words[str(metaw[qi_pos]['word']) ] = {
348
+ 'word': metaw[qi_pos]['word'],
349
+ 'text': metaw[qi_pos]['word'],
350
+ "similar_topk": [md_words[i_index]['word'] for i_index in ind if (i_index != -1) ],
351
+ "distance": [dist[i] for (i, i_index) in enumerate(ind) if (i_index != -1)],
352
+ }
353
+ # )
354
+ except:
355
+ print("\n ERROR ", qi_pos, dist, ind)
356
+ qi_pos += 1
357
+
358
+
359
+ return {'tokens': {'D': D, 'I': I, 'meta': meta, 'vocab_queries': vocab_queries, 'similar':similar_tokens, 'sentence_key_list': sentence_tokens_list},
360
+ 'words': {'D':Dw,'I': Iw, 'meta': metaw, 'vocab_queries':words_vocab_queries, 'sentence_key_list': sentence_words_list, 'similar': similar_words}
361
+ }
362
+
363
+ def build_reference(all_embeddings, model):
364
+
365
+ # ## biuld vocab for index
366
+ vocab, sentence_tokens = create_vocab_multiple(all_embeddings,model)
367
+ words_vocab, sentences = vocab_words_all_prefix(all_embeddings, model, sufix="@@",prefix="▁")
368
+
369
+ index_tokens, meta_tokens = create_index_voronoi(vocab)
370
+ index_words, meta_words = create_index_voronoi_words(words_vocab)
371
+
372
+
373
+
374
+ return {'tokens': [vocab, index_tokens, meta_tokens],
375
+ 'words': [words_vocab, index_words, meta_words]
376
+ } # , index, meta
377
+
378
+
379
+ def embds_input_projection_vocab(vocab, key="token"):
380
+ t0 = time()
381
+
382
+ nb_ids = [] ##ordered ids list
383
+ nb_embds = [] ##ordered embeddings list
384
+ nb_text = [] ##ordered embeddings list
385
+ tnse_error = []
386
+ for _ , token_values in vocab.items():
387
+ tnse_error.append([0,0])
388
+ nb_ids.append(token_values[key]) # for x in vocab_tokens]
389
+ nb_text.append(token_values['text']) # for x in vocab_tokens]
390
+ nb_embds.append(token_values['embed']) # for x in vocab_tokens]
391
+
392
+ X = np.array(nb_embds).astype('float32') #elements to project
393
+ try:
394
+ tsne = TSNE(random_state=0, n_iter=1000)
395
+ tsne_results = tsne.fit_transform(X)
396
+
397
+ tsne_results = np.c_[tsne_results, nb_ids, nb_text, range(len(nb_ids))] ## creates a zip array : [[TNSE[X,Y], tokenid, token_text], ...]
398
+ except:
399
+ tsne_results = np.c_[tnse_error, nb_ids, nb_text, range(len(nb_ids))] ## creates a zip array : [[TNSE[X,Y], tokenid, token_text], ...]
400
+
401
+ t1 = time()
402
+ print("t-SNE: %.2g sec" % (t1 - t0))
403
+ print(tsne_results)
404
+
405
+ return tsne_results.tolist()
406
+
407
+ def filtered_projection(similar_key, vocab, type="input", key="word"):
408
+ global metadata_all
409
+ vocab_proj = vocab.copy()
410
+ ## tnse projection Input words
411
+ source_words_voc_similar = set()
412
+ # for words_set in similar_key:
413
+ for key_i in similar_key:
414
+ words_set = similar_key[key_i]
415
+ source_words_voc_similar.update(words_set['similar_topk'])
416
+
417
+ print(len(source_words_voc_similar))
418
+ # source_embeddings_filtered = {key: metadata_all['input']['words'][0][key] for key in source_words_voc_similar}
419
+ source_embeddings_filtered = {key_value: metadata_all[type][key][0][key_value] for key_value in source_words_voc_similar}
420
+ vocab_proj.update(source_embeddings_filtered)
421
+ ## vocab_proj add
422
+ try:
423
+ result_TSNE = embds_input_projection_vocab(vocab_proj, key=key[:-1]) ## singular => without 's'
424
+ dict_projected_embds_all = {str(embds[2]): [embds[0], embds[1], embds[2], embds[3], embds[4]] for embds in result_TSNE}
425
+ except:
426
+ print('TSNE error', type, key)
427
+ dict_projected_embds_all = {}
428
+
429
+
430
+
431
+ # print(result_TSNE)
432
+ return dict_projected_embds_all
433
+
434
+ def first_function(w1, model):
435
+ global metadata_all
436
+ #translate and get internal values
437
+ # print(w1)
438
+ sentences = w1.split("\n")
439
+ all_sentences = []
440
+ translated_text = ''
441
+ input_embeddings = []
442
+ output_embeddings = []
443
+ for sentence in sentences :
444
+ # print(sentence, end=";")
445
+ params = translation_model(sentence, model)
446
+ all_sentences.append(params)
447
+ # print(len(params))
448
+ translated_text += params[0] + ' \n'
449
+ input_embeddings.append({
450
+ 'embeddings': params[3].detach(), ## create a vocabulary with the set of embeddings
451
+ 'tokens': params[2].tolist(), # one translation = one sentence
452
+ # 'texts' : dict_tokenizer_tr[model].decode(params[2].tolist())
453
+
454
+ })
455
+ output_embeddings.append({
456
+ 'embeddings' : params[4].detach(),
457
+ 'tokens': params[1].sequences.tolist(),
458
+ # 'texts' : dict_tokenizer_tr[model].decode(params[1].sequences.tolist())
459
+ })
460
+ # print(input_embeddings)
461
+ # print(output_embeddings)
462
+
463
+ ## Build FAISS index
464
+ # ---> preload faiss using the respective model with a initial dataset.
465
+ result_input = build_reference(input_embeddings,model)
466
+ result_output = build_reference(output_embeddings,model)
467
+ # print(result_input, result_output)
468
+
469
+ metadata_all = {'input': result_input, 'output': result_output}
470
+
471
+ ### get translation
472
+
473
+ return [translated_text, params]
474
+
475
+ def first_function_tr(w1, model, var2={}):
476
+ global metadata_all
477
+ #Translate and find similar tokens in token
478
+ print("SEARCH -- ")
479
+ sentences = w1.split("\n")
480
+ all_sentences = []
481
+ translated_text = ''
482
+ input_embeddings = []
483
+ output_embeddings = []
484
+ for sentence in sentences :
485
+ # print(sentence, end=";")
486
+ params = translation_model(sentence, model)
487
+ all_sentences.append(params)
488
+ # print(len(params))
489
+ translated_text += params[0] + ' \n'
490
+ input_embeddings.append({
491
+ 'embeddings': params[3].detach(), ## create a vocabulary with the set of embeddings
492
+ 'tokens': params[2].tolist(), # one translation = one sentence
493
+ # 'texts' : dict_tokenizer_tr[model].decode(params[2].tolist()[0])
494
+ })
495
+ output_embeddings.append({
496
+ 'embeddings' : params[4].detach(),
497
+ 'tokens': params[1].sequences.tolist(),
498
+ # 'texts' : dict_tokenizer_tr[model].decode(params[1].sequences.tolist())
499
+ })
500
+
501
+ ## Build FAISS index
502
+ # ---> preload faiss using the respective model with a initial dataset.
503
+ result_search = {}
504
+ result_search['input'] = build_search(input_embeddings, model, type='input')
505
+ result_search['output'] = build_search(output_embeddings, model, type='output')
506
+
507
+ # D, I, meta, vocab_words, sentence_words_list = result_input['words']# [2] # D ; I ; meta
508
+ # md = metadata_all['input']['words'][2]
509
+ # qi_pos = 0
510
+ # similar_words = []
511
+ # for dist, ind in zip(D,I):
512
+ # try:
513
+ # similar_words.append({
514
+ # 'word': meta[qi_pos]['word'],
515
+ # "similar_topk": [md[i_index]['word'] for i_index in ind if (i_index != -1) ],
516
+ # "distance": [D[qi_pos][i] for (i, i_index) in enumerate(ind) if (i_index != -1)],
517
+ # })
518
+ # except:
519
+ # print("\n ERROR ", qi_pos, dist, ind)
520
+ # qi_pos += 1
521
+ # similar_vocab_queries = similar_vocab_queries[3]
522
+
523
+ # result_output = build_search(output_embeddings, model, type="output")
524
+ ## {'tokens': {'D': D, 'I': I, 'meta': meta, 'vocab_queries': vocab_queries, 'similar':similar_tokens},
525
+ ## 'words': {'D':Dw,'I': Iw, 'meta': metaw, 'vocab_queries':words_vocab_queries, 'sentence_key_list': sentence_words_list, 'similar': similar_words}
526
+ ## }
527
+
528
+ # print(result_input, result_output)
529
+
530
+
531
+ # json_out['input']['tokens'] = { 'similar_queries' : result_input['token'][5], # similarity and distance dict.
532
+ # 'tnse': dict_projected_embds_all, #projected points (all)
533
+ # 'key_text_list': result_input['token'][4], # current sentences keys
534
+ # }
535
+
536
+ json_out = {'input': {'tokens': {}, 'words': {}}, 'output': {'tokens': {}, 'words': {}}}
537
+ dict_projected = {}
538
+ for type in ['input', 'output']:
539
+ dict_projected[type] = {}
540
+ for key in ['tokens', 'words']:
541
+ similar_key = result_search[type][key]['similar']
542
+ vocab = result_search[type][key]['vocab_queries']
543
+ dict_projected[type][key] = filtered_projection(similar_key, vocab, type=type, key=key)
544
+ json_out[type][key]['similar_queries'] = similar_key
545
+ json_out[type][key]['tnse'] = dict_projected[type][key]
546
+ json_out[type][key]['key_text_list'] = result_search[type][key]['sentence_key_list']
547
+
548
+ return [translated_text, [ json_out, json_out['output']['words'], json_out['output']['tokens']] ]
549
+
550
+
551
+
552
+ ## First create html and divs
553
+ html = """
554
+ <html>
555
+ <script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
556
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min"></script>
557
+ <script async data-require="d3@3.5.3" data-semver="3.5.3"
558
+ src="//cdnjs.cloudflare.com/ajax/libs/d3/3.5.3/d3.js"></script>
559
+ <body>
560
+ <div id="select_div">
561
+ <select id="select_type" class="form-select" aria-label="select example" hidden>
562
+ <option selected value="words">Words</option>
563
+ <option value="tokens">Tokens</option>
564
+ </select>
565
+ </div>
566
+ <div id="d3_embed_div">
567
+ <div class="row">
568
+ <div class="col-6">
569
+ <div id="d3_embeds_input_words" class="d3_embed words"></div>
570
+ </div>
571
+ <div class="col-6">
572
+ <div id="d3_embeds_output_words" class="d3_embed words"></div>
573
+
574
+ </div>
575
+ <div class="col-6">
576
+ <div id="d3_embeds_input_tokens" class="d3_embed tokens"></div>
577
+ </div>
578
+ <div class="col-6">
579
+ <div id="d3_embeds_output_tokens" class="d3_embed tokens"></div>
580
+ </div>
581
+ </div>
582
+ </div>
583
+ <div id="d3_graph_div">
584
+ <div class="row">
585
+ <div class="col-4">
586
+ <div id="d3_graph_input_words" class="d3_graph words"></div>
587
+
588
+ </div>
589
+ <div class="col-4">
590
+ <div id="similar_input_words" class=""></div>
591
+ </div>
592
+ <div class="col-4">
593
+ <div id="d3_graph_output_words" class="d3_graph words"></div>
594
+ <div id="similar_output_words" class="d3_graph words"></div>
595
+ </div>
596
+ </div>
597
+ <div class="row">
598
+ <div class="col-6">
599
+ <div id="d3_graph_input_tokens" class="d3_graph tokens"></div>
600
+ <div id="similar_input_tokens" class="d3_graph tokens"></div>
601
+ </div>
602
+ <div class="col-6">
603
+ <div id="d3_graph_output_tokens" class="d3_graph tokens"></div>
604
+ <div id="similar_output_tokens" class="d3_graph tokens"></div>
605
+ </div>
606
+ </div>
607
+ </div>
608
+ </body>
609
+
610
+ </html>
611
+ """
612
+ html0 = """
613
+ <html>
614
+ <script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
615
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min"></script>
616
+ <script async data-require="d3@3.5.3" data-semver="3.5.3"
617
+ src="//cdnjs.cloudflare.com/ajax/libs/d3/3.5.3/d3.js"></script>
618
+ <body>
619
+ <div id="select_div">
620
+ <select id="select_type" class="form-select" aria-label="select example" hidden>
621
+ <option selected value="words">Words</option>
622
+ <option value="tokens">Tokens</option>
623
+ </select>
624
+ </div>
625
+ </body>
626
+
627
+ </html>
628
+ """
629
+
630
+ html_col1 = """
631
+ <div id="d3_graph_input_words" class="d3_graph words"></div>
632
+ <div id="d3_graph_input_tokens" class="d3_graph tokens"></div>
633
+ """
634
+
635
+ html_col2 = """
636
+ <div id="similar_input_words" class=""></div>
637
+ <div id="similar_output_words" class=""></div>
638
+ <div id="similar_input_tokens" class=" "></div>
639
+ <div id="similar_output_tokens" class=" "></div>
640
+
641
+ """
642
+
643
+
644
+ html_col3 = """
645
+ <div id="d3_graph_output_words" class="d3_graph words"></div>
646
+ <div id="d3_graph_output_tokens" class="d3_graph tokens"></div>
647
+ """
648
+
649
+
650
+ # # <div class="row">
651
+ # <div class="col-6" id="d3_legend_data_source"> </div>
652
+ # <div class="col-6" id="d3_legend_similar_source"> </div>
653
+ # </div>
654
+ def second_function(w1,j2):
655
+ # json_value = {'one':1}# return f"{w1['two']} in sentence22..."
656
+ # to transfer the data to json.
657
+ print("second_function -- after the js", w1,j2)
658
+ return "transition to second js function finished."
659
+
660
+
661
+ with gr.Blocks(js="plotsjs.js") as demo:
662
+ gr.Markdown(
663
+ """
664
+ # MAKE NMT Workshop \t `Embeddings representation`
665
+ """)
666
+ with gr.Row():
667
+ with gr.Column(scale=1):
668
+ model_radio_c = gr.Radio(choices=['en-es', 'en-zh', 'en-fr'], value="en-es", label= '', container=False)
669
+
670
+ with gr.Column(scale=2):
671
+ gr.Markdown(
672
+ """
673
+ ### Reference Translation Sentences
674
+ Enter at least 50 sentences to be used as comparison.
675
+ This is submitted just once.
676
+ """)
677
+ in_text = gr.Textbox(lines=2, label="reference source text")
678
+ out_text = gr.Textbox(label="reference target text", interactive=False)
679
+ out_text2 = gr.Textbox(visible=False)
680
+ var2 = gr.JSON(visible=False)
681
+ btn = gr.Button("Reference Translation")
682
+
683
+
684
+ with gr.Column(scale=3):
685
+
686
+ gr.Markdown(
687
+ """
688
+ ### Translation Sentences
689
+ Sentences to be analysed.
690
+ """)
691
+ in_text_tr = gr.Textbox(lines=2, label="source text")
692
+ out_text_tr = gr.Textbox(label="target text", interactive=False)
693
+ out_text2_tr = gr.Textbox(visible=False)
694
+ var2_tr = gr.JSON(visible=False)
695
+ btn_faiss= gr.Button("Translation ")
696
+
697
+
698
+ with gr.Row():
699
+ # input_mic = gr.HTML(html)
700
+ with gr.Column(scale=1):
701
+ input_mic = gr.HTML(html0)
702
+ input_html2 = gr.HTML(html_col2)
703
+
704
+ with gr.Column(scale=2):
705
+ input_html1 = gr.HTML(html_col1)
706
+ # with gr.Column(scale=2):
707
+
708
+ with gr.Column(scale=2):
709
+ input_html3 = gr.HTML(html_col3)
710
+
711
+ ## first function input w1, model ; return out_text, var2; it does first function and js;
712
+ btn.click(first_function, [in_text, model_radio_c], [out_text,var2], js="(in_text,model_radio_c) => testFn_out(in_text,model_radio_c)") #should return an output comp.
713
+ btn_faiss.click(first_function_tr, [in_text_tr, model_radio_c], [out_text_tr,var2_tr], js="(in_text_tr,model_radio_c) => testFn_out(in_text_tr,model_radio_c)") #should return an output comp.
714
+ ## second function input out_text(returned in first_function), [json]var2(returned in first_function) ;
715
+ ## second function returns out_text2, var2; it does second function and js(with the input params);
716
+ out_text.change(second_function, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
717
+ out_text_tr.change(second_function, [out_text_tr, var2_tr], out_text2_tr, js="(out_text_tr,var2_tr) => testFn_out_json_tr(var2_tr)") #
718
+
719
+ # run script function on load,
720
+ # demo.load(None,None,None,js="plotsjs.js")
721
+
722
+ if __name__ == "__main__":
723
+ demo.launch()
plotsjs.js ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ async () => {
3
+ // set testFn() function on globalThis, so you html onlclick can access it
4
+
5
+
6
+ globalThis.testFn = () => {
7
+ document.getElementById('demo').innerHTML = "Hello?"
8
+ };
9
+
10
+ const d3 = await import("https://cdn.jsdelivr.net/npm/d3@7/+esm");
11
+ // const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm");
12
+ const $ = await import("https://cdn.jsdelivr.net/npm/jquery@3.7.1/dist/jquery.min.js");
13
+
14
+ globalThis.$ = $;
15
+ globalThis.d3 = d3;
16
+
17
+ globalThis.d3Fn = () => {
18
+ d3.select('#viz').append('svg')
19
+ .append('rect')
20
+ .attr('width', 50)
21
+ .attr('height', 50)
22
+ .attr('fill', 'black')
23
+ .on('mouseover', function(){d3.select(this).attr('fill', 'red')})
24
+ .on('mouseout', function(){d3.select(this).attr('fill', 'black')});
25
+
26
+ };
27
+
28
+ globalThis.testFn_out = (val,model_radio_c) => {
29
+ // document.getElementById('demo').innerHTML = val
30
+ console.log(val, "testFn_out");
31
+ // globalThis.d3Fn();
32
+ return([val,model_radio_c]);
33
+ };
34
+
35
+
36
+ globalThis.testFn_out_json = (data) => {
37
+ console.log(data, "testFn_out_json --");
38
+ // var $ = jQuery;
39
+ // console.log( d3.select('#d3_embeddings'));
40
+ return(['string', {}])
41
+ }
42
+
43
+ globalThis.testFn_out_json_tr = (data) => {
44
+ // data['input|output']['words|tokens']
45
+
46
+ console.log(data, "testFn_out_json_tr new");
47
+ var $ = jQuery;
48
+ console.log("$('#d3_embeddings')");
49
+ console.log($('#d3_embeddings'));
50
+ // d3.select('#d3_embeddings').html("");
51
+
52
+
53
+ d3.select("#d3_embeds_source").html("here");
54
+
55
+ // words or token visualization ?
56
+ console.log(d3.select("#select_type").node().value);
57
+ d3.select("#select_type").attr("hidden", null);
58
+ d3.select("#select_type").on("change", change);
59
+ change();
60
+ // tokens
61
+ // network plots;
62
+ ['input', 'output'].forEach(text_type => {
63
+ ['tokens', 'words'].forEach(text_key => {
64
+ // console.log(type, key, data[0][text_type]);
65
+ data_i = data[0][text_type][text_key];
66
+ embeddings_network([], data_i['tnse'], data_i['similar_queries'], type=text_type +"_"+text_key, )
67
+ });
68
+ });
69
+
70
+
71
+
72
+
73
+
74
+ // data_proj = data['tsne']; // it is not a dict.
75
+ // d3.select("#d3_embeds_" + type).html(scatterPlot(data_proj, data_sentences, dict_token_sentence_id, similar_vocab_queries, 'd3_embeds_'+type, type ));
76
+ // d3.select('#d3_embeddings').append(function(){return Tree(root);});
77
+ // embeddings_network(data['source_tokens'], data['dict_projected_embds_all']['source'], data['similar_vocab_queries']['source'], "source")
78
+
79
+ // source
80
+ // embeddings_graph(data['dict_projected_embds_all'],source_tks_list, data['source_tokens'], data['similar_vocab_queries'], "source"); //, data['similar_text'], data['similar_embds']);
81
+ // target decision: all tokens ? or separeted by language ? hint: do not assume they share the same dict.
82
+ // embeddings_graph(data['dict_projected_embds_all'], translated_tks_text, translated_tks_ids_by_sent, data['similar_vocab_queries'], "target"); //, data['similar_text'], data['similar_embds']);
83
+
84
+ return(['string', {}])
85
+
86
+ }
87
+
88
+ function change() {
89
+ show_type = d3.select("#select_type").node().value;
90
+ // hide all
91
+ d3.selectAll(".d3_embed").attr("hidden",'');
92
+ d3.selectAll(".d3_graph").attr("hidden", '');
93
+ // show current type;
94
+ d3.select("#d3_embeds_input_" + show_type).attr("hidden", null);
95
+ d3.select("#d3_embeds_output_" + show_type).attr("hidden", null);
96
+ d3.select("#d3_graph_input_" + show_type).attr("hidden", null);
97
+ d3.select("#d3_graph_output_" + show_type).attr("hidden", null);
98
+ }
99
+
100
+
101
+
102
+ function embeddings_network(tokens_text, dict_projected_embds, similar_vocab_queries, type="source", ){
103
+ // tokens_text : not used;
104
+ // dict_projected_embds = tnse
105
+ console.log("Each token is a node; distance if in similar list", type );
106
+ console.log(tokens_text, dict_projected_embds, similar_vocab_queries);
107
+ // similar_vocab_queries_target[key]['similar_topk']
108
+
109
+ var nodes_tokens = {}
110
+ var nodeHash = {};
111
+ var nodes = []; // [{id: , label: }]
112
+ var edges = []; // [{source: , target: weight: }]
113
+ var edges_ids = []; // [{source: , target: weight: }]
114
+
115
+ // similar_vocab_queries {key: {similar_topk : [], distance : []}}
116
+ console.log('similar_vocab_queries', similar_vocab_queries);
117
+ prev_node = '';
118
+ for ([sent_token, value] of Object.entries(similar_vocab_queries)) {
119
+ // console.log('dict_projected_embds',sent_token, parseInt(sent_token), value, dict_projected_embds);
120
+ // sent_token = parseInt(sent_token); // Object.entries assumes key:string;
121
+ token_text = dict_projected_embds[sent_token][3]
122
+ if (!nodeHash[sent_token]) {
123
+ nodeHash[sent_token] = {id: sent_token, label: token_text, type: 'sentence', type_i: 0};
124
+ nodes.push(nodeHash[sent_token]);
125
+ }
126
+ sim_tokens = value['similar_topk']
127
+ dist_tokens = value['distance']
128
+
129
+ for (let index = 0; index < sim_tokens.length; index++) {
130
+ const sim = sim_tokens[index];
131
+ const dist = dist_tokens[index];
132
+
133
+ token_text_sim = dict_projected_embds[sim][3]
134
+ if (!nodeHash[sim]) {
135
+ nodeHash[sim] = {id: sim, label: token_text_sim, type:'similar', type_i: 1};
136
+ nodes.push(nodeHash[sim]);
137
+ }
138
+ edges.push({source: nodeHash[sent_token], target: nodeHash[sim], weight: dist});
139
+ edges_ids.push({source: sent_token, target: sim, weight: dist});
140
+ }
141
+
142
+ if (prev_node != '' ) {
143
+ edges.push({source: nodeHash[prev_node], target:nodeHash[sent_token], weight: 1});
144
+ edges_ids.push({source: prev_node, target: sent_token, weight: 1});
145
+ }
146
+ prev_node = sent_token;
147
+
148
+ }
149
+ console.log("TYPE", type, edges, nodes, edges_ids, similar_vocab_queries)
150
+ // d3.select('#d3_graph_input_tokens').html(networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, div_type=type) );
151
+ // type +"_"+key
152
+ d3.select('#d3_graph_'+type).html("");
153
+ d3.select('#d3_graph_'+type).append(function(){return networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, dict_projected_embds,div_type=type);});
154
+
155
+ // $('#d3_embeds_network_target').html(networkPlot({nodes: nodes, links:edges}));
156
+ // $('#d3_embeds_network_'+type).html(etworkPlot({nodes: nodes, link:edges}));
157
+ }
158
+
159
+ function embeddings_graph(data, source_tokens_text_list, source_tokens, similar_vocab_queries, type="source") {
160
+ /*
161
+ ### source
162
+ data: dict_projected_embds_all = { token_id: [tns1, tns2, token_id, token_text] ...}
163
+ ### target
164
+ */
165
+ console.log("embeddings_graph");
166
+ active_sentences = get_sentences();
167
+ console.log("active_sentences", active_sentences, type); // working
168
+
169
+ active_sentences_tokens_text = active_sentences.map((x) => source_tokens_text_list[x]);
170
+ active_sentences_tokens = active_sentences.map((x) => source_tokens[x]);
171
+
172
+ console.log(active_sentences_tokens);
173
+
174
+ data_sentences = []
175
+ dict_token_sentence_id = {}
176
+ // active_sentences_tokens.forEach((sentence, i) => {
177
+ source_tokens_text_list.forEach((sentence, i) => {
178
+ /// opt1
179
+ proj = []
180
+ sentence.forEach((tok, tok_j) => {
181
+ console.log("tok,tok_j", tok, tok_j);
182
+ token_text = source_tokens_text_list[i][tok_j];
183
+ proj.push([data[tok][0], data[tok][1], token_text, i, tok_j, tok])
184
+ if (token_text in dict_token_sentence_id){
185
+ dict_token_sentence_id[token_text].push(i);
186
+ }
187
+ else{
188
+ dict_token_sentence_id[token_text] = [i];
189
+ }
190
+ });
191
+ data_sentences.push(proj);
192
+ });
193
+ console.log("data_sentences error here in target", data_sentences);
194
+
195
+ console.log(data);
196
+
197
+ $('#d3_embeds_' + type).html(scatterPlot(data, data_sentences, dict_token_sentence_id, similar_vocab_queries, 'd3_embeds_'+type, type ));
198
+ }
199
+
200
+
201
+ /*
202
+ data: dict_projected_embds_all = { token_id: [tns1, tns2, token_id, token_text] ...}
203
+ */
204
+ function scatterPlot(data, data_sentences, dict_token_sentence_id, similar_vocab_queries, div_name, div_type="source", {
205
+ width = 400, // outer width, in pixels
206
+ height , // outer height, in pixels
207
+ r = 3, // radius of nodes
208
+ padding = 1, // horizontal padding for first and last column
209
+ // text = d => d[2],
210
+ } = {}){
211
+ // data_dict = data[div_type];
212
+ var data_dict = { ...data[div_type] };
213
+ data = Object.values(data[div_type]);
214
+ // similar_vocab_queries = similar_vocab_queries[div_type];
215
+ var similar_vocab_queries = { ...similar_vocab_queries[div_type] };
216
+ console.log("div_type, data, data_dict, data_sentences, dict_token_sentence_id, similar_vocab_queries");
217
+ console.log(div_type, data, data_dict, data_sentences, dict_token_sentence_id, similar_vocab_queries);
218
+
219
+ // Create the SVG container.
220
+ var margin = {top: 10, right: 10, bottom: 30, left: 50 },
221
+ width = width - margin.left - margin.right,
222
+ height = 400 - margin.top - margin.bottom;
223
+
224
+ // append the svg object to the body of the page
225
+ var svg = d3.create("svg")
226
+ // .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
227
+ .attr("width", width + margin.left + margin.right)
228
+ .attr("height", height + margin.top + margin.bottom)
229
+
230
+ svg.append("g")
231
+ .attr("transform",
232
+ "translate(" + margin.left + "," + margin.top + ")");
233
+
234
+ // const svg = d3.create("svg")
235
+ // .attr("width", width)
236
+ // .attr("height", height);
237
+
238
+ // Add X axis
239
+ min_value_x = d3.min(data, d => d[0])
240
+ max_value_x = d3.max(data, d => d[0])
241
+
242
+
243
+ var x = d3.scaleLinear()
244
+ .domain([min_value_x, max_value_x])
245
+ .range([ margin.left , width ]);
246
+
247
+ svg.append("g")
248
+ // .attr("transform", "translate("+ margin.left +"," + height + ")")
249
+ .attr("transform", "translate(0," + height + ")")
250
+ .call(d3.axisBottom(x));
251
+
252
+ // Add Y axis
253
+ min_value_y = d3.min(data, d => d[1])
254
+ max_value_y = d3.max(data, d => d[1])
255
+
256
+ var y = d3.scaleLinear()
257
+ .domain([min_value_y, max_value_y])
258
+ .range([ height, margin.top]);
259
+
260
+ svg.append("g")
261
+ .attr("transform", "translate("+ margin.left +", 0)")
262
+ .call(d3.axisLeft(y));
263
+
264
+ svg.selectAll()
265
+ .data(data)
266
+ .enter()
267
+ .append('circle')
268
+ .attr("class", function (d) { return "dot-" + d[2] } )
269
+ // .attr("cx", function (d) { return x(d[0] + margin.left); } )
270
+ .attr("cx", function (d) { return x(d[0]); } )
271
+ .attr("cy", function (d) { return y(d[1] - margin.bottom); } )
272
+ .attr("r", 5)
273
+ .style("fill", "#e85252")
274
+ .style("fillOpacity",0.2)
275
+ .style("stroke", "#000000ff")
276
+ .style("strokeWidth", 1)
277
+ .style("opacity", 0.7);
278
+
279
+ // svg.selectAll()
280
+ // .data(data)
281
+ // .enter()
282
+ // .append('text')
283
+ // .text(d => d[3])
284
+ // .attr("class", function (d) { return "text-" + d[2] } )
285
+ // // .attr("cx", function (d) { return x(d[0] + margin.left); } )
286
+ // .attr("x", function (d) { return x(d[0]); } )
287
+ // .attr("y", function (d) { return y(d[1] - margin.bottom); } )
288
+ // .attr("dy", "0.35em");
289
+
290
+ // colors = ['#cb1dd1',"#e0ac2b", "#e85252", "#6689c6", "#9a6fb0", "#a53253"];
291
+ colors = ['#6689c6',"#e0ac2b", "#e0ac2b", "#cb1dd1", "#cb1dd1", "#cb1dd1"];
292
+
293
+ // create a tooltip
294
+ var Tooltip = d3.select("#"+div_name)
295
+ .append("div")
296
+ .style("opacity", 0)
297
+ .attr("class", "tooltip")
298
+ .style("background-color", "white")
299
+ .style("border", "solid")
300
+ .style("border-width", "2px")
301
+ .style("border-radius", "5px")
302
+ .style("padding", "5px")
303
+ .text("I'm a circle!");
304
+
305
+ // const colorScale = d3.scaleOrdinal()
306
+ // .domain(domain_values)
307
+ // .range(["#e0ac2b", "#e85252", "#6689c6", "#9a6fb0", "#a53253"]);
308
+ // colorScale(d.group)
309
+
310
+ for (let i_snt = 0; i_snt < data_sentences.length; i_snt++) {
311
+ const sentence = data_sentences[i_snt];
312
+ // similar_tokens;
313
+ console.log("sentence: ", sentence);
314
+
315
+ svg.selectAll()
316
+ .data(sentence)
317
+ .enter()
318
+ .append('text')
319
+ .text(d => d[2])
320
+ .attr("class", function (d) { return "text-" + d[2] + " sent-" + i_snt } )
321
+ // .attr("cx", function (d) { return x(d[0] + margin.left); } )
322
+ .attr("x", function (d) { return x(d[0]); } )
323
+ .attr("y", function (d) { return y(d[1] - margin.bottom); } )
324
+ .attr("dy", "0.35em")
325
+ .attr("sentence_i", i_snt );
326
+
327
+ svg.selectAll()
328
+ .data(sentence)
329
+ .enter()
330
+ .append('circle')
331
+ .attr("class", function (d) { return "dot " + d[2] + " " + i_snt } )
332
+ // .attr("cx", function (d) { return x(d[0] + margin.left); } )
333
+ .attr("cx", function (d) { return x(d[0]); } )
334
+ .attr("cy", function (d) { return y(d[1] - margin.bottom); } )
335
+ .attr("sentence_i", i_snt )
336
+ .attr("r", 6)
337
+ .style("fill", colors[0])
338
+ .style("fillOpacity",0.2)
339
+ .style("stroke", "#000000")
340
+ .style("strokeWidth", 1)
341
+ .style("opacity", 1)
342
+ .on('click', change_legend )
343
+ .on('mouseover', highlight_mouseover )
344
+ .on('mouseout', highlight_mouseout )
345
+ // .on("mousemove", mousemove);
346
+
347
+
348
+ }
349
+
350
+
351
+ function change_legend(d,i) {
352
+ console.log(d,i);
353
+ if (i[2] in dict_token_sentence_id){
354
+ show_sentences(dict_token_sentence_id[i[2]], i[2]);
355
+
356
+ show_similar_tokens(i[5], '#d3_legend_similar_'+type);
357
+
358
+ console.log(dict_token_sentence_id[i[2]]);
359
+ }
360
+ else{console.log("no sentence")};
361
+ }
362
+
363
+ function highlight_mouseover(d,i) {
364
+ console.log("highlight_mouseover", d,i);
365
+ // token_id = parseInt(i[5])
366
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
367
+ d3.select(this).transition()
368
+ .duration('50')
369
+ .style('opacity', '1')
370
+ .attr("r", 12)
371
+
372
+ similar_ids.forEach(similar_token => {
373
+ d3.selectAll('.dot-' + similar_token).attr("r",12 ).style('opacity', '1')//.raise()
374
+ });
375
+
376
+ Tooltip
377
+ .style("opacity", 1)
378
+ .style("visibility", "visible")
379
+ // .style("top", (event.pageY-height)+"px").style("left",(event.pageX-width)+"px")
380
+ d3.select(this)
381
+ .style("stroke", "red")
382
+ .attr("strokeWidth", 2)
383
+ .style("opacity", 0.7)
384
+
385
+ // .html("The exact value of<br>this cell is: ")
386
+ // .style("left", (d3.mouse(this)[0]+70) + "px")
387
+ // .style("top", (d3.mouse(this)[1]) + "px")
388
+
389
+ }
390
+ function highlight_mouseout(d,i) {
391
+ // token_id = parseInt(i[5])
392
+ console.log("similar_vocab_queries", similar_vocab_queries);
393
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
394
+ // clean_sentences();
395
+ d3.select(this).transition()
396
+ .duration('50')
397
+ .style('opacity', '.7')
398
+ .attr("r", 6)
399
+
400
+ similar_ids.forEach(similar_token => {
401
+ d3.selectAll('.dot-' + similar_token).attr("r",6 ).style('opacity', '.7')
402
+ });
403
+
404
+ Tooltip
405
+ .style("opacity", 0)
406
+ d3.select(this)
407
+ .style("stroke", "none")
408
+ .style("opacity", 0.8)
409
+ }
410
+
411
+ function mousemove(d,i) {
412
+ console.log("mousemove", d, i)
413
+ pointer = d3.pointer(d);
414
+ Tooltip
415
+ .html("The exact value of<br> ")
416
+ // .style("top", ((e.pageY ) - (height*2)) +"px")
417
+ // .attr("transform", `translate(${pointer[0]},0)`)
418
+ .style("top", height - pointer[1] +"px")
419
+ .style("left", pointer[0]+"px")
420
+ }
421
+
422
+
423
+ function show_sentences(sentences_id, token) {
424
+
425
+ // Show sentences with token "token"
426
+ d3.select('#d3_legend_data_'+div_type).html("");
427
+ console.log("show_sentences", data_sentences, sentences_id);
428
+ sentences_id.forEach(sent_id => {
429
+ console.log(data_sentences[sent_id])
430
+ // console.log(data_sentences[sent_id].map( x => x[2] ));
431
+ // p = d3.select('#d3_legend_data').append("p").enter();
432
+ d3.select('#d3_legend_data_'+div_type)
433
+ .selectAll().append("p")
434
+ .data(data_sentences[sent_id])
435
+ .enter()
436
+ .append('text')
437
+ .attr('class_data', sent_id)
438
+ .attr('class_id', d => d[5])
439
+ .style("background", d=> {if (d[2]== token) return "yellow"} )
440
+ .text( d => d[2] + " ");
441
+ d3.select('#d3_legend_data_'+div_type).append("p").enter();
442
+ });
443
+ // $("#d3_legend_data")
444
+ // data_sentences
445
+ }
446
+
447
+ function clean_sentences() {
448
+ d3.select('#d3_legend_data_'+div_type).html("");
449
+ }
450
+
451
+ function show_similar_tokens(token, div_name_similar= '#d3_legend_similar_') {
452
+ d3.select(div_name_similar).html("");
453
+ console.log("token", token);
454
+ console.log("similar_vocab_queries[token]", similar_vocab_queries[token]);
455
+ token_data = similar_vocab_queries[token];
456
+ console.log(token, token_data);
457
+ var decForm = d3.format(".3f");
458
+
459
+ d3.select(div_name_similar)
460
+ .selectAll().append("p")
461
+ .data(token_data['similar_topk'])
462
+ .enter()
463
+ .append("p").append('text')
464
+ // .attr('class_data', sent_id)
465
+ .attr('class_id', d => d)
466
+ .style("background", d=> {if (d == token) return "yellow"} )
467
+ // .text( d => d + " \n ");
468
+ .text((d,i) => do_text(d,i) );
469
+
470
+ function do_text(d,i){
471
+ console.log("do_text d,i" );
472
+ console.log(d,i);
473
+ console.log("data_dict[d], data_dict");
474
+ // console.log(data_dict[d], data_dict);
475
+ // return data_dict[d][3] + " " + decForm(token_data['distance'][i]) + " ";
476
+ return " " + decForm(token_data['distance'][i]) + " ";
477
+ }
478
+
479
+
480
+ }
481
+ // data_sentences
482
+
483
+ // .attr('x', (d) => x_scale(d[0]) + margin.left)
484
+ // .attr('y', (d) => y_scale(d[1]) + margin_top_extra)
485
+ // .attr("rx", 4)
486
+ // .attr("ry", 4)
487
+ // .attr("stroke", "#F7F7F7")
488
+ // .attr("stroke-width","2px")
489
+ // .attr('width', x_scale.bandwidth())
490
+ // .attr('height', (d) => height_text);
491
+ // // .attr('fill', (d) => color_scale(d.value));
492
+
493
+ // Add dots
494
+ // svg.append('g')
495
+ // // .selectAll("dot")
496
+ // .data(data)
497
+ // .enter()
498
+ // .append("circle")
499
+ // .attr("class", function (d) { return "dot " + d[2] } )
500
+ // .attr("cx", function (d) { return x(d[0]); } )
501
+ // .attr("cy", function (d) { return y(d[1]); } )
502
+ // .attr("r", 5)
503
+ // .style("fill", function (d) { return color(d.Species) } )
504
+ // .on("mouseover", highlight)
505
+ // .on("mouseleave", doNotHighlight )
506
+
507
+
508
+
509
+ return svg.node();
510
+ }
511
+
512
+
513
+
514
+ function networkPlot(data, similar_vocab_queries,dict_proj, div_type="source", {
515
+ width = 400, // outer width, in pixels
516
+ height , // outer height, in pixels
517
+ r = 3, // radius of nodes
518
+ padding = 1, // horizontal padding for first and last column
519
+ // text = d => d[2],
520
+ } = {}){
521
+ // data_dict = data;
522
+ data = data// [div_type];
523
+ similar_vocab_queries = similar_vocab_queries// [div_type];
524
+ console.log("data, similar_vocab_queries, div_type");
525
+ console.log(data, similar_vocab_queries, div_type);
526
+
527
+ // Create the SVG container.
528
+ var margin = {top: 10, right: 10, bottom: 30, left: 50 },
529
+ width = width //- margin.left - margin.right,
530
+ height = 400 //- margin.top - margin.bottom;
531
+
532
+ width_box = width + margin.left + margin.right;
533
+ height_box = height + margin.top + margin.bottom
534
+ totalWidth = width*2;
535
+ // append the svg object to the body of the page
536
+ // const parent = d3.create("div");
537
+ // const body = parent.append("div")
538
+ // .style("overflow-x", "scroll")
539
+ // .style("-webkit-overflow-scrolling", "touch");
540
+
541
+
542
+ var svg = d3.create("svg")
543
+ // var svg = body.create("svg")
544
+ // .style("display", "block")
545
+ // .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
546
+ .attr("width", width + margin.left + margin.right)
547
+ .attr("height", height + margin.top + margin.bottom)
548
+ // .attr("viewBox", [-width_box / 2, -height_box / 2, width_box, height_box])
549
+ // .attr("viewBox", [0, 0, width, height]);
550
+ // .attr("style", "max-width: 100%; height: auto;");
551
+
552
+ // svg.append("g")
553
+ // .attr("transform",
554
+ // "translate(" + margin.left + "," + margin.top + ")");
555
+
556
+
557
+
558
+ // Initialize the links
559
+ var link = svg
560
+ .selectAll("line")
561
+ .data(data.links)
562
+ .enter()
563
+ .append("line")
564
+ .style("fill", d => d.weight == 1 ? "#dfd5d5" : "#000000") // , "#69b3a2" : "#69b3a2")
565
+ .style("stroke", "#aaa")
566
+
567
+
568
+
569
+ var text = svg
570
+ .selectAll("text")
571
+ .data(data.nodes)
572
+ .enter()
573
+ .append("text")
574
+ .style("text-anchor", "middle")
575
+ .attr("y", 15)
576
+ .attr("class", d => 'text_token-'+ dict_proj[d.id][4] + div_type)
577
+ .attr("div-type", div_type)
578
+ // .attr("class", d => 'text_token-'+ d.index)
579
+ .text(function (d) {return d.label} )
580
+ // .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseover_text : console.log(0)} )
581
+ // .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseout_text : '' } )
582
+ // .on('mouseout', highlight_mouseout_text )
583
+ // .join('text')
584
+ // .text(function(d) {
585
+ // return d.id
586
+ // })
587
+
588
+ // Initialize the nodes
589
+ var node = svg
590
+ .selectAll("circle")
591
+ .data(data.nodes)
592
+ .enter()
593
+ .append("circle")
594
+ .attr("r", 6)
595
+ // .attr("class", d => 'node_token-'+ d.id)
596
+ .attr("class", d => 'node_token-'+ dict_proj[d.id][4] + div_type)
597
+ .attr("div-type", div_type)
598
+ .style("fill", d => d.type_i ? "#e85252" : "#6689c6") // , "#69b3a2" : "#69b3a2")
599
+ .on('mouseover', highlight_mouseover )
600
+ // .on('mouseover', function(d) { return (d.type_i == 0) ? highlight_mouseover : console.log(0)} )
601
+ .on('mouseout',highlight_mouseout )
602
+ .on('click', change_legend )
603
+ // .on('click', show_similar_tokens )
604
+
605
+
606
+
607
+ // Let's list the force we wanna apply on the network
608
+ var simulation = d3.forceSimulation(data.nodes) // Force algorithm is applied to data.nodes
609
+ .force("link", d3.forceLink() // This force provides links between nodes
610
+ .id(function(d) { return d.id; }) // This provide the id of a node
611
+ .links(data.links) // and this the list of links
612
+ )
613
+ .force("charge", d3.forceManyBody(-400)) // This adds repulsion between nodes. Play with the -400 for the repulsion strength
614
+ .force("center", d3.forceCenter(width / 2, height / 2)) // This force attracts nodes to the center of the svg area
615
+ // .force("collision", d3.forceCollide())
616
+ .on("end", ticked);
617
+
618
+ // This function is run at each iteration of the force algorithm, updating the nodes position.
619
+ function ticked() {
620
+ link
621
+ .attr("x1", function(d) { return d.source.x; })
622
+ .attr("y1", function(d) { return d.source.y; })
623
+ .attr("x2", function(d) { return d.target.x; })
624
+ .attr("y2", function(d) { return d.target.y; });
625
+
626
+ node
627
+ .attr("cx", function (d) { return d.x+3; })
628
+ .attr("cy", function(d) { return d.y-3; });
629
+
630
+ text
631
+ .attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; })
632
+ }
633
+
634
+ function highlight_mouseover(d,i) {
635
+ console.log("highlight_mouseover", d,i, d3.select(this).attr("div-type"));
636
+ if (i.type_i == 0 ){
637
+ token_id = i.id
638
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
639
+ d3.select(this).transition()
640
+ .duration('50')
641
+ .style('opacity', '1')
642
+ .attr("r", 12)
643
+ type = d3.select(this).attr("div-type")
644
+ similar_ids.forEach(similar_token => {
645
+ node_id_name = dict_proj[similar_token][4]
646
+ d3.selectAll('.node_token-'+ node_id_name + type).attr("r",12 ).style('opacity', '1')//.raise()
647
+ // d3.selectAll('.text_token-'+ node_id_name).raise()
648
+ });
649
+ }
650
+ }
651
+
652
+
653
+ function highlight_mouseout(d,i) {
654
+ if (i.type_i == 0 ){
655
+ token_id = i.id
656
+ console.log("similar_vocab_queries", similar_vocab_queries, "this type:", d3.select(this).attr("div-type"));
657
+ similar_ids = similar_vocab_queries[token_id]['similar_topk'];
658
+ // clean_sentences();
659
+ d3.select(this).transition()
660
+ .duration('50')
661
+ .style('opacity', '.7')
662
+ .attr("r", 6)
663
+ type = d3.select(this).attr("div-type")
664
+ similar_ids.forEach(similar_token => {
665
+ node_id_name = dict_proj[similar_token][4]
666
+ d3.selectAll('.node_token-' + node_id_name + type).attr("r",6 ).style('opacity', '.7')
667
+ d3.selectAll("circle").raise()
668
+ });
669
+ }
670
+ }
671
+
672
+ function change_legend(d,i,j) {
673
+ console.log(d,i,dict_proj);
674
+ if (i['id'] in dict_proj){
675
+ // show_sentences(dict_proj[i[2]], i[2]);
676
+
677
+ show_similar_tokens(i['id'], '#similar_'+type);
678
+
679
+ console.log(dict_proj[i['id']]);
680
+ }
681
+ else{console.log("no sentence")};
682
+ }
683
+
684
+ function show_similar_tokens(token, div_name_similar='#similar_input_tokens') {
685
+ d3.select(div_name_similar).html("");
686
+ console.log("token", token);
687
+ console.log("similar_vocab_queries[token]", similar_vocab_queries[token]);
688
+ token_data = similar_vocab_queries[token];
689
+ console.log(token, token_data);
690
+ var decForm = d3.format(".3f");
691
+
692
+ d3.select(div_name_similar)
693
+ .selectAll().append("p")
694
+ .data(token_data['similar_topk'])
695
+ .enter()
696
+ .append("p").append('text')
697
+ // .attr('class_data', sent_id)
698
+ .attr('class_id', d => d)
699
+ .style("background", d=> {if (d == token) return "yellow"} )
700
+ // .text( d => d + " \n ");
701
+ .text((d,i) => do_text(d,i) );
702
+
703
+ function do_text(d,i){
704
+ console.log("do_text d,i" );
705
+ console.log(d,i);
706
+ console.log("data_dict[d], data_dict");
707
+ console.log(dict_proj[d], dict_proj);
708
+ return dict_proj[d][3] + " " + decForm(token_data['distance'][i]) + " ";
709
+ }
710
+
711
+
712
+ }
713
+
714
+ // svg.call(d3.zoom()
715
+ // .extent([[0, 0], [width, height]])
716
+ // .scaleExtent([1, 8])
717
+ // .on("zoom", zoomed));
718
+
719
+ // function zoomed({transform}) {
720
+ // circle.attr("transform", d => `translate(${transform.apply(d)})`);
721
+ // }
722
+
723
+ // svg.call(
724
+ // d3.zoom().on("zoom", (event) => {
725
+ // g.attr("transform", event.transform);
726
+ // })
727
+ // );
728
+ // body.node().scrollBy(totalWidth, 0);
729
+
730
+
731
+ return svg.node();
732
+ // return parent.node();
733
+
734
+ };
735
+
736
+
737
+
738
+
739
+
740
+
741
+
742
+
743
+
744
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ inseq
2
+ bertviz
3
+ jupyter
4
+ faiss-cpu