Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT | |
| # Generate a colormap with a specified number of colors | |
| cmap = plt.cm.get_cmap(palette, num_colors) | |
| # Get the RGB values of the colors in the colormap | |
| colors_rgb = cmap(np.arange(num_colors)) | |
| # Convert the RGB values to hexadecimal color codes | |
| colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] | |
| return colors_hex | |
| def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, | |
| threshold=0.4, skip_first_src=True, skip_second_src=False, | |
| layer=2, head=6): | |
| alignment = [] | |
| # threshold = 0.05 | |
| for i, tok in enumerate(outputs.cross_attentions[layer][0][head]): | |
| alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) | |
| # for i in alignment: | |
| # src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]] | |
| # trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]] | |
| # print(src_tok, "=>", trg_tok) | |
| merged = [] | |
| for i in alignment: | |
| token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] | |
| # print(token) | |
| if token not in ["</s>", "<pad>", "<unk>", "<s>"]: | |
| if merged: | |
| tomerge = False | |
| # check overlap with previous entry | |
| for x in i[1]: | |
| if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "β": | |
| tomerge = True | |
| break | |
| # if first character is not a "β" | |
| if token[0] != "β": | |
| tomerge = True | |
| if tomerge: | |
| merged[-1][0] += i[0] | |
| merged[-1][1] += i[1] | |
| else: | |
| merged.append(i) | |
| else: | |
| merged.append(i) | |
| # print("=====MERGED=====") | |
| # for i in merged: | |
| # src_tok = [tokenizer.decode(decoder_input_ids[0][x]) for x in i[0]] | |
| # trg_tok = [tokenizer.decode(encoder_input_ids[0][x]) for x in i[1]] | |
| # print(src_tok, "=>", trg_tok) | |
| colordict = {} | |
| ncolors = 0 | |
| for i in merged: | |
| src_tok = [f"src_{x}" for x in i[0]] | |
| trg_tok = [f"trg_{x}" for x in i[1]] | |
| all_tok = src_tok + trg_tok | |
| # see if any tokens in entry already have associated color | |
| newcolor = None | |
| for t in all_tok: | |
| if t in colordict: | |
| newcolor = colordict[t] | |
| break | |
| if not newcolor: | |
| newcolor = ncolors | |
| ncolors += 1 | |
| for t in all_tok: | |
| if t not in colordict: | |
| colordict[t] = newcolor | |
| colors = generate_diverging_colors(ncolors, palette="Set2") | |
| id_to_color = {i: c for i, c in enumerate(colors)} | |
| for k, v in colordict.items(): | |
| colordict[k] = id_to_color[v] | |
| tgthtml = [] | |
| for i, token in enumerate(decoder_input_ids[0]): | |
| if f"src_{i}" in colordict: | |
| label = f"src_{i}" | |
| tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| else: | |
| tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| tgthtml = "".join(tgthtml) | |
| tgthtml = tgthtml.replace("β", " ") | |
| tgthtml = f"<span style='font-size: 25px'>{tgthtml}</span>" | |
| srchtml = [] | |
| for i, token in enumerate(encoder_input_ids[0]): | |
| if (i == 0 and skip_first_src) or (i == 1 and skip_second_src): | |
| continue | |
| if f"trg_{i}" in colordict: | |
| label = f"trg_{i}" | |
| srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| else: | |
| srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
| srchtml = "".join(srchtml) | |
| srchtml = srchtml.replace("β", " ") | |
| srchtml = f"<span style='font-size: 25px'>{srchtml}</span>" | |
| return srchtml, tgthtml | |