import gradio as gr import torch import torch.nn.functional as F import random from captum.attr import LayerConductance, LayerIntegratedGradients from visualization import VisualizationDataRecord, visualize_text from transformers import DistilBertTokenizerFast from transformers import DistilBertForSequenceClassification tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') model = DistilBertForSequenceClassification.from_pretrained("./checkpoint-57000") def summarize_attributions(attributions): attributions = attributions.sum(dim=-1).squeeze(0) attributions = attributions / torch.norm(attributions) return attributions def forward_func(inputs, attention_mask): outputs = model(input_ids=inputs, attention_mask=attention_mask) logits = F.softmax(outputs.logits) return logits def model_inference(text): encodings = tokenizer(text, truncation=True, padding=True, return_tensors="pt",) outputs = model(**encodings) logits = F.softmax(outputs.logits) logits = logits[0].detach().numpy() labels = {"human_written": float(logits[0]), "AI_generated": float(logits[1])} ref_input_ids = torch.zeros_like(encodings['input_ids']) lig = LayerIntegratedGradients(forward_func, model.distilbert.embeddings) attributions_start, delta_start = lig.attribute(inputs=encodings['input_ids'], baselines=ref_input_ids, additional_forward_args=(encodings['attention_mask']), return_convergence_delta=True, target=0) attributions_start_sum = summarize_attributions(attributions_start) start_scores = forward_func(encodings['input_ids'], encodings['attention_mask']) indices = encodings['input_ids'][0].detach().tolist() all_tokens = tokenizer.convert_ids_to_tokens(indices) ground_truth_start_ind = encodings['input_ids'][0][0].numpy() start_position_vis = VisualizationDataRecord( attributions_start_sum, torch.max(torch.softmax(start_scores[0], dim=0)), torch.argmax(start_scores), torch.argmax(start_scores), str(ground_truth_start_ind), attributions_start_sum.sum(), all_tokens, delta_start) print('\033[1m', 'Visualizations For Start Position', '\033[0m') img = visualize_text([start_position_vis]) html = ( "" + img + "" ) return labels, html input_text = gr.Textbox(placeholder="Enter sentence here...") output = [gr.outputs.Label(num_top_classes=2), "html"] #output = [gr.outputs.Label(num_top_classes=2), "highlight"] #output = [gr.outputs.Label(num_top_classes=2), "image"] # label = gr.outputs.Label(num_top_classes=2) demo = gr.Interface( model_inference, input_text, output, examples=[ ["Is this text written by a human?"], ["Or is this text generated by an AI model?"], ], ) demo.launch()