Spaces:
Runtime error
Runtime error
Commit Β·
e8ffac3
1
Parent(s): f7b75de
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,42 +28,41 @@ def model_inference(text):
|
|
| 28 |
|
| 29 |
logits = F.softmax(outputs.logits)
|
| 30 |
logits = logits[0].detach().numpy()
|
| 31 |
-
labels = {"human_written": 0, "AI_generated":
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# img = visualize_text([start_position_vis])
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
|
| 69 |
return labels, html
|
|
|
|
| 28 |
|
| 29 |
logits = F.softmax(outputs.logits)
|
| 30 |
logits = logits[0].detach().numpy()
|
| 31 |
+
labels = {"human_written": float(logits[0]), "AI_generated": float(logits[1])}
|
| 32 |
+
|
| 33 |
+
ref_input_ids = torch.zeros_like(encodings['input_ids'])
|
| 34 |
+
lig = LayerIntegratedGradients(forward_func, model.distilbert.embeddings)
|
| 35 |
+
attributions_start, delta_start = lig.attribute(inputs=encodings['input_ids'],
|
| 36 |
+
baselines=ref_input_ids,
|
| 37 |
+
additional_forward_args=(encodings['attention_mask']),
|
| 38 |
+
return_convergence_delta=True,
|
| 39 |
+
target=0)
|
| 40 |
+
|
| 41 |
+
attributions_start_sum = summarize_attributions(attributions_start)
|
| 42 |
+
start_scores = forward_func(encodings['input_ids'], encodings['attention_mask'])
|
| 43 |
+
indices = encodings['input_ids'][0].detach().tolist()
|
| 44 |
+
all_tokens = tokenizer.convert_ids_to_tokens(indices)
|
| 45 |
+
ground_truth_start_ind = encodings['input_ids'][0][0].numpy()
|
| 46 |
+
|
| 47 |
+
start_position_vis = VisualizationDataRecord(
|
| 48 |
+
attributions_start_sum,
|
| 49 |
+
torch.max(torch.softmax(start_scores[0], dim=0)),
|
| 50 |
+
torch.argmax(start_scores),
|
| 51 |
+
torch.argmax(start_scores),
|
| 52 |
+
str(ground_truth_start_ind),
|
| 53 |
+
attributions_start_sum.sum(),
|
| 54 |
+
all_tokens,
|
| 55 |
+
delta_start)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
print('\033[1m', 'Visualizations For Start Position', '\033[0m')
|
| 59 |
+
img = visualize_text([start_position_vis])
|
|
|
|
| 60 |
|
| 61 |
+
html = (
|
| 62 |
+
""
|
| 63 |
+
+ img
|
| 64 |
+
+ ""
|
| 65 |
+
)
|
| 66 |
|
| 67 |
|
| 68 |
return labels, html
|