Spaces:
Sleeping
Sleeping
| ### 1. Imports and class names setup ### | |
| import gradio as gr | |
| import os | |
| import torch | |
| import torchtext | |
| from model import xlmr_base_encoder_model | |
| from timeit import default_timer as timer | |
| from torchdata.datapipes.iter import IterableWrapper | |
| from torch.utils.data import DataLoader | |
| import torchtext.functional as F | |
| import torchtext.transforms as T | |
| from torch.hub import load_state_dict_from_url | |
| # Setup class names | |
| class_names = ["Bad", "Good"] | |
| ### 2. Model and transforms preparation ### | |
| model, transforms = xlmr_base_encoder_model( | |
| num_classes = 2 | |
| ) | |
| # load save weights | |
| model.load_state_dict( | |
| torch.load( | |
| f = "xlmr_base_encoder.pth", | |
| map_location = torch.device("cpu") # Load the model to the CPU | |
| ) | |
| ) | |
| ### 3. Predict function ### | |
| def predict(string): | |
| start_time = timer() | |
| var = (string, -9999999) | |
| dp = IterableWrapper([var]) | |
| dp = dp.sharding_filter() | |
| padding_idx = 1 | |
| bos_idx = 0 | |
| eos_idx = 2 | |
| max_seq_len = 256 | |
| xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" | |
| xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" | |
| text_transform = T.Sequential( | |
| T.SentencePieceTokenizer(xlmr_spm_model_path), | |
| T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), | |
| T.Truncate(max_seq_len-2), | |
| T.AddToken(token = bos_idx, begin = True), | |
| T.AddToken(token = eos_idx, begin = False) | |
| ) | |
| # Transform the raw dataset using non-batched API (i.e apply transformation line by line) | |
| def apply_transform(x): | |
| return text_transform(x[0]), x[1] | |
| dp = dp.map(apply_transform) | |
| dp = dp.batch(1) | |
| dp = dp.rows2columnar(["token_ids", "target"]) | |
| dp = DataLoader(dp, batch_size=None) | |
| val = next(iter(dp)) | |
| model.to('cpu') | |
| value = F.to_tensor(val["token_ids"], padding_value = padding_idx).to('cpu') | |
| # Pass transformed image through the model and turn the prediction logits into probabilities | |
| model.eval() | |
| with torch.inference_mode(): | |
| answer = model(value) | |
| print(answer) | |
| # answer = answer.argmax(1) | |
| answer = torch.softmax(answer, dim=1) | |
| pred_labels_and_probs = {class_names[i]: float(answer[0][i]) for i in range(len(class_names))} | |
| # Calculate pred time | |
| end_time = timer() | |
| pred_time = round(end_time - start_time, 4) | |
| # Return pred dict and pred time | |
| return pred_labels_and_probs, pred_time | |
| ### 4. Gradio app ### | |
| title = "Good or Bad" | |
| description = "Using XLMR_BASE_ENCODER" | |
| # Create the gradio demo | |
| demo = gr.Interface( | |
| fn = predict, # maps inputs to outputs | |
| inputs = "textbox", | |
| outputs=[ | |
| gr.Label(num_top_classes=2, label="Predictions"), | |
| gr.Number(label = "Prediction time(s) ") | |
| ], | |
| title = title, | |
| description = description, | |
| # article = article | |
| ) | |
| # launch the demo! | |
| demo.launch() | |