RxT-Beta-Demo / app.py
AdamF92's picture
Update app.py
813cef7 verified
raw
history blame
2.78 kB
# app.py
import os
import gradio as gr
import torch
import spaces
from rxlm.rxt.models import RxTBeta
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub
HF_TOKEN = os.environ.get("HF_TOKEN")
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta-Micro-Supervised-AI', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Micro-Supervised-AI', token=HF_TOKEN, tokenizer=tokenizer)
model.share_components()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
initial_stm = model.export_stm_state().cpu()
seq_len = 1024
@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float):
tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)
model.load_stm_state(stm_state)
response = ""
for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
response += model.stringify_token(token_id, show_memory_update=True)
yield history + [[message, response]]
return history + [[message, response]], model.export_stm_state().cpu()
with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
gr.Markdown("""
# RxT-Beta-Micro-AI 270M (Supervised) Demo
Experimental Reactive Transformer model fine-tuned for AI/Data Science knowledge based chats
and interactive Reactive AI documentation.
## Limitations
Supervised version of the model is still in intermediate stage and will be further improved
in Reinforcement Learning stages (demo will be constantly updated), so model could generate
inaccurate answers and memory retention is weak. However, it should still demonstate the architecture
advantages, especially infinite context and no delays.
""")
chatbot = gr.Chatbot(height=600, type='tuples')
with gr.Row():
msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4)
send_btn = gr.Button("Send", scale=1)
clear = gr.Button("Clear & Reset STM", scale=1)
with gr.Row():
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
stm_state = gr.State(initial_stm.clone())
msg.submit(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state])
if __name__ == "__main__":
demo.queue()
demo.launch()