Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from threading import Thread | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| MODEL_ID = os.environ.get("MODEL_ID", "seconds-0/nsa-117m-byte") | |
| USE_TOKEN = os.environ.get("HF_TOKEN") is not None | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=USE_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| token=USE_TOKEN, | |
| ) | |
| SYS_PROMPT = ( | |
| "You are a helpful assistant. Answer briefly and clearly. " | |
| "Avoid repeating characters. If unsure, say 'I don't know'." | |
| ) | |
| FEW_SHOTS = [ | |
| ("Hello", "Hello!"), | |
| ("What is the capital of France?", "Paris."), | |
| ("2+2?", "4."), | |
| ] | |
| def build_prompt(message: str, history: list[tuple[str, str]]) -> str: | |
| # Minimal, byte-tokenizer-friendly prompt (no special tokens) | |
| lines = [f"System: {SYS_PROMPT}"] | |
| for q, a in FEW_SHOTS: | |
| lines.append(f"User: {q}") | |
| lines.append(f"Assistant: {a}") | |
| for u, a in history: | |
| if u: | |
| lines.append(f"User: {u}") | |
| if a: | |
| lines.append(f"Assistant: {a}") | |
| lines.append(f"User: {message}") | |
| lines.append("Assistant:") | |
| return "\n".join(lines) | |
| def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size): | |
| prompt = build_prompt(message, history) | |
| x = tok(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| x = {k: v.to(model.device) for k, v in x.items()} | |
| # Streaming using TextIteratorStreamer | |
| streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| **x, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=bool(temperature > 0.0), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| temperature=max(1e-6, float(temperature)), | |
| repetition_penalty=max(1.0, float(repetition_penalty)), | |
| no_repeat_ngram_size=int(no_repeat_ngram_size), | |
| streamer=streamer, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| partial = "" | |
| for new_text in streamer: | |
| partial += new_text | |
| # Simple repetition guard: if too many identical trailing chars, stop early | |
| tail = partial[-200:] | |
| if len(tail) >= 10 and any(tail.endswith(c * 10) for c in set(tail)): | |
| break | |
| yield partial | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# NSA 117M Chat (byte tokenizer)") | |
| gr.Markdown("Byte-level tokenizer (vocab=256). Streaming enabled. Use controls to reduce repetition.") | |
| chat = gr.Chatbot() | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Message") | |
| with gr.Accordion("Decoding controls", open=False): | |
| max_new = gr.Slider(16, 512, value=128, step=16, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature (0 = greedy)") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| top_k = gr.Slider(0, 200, value=50, step=10, label="Top-k (0 disables)") | |
| rep_pen = gr.Slider(1.0, 2.0, value=1.3, step=0.05, label="Repetition penalty") | |
| ngram = gr.Slider(0, 6, value=3, step=1, label="No-repeat n-gram size (0 disables)") | |
| def user_submit(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot_respond(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, no_repeat_ngram_size): | |
| user_message = history[-1][0] | |
| gen = respond( | |
| user_message, | |
| [(u, a) for u, a in history[:-1] if u is not None and a is not None], | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| ) | |
| partial = "" | |
| for part in gen: | |
| partial = part | |
| history[-1][1] = partial | |
| yield history | |
| msg.submit(user_submit, [msg, chat], [msg, chat]).then( | |
| bot_respond, | |
| [chat, max_new, temperature, top_p, top_k, rep_pen, ngram], | |
| [chat], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |