Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| from transformers import BertTokenizer, BertModel | |
| # Define the BertClassifier class | |
| class BertClassifier(nn.Module): | |
| def __init__(self, bert: BertModel, num_classes: int): | |
| super().__init__() | |
| self.bert = bert | |
| self.classifier = nn.Linear(bert.config.hidden_size, num_classes) | |
| self.criterion = nn.BCELoss() | |
| def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask | |
| ) | |
| cls_output = outputs.pooler_output | |
| cls_output = self.classifier(cls_output) | |
| cls_output = torch.sigmoid(cls_output) | |
| loss = 0 | |
| if labels is not None: | |
| loss = self.criterion(cls_output, labels) | |
| return loss, cls_output | |
| # Load the tokenizer and model | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| bert_model = BertModel.from_pretrained('bert-base-uncased') | |
| model = BertClassifier(bert_model, num_classes=7) | |
| # Load the model weights from the .pkl file | |
| model.load_state_dict(torch.load('bert_classifier_mltc.pkl', map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Define prediction function | |
| def predict(text): | |
| tokens = tokenizer.encode(text, add_special_tokens=True, max_length=512, truncation=True) | |
| input_ids = torch.tensor([tokens]) | |
| mask = (input_ids != tokenizer.pad_token_id).float() | |
| with torch.no_grad(): | |
| _, outputs = model(input_ids, attention_mask=mask) | |
| # Assuming outputs[0] contains the probability scores for each class | |
| confidence_scores = outputs[0].tolist() | |
| # Convert to a dictionary mapping labels to confidence scores | |
| labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', 'clean'] | |
| prediction = {label: score for label, score in zip(labels, confidence_scores)} | |
| return prediction | |
| # Add example texts | |
| examples = [ | |
| ["This is a positive and friendly comment."], | |
| ["I'm not happy with your service."], | |
| ["You are a great person!"], | |
| ["How can one defame someone who thinks the Fort Hood Shooting is justified by God's law?"], | |
| ["A Bisexual, like a homosexual or a heterosexual, is not defined by sexual activity. (Much like a 15 year old boy who is attracted to a girl sexually but has never had sex is still straight). A person who is actually sexually attracted/aroused by the same sex as well as the opposite sex is bisexual."], | |
| ["I'LL SAY NIGGER IF I WANT TO BRO"], | |
| ["Do you need a verifiable source for that statement? Check your own user page you fat, miserable balding fuck."], | |
| ["Fuck you bitch! I will go to your house and kill you with my gun as well! Prepare to die!"], | |
| ["you're hot, i will rape you ;)"], | |
| ["WOULDN'T BE THE FIRST TIME BITCH. FUCK YOU I'LL FIND OUT WHERE YOU LIVE, SODOMIZE YOUR WIFE AND THEN BURN YOUR HOUSE DOWN. FUCK YOU YOU FUCKING QUEER."], | |
| ["What a motherfucking piece of crap those fuckheads for blocking us!"], | |
| ["Get a life you animal fucker. Shut up you fucking nigger. Fuck off and shit your pants full of all the shit you can fill them with. 144.131.176.126"], | |
| ["HOPE YOUR HEAD GETS CUT OFF AND SOMEONE WIPS THERE ASS WITH IT AND THEN STABS YOU IN YOUR HEART"], | |
| ["you people are pretty overzealous with this whole free thing. get a fucking life, you fucking niggers !!!23 16!!!"], | |
| ["Stupid peace of shit stop deleting my stuff asshole go die and fall in a hole go to hell!"], | |
| ["Bye! Don't look, come or think of comming back! Tosser."] | |
| ] | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=10, placeholder="Enter a comment here..."), | |
| outputs=gr.Label(num_top_classes=7), | |
| examples=examples, | |
| title="Toxic Comment Classification", | |
| description="Classify comments into toxic and non-toxic categories using BERT and GNN model.", | |
| ) | |
| iface.launch() | |