Jia0603 commited on
Commit
28a2ee2
·
verified ·
1 Parent(s): e284a04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -9,6 +9,8 @@ MODEL_ID = "LMSeed/GPT2-small-distilled-900M_None_ppo-1000K-seed42"
9
  device = 0 if torch.cuda.is_available() else -1
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
12
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
13
 
14
  if torch.cuda.is_available():
@@ -38,18 +40,18 @@ def generate_reply(prompt, max_new_tokens, temperature, top_p):
38
  def clean_reply(text):
39
 
40
  text = text.strip()
41
-
42
- for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
43
- if text.startswith(prefix):
44
- text = text[len(prefix):].strip()
45
 
46
- lines = [l.strip() for l in text.split("\n")]
47
- lines = [l for l in lines if l]
48
 
49
- if len(lines) == 0:
50
- return ""
51
 
52
- return lines[0]
53
 
54
  def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
55
 
@@ -58,8 +60,8 @@ def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=
58
 
59
  # Build conversation history
60
  # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
61
- history_text = "The following is a friendly conversation between a human and an AI story-telling assistant. \
62
- The assistant should tell a story according to human's requirment.\n"
63
 
64
  for msg in chat_history:
65
  role = "Human" if msg["role"] == "user" else "AI"
 
9
  device = 0 if torch.cuda.is_available() else -1
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+ if tokenizer.pad_token is None:
13
+ tokenizer.pad_token = tokenizer.eos_token_id
14
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
15
 
16
  if torch.cuda.is_available():
 
40
  def clean_reply(text):
41
 
42
  text = text.strip()
43
+ return text
44
+ # for prefix in ["Assistant:", "assistant:", "User:", "user:"]:
45
+ # if text.startswith(prefix):
46
+ # text = text[len(prefix):].strip()
47
 
48
+ # lines = [l.strip() for l in text.split("\n")]
49
+ # lines = [l for l in lines if l]
50
 
51
+ # if len(lines) == 0:
52
+ # return ""
53
 
54
+ # return lines[0]
55
 
56
  def chat_with_model(user_message, chat_history, max_new_tokens=256, temperature=0.8, top_p=0.9):
57
 
 
60
 
61
  # Build conversation history
62
  # history_text = "The following is a friendly conversation between a human and an AI assistant.\n"
63
+ history_text = "" #"The following is a friendly conversation between a human and an AI story-telling assistant. \
64
+ # The assistant should tell a story according to human's requirment.\n"
65
 
66
  for msg in chat_history:
67
  role = "Human" if msg["role"] == "user" else "AI"