Jia0603 commited on
Commit
a0fadff
·
verified ·
1 Parent(s): 9222068

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -34
app.py CHANGED
@@ -9,8 +9,9 @@ MODEL_ID = "LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42"
9
  device = 0 if torch.cuda.is_available() else -1
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
- # if tokenizer.pad_token is None:
13
- # tokenizer.pad_token = tokenizer.eos_token_id
 
14
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
15
 
16
  if torch.cuda.is_available():
@@ -19,31 +20,44 @@ if torch.cuda.is_available():
19
  def generate_reply(prompt, max_new_tokens, temperature, top_p):
20
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
21
 
 
 
22
  output_ids = model.generate(
23
  **inputs,
24
  max_new_tokens=int(max_new_tokens),
25
  do_sample=True,
26
  temperature=float(temperature),
27
  top_p=float(top_p),
28
- eos_token_id=None,
 
 
29
  pad_token_id=tokenizer.eos_token_id
30
  )
31
-
 
32
  text = tokenizer.decode(
33
  output_ids[0],
34
- skip_special_tokens=False,
35
  clean_up_tokenization_spaces=True
36
  )
37
 
38
  return text
39
 
40
  def clean_reply(text):
41
-
42
  text = text.strip()
 
 
 
 
 
 
 
 
 
43
 
44
- for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
45
- if text.startswith(prefix):
46
- text = text[len(prefix):].strip()
47
 
48
  # lines = [l.strip() for l in text.split("\n")]
49
  # lines = [l for l in lines if l]
@@ -51,39 +65,56 @@ def clean_reply(text):
51
  # if len(lines) == 0:
52
  # return ""
53
 
54
- # return lines[0]
55
- return text
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
58
-
59
  if chat_history is None:
60
  chat_history = []
 
61
 
62
- # Build conversation history
63
- # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
64
- history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \
65
- The assistant should tell a story according to human's requirment.\n"
66
-
67
- for msg in chat_history:
68
- role = "Human" if msg["role"] == "user" else "AI"
69
- history_text += f"{role}: {msg['content']}\n"
70
-
71
- history_text += f"Human: {user_message}\nAI:"
72
-
73
- # -------- generate ----------
74
- raw = generate_reply(
75
- history_text,
76
  max_new_tokens,
77
  temperature,
78
  top_p
79
  )
80
- # Only keep new part
81
- reply = raw[len(history_text):]
82
- reply = clean_reply(reply)
83
- # ------------------------------
84
 
85
  chat_history.append({"role": "user", "content": user_message})
86
- chat_history.append({"role": "assistant", "content": reply})
87
 
88
  return "", chat_history, chat_history
89
 
@@ -96,12 +127,14 @@ with gr.Blocks() as demo:
96
  chat = gr.Chatbot(elem_id="chatbot", label="Conversation")
97
  msg = gr.Textbox(label="Your message")
98
  send = gr.Button("Send")
99
- max_tokens = gr.Slider(50, 512, value=256, label="max_new_tokens")
100
- temp = gr.Slider(0.6, 1.5, value=0.8, label="temperature")
101
- top_p = gr.Slider(0.1, 1.0, value=0.9, label="top_p")
102
 
103
  with gr.Column(scale=1):
104
  gr.Markdown("Model: " + MODEL_ID)
 
 
105
 
106
  state = gr.State([])
107
 
 
9
  device = 0 if torch.cuda.is_available() else -1
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+ if tokenizer.pad_token is None:
13
+ tokenizer.pad_token = tokenizer.eos_token_id
14
+
15
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
16
 
17
  if torch.cuda.is_available():
 
20
  def generate_reply(prompt, max_new_tokens, temperature, top_p):
21
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
 
23
+ input_len = inputs["input_ids"].shape[1]
24
+
25
  output_ids = model.generate(
26
  **inputs,
27
  max_new_tokens=int(max_new_tokens),
28
  do_sample=True,
29
  temperature=float(temperature),
30
  top_p=float(top_p),
31
+ no_repeat_ngram_size=3,
32
+ repetition_penalty=1.2,
33
+ eos_token_id=tokenizer.eos_token_id,
34
  pad_token_id=tokenizer.eos_token_id
35
  )
36
+
37
+ generated_tokens = output_ids[0][input_len:]
38
  text = tokenizer.decode(
39
  output_ids[0],
40
+ skip_special_tokens=True,
41
  clean_up_tokenization_spaces=True
42
  )
43
 
44
  return text
45
 
46
  def clean_reply(text):
 
47
  text = text.strip()
48
+ stop_words = ["Human:", "User:", "AI:", "Assistant:"]
49
+ for word in stop_words:
50
+ if word in text:
51
+ text = text.split(word)[0]
52
+ return text.strip()
53
+
54
+ # def clean_reply(text):
55
+
56
+ # text = text.strip()
57
 
58
+ # for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
59
+ # if text.startswith(prefix):
60
+ # text = text[len(prefix):].strip()
61
 
62
  # lines = [l.strip() for l in text.split("\n")]
63
  # lines = [l for l in lines if l]
 
65
  # if len(lines) == 0:
66
  # return ""
67
 
68
+ # return lines[0]
 
69
 
70
+ # def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
71
+
72
+ # if chat_history is None:
73
+ # chat_history = []
74
+
75
+ # # Build conversation history
76
+ # # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
77
+ # history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \
78
+ # The assistant should tell a story according to human's requirment.\n"
79
+
80
+ # for msg in chat_history:
81
+ # role = "Human" if msg["role"] == "user" else "AI"
82
+ # history_text += f"{role}: {msg['content']}\n"
83
+
84
+ # history_text += f"Human: {user_message}\nAI:"
85
+
86
+ # # -------- generate ----------
87
+ # raw = generate_reply(
88
+ # history_text,
89
+ # max_new_tokens,
90
+ # temperature,
91
+ # top_p
92
+ # )
93
+ # # Only keep new part
94
+ # reply = raw[len(history_text):]
95
+ # reply = clean_reply(reply)
96
+ # # ------------------------------
97
+
98
+ # chat_history.append({"role": "user", "content": user_message})
99
+ # chat_history.append({"role": "assistant", "content": reply})
100
+
101
+ # return "", chat_history, chat_history
102
  def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
 
103
  if chat_history is None:
104
  chat_history = []
105
+ prompt_text = f"User request: {user_message}\n\nHere is a long, creative story based on the request:\nOnce upon a time,"
106
 
107
+ reply = generate_reply(
108
+ prompt_text,
 
 
 
 
 
 
 
 
 
 
 
 
109
  max_new_tokens,
110
  temperature,
111
  top_p
112
  )
113
+
114
+ final_reply = "Once upon a time, " + clean_reply(reply)
 
 
115
 
116
  chat_history.append({"role": "user", "content": user_message})
117
+ chat_history.append({"role": "assistant", "content": final_reply})
118
 
119
  return "", chat_history, chat_history
120
 
 
127
  chat = gr.Chatbot(elem_id="chatbot", label="Conversation")
128
  msg = gr.Textbox(label="Your message")
129
  send = gr.Button("Send")
130
+ max_tokens = gr.Slider(50, 1025, value=300, label="max_new_tokens")
131
+ temp = gr.Slider(0.6, 1.5, value=1.0, label="temperature")
132
+ top_p = gr.Slider(0.1, 1.0, value=0.95, label="top_p")
133
 
134
  with gr.Column(scale=1):
135
  gr.Markdown("Model: " + MODEL_ID)
136
+ gr.Markdown("Note: GPT-2 is a base model. If prompts are too complex, \
137
+ it might get confused. This setup is optimized for storytelling.")
138
 
139
  state = gr.State([])
140