# app.py import os import gradio as gr import torch import spaces from typing import Literal from rxlm.rxt.models import RxTBeta from rxlm.llm.models import DecoderOnlyTransformer from rxlm.training.tokenizer import load_tokenizer_from_hf_hub def disable_flash_attention(model: RxTBeta): model.decoder.model.use_flash_attention = False model.encoder.model.use_flash_attention = False for layer in model.decoder.model.layers: layer.attention.use_flash_attention = False layer.memory_cross_attention.use_flash_attention = False for layer in model.decoder.model.stateless_layers: layer.attention.use_flash_attention = False for layer in model.decoder.model.final_stateless_layers: layer.attention.use_flash_attention = False for layer in model.encoder.model.layers: layer.attention.use_flash_attention = False for layer in model.memory_attention.model.attention_layers: layer.use_flash_attention = False for layer in model.memory_attention.model.attention_layers: layer.use_flash_attention = False for layer in model.memory_attention.model.mean_attention_layers: layer.use_flash_attention = False HF_TOKEN = os.environ.get("HF_TOKEN") tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta', token=HF_TOKEN) model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Supervised', token=HF_TOKEN, tokenizer=tokenizer) disable_flash_attention(model) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device, dtype=torch.bfloat16) model.init_model(device=device) model.set_batch_mode(False) initial_stm = model.export_stm_state().cpu() seq_len = 8192 @spaces.GPU def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float, thinking_mode: Literal['auto', 'extended', 'fast']): tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device) model.load_stm_state(stm_state) response = "" is_thinking = False history += [ { 'role': 'user', 'content': message } ] with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): for token_id in model.interact(**tokenized_query, thinking_mode=thinking_mode, max_seq_len=seq_len, temperature=temperature, top_p=top_p): next_token = model.stringify_token(token_id, show_memory_update=False, skip_special_tokens=False) if next_token == '[T]': response += '[START THINKING]\n\n' is_thinking = True elif next_token == '[A]': if is_thinking: is_thinking = False response += '\n\n[END THINKING]\n\n' else: response += next_token yield history + [{ 'role': 'assistant', 'content': response }], stm_state return history + [{ 'role': 'assistant', 'content': response }], model.export_stm_state().cpu() with gr.Blocks(title="RxT-Beta 3B A190M (Supervised) Demo") as demo: gr.Markdown(""" # RxT-Beta 3B A190M Supervised Demo Demo for supervised version of first real-scale Reactive Transformer model with 3B total params and 190M active in decoder. Work in progress - fixing generation errors ## Limitations Supervised version of the model is still in intermediate stage and will be further improved in Direct Memory and Preference Optimization (DMPO) stage (demo will be constantly updated). """) with gr.Row(): chatbot = gr.Chatbot(height=600, label='RxT-Beta Chat') with gr.Row(): msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4) thinking_mode = gr.Radio(choices=['auto', 'extended', 'fast'], label="Thinking Mode", value='auto') send_btn = gr.Button("Send", scale=1) with gr.Row(): temp = gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") clear = gr.Button("Clear & Reset STM", scale=1) stm_state = gr.State(initial_stm.clone()) msg.submit(chat, [msg, chatbot, stm_state, temp, top_p, thinking_mode], [chatbot, stm_state], queue=True).then( lambda: gr.update(value=""), outputs=msg ) send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p, thinking_mode], [chatbot, stm_state], queue=True).then( lambda: gr.update(value=""), outputs=msg ) clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state]) if __name__ == "__main__": demo.queue() demo.launch()