Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import io | |
| from torch import nn | |
| from transformers import AutoModelForSequenceClassification | |
| from sklearn.pipeline import Pipeline | |
| from skorch import NeuralNetClassifier | |
| from skorch.callbacks import LRScheduler, ProgressBar | |
| from skorch.hf import HuggingfacePretrainedTokenizer | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from skorch.callbacks import EarlyStopping | |
| from sklearn.metrics import precision_recall_fscore_support | |
| from sklearn.metrics import balanced_accuracy_score | |
| from modAL.models import ActiveLearner | |
| from modAL.uncertainty import uncertainty_sampling | |
| class BertModule(nn.Module): | |
| """ BERT model according to Skorch convention """ | |
| def __init__(self, name, num_labels): | |
| super().__init__() | |
| self.name = name | |
| self.num_labels = num_labels | |
| self.reset_weights() | |
| def reset_weights(self): | |
| self.bert = AutoModelForSequenceClassification.from_pretrained( | |
| self.name, num_labels=self.num_labels | |
| ) | |
| def forward(self, **kwargs): | |
| pred = self.bert(**kwargs) | |
| return pred.logits | |
| MAX_EPOCHS = 5 | |
| BATCH_SIZE = 12 | |
| num_training_steps = MAX_EPOCHS * (584 // BATCH_SIZE + 1) | |
| def lr_schedule(current_step): | |
| factor = float(num_training_steps - current_step) / float(max(1, num_training_steps)) | |
| assert factor > 0 | |
| return factor | |
| class CPU_Unpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if module == 'torch.storage' and name == '_load_from_bytes': | |
| return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
| else: return super().find_class(module, name) | |
| with open('learner.bin', 'rb') as f: | |
| learner = CPU_Unpickler(f).load() | |
| def dg_predict(tweet): | |
| return '🐕' if learner.predict([tweet])[0] == 1 else '🐈' | |
| examples = [ | |
| "Såå kulturberikande med explosioner utanför porten 🤗 #tackmagda", | |
| "Att WEF ska kunna styra över svenskar är fascistiskt", | |
| "Det råder inget tvivel om att svensk kultur inte är svensk längre, utan dessa MENA kommer hit och lever på bidrag och skjuter i våra förorter", | |
| "Kriminella måste skickas hem där de kommer ifrån!! Återvandring!!", | |
| "En global konspiration från Bryssel för att försvaga vår suveränitet, tills vi alla måste buga för Soros" | |
| ] | |
| iface = gr.Interface(fn=dg_predict, inputs="text", outputs="text", examples=examples) | |
| iface.launch() |