entfane commited on
Commit
d7cb09b
·
verified ·
1 Parent(s): fbfce85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -35,12 +35,19 @@ def analyze(user_message, assistant_reply):
35
 
36
  scores = torch.sigmoid(values[0]).cpu().numpy()
37
 
38
- # Clean up GPT-2 special characters
 
 
 
 
 
 
39
  def clean(tok):
40
  return tok.replace("Ġ", " ").replace("Ċ", "\\n").strip() or tok
41
 
42
  labels = [f"{clean(tok)} [{i}]" for i, tok in enumerate(tokens)]
43
- df = pd.DataFrame({"token": labels, "value score": scores.tolist()})
 
44
 
45
  stats = (
46
  f"**Tokens:** {len(tokens)} | "
 
35
 
36
  scores = torch.sigmoid(values[0]).cpu().numpy()
37
 
38
+ # Only keep tokens that belong to the assistant reply
39
+ # Find where the assistant reply starts in the token list
40
+ reply_tokens = tokenizer(assistant_reply, return_tensors="pt").input_ids[0].tolist()
41
+ n_reply = len(reply_tokens)
42
+ tokens = tokens[-n_reply:]
43
+ scores = scores[-n_reply:]
44
+
45
  def clean(tok):
46
  return tok.replace("Ġ", " ").replace("Ċ", "\\n").strip() or tok
47
 
48
  labels = [f"{clean(tok)} [{i}]" for i, tok in enumerate(tokens)]
49
+ df = pd.DataFrame({"token": labels, "value score": scores.tolist(), "order": list(range(len(tokens)))})
50
+ df = df.sort_values("order").drop(columns="order")
51
 
52
  stats = (
53
  f"**Tokens:** {len(tokens)} | "