Upload handler.py
Browse files- handler.py +30 -22
handler.py
CHANGED
|
@@ -64,7 +64,7 @@ class EndpointHandler():
|
|
| 64 |
# Postprocess
|
| 65 |
prediction = output
|
| 66 |
LOGGER.info(f"Generated text: {prediction}")
|
| 67 |
-
return
|
| 68 |
|
| 69 |
def generate(
|
| 70 |
tokenizer,
|
|
@@ -121,31 +121,39 @@ def generate(
|
|
| 121 |
top_k=top_k,
|
| 122 |
temperature=temperature,
|
| 123 |
num_beams=1,
|
| 124 |
-
repetition_penalty=repetition_penalty
|
| 125 |
)
|
| 126 |
model.generate(**generate_kwargs)
|
| 127 |
|
| 128 |
outputs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
for text in streamer:
|
|
|
|
|
|
|
| 130 |
outputs.append(text)
|
| 131 |
generated_text = "".join(outputs)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
| 64 |
# Postprocess
|
| 65 |
prediction = output
|
| 66 |
LOGGER.info(f"Generated text: {prediction}")
|
| 67 |
+
return prediction
|
| 68 |
|
| 69 |
def generate(
|
| 70 |
tokenizer,
|
|
|
|
| 121 |
top_k=top_k,
|
| 122 |
temperature=temperature,
|
| 123 |
num_beams=1,
|
| 124 |
+
repetition_penalty=repetition_penalty
|
| 125 |
)
|
| 126 |
model.generate(**generate_kwargs)
|
| 127 |
|
| 128 |
outputs = []
|
| 129 |
+
generated_text = ""
|
| 130 |
+
streaming = True
|
| 131 |
+
conclusion_found = None
|
| 132 |
+
context_numbers = []
|
| 133 |
for text in streamer:
|
| 134 |
+
if not streaming:
|
| 135 |
+
break
|
| 136 |
outputs.append(text)
|
| 137 |
generated_text = "".join(outputs)
|
| 138 |
+
for end_sequence in end_sequences:
|
| 139 |
+
if end_sequence in generated_text:
|
| 140 |
+
streaming = False
|
| 141 |
+
generated_text = generated_text.replace(end_sequence, "")
|
| 142 |
+
|
| 143 |
+
# Check for conclusion keys in the generated text
|
| 144 |
+
if conclusions:
|
| 145 |
+
for conclusion_key, _ in conclusions:
|
| 146 |
+
if conclusion_key in generated_text:
|
| 147 |
+
conclusion_found = conclusion_key
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
# Extract context numbers from the generated text
|
| 151 |
+
context_pattern = r"\[\d+\]"
|
| 152 |
+
context_matches = re.findall(context_pattern, generated_text)
|
| 153 |
+
context_numbers = [int(match.strip("[]")) for match in context_matches]
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
"generated_text": generated_text.strip(),
|
| 157 |
+
"conclusion": conclusion_found,
|
| 158 |
+
"context": context_numbers
|
| 159 |
+
}
|