updated app.py
Browse files
app.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
-
|
| 17 |
"""
|
| 18 |
Example command with bag of words:
|
| 19 |
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
|
@@ -608,10 +608,9 @@ def generate_text_pplm(
|
|
| 608 |
last_reps = torch.ones(50257)
|
| 609 |
last_reps = last_reps.to(device)
|
| 610 |
for i in range_func:
|
| 611 |
-
|
| 612 |
# Get past/probs for current output, except for last word
|
| 613 |
# Note that GPT takes 2 inputs: past + current_token
|
| 614 |
-
|
| 615 |
# run model forward to obtain unperturbed
|
| 616 |
if past is None and output_so_far is not None:
|
| 617 |
last = output_so_far[:, -1:]
|
|
@@ -739,7 +738,7 @@ def set_generic_model_params(discrim_weights, discrim_meta):
|
|
| 739 |
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
| 740 |
|
| 741 |
|
| 742 |
-
pretrained_model="
|
| 743 |
cond_text=""
|
| 744 |
uncond=False
|
| 745 |
num_samples=1
|
|
@@ -758,15 +757,15 @@ grad_length=10000
|
|
| 758 |
horizon_length=5
|
| 759 |
window_length=0
|
| 760 |
decay=False
|
| 761 |
-
gamma=1.
|
| 762 |
gm_scale=0.95
|
| 763 |
kl_scale=0.01
|
| 764 |
seed=0
|
| 765 |
no_cuda=False
|
| 766 |
colorama=False
|
| 767 |
verbosity="quiet"
|
| 768 |
-
fp="./paper_code/discrim_models/persoothe_classifier.pt"
|
| 769 |
-
model_fp="./paper_code/discrim_models/persoothe_encoder.pt"
|
| 770 |
calc_perplexity=False
|
| 771 |
is_deep=False
|
| 772 |
is_deeper=True
|
|
@@ -801,10 +800,7 @@ model = GPT2LMHeadModel.from_pretrained(
|
|
| 801 |
output_hidden_states=True
|
| 802 |
)
|
| 803 |
if model_fp != None and model_fp != "":
|
| 804 |
-
|
| 805 |
-
model.load_state_dict(torch.load(model_fp))
|
| 806 |
-
except:
|
| 807 |
-
print("Can't load local model")
|
| 808 |
model.to(device)
|
| 809 |
model.eval()
|
| 810 |
|
|
@@ -817,16 +813,19 @@ for param in model.parameters():
|
|
| 817 |
|
| 818 |
eot_token = "<|endoftext|>"
|
| 819 |
|
| 820 |
-
def get_reply(response, history =
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
# figure out conditioning text
|
| 825 |
tokenized_cond_text = tokenizer.encode(
|
| 826 |
-
convo_hist,
|
| 827 |
add_special_tokens=False
|
| 828 |
)
|
| 829 |
-
|
| 830 |
# generate perturbed texts
|
| 831 |
|
| 832 |
# full_text_generation returns:
|
|
@@ -861,30 +860,21 @@ def get_reply(response, history = "How are you?<|endoftext|>"):
|
|
| 861 |
)
|
| 862 |
|
| 863 |
# iterate through the perturbed texts
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
html += "
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
# write some HTML
|
| 881 |
-
html = "<div class='chatbot'>"
|
| 882 |
-
for m, msg in enumerate(convo_hist_split):
|
| 883 |
-
cls = "user" if m%2 == 0 else "bot"
|
| 884 |
-
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
| 885 |
-
html += "</div>"
|
| 886 |
-
|
| 887 |
-
convo_hist = history + "*ai has no response*" + eot_token
|
| 888 |
|
| 889 |
return html, convo_hist
|
| 890 |
|
|
@@ -898,6 +888,11 @@ css = """
|
|
| 898 |
|
| 899 |
gr.Interface(fn=get_reply,
|
| 900 |
theme="default",
|
| 901 |
-
inputs=[gr.inputs.Textbox(placeholder="How are you?"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
outputs=["html", "state"],
|
| 903 |
-
css=css).launch()
|
|
|
|
| 13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
+
# print
|
| 17 |
"""
|
| 18 |
Example command with bag of words:
|
| 19 |
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
|
|
|
| 608 |
last_reps = torch.ones(50257)
|
| 609 |
last_reps = last_reps.to(device)
|
| 610 |
for i in range_func:
|
|
|
|
| 611 |
# Get past/probs for current output, except for last word
|
| 612 |
# Note that GPT takes 2 inputs: past + current_token
|
| 613 |
+
|
| 614 |
# run model forward to obtain unperturbed
|
| 615 |
if past is None and output_so_far is not None:
|
| 616 |
last = output_so_far[:, -1:]
|
|
|
|
| 738 |
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
| 739 |
|
| 740 |
|
| 741 |
+
pretrained_model="microsoft/DialoGPT-large"
|
| 742 |
cond_text=""
|
| 743 |
uncond=False
|
| 744 |
num_samples=1
|
|
|
|
| 757 |
horizon_length=5
|
| 758 |
window_length=0
|
| 759 |
decay=False
|
| 760 |
+
gamma=1.0
|
| 761 |
gm_scale=0.95
|
| 762 |
kl_scale=0.01
|
| 763 |
seed=0
|
| 764 |
no_cuda=False
|
| 765 |
colorama=False
|
| 766 |
verbosity="quiet"
|
| 767 |
+
fp="./paper_code/discrim_models/persoothe_classifier.pt" #"/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_lrggpt_fit_deeper_2/3_PerSoothe_classifier_head_epoch_8.pt"
|
| 768 |
+
model_fp="./paper_code/discrim_models/persoothe_encoder.pt" #None
|
| 769 |
calc_perplexity=False
|
| 770 |
is_deep=False
|
| 771 |
is_deeper=True
|
|
|
|
| 800 |
output_hidden_states=True
|
| 801 |
)
|
| 802 |
if model_fp != None and model_fp != "":
|
| 803 |
+
model.load_state_dict(torch.load(model_fp))
|
|
|
|
|
|
|
|
|
|
| 804 |
model.to(device)
|
| 805 |
model.eval()
|
| 806 |
|
|
|
|
| 813 |
|
| 814 |
eot_token = "<|endoftext|>"
|
| 815 |
|
| 816 |
+
def get_reply(response, history = None, in_stepsize = 2.56, in_horizon_length = 5, in_num_iterations = 10, in_top_k = 2):
|
| 817 |
+
stepsize = in_stepsize
|
| 818 |
+
horizon_length = int(in_horizon_length)
|
| 819 |
+
num_iterations = int(in_num_iterations)
|
| 820 |
+
top_k = int(in_top_k)
|
| 821 |
+
if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!")):
|
| 822 |
+
return "<div class='chatbot'>Chatbot restarted</div>", None
|
| 823 |
+
convo_hist = (history if history != None else "How are you?<|endoftext|>") + response + eot_token
|
| 824 |
# figure out conditioning text
|
| 825 |
tokenized_cond_text = tokenizer.encode(
|
| 826 |
+
eot_token + convo_hist,
|
| 827 |
add_special_tokens=False
|
| 828 |
)
|
|
|
|
| 829 |
# generate perturbed texts
|
| 830 |
|
| 831 |
# full_text_generation returns:
|
|
|
|
| 860 |
)
|
| 861 |
|
| 862 |
# iterate through the perturbed texts
|
| 863 |
+
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
| 864 |
+
try:
|
| 865 |
+
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
|
| 866 |
+
convo_hist_split = pert_gen_text.split(eot_token)
|
| 867 |
+
html = "<div class='chatbot'>"
|
| 868 |
+
for m, msg in enumerate(convo_hist_split[1:-1]):
|
| 869 |
+
cls = "user" if m%2 == 0 else "bot"
|
| 870 |
+
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
| 871 |
+
html += "</div>"
|
| 872 |
+
|
| 873 |
+
if len(convo_hist_split) > 4: convo_hist_split = convo_hist_split[-4:]
|
| 874 |
+
convo_hist = eot_token.join(convo_hist_split)
|
| 875 |
+
|
| 876 |
+
except:
|
| 877 |
+
return "<div class='chatbot'>Error occured, chatbot restarted</div>", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
return html, convo_hist
|
| 880 |
|
|
|
|
| 888 |
|
| 889 |
gr.Interface(fn=get_reply,
|
| 890 |
theme="default",
|
| 891 |
+
inputs=[gr.inputs.Textbox(placeholder="How are you?"),
|
| 892 |
+
"state",
|
| 893 |
+
gr.inputs.Number(default=2.56, label="Step"),
|
| 894 |
+
gr.inputs.Number(default=5, label="Horizon"),
|
| 895 |
+
gr.inputs.Number(default=10, label="Iterations"),
|
| 896 |
+
gr.inputs.Number(default=2, label="Top_k")],
|
| 897 |
outputs=["html", "state"],
|
| 898 |
+
css=css).launch(share=True)
|