pcalhoun commited on
Commit
422e63e
·
1 Parent(s): 95dc8dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, transformers, peft, torch, gradio as gr
2
+
3
+ base_model = "h2oai/h2ogpt-4096-llama2-13b"
4
+ model = transformers.AutoModelForCausalLM.from_pretrained(
5
+ base_model,
6
+ load_in_8bit=True,
7
+ torch_dtype=torch.float16
8
+ )
9
+ tokenizer = transformers.AutoTokenizer.from_pretrained(base_model)
10
+
11
+ lora_model = "pcalhoun/Llama-2-13b-Conversations-With-Tyler-Swift"
12
+ model = peft.PeftModel.from_pretrained(
13
+ model,
14
+ lora_model,
15
+ torch_dtype=torch.float16
16
+ )
17
+
18
+ def generate(prompt, extra_eos=[]):
19
+ inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
20
+ input_token_length = inputs.input_ids.shape[1]
21
+ outputs = model.generate(**inputs, max_length=4096)
22
+ text = tokenizer.batch_decode(outputs)[0]
23
+ return text
24
+
25
+ def create_next_prompt(title_string,description_string="",conversation_messages=[]):
26
+ if not len(description_string):
27
+ conversation_messages = []
28
+ prompt = """<s>### CONVERSATIONS WITH TYLER SWIFT ###
29
+
30
+ TITLE: """ + title_string.strip() + """
31
+
32
+ DESCRIPTION:"""
33
+ if not len(description_string):
34
+ return prompt
35
+ else:
36
+ prompt += " "+description_string.replace("\n\n","\n").strip() + "\n\n"
37
+ if not len(conversation_messages):
38
+ prompt += "### TYLER SWIFT:"
39
+ return prompt
40
+ else:
41
+ for message_data in conversation_messages:
42
+ prompt += "### " + message_data['speaker'].upper() + ": " + message_data['message'].strip()
43
+ if message_data['speaker'].upper() == "TYLER SWIFT":
44
+ prompt += "</s><s>"
45
+ prompt += "\n"
46
+ if conversation_messages[-1]["speaker"].upper() != "TYLER SWIFT":
47
+ prompt += "### TYLER SWIFT:"
48
+ return prompt
49
+
50
+ def deconstruct_returned_text(text):
51
+ #skip first line
52
+ text = "\n".join(text.split("\n")[1:]).strip()
53
+ title = text.split("\n")[0].replace("TITLE:","").strip()
54
+ text = "\n".join(text.split("\n")[1:]).strip()
55
+ description = text.split("\n\n")[0].replace("DESCRIPTION:","").strip()
56
+ text = "\n\n".join(text.split("\n\n")[1:]).strip()
57
+ conversation_text = text.replace("</s>", "").replace("<s>", "").split("<<")[0].strip()
58
+ return title,description,conversation_text
59
+
60
+ def generate_next(title,description,conversation_text):
61
+ if not len(title):
62
+ title = "Set a Title First"
63
+ return title,description,conversation_text
64
+ conversation = []
65
+ for line in conversation_text.split("\n"):
66
+ if "CONVERSATIONS WITH TYLER SWIFT" in line:
67
+ continue
68
+ if line.startswith("###"):
69
+ speaker = line.split(":")[0].replace("###","").strip()
70
+ message = ":".join(line.split(":")[1:]).strip()
71
+ conversation.append({"speaker":speaker,"message":message.replace("</s>", "").replace("<s>", "").strip()})
72
+ prompt = create_next_prompt(title,description,conversation)
73
+ generated_text = generate(prompt)
74
+ print("GENERATED TEXT:",generated_text)
75
+ title,description,conversation_text = deconstruct_returned_text(generated_text)
76
+ return title,description,conversation_text
77
+
78
+
79
+ with gr.Blocks() as demo:
80
+ title = gr.Textbox(label="Title")
81
+ description = gr.Textbox(label="Description")
82
+ conversation_text = gr.Textbox(label="Conversation")
83
+ generate_button = gr.Button()
84
+ generate_button.click(fn=generate_next, inputs=[title,description,conversation_text], outputs=[title,description,conversation_text])
85
+
86
+ demo.launch()
87
+
88
+