BDHMODEL / app.py
vesakkivignesh's picture
Update app.py
c7cd6db verified
import gradio as gr
import torch
import bdh
from bdh import BDHConfig, BDH
# Device (Spaces usually run on CPU)
device = torch.device("cpu")
# Small config for live demo (important!)
config = BDHConfig(
n_layer=2,
n_embd=128,
n_head=4,
mlp_internal_dim_multiplier=32,
vocab_size=256,
)
model = BDH(config).to(device)
model.eval()
@torch.no_grad()
def generate_text(prompt: str, max_tokens: int = 50):
if len(prompt) == 0:
return "Please enter a prompt."
idx = torch.tensor(
bytearray(prompt, "utf-8"),
dtype=torch.long,
device=device
).unsqueeze(0)
out = model.generate(
idx,
max_new_tokens=max_tokens,
top_k=5
)
return bytes(
out.squeeze(0).to(torch.uint8).cpu()
).decode(errors="replace")
gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(10, 100, value=50, step=10, label="Max tokens")
],
outputs=gr.Textbox(label="BDH Output"),
title="BDH – Beyond Transformer (Live Demo)",
description="Inference-only demo of the BDH architecture"
).launch()