Update handler.py
Browse files- handler.py +10 -11
handler.py
CHANGED
|
@@ -16,22 +16,21 @@ class EndpointHandler():
|
|
| 16 |
messages = request_inputs["messages"]
|
| 17 |
char_name = request_inputs["char_name"]
|
| 18 |
user_name = request_inputs["user_name"]
|
|
|
|
| 19 |
template = self.default_template
|
| 20 |
-
user_input =
|
| 21 |
"{name}: {message}".format(
|
| 22 |
name = char_name if (id["role"] == "AI") else user_name,
|
| 23 |
message = id["message"].strip()
|
| 24 |
) for id in messages
|
| 25 |
-
]
|
| 26 |
-
|
| 27 |
-
char_name = char_name,
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
return_tensors = "pt"
|
| 34 |
-
).to("cuda")
|
| 35 |
encoded_output = self.model.generate(
|
| 36 |
input_ids["input_ids"],
|
| 37 |
max_new_tokens = 50,
|
|
|
|
| 16 |
messages = request_inputs["messages"]
|
| 17 |
char_name = request_inputs["char_name"]
|
| 18 |
user_name = request_inputs["user_name"]
|
| 19 |
+
chats_curled = request_inputs["chats_curled"]
|
| 20 |
template = self.default_template
|
| 21 |
+
user_input = [
|
| 22 |
"{name}: {message}".format(
|
| 23 |
name = char_name if (id["role"] == "AI") else user_name,
|
| 24 |
message = id["message"].strip()
|
| 25 |
) for id in messages
|
| 26 |
+
]
|
| 27 |
+
while True:
|
| 28 |
+
prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
|
| 29 |
+
input_ids = self.tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
| 30 |
+
if input_ids.input_ids.size(1) > 2048:
|
| 31 |
+
chats_curled += 1
|
| 32 |
+
user_input = user_input[chats_curled*2:]
|
| 33 |
+
else: break
|
|
|
|
|
|
|
| 34 |
encoded_output = self.model.generate(
|
| 35 |
input_ids["input_ids"],
|
| 36 |
max_new_tokens = 50,
|