BlinkDL commited on
Commit
b6bc5c9
·
verified ·
1 Parent(s): fa6b961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -31,6 +31,16 @@ pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
31
 
32
  args = model_v6.args
33
 
 
 
 
 
 
 
 
 
 
 
34
  penalty_decay = 0.996
35
 
36
  def generate_prompt(instruction, input=""):
 
31
 
32
  args = model_v6.args
33
 
34
+ _, _ = model.forward([0], None)
35
+ state = model.generate_zero_state()
36
+ static_input = torch.empty((model.n_embd), device="cuda", dtype=torch.half)
37
+ static_state_in = [torch.empty_like(x, device="cuda") for x in state]
38
+ static_state_out = [torch.empty_like(x, device="cuda") for x in state]
39
+ static_output = torch.empty((model.args.vocab_size), device="cuda", dtype=torch.half)
40
+ graph = torch.cuda.CUDAGraph()
41
+ with torch.cuda.graph(graph):
42
+ static_output, static_state_out = model.forward_one_alt(static_input, static_state_in)
43
+
44
  penalty_decay = 0.996
45
 
46
  def generate_prompt(instruction, input=""):