realTextGPT / app.py
goodboyanush's picture
Update app.py
e8ffac3
import gradio as gr
import torch
import torch.nn.functional as F
import random
from captum.attr import LayerConductance, LayerIntegratedGradients
from visualization import VisualizationDataRecord, visualize_text
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained("./checkpoint-57000")
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def forward_func(inputs, attention_mask):
outputs = model(input_ids=inputs, attention_mask=attention_mask)
logits = F.softmax(outputs.logits)
return logits
def model_inference(text):
encodings = tokenizer(text, truncation=True, padding=True, return_tensors="pt",)
outputs = model(**encodings)
logits = F.softmax(outputs.logits)
logits = logits[0].detach().numpy()
labels = {"human_written": float(logits[0]), "AI_generated": float(logits[1])}
ref_input_ids = torch.zeros_like(encodings['input_ids'])
lig = LayerIntegratedGradients(forward_func, model.distilbert.embeddings)
attributions_start, delta_start = lig.attribute(inputs=encodings['input_ids'],
baselines=ref_input_ids,
additional_forward_args=(encodings['attention_mask']),
return_convergence_delta=True,
target=0)
attributions_start_sum = summarize_attributions(attributions_start)
start_scores = forward_func(encodings['input_ids'], encodings['attention_mask'])
indices = encodings['input_ids'][0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)
ground_truth_start_ind = encodings['input_ids'][0][0].numpy()
start_position_vis = VisualizationDataRecord(
attributions_start_sum,
torch.max(torch.softmax(start_scores[0], dim=0)),
torch.argmax(start_scores),
torch.argmax(start_scores),
str(ground_truth_start_ind),
attributions_start_sum.sum(),
all_tokens,
delta_start)
print('\033[1m', 'Visualizations For Start Position', '\033[0m')
img = visualize_text([start_position_vis])
html = (
""
+ img
+ ""
)
return labels, html
input_text = gr.Textbox(placeholder="Enter sentence here...")
output = [gr.outputs.Label(num_top_classes=2), "html"]
#output = [gr.outputs.Label(num_top_classes=2), "highlight"]
#output = [gr.outputs.Label(num_top_classes=2), "image"]
# label = gr.outputs.Label(num_top_classes=2)
demo = gr.Interface(
model_inference,
input_text,
output,
examples=[
["Is this text written by a human?"],
["Or is this text generated by an AI model?"],
],
)
demo.launch()