File size: 4,475 Bytes
29a9e17
 
 
 
 
 
 
 
d70f1fd
29a9e17
 
 
 
d70f1fd
035151b
29a9e17
 
 
 
 
91203bd
29a9e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from pathlib import Path
from model import Seq2SeqLightning
import torch
import spacy
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import wandb
import os
wandb_team = "maneel"
wandb_project = "attentionvisualizer"
wandb_model = "attentionvisualizer:v3"
wandb_model_path = f"{wandb_team}/{wandb_project}/{wandb_model}"
os.system("python3 -m spacy download en_core_web_sm")
os.system("python3 -m spacy download de_core_news_sm")

wandb.init()
current_folder = Path().parent
path = wandb.use_artifact(wandb_model_path).download(root=current_folder)
checkpoint_path = 'model.ckpt'  # Replace with the actual path
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
enc_tokenizer = checkpoint['enc_tokenizer']
de_tokenizer = checkpoint['de_tokenizer']
params = checkpoint['hyper_parameters']['params']
inf_model = Seq2SeqLightning(enc_tokenizer=enc_tokenizer, de_tokenizer=de_tokenizer, params=params)
inf_model.load_state_dict(checkpoint['state_dict'])
nlp = spacy.load('de_core_news_sm')

## Inference function
def inference(model,enc_tokenizer,dec_tokenizer,sentence):
    model.eval()
    if isinstance(sentence, str):
            
            tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    tokens = ["<sos>"] + tokens + ["<eos>"]
    src_indexes = [enc_tokenizer[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)
    src_mask = model.make_src_mask(src_tensor)
    with torch.no_grad():
            enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [de_tokenizer["<sos>"]]

    for i in range(50):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == de_tokenizer["<eos>"]:
            break

    trg_tokens = de_tokenizer.lookup_tokens(trg_indexes)[1:-1]
    return trg_tokens,attention
## Attentions
def display_attention(sentence, translation, attention, n_heads=8, n_rows=4, n_cols=2):
    if isinstance(sentence, str):
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]
    assert n_rows * n_cols == n_heads
    fig = plt.figure(figsize=(15, 25))
    for i in range(n_heads):
        ax = fig.add_subplot(n_rows, n_cols, i + 1)

        _attention = attention.squeeze(0)[i].cpu().detach().numpy()

        cax = ax.matshow(_attention, cmap='bwr', alpha=0.6)

        ax.tick_params(labelsize=12)
        xticks = ['<sos>'] + tokens + ['<eos>']
        yticks = translation
        ax.set_xticks(range(len(xticks)))  # Set the x-axis tick positions
        ax.set_yticks(range(len(yticks)))  # Set the y-axis tick positions
        ax.set_xticklabels(xticks, rotation=45)
        ax.set_yticklabels(yticks)

        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.set_title("Head {}".format(i+1))
        cbar = fig.colorbar(cax)
        cbar.set_label('Attention Score')
    return plt



# Define the translation function
def translate_sentence(input_text):
    translation, attention = inference(inf_model.model, enc_tokenizer, de_tokenizer, input_text)
    y = display_attention(input_text, translation, attention)
    return " ".join(translation),y


# # Create a Gradio interface
# gr_interface = gr.Interface(
#     fn=translate_sentence,
#     inputs="text",
#     outputs="text",
#     live=True,  # Set to True to enable live updates
#     title="Translation with Attention Visualization",
#     description="Translate English to German with attention visualization.",
# )

# @app.get("/")
# def read_root():
#     return {"message": "Welcome to the translation service!"}

# # Define a route for the Gradio interface
# @app.get("/start_gradio")
# async def start_gradio_interface():
#     return gr_interface

# Run the FastAPI application with uvicorn
gr.Interface(
    fn=translate_sentence,
    inputs="text",
    outputs=["text",gr.Plot(type="pil")],
    live=True,
    title="Translation with Attention Visualization",
    description="Translate English to German with attention visualization.",
).launch()