Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,6 +31,16 @@ pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
|
|
| 31 |
|
| 32 |
args = model_v6.args
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
penalty_decay = 0.996
|
| 35 |
|
| 36 |
def generate_prompt(instruction, input=""):
|
|
|
|
| 31 |
|
| 32 |
args = model_v6.args
|
| 33 |
|
| 34 |
+
_, _ = model.forward([0], None)
|
| 35 |
+
state = model.generate_zero_state()
|
| 36 |
+
static_input = torch.empty((model.n_embd), device="cuda", dtype=torch.half)
|
| 37 |
+
static_state_in = [torch.empty_like(x, device="cuda") for x in state]
|
| 38 |
+
static_state_out = [torch.empty_like(x, device="cuda") for x in state]
|
| 39 |
+
static_output = torch.empty((model.args.vocab_size), device="cuda", dtype=torch.half)
|
| 40 |
+
graph = torch.cuda.CUDAGraph()
|
| 41 |
+
with torch.cuda.graph(graph):
|
| 42 |
+
static_output, static_state_out = model.forward_one_alt(static_input, static_state_in)
|
| 43 |
+
|
| 44 |
penalty_decay = 0.996
|
| 45 |
|
| 46 |
def generate_prompt(instruction, input=""):
|