|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
import bdh |
|
|
from bdh import BDHConfig, BDH |
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
config = BDHConfig( |
|
|
n_layer=2, |
|
|
n_embd=128, |
|
|
n_head=4, |
|
|
mlp_internal_dim_multiplier=32, |
|
|
vocab_size=256, |
|
|
) |
|
|
|
|
|
model = BDH(config).to(device) |
|
|
model.eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_text(prompt: str, max_tokens: int = 50): |
|
|
if len(prompt) == 0: |
|
|
return "Please enter a prompt." |
|
|
|
|
|
idx = torch.tensor( |
|
|
bytearray(prompt, "utf-8"), |
|
|
dtype=torch.long, |
|
|
device=device |
|
|
).unsqueeze(0) |
|
|
|
|
|
out = model.generate( |
|
|
idx, |
|
|
max_new_tokens=max_tokens, |
|
|
top_k=5 |
|
|
) |
|
|
|
|
|
return bytes( |
|
|
out.squeeze(0).to(torch.uint8).cpu() |
|
|
).decode(errors="replace") |
|
|
|
|
|
|
|
|
gr.Interface( |
|
|
fn=generate_text, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Prompt"), |
|
|
gr.Slider(10, 100, value=50, step=10, label="Max tokens") |
|
|
], |
|
|
outputs=gr.Textbox(label="BDH Output"), |
|
|
title="BDH – Beyond Transformer (Live Demo)", |
|
|
description="Inference-only demo of the BDH architecture" |
|
|
).launch() |
|
|
|