import gradio as gr from pathlib import Path from model import Seq2SeqLightning import torch import spacy import matplotlib.pyplot as plt import matplotlib.ticker as ticker import wandb import os wandb_team = "maneel" wandb_project = "attentionvisualizer" wandb_model = "attentionvisualizer:v3" wandb_model_path = f"{wandb_team}/{wandb_project}/{wandb_model}" os.system("python3 -m spacy download en_core_web_sm") os.system("python3 -m spacy download de_core_news_sm") wandb.init() current_folder = Path().parent path = wandb.use_artifact(wandb_model_path).download(root=current_folder) checkpoint_path = 'model.ckpt' # Replace with the actual path checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu')) enc_tokenizer = checkpoint['enc_tokenizer'] de_tokenizer = checkpoint['de_tokenizer'] params = checkpoint['hyper_parameters']['params'] inf_model = Seq2SeqLightning(enc_tokenizer=enc_tokenizer, de_tokenizer=de_tokenizer, params=params) inf_model.load_state_dict(checkpoint['state_dict']) nlp = spacy.load('de_core_news_sm') ## Inference function def inference(model,enc_tokenizer,dec_tokenizer,sentence): model.eval() if isinstance(sentence, str): tokens = [token.text.lower() for token in nlp(sentence)] else: tokens = [token.lower() for token in sentence] tokens = [""] + tokens + [""] src_indexes = [enc_tokenizer[token] for token in tokens] src_tensor = torch.LongTensor(src_indexes).unsqueeze(0) src_mask = model.make_src_mask(src_tensor) with torch.no_grad(): enc_src = model.encoder(src_tensor, src_mask) trg_indexes = [de_tokenizer[""]] for i in range(50): trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0) trg_mask = model.make_trg_mask(trg_tensor) with torch.no_grad(): output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask) pred_token = output.argmax(2)[:,-1].item() trg_indexes.append(pred_token) if pred_token == de_tokenizer[""]: break trg_tokens = de_tokenizer.lookup_tokens(trg_indexes)[1:-1] return trg_tokens,attention ## Attentions def display_attention(sentence, translation, attention, n_heads=8, n_rows=4, n_cols=2): if isinstance(sentence, str): tokens = [token.text.lower() for token in nlp(sentence)] else: tokens = [token.lower() for token in sentence] assert n_rows * n_cols == n_heads fig = plt.figure(figsize=(15, 25)) for i in range(n_heads): ax = fig.add_subplot(n_rows, n_cols, i + 1) _attention = attention.squeeze(0)[i].cpu().detach().numpy() cax = ax.matshow(_attention, cmap='bwr', alpha=0.6) ax.tick_params(labelsize=12) xticks = [''] + tokens + [''] yticks = translation ax.set_xticks(range(len(xticks))) # Set the x-axis tick positions ax.set_yticks(range(len(yticks))) # Set the y-axis tick positions ax.set_xticklabels(xticks, rotation=45) ax.set_yticklabels(yticks) ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) ax.set_title("Head {}".format(i+1)) cbar = fig.colorbar(cax) cbar.set_label('Attention Score') return plt # Define the translation function def translate_sentence(input_text): translation, attention = inference(inf_model.model, enc_tokenizer, de_tokenizer, input_text) y = display_attention(input_text, translation, attention) return " ".join(translation),y # # Create a Gradio interface # gr_interface = gr.Interface( # fn=translate_sentence, # inputs="text", # outputs="text", # live=True, # Set to True to enable live updates # title="Translation with Attention Visualization", # description="Translate English to German with attention visualization.", # ) # @app.get("/") # def read_root(): # return {"message": "Welcome to the translation service!"} # # Define a route for the Gradio interface # @app.get("/start_gradio") # async def start_gradio_interface(): # return gr_interface # Run the FastAPI application with uvicorn gr.Interface( fn=translate_sentence, inputs="text", outputs=["text",gr.Plot(type="pil")], live=True, title="Translation with Attention Visualization", description="Translate English to German with attention visualization.", ).launch()