File size: 2,778 Bytes
5729414
aa6c0d9
fd6eb7a
5729414
 
 
 
fd6eb7a
58e3ad0
 
 
306fd6d
5729414
fd6eb7a
5729414
 
fd6eb7a
813cef7
 
5729414
fd6eb7a
5729414
813cef7
5729414
813cef7
 
5729414
fd6eb7a
813cef7
5729414
 
 
813cef7
fd6eb7a
5729414
 
813cef7
5729414
 
fd6eb7a
813cef7
5729414
 
813cef7
5729414
 
fd6eb7a
813cef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6eb7a
813cef7
5729414
 
813cef7
 
fd6eb7a
 
 
5729414
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# app.py
import os
import gradio as gr
import torch
import spaces
from rxlm.rxt.models import RxTBeta
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub

HF_TOKEN = os.environ.get("HF_TOKEN")

tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro-Supervised-AI', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised-AI', token=HF_TOKEN, tokenizer=tokenizer)
model.share_components()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

initial_stm = model.export_stm_state().cpu()

seq_len = 1024

@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float):
    tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)

    model.load_stm_state(stm_state)
    
    response = ""
    for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
        response += model.stringify_token(token_id, show_memory_update=True)
        yield history + [[message, response]]
    
    return history + [[message, response]], model.export_stm_state().cpu()

with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
    gr.Markdown("""
    # RxT-Beta-Micro-AI 270M (Supervised) Demo
    Experimental Reactive Transformer model fine-tuned for AI/Data Science knowledge based chats
    and interactive Reactive AI documentation. 

    ## Limitations
    Supervised version of the model is still in intermediate stage and will be further improved
    in Reinforcement Learning stages (demo will be constantly updated), so model could generate
    inaccurate answers and memory retention is weak. However, it should still demonstate the architecture
    advantages, especially infinite context and no delays.
    """)

    chatbot = gr.Chatbot(height=600, type='tuples')
    with gr.Row():
        msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4)
        send_btn = gr.Button("Send", scale=1)
        clear = gr.Button("Clear & Reset STM", scale=1)

    with gr.Row():
        temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
    
    stm_state = gr.State(initial_stm.clone())
    
    msg.submit(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
        lambda: gr.update(value=""), outputs=msg
    )

    send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
        lambda: gr.update(value=""), outputs=msg
    )
    
    clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state])


if __name__ == "__main__":
    demo.queue()
    demo.launch()