File size: 2,403 Bytes
14f1f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7f52f4
14f1f46
 
 
659202d
14f1f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7f52f4
14f1f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""GlubLM HuggingFace Space - Gradio chat interface."""
from __future__ import annotations

import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from glublm.config import ModelConfig
from glublm.inference import generate
from glublm.model import GlubLM
from glublm.tokenizer import GlubTokenizer

REPO_ID = "DenSec02/glublm-18m"

# Download weights and tokenizer from HF Hub
weights_path = hf_hub_download(REPO_ID, "model.safetensors")
tok_path = hf_hub_download(REPO_ID, "tokenizer.json")

tok = GlubTokenizer.from_file(tok_path)
cfg = ModelConfig(vocab_size=tok.vocab_size)
model = GlubLM(cfg)

state = load_file(weights_path)
model.load_state_dict(state, strict=False)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def chat(prompt: str, temperature: float, top_k: int, top_p: float, max_new_tokens: int) -> str:
    return generate(
        model=model,
        tokenizer=tok,
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        device=device,
    )


TAGLINE = "the language model that already forgot this sentence"

with gr.Blocks(title="GlubLM") as demo:
    gr.Markdown(f"# GlubLM\n> *{TAGLINE}*\n\nA 35M-parameter goldfish with a 10-second memory. [Try the Desk Pet](https://den-sec.github.io/glublm/desk-pet/).")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="say something to the goldfish", value="hello")
            temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="temperature")
            top_k = gr.Slider(1, 100, value=40, step=1, label="top-k")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top-p")
            max_new = gr.Slider(8, 64, value=32, step=1, label="max new tokens")
            btn = gr.Button("generate", variant="primary")
        with gr.Column():
            out = gr.Textbox(label="glub says", lines=6)
    btn.click(fn=chat, inputs=[prompt, temperature, top_k, top_p, max_new], outputs=out)

    gr.Markdown(
        "Learn more: [GitHub](https://github.com/Den-Sec/glublm) - "
        "[Model card](https://huggingface.co/DenSec02/glublm-18m) - "
        "[Dataset](https://huggingface.co/datasets/DenSec02/glublm-60k-ted)"
    )

if __name__ == "__main__":
    demo.launch()