achyut's picture
app.py
c8c4967
from transformers import pipeline, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
import lime
from lime.lime_text import LimeTextExplainer
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained("achyut/narrativemodel")
model = AutoModelForSequenceClassification.from_pretrained("achyut/narrativemodel")
model.cuda()
import more_itertools
from transformers import pipeline
def my_preds(texts, batch_size=64):
probs = []
for chunk in more_itertools.chunked(texts, batch_size):
tokenized = tokenizer(chunk, return_tensors="pt", padding=True)
outputs = model(tokenized['input_ids'].to('cuda'), tokenized['attention_mask'].to('cuda'))
probs.append(F.softmax(outputs.logits).cpu().detach().numpy())
return np.vstack(probs)
title = "Narrative Detection and Feature Interpretability Using BERT"
description = """The BERT model outputs the probability scores for narrativity given a text input followed by a feature analysis plot implemented using LIME
"""
narrative_detection = pipeline(model="achyut/narrativemodel")
class_names = ['non-narative','narative']
explainer = LimeTextExplainer(class_names=class_names)
def func(string):
exp = explainer.explain_instance(string, my_preds, num_features= 25, num_samples = 5000)
results = narrative_detection(string)
return "Narrative" if results[0]["label"] == "LABEL_1" else "Non-Narrative", round(results[0]["score"], 3), exp.as_pyplot_figure()
demo = gr.Interface(
func,
gr.Textbox(placeholder="Enter sentence here..."),
[gr.outputs.Textbox(label="Narrative Label"),
gr.outputs.Textbox(label="Narrativity Score"),"plot"],
examples=[
["I was 5 years old when my dad gifted me a watch and i have been wearing it ever since."],
["You guys have no idea how many people you're putting in danger. The towers don't cause covied"],
["i took my vaccine a few days ago and i feel amazing"],
["I am a boy"],
["The world and humanity has had a lot in the last two years and we need to hit that refresh button"],
["if you went to Dr. Sam as a kid, don't worry about what's in the vaccine"]
],
title = title,
description = description,
)
demo.launch()