Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from pathlib import Path | |
| from model import Seq2SeqLightning | |
| import torch | |
| import spacy | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| import wandb | |
| import os | |
| wandb_team = "maneel" | |
| wandb_project = "attentionvisualizer" | |
| wandb_model = "attentionvisualizer:v3" | |
| wandb_model_path = f"{wandb_team}/{wandb_project}/{wandb_model}" | |
| os.system("python3 -m spacy download en_core_web_sm") | |
| os.system("python3 -m spacy download de_core_news_sm") | |
| wandb.init() | |
| current_folder = Path().parent | |
| path = wandb.use_artifact(wandb_model_path).download(root=current_folder) | |
| checkpoint_path = 'model.ckpt' # Replace with the actual path | |
| checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu')) | |
| enc_tokenizer = checkpoint['enc_tokenizer'] | |
| de_tokenizer = checkpoint['de_tokenizer'] | |
| params = checkpoint['hyper_parameters']['params'] | |
| inf_model = Seq2SeqLightning(enc_tokenizer=enc_tokenizer, de_tokenizer=de_tokenizer, params=params) | |
| inf_model.load_state_dict(checkpoint['state_dict']) | |
| nlp = spacy.load('de_core_news_sm') | |
| ## Inference function | |
| def inference(model,enc_tokenizer,dec_tokenizer,sentence): | |
| model.eval() | |
| if isinstance(sentence, str): | |
| tokens = [token.text.lower() for token in nlp(sentence)] | |
| else: | |
| tokens = [token.lower() for token in sentence] | |
| tokens = ["<sos>"] + tokens + ["<eos>"] | |
| src_indexes = [enc_tokenizer[token] for token in tokens] | |
| src_tensor = torch.LongTensor(src_indexes).unsqueeze(0) | |
| src_mask = model.make_src_mask(src_tensor) | |
| with torch.no_grad(): | |
| enc_src = model.encoder(src_tensor, src_mask) | |
| trg_indexes = [de_tokenizer["<sos>"]] | |
| for i in range(50): | |
| trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0) | |
| trg_mask = model.make_trg_mask(trg_tensor) | |
| with torch.no_grad(): | |
| output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask) | |
| pred_token = output.argmax(2)[:,-1].item() | |
| trg_indexes.append(pred_token) | |
| if pred_token == de_tokenizer["<eos>"]: | |
| break | |
| trg_tokens = de_tokenizer.lookup_tokens(trg_indexes)[1:-1] | |
| return trg_tokens,attention | |
| ## Attentions | |
| def display_attention(sentence, translation, attention, n_heads=8, n_rows=4, n_cols=2): | |
| if isinstance(sentence, str): | |
| tokens = [token.text.lower() for token in nlp(sentence)] | |
| else: | |
| tokens = [token.lower() for token in sentence] | |
| assert n_rows * n_cols == n_heads | |
| fig = plt.figure(figsize=(15, 25)) | |
| for i in range(n_heads): | |
| ax = fig.add_subplot(n_rows, n_cols, i + 1) | |
| _attention = attention.squeeze(0)[i].cpu().detach().numpy() | |
| cax = ax.matshow(_attention, cmap='bwr', alpha=0.6) | |
| ax.tick_params(labelsize=12) | |
| xticks = ['<sos>'] + tokens + ['<eos>'] | |
| yticks = translation | |
| ax.set_xticks(range(len(xticks))) # Set the x-axis tick positions | |
| ax.set_yticks(range(len(yticks))) # Set the y-axis tick positions | |
| ax.set_xticklabels(xticks, rotation=45) | |
| ax.set_yticklabels(yticks) | |
| ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) | |
| ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) | |
| ax.set_title("Head {}".format(i+1)) | |
| cbar = fig.colorbar(cax) | |
| cbar.set_label('Attention Score') | |
| return plt | |
| # Define the translation function | |
| def translate_sentence(input_text): | |
| translation, attention = inference(inf_model.model, enc_tokenizer, de_tokenizer, input_text) | |
| y = display_attention(input_text, translation, attention) | |
| return " ".join(translation),y | |
| # # Create a Gradio interface | |
| # gr_interface = gr.Interface( | |
| # fn=translate_sentence, | |
| # inputs="text", | |
| # outputs="text", | |
| # live=True, # Set to True to enable live updates | |
| # title="Translation with Attention Visualization", | |
| # description="Translate English to German with attention visualization.", | |
| # ) | |
| # @app.get("/") | |
| # def read_root(): | |
| # return {"message": "Welcome to the translation service!"} | |
| # # Define a route for the Gradio interface | |
| # @app.get("/start_gradio") | |
| # async def start_gradio_interface(): | |
| # return gr_interface | |
| # Run the FastAPI application with uvicorn | |
| gr.Interface( | |
| fn=translate_sentence, | |
| inputs="text", | |
| outputs=["text",gr.Plot(type="pil")], | |
| live=True, | |
| title="Translation with Attention Visualization", | |
| description="Translate English to German with attention visualization.", | |
| ).launch() | |