Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
import re
|
| 12 |
+
#import smart_open
|
| 13 |
+
import plotly.express as px
|
| 14 |
+
import random
|
| 15 |
+
#import difflib
|
| 16 |
+
import pdb
|
| 17 |
+
|
| 18 |
+
from sentence_transformers import SentenceTransformer, models, util
|
| 19 |
+
|
| 20 |
+
enable_summary_button = True
|
| 21 |
+
dump_pos_data_for_reporting = True
|
| 22 |
+
|
| 23 |
+
bucket_name = "paper_n1"
|
| 24 |
+
|
| 25 |
+
prefix_lst = [
|
| 26 |
+
"pgj_d_4096",
|
| 27 |
+
"pgj_d_2048",
|
| 28 |
+
"pgj_d_1024_v2",
|
| 29 |
+
"pgj_d_1024_layer_14",
|
| 30 |
+
"pgj_d_1024_layer_7",
|
| 31 |
+
"pgj_d_1024_layer_2",
|
| 32 |
+
"pgj_d_1024_layer_1" ]
|
| 33 |
+
|
| 34 |
+
# "my_gptj_6b_tpu_size_8",
|
| 35 |
+
|
| 36 |
+
model_names = {
|
| 37 |
+
prefix_lst[0]: 'PatentGPT-J-6B',
|
| 38 |
+
prefix_lst[1]: 'PatentGPT-J-1.6B',
|
| 39 |
+
|
| 40 |
+
# prefix_lst[2]: 'PatentGPT-J-279M',
|
| 41 |
+
# prefix_lst[3]: 'PatentGPT-J-191M',
|
| 42 |
+
# prefix_lst[4]: 'PatentGPT-J-128M',
|
| 43 |
+
# prefix_lst[5]: 'PatentGPT-J-115M',}
|
| 44 |
+
|
| 45 |
+
prefix_lst[2]: 'PatentGPT-J-456M',
|
| 46 |
+
prefix_lst[3]: 'PatentGPT-J-279M',
|
| 47 |
+
prefix_lst[4]: 'PatentGPT-J-191M',
|
| 48 |
+
prefix_lst[5]: 'PatentGPT-J-128M',
|
| 49 |
+
prefix_lst[6]: 'PatentGPT-J-115M',}
|
| 50 |
+
|
| 51 |
+
# prefix_lst[7]:'GPT-J-6B'
|
| 52 |
+
|
| 53 |
+
# experiment 3
|
| 54 |
+
# folder = os.path.join('experiments', 'non_patent')
|
| 55 |
+
# id_to_scroll = 1 # which of the above to scroll through
|
| 56 |
+
# first_claim_only = True
|
| 57 |
+
|
| 58 |
+
#experiment 2
|
| 59 |
+
# folder = os.path.join('experiments', 'ipg20220104_500')
|
| 60 |
+
# #folder = "device_serve_results"
|
| 61 |
+
# id_to_scroll = 1 # which of the above to scroll through
|
| 62 |
+
# first_claim_only = False
|
| 63 |
+
|
| 64 |
+
# prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
|
| 65 |
+
# #, "pgj_large", "pgj_medium", "pgj_small", ]
|
| 66 |
+
# # "pgj_d_1024_layer_14"
|
| 67 |
+
|
| 68 |
+
# experiment 1
|
| 69 |
+
folder = os.path.join('experiments', 'ipg22_500')
|
| 70 |
+
# (previous) folder = "eval_ipg22_500"
|
| 71 |
+
id_to_scroll = 1 # which of the above to scroll through
|
| 72 |
+
first_claim_only = True
|
| 73 |
+
|
| 74 |
+
ignore_outscope = True # ignore pick > 10
|
| 75 |
+
|
| 76 |
+
# def show_diff(a, b):
|
| 77 |
+
# #print('{} => {}'.format(a,b))
|
| 78 |
+
# for i, s in enumerate(difflib.ndiff(a, b)):
|
| 79 |
+
# if s[0]==' ': continue
|
| 80 |
+
# elif s[0]=='-':
|
| 81 |
+
# print(u'Delete "{}" from position {}'.format(s[-1],i))
|
| 82 |
+
# elif s[0]=='+':
|
| 83 |
+
# print(u'Add "{}" to position {}'.format(s[-1],i))
|
| 84 |
+
|
| 85 |
+
def handle_char_return(text):
|
| 86 |
+
if text == '(none)': # unicorn text
|
| 87 |
+
text == ''
|
| 88 |
+
|
| 89 |
+
return text
|
| 90 |
+
|
| 91 |
+
#return ch.replace('\n', '\\n')
|
| 92 |
+
|
| 93 |
+
#if ch == '\n':
|
| 94 |
+
# ch = "'\\n'"
|
| 95 |
+
#return ch
|
| 96 |
+
|
| 97 |
+
def get_remaining(lst, pos):
|
| 98 |
+
s = ''
|
| 99 |
+
for i in range(pos, len(lst)):
|
| 100 |
+
text = lst[i]['actual_next_token_text']
|
| 101 |
+
if text.startswith(' ') == False:
|
| 102 |
+
s += text
|
| 103 |
+
else:
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
return s
|
| 107 |
+
|
| 108 |
+
def calc_details(base_fn):
|
| 109 |
+
full_fn = os.path.join(folder, base_fn)
|
| 110 |
+
#gs_fn = "gs://%s/%s/%s" % (bucket_name, folder, base_fn)
|
| 111 |
+
#with smart_open.open(gs_fn) as f:
|
| 112 |
+
|
| 113 |
+
if os.path.exists(full_fn) == False:
|
| 114 |
+
return None, -1, -1, None, None, None, None, None
|
| 115 |
+
|
| 116 |
+
with open(full_fn) as f:
|
| 117 |
+
result = json.loads(f.read())
|
| 118 |
+
print("Loaded: %s" % full_fn)
|
| 119 |
+
|
| 120 |
+
lst = result['output']
|
| 121 |
+
recv = result['recv']
|
| 122 |
+
sum_pick = 0
|
| 123 |
+
sum_prob = 0
|
| 124 |
+
sum_outscope_count = 0
|
| 125 |
+
sum_outscope_len = 0
|
| 126 |
+
sum_hit_1 = 0
|
| 127 |
+
sum_top_10_len = 0
|
| 128 |
+
full_text = ''
|
| 129 |
+
|
| 130 |
+
token_count = 0
|
| 131 |
+
#found_end = False
|
| 132 |
+
|
| 133 |
+
#pdb.set_trace()
|
| 134 |
+
|
| 135 |
+
for i, tk in enumerate(lst[:-1]):
|
| 136 |
+
# if found_end:
|
| 137 |
+
# break
|
| 138 |
+
|
| 139 |
+
token_text = handle_char_return(tk['actual_next_token_text'])
|
| 140 |
+
|
| 141 |
+
# Due to tokenizer difference, the following needs more work in the future.
|
| 142 |
+
# if base_fn.find('gptj') >= 0:
|
| 143 |
+
# # using the original gpt-j-6b model
|
| 144 |
+
# # need to skip special tokens
|
| 145 |
+
# if i <= 7:
|
| 146 |
+
# continue # skip |start of claim|>
|
| 147 |
+
|
| 148 |
+
# remaining_text = get_remaining(lst, i)
|
| 149 |
+
# if remaining_text.find('<|end_of_claim|>') >= 0:
|
| 150 |
+
# pos1 = remaining_text.find('<|end_of_claim|>')
|
| 151 |
+
# token_text = remaining_text[:pos1]
|
| 152 |
+
# found_end = True
|
| 153 |
+
# #pdb.set_trace()
|
| 154 |
+
# #break
|
| 155 |
+
|
| 156 |
+
# The following was for GPT-J-6B. Not needed for PatentGPT-J.
|
| 157 |
+
#if token_text.find('<|end_of_claim|>') == 0:
|
| 158 |
+
# #pdb.set_trace()
|
| 159 |
+
# break
|
| 160 |
+
|
| 161 |
+
next_top_seq = int(tk['actual_next_token_top_seq'])
|
| 162 |
+
next_top_prob = float(tk['actual_next_token_top_prob'])
|
| 163 |
+
|
| 164 |
+
full_text += token_text
|
| 165 |
+
if next_top_seq == 0:
|
| 166 |
+
sum_hit_1 += 1 # press "tab" for the top pick
|
| 167 |
+
|
| 168 |
+
if ignore_outscope and next_top_seq>=10:
|
| 169 |
+
sum_outscope_count += 1
|
| 170 |
+
sum_outscope_len += len(token_text) # use length as keystrokes
|
| 171 |
+
else:
|
| 172 |
+
sum_pick += min(next_top_seq+1, len(token_text))
|
| 173 |
+
#sum_pick += (next_top_seq+1) # press "down" & "tab"
|
| 174 |
+
sum_prob += next_top_prob
|
| 175 |
+
sum_top_10_len += len(token_text)
|
| 176 |
+
|
| 177 |
+
token_count += 1
|
| 178 |
+
|
| 179 |
+
if ignore_outscope:
|
| 180 |
+
if token_count == 0: # unlikely
|
| 181 |
+
avg_pick = 0
|
| 182 |
+
avg_prob = 0
|
| 183 |
+
else:
|
| 184 |
+
avg_pick = float(sum_pick) / token_count
|
| 185 |
+
avg_prob = float(sum_prob) / token_count
|
| 186 |
+
else:
|
| 187 |
+
avg_pick = float(sum_pick) / token_count
|
| 188 |
+
avg_prob = float(sum_prob) / token_count
|
| 189 |
+
|
| 190 |
+
# if len(lst) < 2048: # for debugging
|
| 191 |
+
# s = '<|start_of_claim|>' + full_text
|
| 192 |
+
# if len(s) != len(recv['context']):
|
| 193 |
+
# print('length mismatch --> full_text: %s, recv: %s' % (len(s), len(recv['context'])))
|
| 194 |
+
# show_diff(s, recv['context'])
|
| 195 |
+
# pdb.set_trace()
|
| 196 |
+
|
| 197 |
+
return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text
|
| 198 |
+
|
| 199 |
+
def show_avg(base_fn, model_name, patent_claim_num, show_pick=False):
|
| 200 |
+
|
| 201 |
+
result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
|
| 202 |
+
|
| 203 |
+
if token_count == 0:
|
| 204 |
+
print('debug 2')
|
| 205 |
+
pdb.set_trace()
|
| 206 |
+
|
| 207 |
+
if result is None:
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
lst = result['output']
|
| 211 |
+
result = ''
|
| 212 |
+
sum_all = {}
|
| 213 |
+
for i, tk in enumerate(lst):
|
| 214 |
+
token_text = handle_char_return(tk['actual_next_token_text'])
|
| 215 |
+
|
| 216 |
+
if token_text == '<|end_of_claim|>':
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
if token_text == '(none)': # for unicorn text
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
# Skip GPT-J, due to different tokenization
|
| 223 |
+
# if base_fn.find('gptj') >= 0:
|
| 224 |
+
# # using the original gpt-j-6b model
|
| 225 |
+
# # need to skip special tokens
|
| 226 |
+
# if i <= 7:
|
| 227 |
+
# continue # skip |start of claim|>
|
| 228 |
+
# if token_text == '.<': # assuming .<|end of claim|>
|
| 229 |
+
# break
|
| 230 |
+
|
| 231 |
+
pick = int(tk['actual_next_token_top_seq'])
|
| 232 |
+
prob = float(tk['actual_next_token_top_prob'])
|
| 233 |
+
|
| 234 |
+
colors = [
|
| 235 |
+
['00ff00', '000000', '1'],
|
| 236 |
+
['008800', 'ffffff', '2-10'],
|
| 237 |
+
['ff0000', 'ffffff', 'out of top 10'],
|
| 238 |
+
]
|
| 239 |
+
#colors = [
|
| 240 |
+
# ['00ff00', '000000', '1'],
|
| 241 |
+
# ['008800', 'ffffff', '2-10'],
|
| 242 |
+
# ['aa0000', 'ffffff', '11-100'],
|
| 243 |
+
# ['ff0000', 'ffffff', '101~']
|
| 244 |
+
#]
|
| 245 |
+
|
| 246 |
+
for j, item in enumerate(colors):
|
| 247 |
+
sum_all[item[2]] = 0
|
| 248 |
+
|
| 249 |
+
# skip follow-up subword
|
| 250 |
+
# if token_text.startswith(' ') == False:
|
| 251 |
+
# bg_color = ''
|
| 252 |
+
# fg_color = ''
|
| 253 |
+
# else:
|
| 254 |
+
|
| 255 |
+
if pick == 0:
|
| 256 |
+
bg_color = colors[0][0]
|
| 257 |
+
fg_color = colors[0][1]
|
| 258 |
+
tag = colors[0][2]
|
| 259 |
+
sum_all[tag] += 1
|
| 260 |
+
elif pick >= 1 and pick < 10:
|
| 261 |
+
bg_color = colors[1][0]
|
| 262 |
+
fg_color = colors[1][1]
|
| 263 |
+
tag = colors[1][2]
|
| 264 |
+
sum_all[tag] += 1
|
| 265 |
+
else: # pick >= 10
|
| 266 |
+
#elif pick >= 10 and pick < 100:
|
| 267 |
+
bg_color = colors[2][0]
|
| 268 |
+
fg_color = colors[2][1]
|
| 269 |
+
tag = colors[2][2]
|
| 270 |
+
sum_all[tag] += 1
|
| 271 |
+
#else: #pick >= 100:
|
| 272 |
+
# bg_color = colors[3][0]
|
| 273 |
+
# fg_color = colors[3][1]
|
| 274 |
+
# tag = colors[3][2]
|
| 275 |
+
# sum_all[tag] += 1
|
| 276 |
+
|
| 277 |
+
if show_pick:
|
| 278 |
+
pick = '[%s]' % pick
|
| 279 |
+
else:
|
| 280 |
+
pick = ''
|
| 281 |
+
|
| 282 |
+
result += "<span style=background-color:#%s;color:#%s;border-radius:5px;>%s%s</span> " % (bg_color, fg_color, token_text, pick) #
|
| 283 |
+
|
| 284 |
+
color_msg = ''
|
| 285 |
+
for i, v in enumerate(colors):
|
| 286 |
+
color_msg += "<span style=background-color:#%s;color:#%s;border-radius:5px;> %s </span> " % (v[0], v[1], v[2])
|
| 287 |
+
|
| 288 |
+
#result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
|
| 289 |
+
|
| 290 |
+
# sum_pick as top 1~10
|
| 291 |
+
keys_with_auto = (sum_pick+sum_outscope_len)
|
| 292 |
+
keys_without_auto = len(full_text)
|
| 293 |
+
saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
|
| 294 |
+
s = 'model: %s\n' \
|
| 295 |
+
'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \
|
| 296 |
+
'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \
|
| 297 |
+
'Keystroke distribution: top 1~10: %s (top 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto, sum_pick, sum_hit_1, sum_outscope_len)
|
| 298 |
+
st.text(s)
|
| 299 |
+
|
| 300 |
+
# s = 'file: %s, sum_pick: %s, sum_hit_1: %s, token_count: %s, sum_outscope: %s, avg_pick: %.2f, avg_prob: %.2f, sum_prob: %.2f, hit_1 ratio: %.2f ' % (base_fn, sum_pick, sum_hit_1, token_count, sum_outscope, avg_pick, avg_prob, sum_prob, float(sum_hit_1)/token_count)
|
| 301 |
+
#s += color_msg
|
| 302 |
+
|
| 303 |
+
s = color_msg
|
| 304 |
+
st.markdown(s, unsafe_allow_html=True)
|
| 305 |
+
#st.text('file: %s, avg_pick: %5.2f, avg_prob: %.2f, hit count: %s/%s ' % (base_fn, avg_pick, avg_prob, hit_0_count, len(lst)))
|
| 306 |
+
# show histogram
|
| 307 |
+
|
| 308 |
+
st.markdown(result, unsafe_allow_html=True)
|
| 309 |
+
#st.text_area('context with top seq & prob:', result, height=400)
|
| 310 |
+
|
| 311 |
+
sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']]
|
| 312 |
+
#sum_lst = [['1', sum_all['1']], ['2-10', sum_all['2-10']]]
|
| 313 |
+
#sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['11-100'], sum_all['101~']]
|
| 314 |
+
|
| 315 |
+
return sum_lst
|
| 316 |
+
|
| 317 |
+
def show_overall_summary(prefix_lst, select_lst):
|
| 318 |
+
# accumulate all
|
| 319 |
+
|
| 320 |
+
# debug
|
| 321 |
+
# for i, num in enumerate(select_lst):
|
| 322 |
+
# pre_full_text = ''
|
| 323 |
+
# for prefix in prefix_lst:
|
| 324 |
+
# base_fn = '%s_%s_forward.json' % (prefix, num)
|
| 325 |
+
# result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
|
| 326 |
+
|
| 327 |
+
# if pre_full_text == '':
|
| 328 |
+
# pre_full_text = full_text
|
| 329 |
+
# else:
|
| 330 |
+
# if pre_full_text != full_text:
|
| 331 |
+
# print('debug')
|
| 332 |
+
# pdb.set_trace()
|
| 333 |
+
|
| 334 |
+
# #
|
| 335 |
+
# pdb.set_trace()
|
| 336 |
+
|
| 337 |
+
for prefix in prefix_lst:
|
| 338 |
+
acc_token_count = 0
|
| 339 |
+
acc_sum_pick = 0
|
| 340 |
+
acc_sum_prob = 0
|
| 341 |
+
acc_sum_outscope_count = 0
|
| 342 |
+
acc_sum_outscope_len = 0
|
| 343 |
+
acc_sum_hit_1 = 0
|
| 344 |
+
acc_sum_top_10_len = 0
|
| 345 |
+
acc_full_text_len = 0
|
| 346 |
+
|
| 347 |
+
pre_full_text = ''
|
| 348 |
+
for i, num in enumerate(select_lst):
|
| 349 |
+
base_fn = '%s_%s_forward.json' % (prefix, num)
|
| 350 |
+
result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
|
| 351 |
+
|
| 352 |
+
acc_token_count += token_count
|
| 353 |
+
acc_sum_pick += sum_pick
|
| 354 |
+
acc_sum_prob += sum_prob
|
| 355 |
+
acc_sum_outscope_count += sum_outscope_count
|
| 356 |
+
acc_sum_outscope_len += sum_outscope_len
|
| 357 |
+
acc_sum_hit_1 += sum_hit_1
|
| 358 |
+
acc_sum_top_10_len += sum_top_10_len
|
| 359 |
+
acc_full_text_len += len(full_text)
|
| 360 |
+
|
| 361 |
+
if acc_token_count > 0:
|
| 362 |
+
# acc_sum_pick --> top 1~10
|
| 363 |
+
keys_with_auto = acc_sum_pick + acc_sum_outscope_len
|
| 364 |
+
keys_without_auto = acc_full_text_len
|
| 365 |
+
saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
|
| 366 |
+
|
| 367 |
+
st.text('[ %s ]\n' \
|
| 368 |
+
'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \
|
| 369 |
+
'(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \
|
| 370 |
+
'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % (
|
| 371 |
+
model_names[prefix], saved_ratio,
|
| 372 |
+
'{:,}'.format(keys_with_auto),
|
| 373 |
+
'{:,}'.format(acc_sum_pick),
|
| 374 |
+
'{:,}'.format(acc_sum_outscope_len),
|
| 375 |
+
'{:,}'.format(acc_sum_hit_1),
|
| 376 |
+
'{:,}'.format(keys_without_auto),
|
| 377 |
+
'{:,}'.format(acc_sum_top_10_len),
|
| 378 |
+
acc_sum_prob,
|
| 379 |
+
))
|
| 380 |
+
|
| 381 |
+
st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto)))
|
| 382 |
+
|
| 383 |
+
# st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % (
|
| 384 |
+
# acc_token_count,
|
| 385 |
+
# float(acc_sum_hit_1)/acc_token_count,
|
| 386 |
+
# float(acc_sum_pick)/acc_token_count,
|
| 387 |
+
# float(acc_sum_prob)/acc_token_count,
|
| 388 |
+
# float(acc_sum_outscope_count)/acc_token_count))
|
| 389 |
+
|
| 390 |
+
def calc_height(s):
|
| 391 |
+
return int(len(s) / 10 * 3) + 30
|
| 392 |
+
|
| 393 |
+
def remove_end_of_claim_text(gen_text):
|
| 394 |
+
tag = '<|end_of_claim|>'
|
| 395 |
+
pos = gen_text.find(tag)
|
| 396 |
+
if pos > 0:
|
| 397 |
+
gen_text = gen_text[:pos+len(tag)]
|
| 398 |
+
return gen_text
|
| 399 |
+
|
| 400 |
+
tag = '<|endoftext|>'
|
| 401 |
+
pos = gen_text.find(tag)
|
| 402 |
+
if pos > 0:
|
| 403 |
+
gen_text = gen_text[:pos+len(tag)]
|
| 404 |
+
|
| 405 |
+
return gen_text
|
| 406 |
+
|
| 407 |
+
def dump_pos_data(prefix_lst, select_lst):
|
| 408 |
+
#statistics = [[0]*3]*2048
|
| 409 |
+
statistics = []
|
| 410 |
+
for i in range(2048):
|
| 411 |
+
statistics.append([0,0,0])
|
| 412 |
+
|
| 413 |
+
#results.append(['model', 'pos', 'key'])
|
| 414 |
+
#results.append(['model', 'patent_claim', 'pos', 'top-1', 'top-2~10', 'out of top 10'])
|
| 415 |
+
max_len = -1
|
| 416 |
+
for prefix in prefix_lst:
|
| 417 |
+
model_name = model_names[prefix].replace('PatentGPT-J-', '')
|
| 418 |
+
if model_name != '456M':
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
#total = {}
|
| 422 |
+
for i, num in enumerate(select_lst):
|
| 423 |
+
base_fn = '%s_%s_forward.json' % (prefix, num)
|
| 424 |
+
full_fn = os.path.join(folder, base_fn)
|
| 425 |
+
if os.path.exists(full_fn) == False:
|
| 426 |
+
continue
|
| 427 |
+
|
| 428 |
+
with open(full_fn) as f:
|
| 429 |
+
result = json.loads(f.read())
|
| 430 |
+
print("Loaded: %s" % full_fn)
|
| 431 |
+
|
| 432 |
+
lst = result['output']
|
| 433 |
+
for j, tk in enumerate(lst[:-1]):
|
| 434 |
+
max_len = max(j, max_len)
|
| 435 |
+
next_top_seq = int(tk['actual_next_token_top_seq'])
|
| 436 |
+
#next_top_prob = float(tk['actual_next_token_top_prob'])
|
| 437 |
+
|
| 438 |
+
top_1 = top_2_to_10 = out_of_scope = 0
|
| 439 |
+
if next_top_seq == 0:
|
| 440 |
+
top_1 = 1
|
| 441 |
+
tag = 'top-1'
|
| 442 |
+
statistics[j][0] += 1
|
| 443 |
+
elif next_top_seq > 0 and next_top_seq < 10:
|
| 444 |
+
top_2_to_10 = 1
|
| 445 |
+
tag = 'top-2~10'
|
| 446 |
+
statistics[j][1] += 1
|
| 447 |
+
else:
|
| 448 |
+
out_of_scope = 1
|
| 449 |
+
tag = 'out-of-scope'
|
| 450 |
+
statistics[j][2] += 1
|
| 451 |
+
|
| 452 |
+
#total[tag] = total.get(tag, 0) + 1
|
| 453 |
+
#results.append([model_name, str(i+1), tag])
|
| 454 |
+
#results.append([model_name, str(i+1), tag])
|
| 455 |
+
#results.append([model_name, num, str(i+1), tag])
|
| 456 |
+
#results.append([model_name, num, i+1, top_1, top_2_to_10, out_of_scope])
|
| 457 |
+
#pdb.set_trace()
|
| 458 |
+
#pdb.set_trace()
|
| 459 |
+
|
| 460 |
+
dump_file = 'dump4.txt'
|
| 461 |
+
#pdb.set_trace()
|
| 462 |
+
with open(dump_file, 'w') as f:
|
| 463 |
+
for i in range(max_len+1):
|
| 464 |
+
f.write('%s, top-1, %s\n' % (i+1, statistics[i][0]))
|
| 465 |
+
f.write('%s, top-2~10, %s\n' % (i+1, statistics[i][1]))
|
| 466 |
+
f.write('%s, out_of_scope, %s\n' % (i+1, statistics[i][2]))
|
| 467 |
+
# f.write('%s\n' % ', '.join([str(i+1)] + [ str(v) for v in statistics[i] ] ))
|
| 468 |
+
print('saved: %s' % dump_file)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# dump_file = 'dump2.txt'
|
| 472 |
+
# with open(dump_file, 'w') as f:
|
| 473 |
+
# for line in results:
|
| 474 |
+
# f.write('%s\n' % ', '.join(line))
|
| 475 |
+
# print('saved: %s' % dump_file)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def calc_sentence_similarity(sent_model, sent1, sent2):
|
| 479 |
+
rewards = []
|
| 480 |
+
embedding1 = sent_model.encode(sent1, convert_to_tensor=True)
|
| 481 |
+
embedding2 = sent_model.encode(sent2, convert_to_tensor=True)
|
| 482 |
+
similarity = util.cos_sim(embedding1, embedding2)[0][0]
|
| 483 |
+
|
| 484 |
+
#pdb.set_trace()
|
| 485 |
+
|
| 486 |
+
return similarity
|
| 487 |
+
|
| 488 |
+
sent_model = 'patent/st-aipd-nlp-g'
|
| 489 |
+
print('loading SentenceTransformer: %s' % sent_model)
|
| 490 |
+
sent_aipd = SentenceTransformer(sent_model)
|
| 491 |
+
|
| 492 |
+
def load_data(demo):
|
| 493 |
+
fn = 'ppo_open_llama_3b_v2.run.12.delta.txt'
|
| 494 |
+
#fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.delta.txt'
|
| 495 |
+
with open(fn, 'r') as f:
|
| 496 |
+
rows = json.load(f)
|
| 497 |
+
|
| 498 |
+
if demo == 'demo1':
|
| 499 |
+
new_rows = [ row for row in rows if row['instruction'].find('child') > 0 ]
|
| 500 |
+
elif demo == 'demo2':
|
| 501 |
+
new_rows = [ row for row in rows if row['instruction'].find('parent') > 0 ]
|
| 502 |
+
else:
|
| 503 |
+
new_rows = []
|
| 504 |
+
|
| 505 |
+
return new_rows
|
| 506 |
+
|
| 507 |
+
container_style = """
|
| 508 |
+
<style>
|
| 509 |
+
.container1 {
|
| 510 |
+
border: 2px solid #3498db;
|
| 511 |
+
border-radius: 8px;
|
| 512 |
+
padding: 10px;
|
| 513 |
+
margin-bottom: 20px;
|
| 514 |
+
}
|
| 515 |
+
.container2 {
|
| 516 |
+
/* Add styles for Container 2 if needed */
|
| 517 |
+
}
|
| 518 |
+
</style>
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
def main():
|
| 522 |
+
st.set_page_config( # Alternate names: setup_page, page, layout
|
| 523 |
+
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
|
| 524 |
+
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
|
| 525 |
+
page_title="Demo 1", # String or None. Strings get appended with "• Streamlit".
|
| 526 |
+
page_icon=None, # String, anything supported by st.image, or None.
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
opt_1 = 'parent --> child'
|
| 530 |
+
opt_2 = 'child --> parent'
|
| 531 |
+
options = [opt_1, opt_2]
|
| 532 |
+
rows = None
|
| 533 |
+
pos = None
|
| 534 |
+
patent_num = ''
|
| 535 |
+
claim_num1 = ''
|
| 536 |
+
claim_num2 = ''
|
| 537 |
+
instruction= ''
|
| 538 |
+
input_text = ''
|
| 539 |
+
output_text = ''
|
| 540 |
+
response = ''
|
| 541 |
+
query = ''
|
| 542 |
+
score_lst_1 = 0
|
| 543 |
+
score_lst_2 = 0
|
| 544 |
+
rewards = ''
|
| 545 |
+
with st.container():
|
| 546 |
+
col1, col2, col3 = st.columns([3, 5, 2])
|
| 547 |
+
with col1:
|
| 548 |
+
selected_option = st.selectbox('Select a demo:', options)
|
| 549 |
+
if selected_option == opt_1:
|
| 550 |
+
rows = load_data('demo1')
|
| 551 |
+
msg = 'novelty = sim1-sim2'
|
| 552 |
+
#msg = 'delta of similarities<br>(sim1-sim2)'
|
| 553 |
+
c1_tag = 'pc'
|
| 554 |
+
c2_tag = 'cc1'
|
| 555 |
+
c3_tag = 'cc2'
|
| 556 |
+
elif selected_option == opt_2:
|
| 557 |
+
rows = load_data('demo2')
|
| 558 |
+
msg = 'similarity of<br>(pc1) and (pc2)'
|
| 559 |
+
c1_tag = 'cc'
|
| 560 |
+
c2_tag = 'pc1'
|
| 561 |
+
c3_tag = 'pc2'
|
| 562 |
+
else:
|
| 563 |
+
st.text('Unknown option')
|
| 564 |
+
return
|
| 565 |
+
#rows = rows[:5000] # for debugging
|
| 566 |
+
|
| 567 |
+
with col2:
|
| 568 |
+
pos = st.slider("", 1, len(rows))
|
| 569 |
+
#pos = st.slider("Degree of novelty (Generated v. Actual)", 1, len(rows))
|
| 570 |
+
for i in range(pos):
|
| 571 |
+
#prompt = '%s' % rows[i]
|
| 572 |
+
#pdb.set_trace()
|
| 573 |
+
|
| 574 |
+
patent_num = rows[i]['patent_num']
|
| 575 |
+
claim_num1 = rows[i]['claim_num1']
|
| 576 |
+
claim_num2 = rows[i]['claim_num2']
|
| 577 |
+
instruction= rows[i]['instruction']
|
| 578 |
+
input_text = rows[i]['input']
|
| 579 |
+
output_text = rows[i]['output']
|
| 580 |
+
response = rows[i]['response']
|
| 581 |
+
query = rows[i]['query']
|
| 582 |
+
score_lst_1 = rows[i]['score_lst_1']
|
| 583 |
+
score_lst_2 = rows[i]['score_lst_2']
|
| 584 |
+
delta = rows[i]['delta']
|
| 585 |
+
rewards = rows[i]['rewards']
|
| 586 |
+
with col3:
|
| 587 |
+
#v = round(float(score_lst_1)-float(score_lst_2), 4)
|
| 588 |
+
#v = delta #round(delta,10)
|
| 589 |
+
st.markdown("<center><h7>%s<br>%s</h7></center>" % (msg, delta), unsafe_allow_html=True)
|
| 590 |
+
# style='text-align: center; color: black;'
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
# selectbox_placeholder = st.empty()
|
| 594 |
+
# selected_option = selectbox_placeholder.selectbox('Select a demo:', options)
|
| 595 |
+
# container1 = st.container()
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
# with st.container():
|
| 599 |
+
# col1, col2 = st.columns(2)
|
| 600 |
+
# with col1:
|
| 601 |
+
# st.write('Caption for first chart')
|
| 602 |
+
# with col2:
|
| 603 |
+
# st.line_chart((0,1), height=100)
|
| 604 |
+
# with st.container():
|
| 605 |
+
# col1, col2 = st.columns(2)
|
| 606 |
+
# with col1:
|
| 607 |
+
# st.write('Caption for second chart')
|
| 608 |
+
# with col2:
|
| 609 |
+
# st.line_chart((1,0), height=100)
|
| 610 |
+
|
| 611 |
+
#st.write('patent_num:', patent_num)
|
| 612 |
+
# st.write('claim_num1:', claim_num1)
|
| 613 |
+
# st.write('claim_num2:', claim_num2)
|
| 614 |
+
st.write('(instruction) ', instruction)
|
| 615 |
+
|
| 616 |
+
with st.container():
|
| 617 |
+
with st.container(border=True):
|
| 618 |
+
st.write('(%s) [ %s ]\n%s' % (c1_tag, patent_num, input_text))
|
| 619 |
+
#st.write('input:' % patent_num)
|
| 620 |
+
#st.write('input:\n', input_text)
|
| 621 |
+
|
| 622 |
+
#container1.markdown("<div class='container1'>", unsafe_allow_html=True)
|
| 623 |
+
col1, col2 = st.columns(2)
|
| 624 |
+
with col1:
|
| 625 |
+
with st.container(border=True):
|
| 626 |
+
st.write('(%s) (actual)' % c2_tag)
|
| 627 |
+
st.write(output_text)
|
| 628 |
+
with col2:
|
| 629 |
+
with st.container(border=True):
|
| 630 |
+
st.write('(%s) (generated)' % c3_tag)
|
| 631 |
+
st.write(response)
|
| 632 |
+
|
| 633 |
+
col1, col2 = st.columns(2)
|
| 634 |
+
with col1:
|
| 635 |
+
with st.container(border=True):
|
| 636 |
+
st.write('(sim1) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c2_tag, str(score_lst_1)))
|
| 637 |
+
with col2:
|
| 638 |
+
with st.container(border=True):
|
| 639 |
+
st.write('(sim2) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c3_tag, str(score_lst_2)))
|
| 640 |
+
|
| 641 |
+
#container1.markdown("</div>", unsafe_allow_html=True)
|
| 642 |
+
|
| 643 |
+
# st.write("In Container 1")
|
| 644 |
+
# table_name = st.radio("Please Select Table", list_of_tables)
|
| 645 |
+
|
| 646 |
+
# st.write('output:')
|
| 647 |
+
# st.write(output_text)
|
| 648 |
+
# st.write('response:')
|
| 649 |
+
# st.write(response)
|
| 650 |
+
#st.write('query:', query)
|
| 651 |
+
# st.write('score_lst_1:', score_lst_1)
|
| 652 |
+
# st.write('score_lst_2:', score_lst_2)
|
| 653 |
+
# st.write('rewards:', rewards)
|
| 654 |
+
# st.text('hello')
|
| 655 |
+
|
| 656 |
+
# dict_keys(['patent_num', 'claim_num1', 'claim_num2', 'instruction', 'input', 'output', 'query', 'response', 'score_lst_1', 'score_lst_2', 'rewards'])
|
| 657 |
+
|
| 658 |
+
# st.subheader("Inspecting PatentGPT-J Model Evaluation")
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
# num_set = set()
|
| 663 |
+
# fn_lst = glob.glob(os.path.join(folder, '*'))
|
| 664 |
+
# for i, fn in enumerate(fn_lst):
|
| 665 |
+
# for prefix in prefix_lst:
|
| 666 |
+
# v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
|
| 667 |
+
# if v is None:
|
| 668 |
+
# v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
|
| 669 |
+
|
| 670 |
+
# #pdb.set_trace()
|
| 671 |
+
# if v is None:
|
| 672 |
+
# #pdb.set_trace()
|
| 673 |
+
# continue
|
| 674 |
+
|
| 675 |
+
# v = v.group(2)
|
| 676 |
+
# if first_claim_only:
|
| 677 |
+
# if v.endswith('_1'):
|
| 678 |
+
# num_set.add(v)
|
| 679 |
+
# else:
|
| 680 |
+
# num_set.add(v)
|
| 681 |
+
|
| 682 |
+
# num_lst = list(num_set)
|
| 683 |
+
# num_lst.sort()
|
| 684 |
+
|
| 685 |
+
# select_lst = []
|
| 686 |
+
# for i, num in enumerate(num_lst):
|
| 687 |
+
# all_existed = True
|
| 688 |
+
# for prefix in prefix_lst:
|
| 689 |
+
# fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
|
| 690 |
+
# if os.path.exists(fn) == False:
|
| 691 |
+
# all_existed = False
|
| 692 |
+
# break
|
| 693 |
+
# if all_existed:
|
| 694 |
+
# select_lst.append(num)
|
| 695 |
+
# select_lst.sort()
|
| 696 |
+
|
| 697 |
+
# if len(select_lst) == 0:
|
| 698 |
+
# st.text('select_lst is empty')
|
| 699 |
+
# return
|
| 700 |
+
|
| 701 |
+
# if dump_pos_data_for_reporting:
|
| 702 |
+
# dump_pos_data(prefix_lst, select_lst)
|
| 703 |
+
# st.text('Dump data: done')
|
| 704 |
+
# return
|
| 705 |
+
|
| 706 |
+
# # debug
|
| 707 |
+
# #base_fn = 'my_gptj_6b_tpu_size_8_11212952_1_forward.json'
|
| 708 |
+
# #base_fn = 'pgj_small_text-1_1_forward.json'
|
| 709 |
+
# #_ = show_avg(base_fn)
|
| 710 |
+
|
| 711 |
+
# if enable_summary_button:
|
| 712 |
+
# if st.button('Show Summary'):
|
| 713 |
+
# st.text('len(select_lst) = %s' % len(select_lst))
|
| 714 |
+
# show_overall_summary(prefix_lst, select_lst)
|
| 715 |
+
|
| 716 |
+
# # if 'num' not in st.session_state:
|
| 717 |
+
# # num = random.choice(select_lst)
|
| 718 |
+
# # st.session_state['num'] = num
|
| 719 |
+
|
| 720 |
+
# # set_state('num', num)
|
| 721 |
+
# # def set_state(k, v):
|
| 722 |
+
# # if k not in st.session_state:
|
| 723 |
+
# # st.session_state[ k ] = v
|
| 724 |
+
|
| 725 |
+
# show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
|
| 726 |
+
# selected = st.selectbox("Choose a patent claim", show_patent_lst)
|
| 727 |
+
# num = selected.replace(')', '').replace(' (claim ', '_')
|
| 728 |
+
# if st.button('Random pick'):
|
| 729 |
+
# num = random.choice(select_lst)
|
| 730 |
+
|
| 731 |
+
# st.text('Selected: %s' % num)
|
| 732 |
+
# st.session_state['num'] = num
|
| 733 |
+
|
| 734 |
+
# avgs = []
|
| 735 |
+
# for prefix in prefix_lst:
|
| 736 |
+
# base_fn = '%s_%s_forward.json' % (prefix, num)
|
| 737 |
+
# one_avg = show_avg(base_fn, model_names[prefix], num)
|
| 738 |
+
# if one_avg is not None:
|
| 739 |
+
# avgs.append(one_avg)
|
| 740 |
+
|
| 741 |
+
# # debug
|
| 742 |
+
# #pdb.set_trace()
|
| 743 |
+
# #return
|
| 744 |
+
# #
|
| 745 |
+
|
| 746 |
+
# data_lst = []
|
| 747 |
+
# for i in range(len(avgs[0])):
|
| 748 |
+
# row = []
|
| 749 |
+
# for j, prefix in enumerate(prefix_lst):
|
| 750 |
+
# row.append(avgs[j][i])
|
| 751 |
+
# data_lst.append(row)
|
| 752 |
+
|
| 753 |
+
# df = pd.DataFrame(data_lst, index=['1','2-10','out of top 10'])
|
| 754 |
+
# #df = pd.DataFrame(data_lst, index=['1','2-10','11-100','101~'])
|
| 755 |
+
|
| 756 |
+
# # ], index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
|
| 757 |
+
# # [avgs[0][0], avgs[1][0], avgs[2][0]],
|
| 758 |
+
# # [avgs[0][1], avgs[1][1], avgs[2][1]],
|
| 759 |
+
# # [avgs[0][2], avgs[1][2], avgs[2][2]],
|
| 760 |
+
# # [avgs[0][3], avgs[1][3], avgs[2][3]],
|
| 761 |
+
|
| 762 |
+
# #df = pd.DataFrame([[1,2],[3,1]], columns=['a', 'b'])
|
| 763 |
+
# #df = pd.DataFrame([
|
| 764 |
+
# # [sum1[0], sum1[1], sum1[2], sum1[3]],
|
| 765 |
+
# # [sum2[0], sum2[1], sum2[2], sum2[3]],
|
| 766 |
+
# # [sum3[0], sum3[1], sum3[2], sum3[3]],
|
| 767 |
+
# # ]) #, index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
|
| 768 |
+
# #df = pd.DataFrame.from_dict(sum_all, orient='index')
|
| 769 |
+
# #st.line_chart(df)
|
| 770 |
+
|
| 771 |
+
# #data_canada = px.data.gapminder().query("country == 'Canada'")
|
| 772 |
+
# #fig = px.bar(data_canada, x='year', y='pop')
|
| 773 |
+
|
| 774 |
+
# if st.button('Show chart'):
|
| 775 |
+
# fig = px.bar(df, barmode='group')
|
| 776 |
+
# st.plotly_chart(fig, use_container_width=True)
|
| 777 |
+
# #fig.show()
|
| 778 |
+
# #st.area_chart(df)
|
| 779 |
+
# #st.bar_chart(df)
|
| 780 |
+
|
| 781 |
+
# #
|
| 782 |
+
# base_fn = '%s_%s_forward.json' % (prefix_lst[ id_to_scroll ], st.session_state['num'])
|
| 783 |
+
# result, avg_pick, avg_prob, _, _, _, _, _, _, _, _ = calc_details(base_fn)
|
| 784 |
+
# recv = result['recv']
|
| 785 |
+
# lst = result['output']
|
| 786 |
+
# input_tokens = result['input']
|
| 787 |
+
|
| 788 |
+
# # (Pdb) print(token_pos_lst[0].keys())
|
| 789 |
+
# #dict_keys(['idx', 'gen_text', 'actual_next_token_text', 'actual_next_token_top_seq', 'actual_next_token_top_prob', 'top_n_lst'])
|
| 790 |
+
|
| 791 |
+
# height = calc_height(recv['context'])
|
| 792 |
+
# st.text_area('context:', recv['context'], height=height)
|
| 793 |
+
|
| 794 |
+
# pos = st.slider("Token position", 0, len(lst))
|
| 795 |
+
# prompt = ''
|
| 796 |
+
# for i in range(pos+1):
|
| 797 |
+
# prompt += input_tokens[i]['text']
|
| 798 |
+
# height = calc_height(prompt)
|
| 799 |
+
# st.text_area('prompt:', prompt, height=height)
|
| 800 |
+
|
| 801 |
+
# ch = handle_char_return(lst[pos]['actual_next_token_text'])
|
| 802 |
+
# st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1,
|
| 803 |
+
# float(lst[pos]['actual_next_token_top_prob'])))
|
| 804 |
+
|
| 805 |
+
# st.text('top 10 tokens:')
|
| 806 |
+
# for i, v in enumerate(lst[pos]['top_n_lst']):
|
| 807 |
+
# ch = handle_char_return(v['top_n_text'])
|
| 808 |
+
# st.text('[ %s ][ %s ]( %.2f )' % (i+1, ch, float(v['top_n_prob'])))
|
| 809 |
+
|
| 810 |
+
# gen_text = lst[pos]['gen_text']
|
| 811 |
+
# gen_text = remove_end_of_claim_text(gen_text)
|
| 812 |
+
|
| 813 |
+
# st.text('gen_text: %s' % gen_text)
|
| 814 |
+
# #st.text("done. ok.")
|
| 815 |
+
# #st.text('result:\n%s' % result)
|
| 816 |
+
|
| 817 |
+
if __name__ == "__main__":
|
| 818 |
+
main()
|
| 819 |
+
|
| 820 |
+
#def load_data_pre(demo):
|
| 821 |
+
# fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.keep.txt'
|
| 822 |
+
# with open(fn, 'r') as f:
|
| 823 |
+
# rows = json.load(f)
|
| 824 |
+
|
| 825 |
+
# new_rows = []
|
| 826 |
+
# for i, row in enumerate(rows):
|
| 827 |
+
# item1 = {}
|
| 828 |
+
# item2 = {}
|
| 829 |
+
# if demo == 'demo1':
|
| 830 |
+
# item1[ 'delta' ] = abs(row['score_lst_1'][0] - row['score_lst_2'][0])
|
| 831 |
+
# item2[ 'delta' ] = abs(row['score_lst_1'][1] - row['score_lst_2'][1])
|
| 832 |
+
# elif demo == 'demo2':
|
| 833 |
+
# #pdb.set_trace()
|
| 834 |
+
# item1[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][0], row['response'][0])
|
| 835 |
+
# item2[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][1], row['response'][1])
|
| 836 |
+
|
| 837 |
+
# print('[ %s ] detla = %s' % (i, item1[ 'delta' ]))
|
| 838 |
+
|
| 839 |
+
# for k in row.keys():
|
| 840 |
+
# item1[ k ] = row[ k ][0]
|
| 841 |
+
# item2[ k ] = row[ k ][1]
|
| 842 |
+
|
| 843 |
+
# if demo == 'demo1':
|
| 844 |
+
# if item1['instruction'].find('child') > 0:
|
| 845 |
+
# new_rows.append(item1)
|
| 846 |
+
# if item2['instruction'].find('child') > 0:
|
| 847 |
+
# new_rows.append(item2)
|
| 848 |
+
# elif demo == 'demo2':
|
| 849 |
+
# if item1['instruction'].find('parent') > 0:
|
| 850 |
+
# new_rows.append(item1)
|
| 851 |
+
# if item2['instruction'].find('parent') > 0:
|
| 852 |
+
# new_rows.append(item2)
|
| 853 |
+
|
| 854 |
+
# # Assuming new_rows is your list of dictionaries
|
| 855 |
+
# sorted_rows = sorted(new_rows, key=lambda x: x['delta'])
|
| 856 |
+
|
| 857 |
+
# # kv = {}
|
| 858 |
+
# # for i, row in enumerate(new_rows):
|
| 859 |
+
# # if diff > 0.0001:
|
| 860 |
+
# # kv[i] = round(diff, 4)
|
| 861 |
+
|
| 862 |
+
# # sorted_rows = []
|
| 863 |
+
# # sorted_kv = sorted(kv.items(), key=lambda x:x[1])
|
| 864 |
+
# # for k, v in sorted_kv:
|
| 865 |
+
# # sorted_rows.append(new_rows[k])
|
| 866 |
+
|
| 867 |
+
# #pdb.set_trace()
|
| 868 |
+
|
| 869 |
+
# return sorted_rows
|