goodboyanush commited on
Commit
e8ffac3
Β·
1 Parent(s): f7b75de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -35
app.py CHANGED
@@ -28,42 +28,41 @@ def model_inference(text):
28
 
29
  logits = F.softmax(outputs.logits)
30
  logits = logits[0].detach().numpy()
31
- labels = {"human_written": 0, "AI_generated": 0, "Girlfriend_written": 1}
32
-
33
- html = "<html>With loads of Love <3 </html>"
34
- # ref_input_ids = torch.zeros_like(encodings['input_ids'])
35
- # lig = LayerIntegratedGradients(forward_func, model.distilbert.embeddings)
36
- # attributions_start, delta_start = lig.attribute(inputs=encodings['input_ids'],
37
- # baselines=ref_input_ids,
38
- # additional_forward_args=(encodings['attention_mask']),
39
- # return_convergence_delta=True,
40
- # target=0)
41
-
42
- # attributions_start_sum = summarize_attributions(attributions_start)
43
- # start_scores = forward_func(encodings['input_ids'], encodings['attention_mask'])
44
- # indices = encodings['input_ids'][0].detach().tolist()
45
- # all_tokens = tokenizer.convert_ids_to_tokens(indices)
46
- # ground_truth_start_ind = encodings['input_ids'][0][0].numpy()
47
-
48
- # start_position_vis = VisualizationDataRecord(
49
- # attributions_start_sum,
50
- # torch.max(torch.softmax(start_scores[0], dim=0)),
51
- # torch.argmax(start_scores),
52
- # torch.argmax(start_scores),
53
- # str(ground_truth_start_ind),
54
- # attributions_start_sum.sum(),
55
- # all_tokens,
56
- # delta_start)
57
-
58
-
59
- # print('\033[1m', 'Visualizations For Start Position', '\033[0m')
60
- # img = visualize_text([start_position_vis])
61
 
62
- # html = (
63
- # ""
64
- # + img
65
- # + ""
66
- # )
67
 
68
 
69
  return labels, html
 
28
 
29
  logits = F.softmax(outputs.logits)
30
  logits = logits[0].detach().numpy()
31
+ labels = {"human_written": float(logits[0]), "AI_generated": float(logits[1])}
32
+
33
+ ref_input_ids = torch.zeros_like(encodings['input_ids'])
34
+ lig = LayerIntegratedGradients(forward_func, model.distilbert.embeddings)
35
+ attributions_start, delta_start = lig.attribute(inputs=encodings['input_ids'],
36
+ baselines=ref_input_ids,
37
+ additional_forward_args=(encodings['attention_mask']),
38
+ return_convergence_delta=True,
39
+ target=0)
40
+
41
+ attributions_start_sum = summarize_attributions(attributions_start)
42
+ start_scores = forward_func(encodings['input_ids'], encodings['attention_mask'])
43
+ indices = encodings['input_ids'][0].detach().tolist()
44
+ all_tokens = tokenizer.convert_ids_to_tokens(indices)
45
+ ground_truth_start_ind = encodings['input_ids'][0][0].numpy()
46
+
47
+ start_position_vis = VisualizationDataRecord(
48
+ attributions_start_sum,
49
+ torch.max(torch.softmax(start_scores[0], dim=0)),
50
+ torch.argmax(start_scores),
51
+ torch.argmax(start_scores),
52
+ str(ground_truth_start_ind),
53
+ attributions_start_sum.sum(),
54
+ all_tokens,
55
+ delta_start)
56
+
57
+
58
+ print('\033[1m', 'Visualizations For Start Position', '\033[0m')
59
+ img = visualize_text([start_position_vis])
 
60
 
61
+ html = (
62
+ ""
63
+ + img
64
+ + ""
65
+ )
66
 
67
 
68
  return labels, html