Spaces:
Runtime error
Runtime error
File size: 4,475 Bytes
29a9e17 d70f1fd 29a9e17 d70f1fd 035151b 29a9e17 91203bd 29a9e17 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import gradio as gr
from pathlib import Path
from model import Seq2SeqLightning
import torch
import spacy
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import wandb
import os
wandb_team = "maneel"
wandb_project = "attentionvisualizer"
wandb_model = "attentionvisualizer:v3"
wandb_model_path = f"{wandb_team}/{wandb_project}/{wandb_model}"
os.system("python3 -m spacy download en_core_web_sm")
os.system("python3 -m spacy download de_core_news_sm")
wandb.init()
current_folder = Path().parent
path = wandb.use_artifact(wandb_model_path).download(root=current_folder)
checkpoint_path = 'model.ckpt' # Replace with the actual path
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
enc_tokenizer = checkpoint['enc_tokenizer']
de_tokenizer = checkpoint['de_tokenizer']
params = checkpoint['hyper_parameters']['params']
inf_model = Seq2SeqLightning(enc_tokenizer=enc_tokenizer, de_tokenizer=de_tokenizer, params=params)
inf_model.load_state_dict(checkpoint['state_dict'])
nlp = spacy.load('de_core_news_sm')
## Inference function
def inference(model,enc_tokenizer,dec_tokenizer,sentence):
model.eval()
if isinstance(sentence, str):
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = ["<sos>"] + tokens + ["<eos>"]
src_indexes = [enc_tokenizer[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src = model.encoder(src_tensor, src_mask)
trg_indexes = [de_tokenizer["<sos>"]]
for i in range(50):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == de_tokenizer["<eos>"]:
break
trg_tokens = de_tokenizer.lookup_tokens(trg_indexes)[1:-1]
return trg_tokens,attention
## Attentions
def display_attention(sentence, translation, attention, n_heads=8, n_rows=4, n_cols=2):
if isinstance(sentence, str):
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
assert n_rows * n_cols == n_heads
fig = plt.figure(figsize=(15, 25))
for i in range(n_heads):
ax = fig.add_subplot(n_rows, n_cols, i + 1)
_attention = attention.squeeze(0)[i].cpu().detach().numpy()
cax = ax.matshow(_attention, cmap='bwr', alpha=0.6)
ax.tick_params(labelsize=12)
xticks = ['<sos>'] + tokens + ['<eos>']
yticks = translation
ax.set_xticks(range(len(xticks))) # Set the x-axis tick positions
ax.set_yticks(range(len(yticks))) # Set the y-axis tick positions
ax.set_xticklabels(xticks, rotation=45)
ax.set_yticklabels(yticks)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_title("Head {}".format(i+1))
cbar = fig.colorbar(cax)
cbar.set_label('Attention Score')
return plt
# Define the translation function
def translate_sentence(input_text):
translation, attention = inference(inf_model.model, enc_tokenizer, de_tokenizer, input_text)
y = display_attention(input_text, translation, attention)
return " ".join(translation),y
# # Create a Gradio interface
# gr_interface = gr.Interface(
# fn=translate_sentence,
# inputs="text",
# outputs="text",
# live=True, # Set to True to enable live updates
# title="Translation with Attention Visualization",
# description="Translate English to German with attention visualization.",
# )
# @app.get("/")
# def read_root():
# return {"message": "Welcome to the translation service!"}
# # Define a route for the Gradio interface
# @app.get("/start_gradio")
# async def start_gradio_interface():
# return gr_interface
# Run the FastAPI application with uvicorn
gr.Interface(
fn=translate_sentence,
inputs="text",
outputs=["text",gr.Plot(type="pil")],
live=True,
title="Translation with Attention Visualization",
description="Translate English to German with attention visualization.",
).launch()
|