Spaces:
Running on Zero
Running on Zero
| # app.py | |
| import os | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from typing import Literal | |
| from rxlm.rxt.models import RxTBeta | |
| from rxlm.llm.models import DecoderOnlyTransformer | |
| from rxlm.training.tokenizer import load_tokenizer_from_hf_hub | |
| def disable_flash_attention(model: RxTBeta): | |
| model.decoder.model.use_flash_attention = False | |
| model.encoder.model.use_flash_attention = False | |
| for layer in model.decoder.model.layers: | |
| layer.attention.use_flash_attention = False | |
| layer.memory_cross_attention.use_flash_attention = False | |
| for layer in model.decoder.model.stateless_layers: | |
| layer.attention.use_flash_attention = False | |
| for layer in model.decoder.model.final_stateless_layers: | |
| layer.attention.use_flash_attention = False | |
| for layer in model.encoder.model.layers: | |
| layer.attention.use_flash_attention = False | |
| for layer in model.memory_attention.model.attention_layers: | |
| layer.use_flash_attention = False | |
| for layer in model.memory_attention.model.attention_layers: | |
| layer.use_flash_attention = False | |
| for layer in model.memory_attention.model.mean_attention_layers: | |
| layer.use_flash_attention = False | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Beta', token=HF_TOKEN) | |
| model = RxTBeta.from_pretrained('ReactiveAI/RxT-Beta-Supervised', token=HF_TOKEN, tokenizer=tokenizer) | |
| disable_flash_attention(model) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device, dtype=torch.bfloat16) | |
| model.init_model(device=device) | |
| model.set_batch_mode(False) | |
| initial_stm = model.export_stm_state().cpu() | |
| seq_len = 8192 | |
| def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float, thinking_mode: Literal['auto', 'extended', 'fast']): | |
| tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device) | |
| model.load_stm_state(stm_state) | |
| response = "" | |
| is_thinking = False | |
| history += [ | |
| { | |
| 'role': 'user', | |
| 'content': message | |
| } | |
| ] | |
| with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): | |
| for token_id in model.interact(**tokenized_query, thinking_mode=thinking_mode, max_seq_len=seq_len, temperature=temperature, top_p=top_p): | |
| next_token = model.stringify_token(token_id, show_memory_update=False, skip_special_tokens=False) | |
| if next_token == '[T]': | |
| response += '[START THINKING]\n\n' | |
| is_thinking = True | |
| elif next_token == '[A]': | |
| if is_thinking: | |
| is_thinking = False | |
| response += '\n\n[END THINKING]\n\n' | |
| else: | |
| response += next_token | |
| yield history + [{ 'role': 'assistant', 'content': response }], stm_state | |
| return history + [{ 'role': 'assistant', 'content': response }], model.export_stm_state().cpu() | |
| with gr.Blocks(title="RxT-Beta 3B A190M (Supervised) Demo") as demo: | |
| gr.Markdown(""" | |
| # RxT-Beta 3B A190M Supervised Demo | |
| Demo for supervised version of first real-scale Reactive Transformer model with 3B total params and 190M active in decoder. | |
| Work in progress - fixing generation errors | |
| ## Limitations | |
| Supervised version of the model is still in intermediate stage and will be further improved | |
| in Direct Memory and Preference Optimization (DMPO) stage (demo will be constantly updated). | |
| """) | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(height=600, label='RxT-Beta Chat') | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4) | |
| thinking_mode = gr.Radio(choices=['auto', 'extended', 'fast'], label="Thinking Mode", value='auto') | |
| send_btn = gr.Button("Send", scale=1) | |
| with gr.Row(): | |
| temp = gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| clear = gr.Button("Clear & Reset STM", scale=1) | |
| stm_state = gr.State(initial_stm.clone()) | |
| msg.submit(chat, [msg, chatbot, stm_state, temp, top_p, thinking_mode], [chatbot, stm_state], queue=True).then( | |
| lambda: gr.update(value=""), outputs=msg | |
| ) | |
| send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p, thinking_mode], [chatbot, stm_state], queue=True).then( | |
| lambda: gr.update(value=""), outputs=msg | |
| ) | |
| clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |