codys12 commited on
Commit
dd45a31
·
1 Parent(s): 3d5eb81

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -22
handler.py CHANGED
@@ -64,7 +64,7 @@ class EndpointHandler():
64
  # Postprocess
65
  prediction = output
66
  LOGGER.info(f"Generated text: {prediction}")
67
- return {"generated_text": prediction}
68
 
69
  def generate(
70
  tokenizer,
@@ -121,31 +121,39 @@ def generate(
121
  top_k=top_k,
122
  temperature=temperature,
123
  num_beams=1,
124
- repetition_penalty=repetition_penalty,
125
  )
126
  model.generate(**generate_kwargs)
127
 
128
  outputs = []
 
 
 
 
129
  for text in streamer:
 
 
130
  outputs.append(text)
131
  generated_text = "".join(outputs)
132
- conclusion_found = None
133
- context_numbers = []
134
-
135
- # Check for conclusion keys in the generated text
136
- if conclusions:
137
- for conclusion_key, _ in conclusions:
138
- if conclusion_key in generated_text:
139
- conclusion_found = conclusion_key
140
- break
141
-
142
- # Extract context numbers from the generated text
143
- context_pattern = r"\[\d+\]"
144
- context_matches = re.findall(context_pattern, generated_text)
145
- context_numbers = [int(match.strip("[]")) for match in context_matches]
146
-
147
- return {
148
- "generated_text": generated_text.strip(),
149
- "conclusion": conclusion_found,
150
- "context": context_numbers
151
- }
 
 
 
64
  # Postprocess
65
  prediction = output
66
  LOGGER.info(f"Generated text: {prediction}")
67
+ return prediction
68
 
69
  def generate(
70
  tokenizer,
 
121
  top_k=top_k,
122
  temperature=temperature,
123
  num_beams=1,
124
+ repetition_penalty=repetition_penalty
125
  )
126
  model.generate(**generate_kwargs)
127
 
128
  outputs = []
129
+ generated_text = ""
130
+ streaming = True
131
+ conclusion_found = None
132
+ context_numbers = []
133
  for text in streamer:
134
+ if not streaming:
135
+ break
136
  outputs.append(text)
137
  generated_text = "".join(outputs)
138
+ for end_sequence in end_sequences:
139
+ if end_sequence in generated_text:
140
+ streaming = False
141
+ generated_text = generated_text.replace(end_sequence, "")
142
+
143
+ # Check for conclusion keys in the generated text
144
+ if conclusions:
145
+ for conclusion_key, _ in conclusions:
146
+ if conclusion_key in generated_text:
147
+ conclusion_found = conclusion_key
148
+ break
149
+
150
+ # Extract context numbers from the generated text
151
+ context_pattern = r"\[\d+\]"
152
+ context_matches = re.findall(context_pattern, generated_text)
153
+ context_numbers = [int(match.strip("[]")) for match in context_matches]
154
+
155
+ return {
156
+ "generated_text": generated_text.strip(),
157
+ "conclusion": conclusion_found,
158
+ "context": context_numbers
159
+ }