feat: Include clean final answer in OpenAI stream payload
Browse files- Updated `ZIPRCSampler.openai()` to calculate the best-performing trajectory (Top-1) at the end of generation.
- The final stream chunk (action="finished") now includes a `final_text` field within the `zip_rc` payload.
- This allows clients to capture and display the coherent "winning" answer, separating it from the noisy, branching thought process visible during the stream.
ziprc.py
CHANGED
|
@@ -342,13 +342,27 @@ class ZIPRCSampler:
|
|
| 342 |
}
|
| 343 |
yield OpenAIObject(chunk_dict)
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
yield OpenAIObject({
|
| 346 |
"id": chat_id,
|
| 347 |
"object": "chat.completion.chunk",
|
| 348 |
"created": created_ts,
|
| 349 |
"model": model_name,
|
| 350 |
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
| 351 |
-
"zip_rc": {
|
|
|
|
|
|
|
|
|
|
| 352 |
})
|
| 353 |
|
| 354 |
def generate_stream(self, prompt, max_new_tokens=512, initial_samples=2):
|
|
|
|
| 342 |
}
|
| 343 |
yield OpenAIObject(chunk_dict)
|
| 344 |
|
| 345 |
+
# Calculate Final Best Answer (clean from swaps/backtracks)
|
| 346 |
+
# Include running candidates in case max_tokens was hit before EOS
|
| 347 |
+
all_trajs = finished_trajectories + candidates
|
| 348 |
+
best_traj = self.select_best_trajectory(all_trajs)
|
| 349 |
+
final_answer = ""
|
| 350 |
+
if best_traj:
|
| 351 |
+
# Decode only the generated response (exclude prompt)
|
| 352 |
+
prompt_len = input_ids.shape[1]
|
| 353 |
+
final_ids = best_traj['ids'][0][prompt_len:]
|
| 354 |
+
final_answer = self.model.tokenizer.decode(final_ids, skip_special_tokens=True)
|
| 355 |
+
|
| 356 |
yield OpenAIObject({
|
| 357 |
"id": chat_id,
|
| 358 |
"object": "chat.completion.chunk",
|
| 359 |
"created": created_ts,
|
| 360 |
"model": model_name,
|
| 361 |
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
| 362 |
+
"zip_rc": {
|
| 363 |
+
"action": "finished",
|
| 364 |
+
"final_text": final_answer
|
| 365 |
+
}
|
| 366 |
})
|
| 367 |
|
| 368 |
def generate_stream(self, prompt, max_new_tokens=512, initial_samples=2):
|