Spaces:
Build error
Build error
Commit
·
b644119
1
Parent(s):
92fcbd4
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,11 @@ def to_md(text):
|
|
| 23 |
def get_model():
|
| 24 |
model = None
|
| 25 |
model = RWKV(
|
| 26 |
-
"https://huggingface.co/
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
return model
|
| 29 |
|
|
@@ -118,10 +122,11 @@ def chat(
|
|
| 118 |
torch.cuda.empty_cache()
|
| 119 |
model = get_model()
|
| 120 |
|
| 121 |
-
if len(history) == 0:
|
| 122 |
# no history, so lets reset chat state
|
| 123 |
model.resetState()
|
| 124 |
-
|
|
|
|
| 125 |
max_new_tokens = int(max_new_tokens)
|
| 126 |
temperature = float(temperature)
|
| 127 |
top_p = float(top_p)
|
|
@@ -143,8 +148,8 @@ def chat(
|
|
| 143 |
model.loadContext(newctx=prompt)
|
| 144 |
generated_text = ""
|
| 145 |
done = False
|
| 146 |
-
|
| 147 |
-
|
| 148 |
generated_text = generated_text.lstrip("\n ")
|
| 149 |
print(f"{generated_text}")
|
| 150 |
|
|
@@ -154,8 +159,8 @@ def chat(
|
|
| 154 |
generated_text = generated_text[:generated_text.find(stop_word)]
|
| 155 |
|
| 156 |
gc.collect()
|
| 157 |
-
history.append((prompt, generated_text))
|
| 158 |
-
return history,history
|
| 159 |
|
| 160 |
|
| 161 |
examples = [
|
|
|
|
| 23 |
def get_model():
|
| 24 |
model = None
|
| 25 |
model = RWKV(
|
| 26 |
+
"https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
|
| 27 |
+
"pytorch(cpu/gpu)",
|
| 28 |
+
runtimedtype=torch.float32,
|
| 29 |
+
useGPU=torch.cuda.is_available(),
|
| 30 |
+
dtype=torch.float32
|
| 31 |
)
|
| 32 |
return model
|
| 33 |
|
|
|
|
| 122 |
torch.cuda.empty_cache()
|
| 123 |
model = get_model()
|
| 124 |
|
| 125 |
+
if len(history[0]) == 0:
|
| 126 |
# no history, so lets reset chat state
|
| 127 |
model.resetState()
|
| 128 |
+
else:
|
| 129 |
+
model.setState(history[1])
|
| 130 |
max_new_tokens = int(max_new_tokens)
|
| 131 |
temperature = float(temperature)
|
| 132 |
top_p = float(top_p)
|
|
|
|
| 148 |
model.loadContext(newctx=prompt)
|
| 149 |
generated_text = ""
|
| 150 |
done = False
|
| 151 |
+
gen = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)
|
| 152 |
+
generated_text = gen["output"]
|
| 153 |
generated_text = generated_text.lstrip("\n ")
|
| 154 |
print(f"{generated_text}")
|
| 155 |
|
|
|
|
| 159 |
generated_text = generated_text[:generated_text.find(stop_word)]
|
| 160 |
|
| 161 |
gc.collect()
|
| 162 |
+
history[0].append((prompt, generated_text))
|
| 163 |
+
return history[0],[history[0],gen["state"]]
|
| 164 |
|
| 165 |
|
| 166 |
examples = [
|