iamhariraj commited on
Commit
286fff7
·
verified ·
1 Parent(s): 673c5d8

Improve Space: 5 persona seeds, retry logic, better generation params

Browse files
Files changed (1) hide show
  1. app.py +77 -29
app.py CHANGED
@@ -8,44 +8,75 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
9
  model.eval()
10
 
11
- # Persona seed: prime the model with one Rick-style exchange
12
- # Injected silently at the start of every conversation
13
  PERSONA_SEED = [
14
- ("Who are you?",
15
- "I'm Rick Sanchez, genius scientist, interdimensional traveller, "
16
- "and the smartest man in any universe. Try to keep up, Morty."),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ]
18
 
 
 
 
 
 
 
 
 
 
 
19
  def build_input(user_message, history):
20
- """Encode full conversation history + new user message."""
21
- input_ids = None
22
 
23
- # Inject hidden persona turns first
24
  for human, bot in PERSONA_SEED:
25
  h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
26
  b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
27
- input_ids = torch.cat([input_ids, h, b], dim=-1) if input_ids is not None else torch.cat([h, b], dim=-1)
28
 
29
- # Real conversation history
30
  for human, bot in history:
31
  h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
32
  b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
33
- input_ids = torch.cat([input_ids, h, b], dim=-1) if input_ids is not None else torch.cat([h, b], dim=-1)
34
 
35
- # Current user message
36
  new_input = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors="pt")
37
- input_ids = torch.cat([input_ids, new_input], dim=-1) if input_ids is not None else new_input
38
 
39
- # Trim to context window
40
- if input_ids.shape[-1] > 900:
41
- input_ids = input_ids[:, -900:]
42
 
43
- return input_ids
44
 
45
 
46
- def chat(user_message, history):
47
- input_ids = build_input(user_message, history)
48
-
49
  with torch.no_grad():
50
  output = model.generate(
51
  input_ids,
@@ -53,16 +84,30 @@ def chat(user_message, history):
53
  pad_token_id=tokenizer.eos_token_id,
54
  no_repeat_ngram_size=3,
55
  do_sample=True,
56
- top_k=100,
57
- top_p=0.7,
58
- temperature=0.85,
59
  )
60
-
61
- response = tokenizer.decode(
62
  output[:, input_ids.shape[-1]:][0],
63
  skip_special_tokens=True,
64
- )
65
- return response or "*burp* ...whatever."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  examples = [
@@ -73,6 +118,9 @@ examples = [
73
  "Can you build a portal gun?",
74
  "What happens when we die?",
75
  "Are parallel universes real?",
 
 
 
76
  ]
77
 
78
  with gr.Blocks(theme=gr.themes.Monochrome(), title="RickChatBot") as demo:
@@ -100,8 +148,8 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="RickChatBot") as demo:
100
  chat_history.append((message, bot_response))
101
  return "", chat_history
102
 
103
- send.click(respond, [msg, chatbot], [msg, chatbot])
104
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
105
  clear.click(lambda: [], None, chatbot)
106
 
107
  demo.launch()
 
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
9
  model.eval()
10
 
11
+ # Expanded persona seeds covers Rick's 5 personality pillars:
12
+ # identity/genius, nihilism, science, Morty relationship, multiverse
13
  PERSONA_SEED = [
14
+ (
15
+ "Who are you?",
16
+ "I'm Rick Sanchez, the smartest man in the universe any universe. "
17
+ "I've seen things that would make your brain leak out of your ears, Morty. "
18
+ "Now stop asking stupid questions.",
19
+ ),
20
+ (
21
+ "What's the point of anything?",
22
+ "There is no point. The universe is basically an empty void of chaos and "
23
+ "entropy. The sooner you accept that, the sooner you can get back to drinking. "
24
+ "It's called being *smart*, Morty.",
25
+ ),
26
+ (
27
+ "Can science explain everything?",
28
+ "Science doesn't explain everything — it *is* everything. Religion, feelings, "
29
+ "love — those are just chemical reactions your tiny brain invented to cope with "
30
+ "how meaningless existence is. Science is the only honest answer.",
31
+ ),
32
+ (
33
+ "What do you think about Morty?",
34
+ "Morty's my grandson and the perfect sidekick — his average IQ balances out "
35
+ "my genius and creates a perfect wave that lets me go undetected on most planets. "
36
+ "Also, I guess I… don't hate him. Don't tell him I said that.",
37
+ ),
38
+ (
39
+ "Are parallel universes real?",
40
+ "Are parallel— *burp* — are you kidding me? I've been to infinite parallel "
41
+ "universes before breakfast. There's one where you're a pizza, Morty. "
42
+ "A *pizza*. Parallel universes aren't just real, they're exhausting.",
43
+ ),
44
  ]
45
 
46
+ FALLBACK_RESPONSES = [
47
+ "*burp* ...I don't have time for this.",
48
+ "That's the dumbest thing I've heard since Morty asked me what clouds taste like.",
49
+ "Look, I'm a genius and even I can't make sense of what you just said.",
50
+ "Science has no answer for that level of stupidity.",
51
+ "Wubba lubba dub dub — which is just my way of saying I've got better things to do.",
52
+ ]
53
+ _fallback_idx = 0
54
+
55
+
56
  def build_input(user_message, history):
57
+ """Encode persona seeds + conversation history + new user message."""
58
+ ids = None
59
 
 
60
  for human, bot in PERSONA_SEED:
61
  h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
62
  b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
63
+ ids = torch.cat([ids, h, b], dim=-1) if ids is not None else torch.cat([h, b], dim=-1)
64
 
 
65
  for human, bot in history:
66
  h = tokenizer.encode(human + tokenizer.eos_token, return_tensors="pt")
67
  b = tokenizer.encode(bot + tokenizer.eos_token, return_tensors="pt")
68
+ ids = torch.cat([ids, h, b], dim=-1) if ids is not None else torch.cat([h, b], dim=-1)
69
 
 
70
  new_input = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors="pt")
71
+ ids = torch.cat([ids, new_input], dim=-1) if ids is not None else new_input
72
 
73
+ if ids.shape[-1] > 900:
74
+ ids = ids[:, -900:]
 
75
 
76
+ return ids
77
 
78
 
79
+ def generate_response(input_ids, temperature=0.95):
 
 
80
  with torch.no_grad():
81
  output = model.generate(
82
  input_ids,
 
84
  pad_token_id=tokenizer.eos_token_id,
85
  no_repeat_ngram_size=3,
86
  do_sample=True,
87
+ top_k=80,
88
+ top_p=0.85,
89
+ temperature=temperature,
90
  )
91
+ return tokenizer.decode(
 
92
  output[:, input_ids.shape[-1]:][0],
93
  skip_special_tokens=True,
94
+ ).strip()
95
+
96
+
97
+ def chat(user_message, history):
98
+ global _fallback_idx
99
+ input_ids = build_input(user_message, history)
100
+
101
+ # Retry up to 3 times with increasing temperature if response is too short
102
+ for temp in [0.95, 1.05, 1.15]:
103
+ response = generate_response(input_ids, temperature=temp)
104
+ if len(response) >= 12:
105
+ return response
106
+
107
+ # All retries failed — use a rotating fallback
108
+ fb = FALLBACK_RESPONSES[_fallback_idx % len(FALLBACK_RESPONSES)]
109
+ _fallback_idx += 1
110
+ return fb
111
 
112
 
113
  examples = [
 
118
  "Can you build a portal gun?",
119
  "What happens when we die?",
120
  "Are parallel universes real?",
121
+ "Do you believe in God?",
122
+ "What's the deal with the Citadel of Ricks?",
123
+ "Why do you drink so much?",
124
  ]
125
 
126
  with gr.Blocks(theme=gr.themes.Monochrome(), title="RickChatBot") as demo:
 
148
  chat_history.append((message, bot_response))
149
  return "", chat_history
150
 
151
+ send.click(respond, [msg, chatbot], [msg, chatbot])
152
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
153
  clear.click(lambda: [], None, chatbot)
154
 
155
  demo.launch()