AdamF92's picture
Update app.py
7441c66 verified
# app.py
import os
import gradio as gr
import torch
import spaces
from typing import Literal
from rxlm.rxt.models import RxTBeta
from rxlm.llm.models import DecoderOnlyTransformer
from rxlm.training.tokenizer import load_tokenizer_from_hf_hub
def disable_flash_attention(model: RxTBeta):
model.decoder.model.use_flash_attention = False
model.encoder.model.use_flash_attention = False
for layer in model.decoder.model.layers:
layer.attention.use_flash_attention = False
layer.memory_cross_attention.use_flash_attention = False
for layer in model.decoder.model.stateless_layers:
layer.attention.use_flash_attention = False
for layer in model.decoder.model.final_stateless_layers:
layer.attention.use_flash_attention = False
for layer in model.encoder.model.layers:
layer.attention.use_flash_attention = False
for layer in model.memory_attention.model.attention_layers:
layer.use_flash_attention = False
for layer in model.memory_attention.model.attention_layers:
layer.use_flash_attention = False
for layer in model.memory_attention.model.mean_attention_layers:
layer.use_flash_attention = False
HF_TOKEN = os.environ.get("HF_TOKEN")
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta', token=HF_TOKEN)
model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Supervised', token=HF_TOKEN, tokenizer=tokenizer)
disable_flash_attention(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device, dtype=torch.bfloat16)
model.init_model(device=device)
model.set_batch_mode(False)
initial_stm = model.export_stm_state().cpu()
seq_len = 8192
@spaces.GPU
def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float, thinking_mode: Literal['auto', 'extended', 'fast']):
tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)
model.load_stm_state(stm_state)
response = ""
is_thinking = False
history += [
{
'role': 'user',
'content': message
}
]
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
for token_id in model.interact(**tokenized_query, thinking_mode=thinking_mode, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
next_token = model.stringify_token(token_id, show_memory_update=False, skip_special_tokens=False)
if next_token == '[T]':
response += '[START THINKING]\n\n'
is_thinking = True
elif next_token == '[A]':
if is_thinking:
is_thinking = False
response += '\n\n[END THINKING]\n\n'
else:
response += next_token
yield history + [{ 'role': 'assistant', 'content': response }], stm_state
return history + [{ 'role': 'assistant', 'content': response }], model.export_stm_state().cpu()
with gr.Blocks(title="RxT-Beta 3B A190M (Supervised) Demo") as demo:
gr.Markdown("""
# RxT-Beta 3B A190M Supervised Demo
Demo for supervised version of first real-scale Reactive Transformer model with 3B total params and 190M active in decoder.
Work in progress - fixing generation errors
## Limitations
Supervised version of the model is still in intermediate stage and will be further improved
in Direct Memory and Preference Optimization (DMPO) stage (demo will be constantly updated).
""")
with gr.Row():
chatbot = gr.Chatbot(height=600, label='RxT-Beta Chat')
with gr.Row():
msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4)
thinking_mode = gr.Radio(choices=['auto', 'extended', 'fast'], label="Thinking Mode", value='auto')
send_btn = gr.Button("Send", scale=1)
with gr.Row():
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
clear = gr.Button("Clear & Reset STM", scale=1)
stm_state = gr.State(initial_stm.clone())
msg.submit(chat, [msg, chatbot, stm_state, temp, top_p, thinking_mode], [chatbot, stm_state], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p, thinking_mode], [chatbot, stm_state], queue=True).then(
lambda: gr.update(value=""), outputs=msg
)
clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state])
if __name__ == "__main__":
demo.queue()
demo.launch()