seconds-0's picture
Add decoding controls, few-shot prompt, repetition guard
bd6edfe verified
import os
import torch
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.streamers import TextIteratorStreamer
MODEL_ID = os.environ.get("MODEL_ID", "seconds-0/nsa-117m-byte")
USE_TOKEN = os.environ.get("HF_TOKEN") is not None
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=USE_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
token=USE_TOKEN,
)
SYS_PROMPT = (
"You are a helpful assistant. Answer briefly and clearly. "
"Avoid repeating characters. If unsure, say 'I don't know'."
)
FEW_SHOTS = [
("Hello", "Hello!"),
("What is the capital of France?", "Paris."),
("2+2?", "4."),
]
def build_prompt(message: str, history: list[tuple[str, str]]) -> str:
# Minimal, byte-tokenizer-friendly prompt (no special tokens)
lines = [f"System: {SYS_PROMPT}"]
for q, a in FEW_SHOTS:
lines.append(f"User: {q}")
lines.append(f"Assistant: {a}")
for u, a in history:
if u:
lines.append(f"User: {u}")
if a:
lines.append(f"Assistant: {a}")
lines.append(f"User: {message}")
lines.append("Assistant:")
return "\n".join(lines)
def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size):
prompt = build_prompt(message, history)
x = tok(prompt, return_tensors="pt")
if torch.cuda.is_available():
x = {k: v.to(model.device) for k, v in x.items()}
# Streaming using TextIteratorStreamer
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**x,
max_new_tokens=int(max_new_tokens),
do_sample=bool(temperature > 0.0),
top_p=float(top_p),
top_k=int(top_k),
temperature=max(1e-6, float(temperature)),
repetition_penalty=max(1.0, float(repetition_penalty)),
no_repeat_ngram_size=int(no_repeat_ngram_size),
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for new_text in streamer:
partial += new_text
# Simple repetition guard: if too many identical trailing chars, stop early
tail = partial[-200:]
if len(tail) >= 10 and any(tail.endswith(c * 10) for c in set(tail)):
break
yield partial
with gr.Blocks() as demo:
gr.Markdown("# NSA 117M Chat (byte tokenizer)")
gr.Markdown("Byte-level tokenizer (vocab=256). Streaming enabled. Use controls to reduce repetition.")
chat = gr.Chatbot()
with gr.Row():
msg = gr.Textbox(label="Message")
with gr.Accordion("Decoding controls", open=False):
max_new = gr.Slider(16, 512, value=128, step=16, label="Max new tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature (0 = greedy)")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
top_k = gr.Slider(0, 200, value=50, step=10, label="Top-k (0 disables)")
rep_pen = gr.Slider(1.0, 2.0, value=1.3, step=0.05, label="Repetition penalty")
ngram = gr.Slider(0, 6, value=3, step=1, label="No-repeat n-gram size (0 disables)")
def user_submit(user_message, history):
return "", history + [[user_message, None]]
def bot_respond(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size):
user_message = history[-1][0]
gen = respond(
user_message,
[(u, a) for u, a in history[:-1] if u is not None and a is not None],
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty,
no_repeat_ngram_size,
)
partial = ""
for part in gen:
partial = part
history[-1][1] = partial
yield history
msg.submit(user_submit, [msg, chat], [msg, chat]).then(
bot_respond,
[chat, max_new, temperature, top_p, top_k, rep_pen, ngram],
[chat],
)
if __name__ == "__main__":
demo.launch()