Maneel
Requirementex.txt
035151b
import gradio as gr
from pathlib import Path
from model import Seq2SeqLightning
import torch
import spacy
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import wandb
import os
wandb_team = "maneel"
wandb_project = "attentionvisualizer"
wandb_model = "attentionvisualizer:v3"
wandb_model_path = f"{wandb_team}/{wandb_project}/{wandb_model}"
os.system("python3 -m spacy download en_core_web_sm")
os.system("python3 -m spacy download de_core_news_sm")
wandb.init()
current_folder = Path().parent
path = wandb.use_artifact(wandb_model_path).download(root=current_folder)
checkpoint_path = 'model.ckpt' # Replace with the actual path
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
enc_tokenizer = checkpoint['enc_tokenizer']
de_tokenizer = checkpoint['de_tokenizer']
params = checkpoint['hyper_parameters']['params']
inf_model = Seq2SeqLightning(enc_tokenizer=enc_tokenizer, de_tokenizer=de_tokenizer, params=params)
inf_model.load_state_dict(checkpoint['state_dict'])
nlp = spacy.load('de_core_news_sm')
## Inference function
def inference(model,enc_tokenizer,dec_tokenizer,sentence):
model.eval()
if isinstance(sentence, str):
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = ["<sos>"] + tokens + ["<eos>"]
src_indexes = [enc_tokenizer[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src = model.encoder(src_tensor, src_mask)
trg_indexes = [de_tokenizer["<sos>"]]
for i in range(50):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == de_tokenizer["<eos>"]:
break
trg_tokens = de_tokenizer.lookup_tokens(trg_indexes)[1:-1]
return trg_tokens,attention
## Attentions
def display_attention(sentence, translation, attention, n_heads=8, n_rows=4, n_cols=2):
if isinstance(sentence, str):
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
assert n_rows * n_cols == n_heads
fig = plt.figure(figsize=(15, 25))
for i in range(n_heads):
ax = fig.add_subplot(n_rows, n_cols, i + 1)
_attention = attention.squeeze(0)[i].cpu().detach().numpy()
cax = ax.matshow(_attention, cmap='bwr', alpha=0.6)
ax.tick_params(labelsize=12)
xticks = ['<sos>'] + tokens + ['<eos>']
yticks = translation
ax.set_xticks(range(len(xticks))) # Set the x-axis tick positions
ax.set_yticks(range(len(yticks))) # Set the y-axis tick positions
ax.set_xticklabels(xticks, rotation=45)
ax.set_yticklabels(yticks)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_title("Head {}".format(i+1))
cbar = fig.colorbar(cax)
cbar.set_label('Attention Score')
return plt
# Define the translation function
def translate_sentence(input_text):
translation, attention = inference(inf_model.model, enc_tokenizer, de_tokenizer, input_text)
y = display_attention(input_text, translation, attention)
return " ".join(translation),y
# # Create a Gradio interface
# gr_interface = gr.Interface(
# fn=translate_sentence,
# inputs="text",
# outputs="text",
# live=True, # Set to True to enable live updates
# title="Translation with Attention Visualization",
# description="Translate English to German with attention visualization.",
# )
# @app.get("/")
# def read_root():
# return {"message": "Welcome to the translation service!"}
# # Define a route for the Gradio interface
# @app.get("/start_gradio")
# async def start_gradio_interface():
# return gr_interface
# Run the FastAPI application with uvicorn
gr.Interface(
fn=translate_sentence,
inputs="text",
outputs=["text",gr.Plot(type="pil")],
live=True,
title="Translation with Attention Visualization",
description="Translate English to German with attention visualization.",
).launch()