Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import random | |
| from captum.attr import LayerConductance, LayerIntegratedGradients | |
| from visualization import VisualizationDataRecord, visualize_text | |
| from transformers import DistilBertTokenizerFast | |
| from transformers import DistilBertForSequenceClassification | |
| tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') | |
| model = DistilBertForSequenceClassification.from_pretrained("./checkpoint-57000") | |
| def summarize_attributions(attributions): | |
| attributions = attributions.sum(dim=-1).squeeze(0) | |
| attributions = attributions / torch.norm(attributions) | |
| return attributions | |
| def forward_func(inputs, attention_mask): | |
| outputs = model(input_ids=inputs, attention_mask=attention_mask) | |
| logits = F.softmax(outputs.logits) | |
| return logits | |
| def model_inference(text): | |
| encodings = tokenizer(text, truncation=True, padding=True, return_tensors="pt",) | |
| outputs = model(**encodings) | |
| logits = F.softmax(outputs.logits) | |
| logits = logits[0].detach().numpy() | |
| labels = {"human_written": float(logits[0]), "AI_generated": float(logits[1])} | |
| ref_input_ids = torch.zeros_like(encodings['input_ids']) | |
| lig = LayerIntegratedGradients(forward_func, model.distilbert.embeddings) | |
| attributions_start, delta_start = lig.attribute(inputs=encodings['input_ids'], | |
| baselines=ref_input_ids, | |
| additional_forward_args=(encodings['attention_mask']), | |
| return_convergence_delta=True, | |
| target=0) | |
| attributions_start_sum = summarize_attributions(attributions_start) | |
| start_scores = forward_func(encodings['input_ids'], encodings['attention_mask']) | |
| indices = encodings['input_ids'][0].detach().tolist() | |
| all_tokens = tokenizer.convert_ids_to_tokens(indices) | |
| ground_truth_start_ind = encodings['input_ids'][0][0].numpy() | |
| start_position_vis = VisualizationDataRecord( | |
| attributions_start_sum, | |
| torch.max(torch.softmax(start_scores[0], dim=0)), | |
| torch.argmax(start_scores), | |
| torch.argmax(start_scores), | |
| str(ground_truth_start_ind), | |
| attributions_start_sum.sum(), | |
| all_tokens, | |
| delta_start) | |
| print('\033[1m', 'Visualizations For Start Position', '\033[0m') | |
| img = visualize_text([start_position_vis]) | |
| html = ( | |
| "" | |
| + img | |
| + "" | |
| ) | |
| return labels, html | |
| input_text = gr.Textbox(placeholder="Enter sentence here...") | |
| output = [gr.outputs.Label(num_top_classes=2), "html"] | |
| #output = [gr.outputs.Label(num_top_classes=2), "highlight"] | |
| #output = [gr.outputs.Label(num_top_classes=2), "image"] | |
| # label = gr.outputs.Label(num_top_classes=2) | |
| demo = gr.Interface( | |
| model_inference, | |
| input_text, | |
| output, | |
| examples=[ | |
| ["Is this text written by a human?"], | |
| ["Or is this text generated by an AI model?"], | |
| ], | |
| ) | |
| demo.launch() |