AdamF92 commited on
Commit
813cef7
·
verified ·
1 Parent(s): 2d11a0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -9
app.py CHANGED
@@ -15,38 +15,57 @@ model.share_components()
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  model.to(device)
17
 
 
 
18
  seq_len = 1024
19
 
20
  @spaces.GPU
21
- def chat(message: str, history: list):
22
  tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)
 
 
23
 
24
  response = ""
25
- for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=0.5):
26
  response += model.stringify_token(token_id, show_memory_update=True)
27
  yield history + [[message, response]]
28
 
29
- return history + [[message, response]]
30
 
31
  with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
32
  gr.Markdown("""
 
33
  Experimental Reactive Transformer model fine-tuned for AI/Data Science knowledge based chats
34
  and interactive Reactive AI documentation.
35
 
 
36
  Supervised version of the model is still in intermediate stage and will be further improved
37
  in Reinforcement Learning stages (demo will be constantly updated), so model could generate
38
- inaccurate answers and memory is weak. However, it should still demonstate the architecture
39
  advantages, especially infinite context and no delays.
40
  """)
41
 
42
- chatbot = gr.Chatbot(height=600)
43
- msg = gr.Textbox(placeholder="Ask RxT...", label="Query")
44
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- msg.submit(chat, [msg, chatbot], chatbot, queue=True).then(
47
  lambda: gr.update(value=""), outputs=msg
48
  )
49
- clear.click(lambda: [], None, chatbot)
 
50
 
51
 
52
  if __name__ == "__main__":
 
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  model.to(device)
17
 
18
+ initial_stm = model.export_stm_state().cpu()
19
+
20
  seq_len = 1024
21
 
22
  @spaces.GPU
23
+ def chat(message: str, history: list, stm_state: torch.Tensor, temperature: float, top_p: float):
24
  tokenized_query = model.tokenize_query(message, max_seq_len=seq_len, device=device)
25
+
26
+ model.load_stm_state(stm_state)
27
 
28
  response = ""
29
+ for token_id in model.interact(**tokenized_query, max_seq_len=seq_len, temperature=temperature, top_p=top_p):
30
  response += model.stringify_token(token_id, show_memory_update=True)
31
  yield history + [[message, response]]
32
 
33
+ return history + [[message, response]], model.export_stm_state().cpu()
34
 
35
  with gr.Blocks(title="RxT-Beta-Micro-AI 270M (Supervised) Demo") as demo:
36
  gr.Markdown("""
37
+ # RxT-Beta-Micro-AI 270M (Supervised) Demo
38
  Experimental Reactive Transformer model fine-tuned for AI/Data Science knowledge based chats
39
  and interactive Reactive AI documentation.
40
 
41
+ ## Limitations
42
  Supervised version of the model is still in intermediate stage and will be further improved
43
  in Reinforcement Learning stages (demo will be constantly updated), so model could generate
44
+ inaccurate answers and memory retention is weak. However, it should still demonstate the architecture
45
  advantages, especially infinite context and no delays.
46
  """)
47
 
48
+ chatbot = gr.Chatbot(height=600, type='tuples')
49
+ with gr.Row():
50
+ msg = gr.Textbox(placeholder="Ask RxT...", label="Query", scale=4)
51
+ send_btn = gr.Button("Send", scale=1)
52
+ clear = gr.Button("Clear & Reset STM", scale=1)
53
+
54
+ with gr.Row():
55
+ temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
56
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
57
+
58
+ stm_state = gr.State(initial_stm.clone())
59
+
60
+ msg.submit(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
61
+ lambda: gr.update(value=""), outputs=msg
62
+ )
63
 
64
+ send_btn.click(chat, [msg, chatbot, stm_state, temp, top_p], [chatbot, stm_state], queue=True).then(
65
  lambda: gr.update(value=""), outputs=msg
66
  )
67
+
68
+ clear.click(lambda: ([], initial_stm.clone()), None, [chatbot, stm_state])
69
 
70
 
71
  if __name__ == "__main__":