seconds-0 commited on
Commit
bd6edfe
·
verified ·
1 Parent(s): ae4db83

Add decoding controls, few-shot prompt, repetition guard

Browse files
Files changed (1) hide show
  1. app.py +78 -22
app.py CHANGED
@@ -18,17 +18,36 @@ model = AutoModelForCausalLM.from_pretrained(
18
  )
19
 
20
 
21
- def respond(message, history):
22
- # Simpler prompt (byte tokenizer is sensitive to unfamiliar special tokens)
23
- turns = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  for u, a in history:
25
  if u:
26
- turns.append(f"User: {u}")
27
  if a:
28
- turns.append(f"Assistant: {a}")
29
- turns.append(f"User: {message}")
30
- turns.append("Assistant:")
31
- prompt = "\n".join(turns)
 
 
 
 
32
  x = tok(prompt, return_tensors="pt")
33
  if torch.cuda.is_available():
34
  x = {k: v.to(model.device) for k, v in x.items()}
@@ -36,12 +55,13 @@ def respond(message, history):
36
  streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
37
  gen_kwargs = dict(
38
  **x,
39
- max_new_tokens=128,
40
- do_sample=True,
41
- top_p=0.9,
42
- temperature=0.7,
43
- repetition_penalty=1.2,
44
- no_repeat_ngram_size=3,
 
45
  streamer=streamer,
46
  )
47
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
@@ -49,17 +69,53 @@ def respond(message, history):
49
  partial = ""
50
  for new_text in streamer:
51
  partial += new_text
 
 
 
 
52
  yield partial
53
 
54
 
55
- demo = gr.ChatInterface(
56
- fn=respond,
57
- title="NSA 117M Chat (byte tokenizer)",
58
- description=(
59
- "Byte-level tokenizer (vocab=256). No KV cache in v1; streaming enabled."
60
- ),
61
- examples=[["Write a haiku about sparse attention."], ["Explain NSA branches succinctly."]],
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  if __name__ == "__main__":
 
18
  )
19
 
20
 
21
+ SYS_PROMPT = (
22
+ "You are a helpful assistant. Answer briefly and clearly. "
23
+ "Avoid repeating characters. If unsure, say 'I don't know'."
24
+ )
25
+
26
+ FEW_SHOTS = [
27
+ ("Hello", "Hello!"),
28
+ ("What is the capital of France?", "Paris."),
29
+ ("2+2?", "4."),
30
+ ]
31
+
32
+
33
+ def build_prompt(message: str, history: list[tuple[str, str]]) -> str:
34
+ # Minimal, byte-tokenizer-friendly prompt (no special tokens)
35
+ lines = [f"System: {SYS_PROMPT}"]
36
+ for q, a in FEW_SHOTS:
37
+ lines.append(f"User: {q}")
38
+ lines.append(f"Assistant: {a}")
39
  for u, a in history:
40
  if u:
41
+ lines.append(f"User: {u}")
42
  if a:
43
+ lines.append(f"Assistant: {a}")
44
+ lines.append(f"User: {message}")
45
+ lines.append("Assistant:")
46
+ return "\n".join(lines)
47
+
48
+
49
+ def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size):
50
+ prompt = build_prompt(message, history)
51
  x = tok(prompt, return_tensors="pt")
52
  if torch.cuda.is_available():
53
  x = {k: v.to(model.device) for k, v in x.items()}
 
55
  streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
56
  gen_kwargs = dict(
57
  **x,
58
+ max_new_tokens=int(max_new_tokens),
59
+ do_sample=bool(temperature > 0.0),
60
+ top_p=float(top_p),
61
+ top_k=int(top_k),
62
+ temperature=max(1e-6, float(temperature)),
63
+ repetition_penalty=max(1.0, float(repetition_penalty)),
64
+ no_repeat_ngram_size=int(no_repeat_ngram_size),
65
  streamer=streamer,
66
  )
67
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
 
69
  partial = ""
70
  for new_text in streamer:
71
  partial += new_text
72
+ # Simple repetition guard: if too many identical trailing chars, stop early
73
+ tail = partial[-200:]
74
+ if len(tail) >= 10 and any(tail.endswith(c * 10) for c in set(tail)):
75
+ break
76
  yield partial
77
 
78
 
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("# NSA 117M Chat (byte tokenizer)")
81
+ gr.Markdown("Byte-level tokenizer (vocab=256). Streaming enabled. Use controls to reduce repetition.")
82
+ chat = gr.Chatbot()
83
+ with gr.Row():
84
+ msg = gr.Textbox(label="Message")
85
+ with gr.Accordion("Decoding controls", open=False):
86
+ max_new = gr.Slider(16, 512, value=128, step=16, label="Max new tokens")
87
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature (0 = greedy)")
88
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
89
+ top_k = gr.Slider(0, 200, value=50, step=10, label="Top-k (0 disables)")
90
+ rep_pen = gr.Slider(1.0, 2.0, value=1.3, step=0.05, label="Repetition penalty")
91
+ ngram = gr.Slider(0, 6, value=3, step=1, label="No-repeat n-gram size (0 disables)")
92
+
93
+ def user_submit(user_message, history):
94
+ return "", history + [[user_message, None]]
95
+
96
+ def bot_respond(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size):
97
+ user_message = history[-1][0]
98
+ gen = respond(
99
+ user_message,
100
+ [(u, a) for u, a in history[:-1] if u is not None and a is not None],
101
+ max_new_tokens,
102
+ temperature,
103
+ top_p,
104
+ top_k,
105
+ repetition_penalty,
106
+ no_repeat_ngram_size,
107
+ )
108
+ partial = ""
109
+ for part in gen:
110
+ partial = part
111
+ history[-1][1] = partial
112
+ yield history
113
+
114
+ msg.submit(user_submit, [msg, chat], [msg, chat]).then(
115
+ bot_respond,
116
+ [chat, max_new, temperature, top_p, top_k, rep_pen, ngram],
117
+ [chat],
118
+ )
119
 
120
 
121
  if __name__ == "__main__":