Spaces:
Sleeping
Sleeping
Improve Space: 5 persona seeds, retry logic, better generation params
Browse files
app.py
CHANGED
|
@@ -8,44 +8,75 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
| 8 |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
| 9 |
model.eval()
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
#
|
| 13 |
PERSONA_SEED = [
|
| 14 |
-
(
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
]
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def build_input(user_message, history):
|
| 20 |
-
"""Encode
|
| 21 |
-
|
| 22 |
|
| 23 |
-
# Inject hidden persona turns first
|
| 24 |
for human, bot in PERSONA_SEED:
|
| 25 |
h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
|
| 26 |
b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
|
| 27 |
-
|
| 28 |
|
| 29 |
-
# Real conversation history
|
| 30 |
for human, bot in history:
|
| 31 |
h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
|
| 32 |
b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
|
| 33 |
-
|
| 34 |
|
| 35 |
-
# Current user message
|
| 36 |
new_input = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors="pt")
|
| 37 |
-
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
input_ids = input_ids[:, -900:]
|
| 42 |
|
| 43 |
-
return
|
| 44 |
|
| 45 |
|
| 46 |
-
def
|
| 47 |
-
input_ids = build_input(user_message, history)
|
| 48 |
-
|
| 49 |
with torch.no_grad():
|
| 50 |
output = model.generate(
|
| 51 |
input_ids,
|
|
@@ -53,16 +84,30 @@ def chat(user_message, history):
|
|
| 53 |
pad_token_id=tokenizer.eos_token_id,
|
| 54 |
no_repeat_ngram_size=3,
|
| 55 |
do_sample=True,
|
| 56 |
-
top_k=
|
| 57 |
-
top_p=0.
|
| 58 |
-
temperature=
|
| 59 |
)
|
| 60 |
-
|
| 61 |
-
response = tokenizer.decode(
|
| 62 |
output[:, input_ids.shape[-1]:][0],
|
| 63 |
skip_special_tokens=True,
|
| 64 |
-
)
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
examples = [
|
|
@@ -73,6 +118,9 @@ examples = [
|
|
| 73 |
"Can you build a portal gun?",
|
| 74 |
"What happens when we die?",
|
| 75 |
"Are parallel universes real?",
|
|
|
|
|
|
|
|
|
|
| 76 |
]
|
| 77 |
|
| 78 |
with gr.Blocks(theme=gr.themes.Monochrome(), title="RickChatBot") as demo:
|
|
@@ -100,8 +148,8 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="RickChatBot") as demo:
|
|
| 100 |
chat_history.append((message, bot_response))
|
| 101 |
return "", chat_history
|
| 102 |
|
| 103 |
-
send.click(respond,
|
| 104 |
-
msg.submit(respond,
|
| 105 |
clear.click(lambda: [], None, chatbot)
|
| 106 |
|
| 107 |
demo.launch()
|
|
|
|
| 8 |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
| 9 |
model.eval()
|
| 10 |
|
| 11 |
+
# Expanded persona seeds — covers Rick's 5 personality pillars:
|
| 12 |
+
# identity/genius, nihilism, science, Morty relationship, multiverse
|
| 13 |
PERSONA_SEED = [
|
| 14 |
+
(
|
| 15 |
+
"Who are you?",
|
| 16 |
+
"I'm Rick Sanchez, the smartest man in the universe — any universe. "
|
| 17 |
+
"I've seen things that would make your brain leak out of your ears, Morty. "
|
| 18 |
+
"Now stop asking stupid questions.",
|
| 19 |
+
),
|
| 20 |
+
(
|
| 21 |
+
"What's the point of anything?",
|
| 22 |
+
"There is no point. The universe is basically an empty void of chaos and "
|
| 23 |
+
"entropy. The sooner you accept that, the sooner you can get back to drinking. "
|
| 24 |
+
"It's called being *smart*, Morty.",
|
| 25 |
+
),
|
| 26 |
+
(
|
| 27 |
+
"Can science explain everything?",
|
| 28 |
+
"Science doesn't explain everything — it *is* everything. Religion, feelings, "
|
| 29 |
+
"love — those are just chemical reactions your tiny brain invented to cope with "
|
| 30 |
+
"how meaningless existence is. Science is the only honest answer.",
|
| 31 |
+
),
|
| 32 |
+
(
|
| 33 |
+
"What do you think about Morty?",
|
| 34 |
+
"Morty's my grandson and the perfect sidekick — his average IQ balances out "
|
| 35 |
+
"my genius and creates a perfect wave that lets me go undetected on most planets. "
|
| 36 |
+
"Also, I guess I… don't hate him. Don't tell him I said that.",
|
| 37 |
+
),
|
| 38 |
+
(
|
| 39 |
+
"Are parallel universes real?",
|
| 40 |
+
"Are parallel— *burp* — are you kidding me? I've been to infinite parallel "
|
| 41 |
+
"universes before breakfast. There's one where you're a pizza, Morty. "
|
| 42 |
+
"A *pizza*. Parallel universes aren't just real, they're exhausting.",
|
| 43 |
+
),
|
| 44 |
]
|
| 45 |
|
| 46 |
+
FALLBACK_RESPONSES = [
|
| 47 |
+
"*burp* ...I don't have time for this.",
|
| 48 |
+
"That's the dumbest thing I've heard since Morty asked me what clouds taste like.",
|
| 49 |
+
"Look, I'm a genius and even I can't make sense of what you just said.",
|
| 50 |
+
"Science has no answer for that level of stupidity.",
|
| 51 |
+
"Wubba lubba dub dub — which is just my way of saying I've got better things to do.",
|
| 52 |
+
]
|
| 53 |
+
_fallback_idx = 0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
def build_input(user_message, history):
|
| 57 |
+
"""Encode persona seeds + conversation history + new user message."""
|
| 58 |
+
ids = None
|
| 59 |
|
|
|
|
| 60 |
for human, bot in PERSONA_SEED:
|
| 61 |
h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
|
| 62 |
b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
|
| 63 |
+
ids = torch.cat([ids, h, b], dim=-1) if ids is not None else torch.cat([h, b], dim=-1)
|
| 64 |
|
|
|
|
| 65 |
for human, bot in history:
|
| 66 |
h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
|
| 67 |
b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
|
| 68 |
+
ids = torch.cat([ids, h, b], dim=-1) if ids is not None else torch.cat([h, b], dim=-1)
|
| 69 |
|
|
|
|
| 70 |
new_input = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors="pt")
|
| 71 |
+
ids = torch.cat([ids, new_input], dim=-1) if ids is not None else new_input
|
| 72 |
|
| 73 |
+
if ids.shape[-1] > 900:
|
| 74 |
+
ids = ids[:, -900:]
|
|
|
|
| 75 |
|
| 76 |
+
return ids
|
| 77 |
|
| 78 |
|
| 79 |
+
def generate_response(input_ids, temperature=0.95):
|
|
|
|
|
|
|
| 80 |
with torch.no_grad():
|
| 81 |
output = model.generate(
|
| 82 |
input_ids,
|
|
|
|
| 84 |
pad_token_id=tokenizer.eos_token_id,
|
| 85 |
no_repeat_ngram_size=3,
|
| 86 |
do_sample=True,
|
| 87 |
+
top_k=80,
|
| 88 |
+
top_p=0.85,
|
| 89 |
+
temperature=temperature,
|
| 90 |
)
|
| 91 |
+
return tokenizer.decode(
|
|
|
|
| 92 |
output[:, input_ids.shape[-1]:][0],
|
| 93 |
skip_special_tokens=True,
|
| 94 |
+
).strip()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def chat(user_message, history):
|
| 98 |
+
global _fallback_idx
|
| 99 |
+
input_ids = build_input(user_message, history)
|
| 100 |
+
|
| 101 |
+
# Retry up to 3 times with increasing temperature if response is too short
|
| 102 |
+
for temp in [0.95, 1.05, 1.15]:
|
| 103 |
+
response = generate_response(input_ids, temperature=temp)
|
| 104 |
+
if len(response) >= 12:
|
| 105 |
+
return response
|
| 106 |
+
|
| 107 |
+
# All retries failed — use a rotating fallback
|
| 108 |
+
fb = FALLBACK_RESPONSES[_fallback_idx % len(FALLBACK_RESPONSES)]
|
| 109 |
+
_fallback_idx += 1
|
| 110 |
+
return fb
|
| 111 |
|
| 112 |
|
| 113 |
examples = [
|
|
|
|
| 118 |
"Can you build a portal gun?",
|
| 119 |
"What happens when we die?",
|
| 120 |
"Are parallel universes real?",
|
| 121 |
+
"Do you believe in God?",
|
| 122 |
+
"What's the deal with the Citadel of Ricks?",
|
| 123 |
+
"Why do you drink so much?",
|
| 124 |
]
|
| 125 |
|
| 126 |
with gr.Blocks(theme=gr.themes.Monochrome(), title="RickChatBot") as demo:
|
|
|
|
| 148 |
chat_history.append((message, bot_response))
|
| 149 |
return "", chat_history
|
| 150 |
|
| 151 |
+
send.click(respond, [msg, chatbot], [msg, chatbot])
|
| 152 |
+
msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
| 153 |
clear.click(lambda: [], None, chatbot)
|
| 154 |
|
| 155 |
demo.launch()
|