import os import torch import gradio as gr from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation.streamers import TextIteratorStreamer MODEL_ID = os.environ.get("MODEL_ID", "seconds-0/nsa-117m-byte") USE_TOKEN = os.environ.get("HF_TOKEN") is not None tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=USE_TOKEN) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, token=USE_TOKEN, ) SYS_PROMPT = ( "You are a helpful assistant. Answer briefly and clearly. " "Avoid repeating characters. If unsure, say 'I don't know'." ) FEW_SHOTS = [ ("Hello", "Hello!"), ("What is the capital of France?", "Paris."), ("2+2?", "4."), ] def build_prompt(message: str, history: list[tuple[str, str]]) -> str: # Minimal, byte-tokenizer-friendly prompt (no special tokens) lines = [f"System: {SYS_PROMPT}"] for q, a in FEW_SHOTS: lines.append(f"User: {q}") lines.append(f"Assistant: {a}") for u, a in history: if u: lines.append(f"User: {u}") if a: lines.append(f"Assistant: {a}") lines.append(f"User: {message}") lines.append("Assistant:") return "\n".join(lines) def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size): prompt = build_prompt(message, history) x = tok(prompt, return_tensors="pt") if torch.cuda.is_available(): x = {k: v.to(model.device) for k, v in x.items()} # Streaming using TextIteratorStreamer streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **x, max_new_tokens=int(max_new_tokens), do_sample=bool(temperature > 0.0), top_p=float(top_p), top_k=int(top_k), temperature=max(1e-6, float(temperature)), repetition_penalty=max(1.0, float(repetition_penalty)), no_repeat_ngram_size=int(no_repeat_ngram_size), streamer=streamer, ) thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() partial = "" for new_text in streamer: partial += new_text # Simple repetition guard: if too many identical trailing chars, stop early tail = partial[-200:] if len(tail) >= 10 and any(tail.endswith(c * 10) for c in set(tail)): break yield partial with gr.Blocks() as demo: gr.Markdown("# NSA 117M Chat (byte tokenizer)") gr.Markdown("Byte-level tokenizer (vocab=256). Streaming enabled. Use controls to reduce repetition.") chat = gr.Chatbot() with gr.Row(): msg = gr.Textbox(label="Message") with gr.Accordion("Decoding controls", open=False): max_new = gr.Slider(16, 512, value=128, step=16, label="Max new tokens") temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature (0 = greedy)") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") top_k = gr.Slider(0, 200, value=50, step=10, label="Top-k (0 disables)") rep_pen = gr.Slider(1.0, 2.0, value=1.3, step=0.05, label="Repetition penalty") ngram = gr.Slider(0, 6, value=3, step=1, label="No-repeat n-gram size (0 disables)") def user_submit(user_message, history): return "", history + [[user_message, None]] def bot_respond(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size): user_message = history[-1][0] gen = respond( user_message, [(u, a) for u, a in history[:-1] if u is not None and a is not None], max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size, ) partial = "" for part in gen: partial = part history[-1][1] = partial yield history msg.submit(user_submit, [msg, chat], [msg, chat]).then( bot_respond, [chat, max_new, temperature, top_p, top_k, rep_pen, ngram], [chat], ) if __name__ == "__main__": demo.launch()