Spaces:
Runtime error
Runtime error
streaming token support
Browse filespass mapping
make sure to call different models
fix fill value on iter
app.py
CHANGED
|
@@ -7,8 +7,12 @@ import re
|
|
| 7 |
import traceback
|
| 8 |
import uuid
|
| 9 |
import datetime
|
|
|
|
|
|
|
|
|
|
| 10 |
from collections import defaultdict
|
| 11 |
from time import sleep
|
|
|
|
| 12 |
|
| 13 |
import boto3
|
| 14 |
import gradio as gr
|
|
@@ -56,7 +60,7 @@ class Pipeline:
|
|
| 56 |
"stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
|
| 57 |
}
|
| 58 |
|
| 59 |
-
def __call__(self, prompt):
|
| 60 |
input = self.generation_config.copy()
|
| 61 |
input["prompt"] = prompt
|
| 62 |
|
|
@@ -71,12 +75,26 @@ class Pipeline:
|
|
| 71 |
|
| 72 |
if response.status_code == 200:
|
| 73 |
data = response.json()
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
def poll_for_status(self, task_id):
|
| 82 |
url = f"https://api.runpod.ai/v2/{self.endpoint_id}/status/{task_id}"
|
|
@@ -134,6 +152,19 @@ def user(message, nudge_msg, history1, history2):
|
|
| 134 |
return "", nudge_msg, history1, history2
|
| 135 |
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def chat(history1, history2, system_msg):
|
| 138 |
history1 = history1 or []
|
| 139 |
history2 = history2 or []
|
|
@@ -151,34 +182,17 @@ def chat(history1, history2, system_msg):
|
|
| 151 |
messages1 = messages1.rstrip()
|
| 152 |
messages2 = messages2.rstrip()
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
# If desired, you can check for exceptions here...
|
| 163 |
-
if future.exception() is not None:
|
| 164 |
-
print('Exception: {}'.format(future.exception()))
|
| 165 |
-
traceback.print_exception(type(future.exception()), future.exception(), future.exception().__traceback__)
|
| 166 |
-
|
| 167 |
-
tokens_model1 = re.findall(r'\s*\S+\s*', futures[0].result()[0]['generated_text'])
|
| 168 |
-
tokens_model2 = re.findall(r'\s*\S+\s*', futures[1].result()[0]['generated_text'])
|
| 169 |
-
len_tokens_model1 = len(tokens_model1)
|
| 170 |
-
len_tokens_model2 = len(tokens_model2)
|
| 171 |
-
max_tokens = max(len_tokens_model1, len_tokens_model2)
|
| 172 |
-
for i in range(0, max_tokens):
|
| 173 |
-
if i < len_tokens_model1:
|
| 174 |
-
answer1 = tokens_model1[i]
|
| 175 |
-
history1[-1][1] += answer1
|
| 176 |
-
if i < len_tokens_model2:
|
| 177 |
-
answer2 = tokens_model2[i]
|
| 178 |
-
history2[-1][1] += answer2
|
| 179 |
# stream the response
|
| 180 |
yield history1, history2, "", gr.update(value=random_battle[0]), gr.update(value=random_battle[1]), {"models": [model1.name, model2.name]}
|
| 181 |
-
sleep(0.
|
| 182 |
|
| 183 |
|
| 184 |
def chosen_one(label, choice1_history, choice2_history, system_msg, nudge_msg, rlhf_persona, state):
|
|
|
|
| 7 |
import traceback
|
| 8 |
import uuid
|
| 9 |
import datetime
|
| 10 |
+
from collections import deque
|
| 11 |
+
import itertools
|
| 12 |
+
|
| 13 |
from collections import defaultdict
|
| 14 |
from time import sleep
|
| 15 |
+
from typing import Generator, Tuple
|
| 16 |
|
| 17 |
import boto3
|
| 18 |
import gradio as gr
|
|
|
|
| 60 |
"stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
|
| 61 |
}
|
| 62 |
|
| 63 |
+
def __call__(self, prompt) -> Generator[str, None, None]:
|
| 64 |
input = self.generation_config.copy()
|
| 65 |
input["prompt"] = prompt
|
| 66 |
|
|
|
|
| 75 |
|
| 76 |
if response.status_code == 200:
|
| 77 |
data = response.json()
|
| 78 |
+
task_id = data.get('id')
|
| 79 |
+
return self.stream_output(task_id)
|
| 80 |
+
|
| 81 |
+
def stream_output(self,task_id) -> Generator[str, None, None]:
|
| 82 |
+
url = f"https://api.runpod.ai/v2/{self.endpoint_id}/stream/{task_id}"
|
| 83 |
+
headers = {
|
| 84 |
+
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
while True:
|
| 88 |
+
response = requests.get(url, headers=headers)
|
| 89 |
+
if response.status_code == 200:
|
| 90 |
+
data = response.json()
|
| 91 |
+
yield [{"generated_text": "".join([s["output"] for s in data["stream"]])}]
|
| 92 |
+
if data.get('status') == 'COMPLETED':
|
| 93 |
+
return
|
| 94 |
+
elif response.status_code >= 400:
|
| 95 |
+
logging.error(response.json())
|
| 96 |
+
# Sleep for 0.5 seconds between each request
|
| 97 |
+
sleep(0.5)
|
| 98 |
|
| 99 |
def poll_for_status(self, task_id):
|
| 100 |
url = f"https://api.runpod.ai/v2/{self.endpoint_id}/status/{task_id}"
|
|
|
|
| 152 |
return "", nudge_msg, history1, history2
|
| 153 |
|
| 154 |
|
| 155 |
+
def token_generator(generator1, generator2, mapping_fn=None, fillvalue=None):
|
| 156 |
+
if not fillvalue:
|
| 157 |
+
fillvalue = ''
|
| 158 |
+
if not mapping_fn:
|
| 159 |
+
mapping_fn = lambda x: x
|
| 160 |
+
for output1, output2 in itertools.zip_longest(generator1, generator2, fillvalue=fillvalue):
|
| 161 |
+
tokens1 = re.findall(r'\s*\S+\s*', mapping_fn(output1))
|
| 162 |
+
tokens2 = re.findall(r'\s*\S+\s*', mapping_fn(output2))
|
| 163 |
+
|
| 164 |
+
for token1, token2 in itertools.zip_longest(tokens1, tokens2, fillvalue=''):
|
| 165 |
+
yield token1, token2
|
| 166 |
+
|
| 167 |
+
|
| 168 |
def chat(history1, history2, system_msg):
|
| 169 |
history1 = history1 or []
|
| 170 |
history2 = history2 or []
|
|
|
|
| 182 |
messages1 = messages1.rstrip()
|
| 183 |
messages2 = messages2.rstrip()
|
| 184 |
|
| 185 |
+
model1_res = model1(messages1) # type: Generator[str, None, None]
|
| 186 |
+
model2_res = model2(messages2) # type: Generator[str, None, None]
|
| 187 |
+
res = token_generator(model1_res, model2_res, lambda x: x[0]['generated_text'], fillvalue=[{'generated_text': ''}]) # type: Generator[Tuple[str, str], None, None]
|
| 188 |
+
for t1, t2 in res:
|
| 189 |
+
if t1 is not None:
|
| 190 |
+
history1[-1][1] += t1
|
| 191 |
+
if t2 is not None:
|
| 192 |
+
history2[-1][1] += t2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# stream the response
|
| 194 |
yield history1, history2, "", gr.update(value=random_battle[0]), gr.update(value=random_battle[1]), {"models": [model1.name, model2.name]}
|
| 195 |
+
sleep(0.2)
|
| 196 |
|
| 197 |
|
| 198 |
def chosen_one(label, choice1_history, choice2_history, system_msg, nudge_msg, rlhf_persona, state):
|