File size: 2,933 Bytes
5729414
aa6c0d9
fd6eb7a
5729414
 
 
 
fd6eb7a
58e3ad0
 
 
306fd6d
5729414
fd6eb7a
5729414
 
fd6eb7a
813cef7
 
5729414
fd6eb7a
5729414
813cef7
5729414
813cef7
 
5729414
fd6eb7a
f49fde3
 
 
 
 
5729414
813cef7
fd6eb7a
5729414
 
813cef7
5729414
 
fd6eb7a
813cef7
5729414
 
813cef7
79c163c
5729414
fd6eb7a
813cef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6eb7a
813cef7
5729414
 
813cef7
 
fd6eb7a
 
 
5729414
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# app.py
import os
import gradio as gr
import torch
import spaces
from rxlm.rxt.models import RxTBeta
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub

HF_TOKEN = os.environ.get("HF_TOKEN")

tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro-Supervised-AI', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised-AI', token=HF_TOKEN, tokenizer=tokenizer)
model.share_components()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

initial_stm = model.export_stm_state().cpu()

seq_len = 1024

@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float):
    tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)

    model.load_stm_state(stm_state)
    
    response = ""

    with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
        for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
            response += model.stringify_token(token_id, show_memory_update=True)
            yield history + [[message, response]], stm_state
    
    return history + [[message, response]], model.export_stm_state().cpu()

with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
    gr.Markdown("""
    # RxT-Beta-Micro-AI 270M (Supervised) Demo
    Experimental Reactive Transformer model fine-tuned for AI/Data Science knowledge based chats
    and interactive Reactive AI documentation. 

    ## Limitations
    Supervised version of the model is still in intermediate stage and will be further improved
    in Reinforcement Learning stages (demo will be constantly updated), so model could generate
    inaccurate answers and memory retention is weak. However, it should still demonstate the architecture
    advantages, especially infinite context and no delays (small delays are caused by Spaces ZeroGPU allocation).
    """)

    chatbot = gr.Chatbot(height=600, type='tuples')
    with gr.Row():
        msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4)
        send_btn = gr.Button("Send", scale=1)
        clear = gr.Button("Clear & Reset STM", scale=1)

    with gr.Row():
        temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
    
    stm_state = gr.State(initial_stm.clone())
    
    msg.submit(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
        lambda: gr.update(value=""), outputs=msg
    )

    send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
        lambda: gr.update(value=""), outputs=msg
    )
    
    clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state])


if __name__ == "__main__":
    demo.queue()
    demo.launch()