| import re | |
| import math | |
| import copy | |
| from tqdm import tqdm | |
| from typing import Optional, Tuple, List, Dict, Union, Any | |
| from ..utils import Result, batch_iterator | |
| from .base import RerankStrategy | |
| TEMPLATE = """ | |
| Instruction: | |
| Given the following request, write {NUM} diverse and non-repeating sub-questions that can help guide the creation of a focused and comprehensive report. The sub-questions should help break down the topic into key areas that need to be investigated or explained. Each sub-question should be short (ideally under 20 words) and should focus on a single aspect or dimension of the report. | |
| Request: | |
| {query} | |
| Output format: | |
| - List each sub-question on a new line. Do not number the sub-questions. | |
| - Do not add any comment or explanation. | |
| - Output without adding additional questions after the specified {NUM}. Begin with "<START OF LIST>" and, when you are finished, output "<END OF LIST>". Never ever add anything else after "<END OF LIST>", my life depends on it!!! | |
| Now, generate the {NUM} sub-questions: | |
| """ | |
| class Dev(RerankStrategy): | |
| def run( | |
| self, | |
| init_results: List[Result], | |
| rank_start: int = 0, | |
| rank_end: int = None, | |
| batch_size: Optional[int] = 32, | |
| num_runs: int = 2, | |
| **kwargs | |
| ) -> List[Result]: | |
| num_subquestions = num_runs | |
| # Copy and reset | |
| results = copy.deepcopy(init_results) | |
| for r in results: | |
| r.reset() | |
| # 1. Generate n sub-questions for all queries | |
| queries = [r.query for r in init_results] | |
| prompts = [self.preprocess(q, num_subquestions) for q in queries] | |
| with self._llm.default(): | |
| outputs = self._llm.generate(prompts) | |
| subquestions = [self.postprocess(o, num_subquestions) for o in outputs] | |
| # 2. Answerability judgment for each question and add the score | |
| for i in range(num_subquestions): | |
| for j, r in enumerate(results): | |
| r.query = r.query + " " + subquestions[j][i] | |
| # r.query = subquestions[j][i] | |
| subresults = self.run_pass( | |
| results, | |
| rank_start=rank_start, | |
| rank_end=rank_end, | |
| batch_size=batch_size, | |
| **kwargs | |
| ) | |
| for j, r in enumerate(results): | |
| r.append_subresult(subquestions[j][i], subresults[j]) | |
| return results | |
| def run_pass( | |
| self, | |
| init_results: List[Result], | |
| rank_start: int = 0, | |
| rank_end: int = None, | |
| batch_size: Optional[int] = 32, | |
| **kwargs | |
| ): | |
| results = [copy.deepcopy(result) for result in init_results] | |
| all_scores = {} | |
| for index, result in enumerate(results): | |
| ## Placeholder for scores | |
| result.hits = [hit for hit in result.hits[:rank_end]] | |
| all_scores[result.qid] = [0 for _ in result.hits] | |
| ## Create prompts for all pairs | |
| prompts = self._prompt_builder.create_prompt(result, rank_start=0, rank_end=rank_end) | |
| ## Iterate over pairs | |
| scores = [] | |
| for batch_prompts in batch_iterator(prompts, batch_size): | |
| batch_scores = self._llm.generate( | |
| batch_prompts, | |
| binary_probs=(self.config.result_parser_name == "binary_prob"), | |
| dist_logp=(self.config.result_parser_name == "dist_logp"), | |
| rating_logp=(self.config.result_parser_name == "rating_logp"), | |
| expected_rating=(self.config.result_parser_name == "expected_rating"), | |
| ) | |
| scores.extend(batch_scores) | |
| ## Score aggregation | |
| for i, score in enumerate(scores): | |
| all_scores[result.qid][i] = score | |
| ## Update results with scores | |
| reranked_results = self._result_parser.parse( | |
| [all_scores[result.qid] for result in results], | |
| init_results | |
| ) | |
| return reranked_results | |
| # NOTE: ad-hoc adoption of LANCER method | |
| def preprocess(self, query, num_subquestions=2): | |
| template = self._prompt_builder._tokenizer.apply_chat_template( | |
| conversation=[{"role": "user", "content": TEMPLATE}], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return template.format(NUM=num_subquestions, query=query) | |
| def postprocess(self, llm_output, n): | |
| pattern = r'<START OF LIST>(.*?)<END OF LIST>' | |
| match = re.search(pattern, llm_output, flags=re.MULTILINE | re.DOTALL) | |
| if match: | |
| extracted = match.group(1).strip() | |
| else: | |
| extracted = llm_output | |
| subquestions = extracted.split("\n") | |
| subquestions = [s.strip() for s in subquestions if (s and s not in ["START OF LIST", "END OF LIST"])] | |
| subquestions = [re.sub(r'^[\-\*\d\.\)\s]+', '', s) for s in subquestions] | |
| subquestions = [s for s in subquestions if s != ""] | |
| return subquestions[:n] | |
Xet Storage Details
- Size:
- 5.04 kB
- Xet hash:
- c4ef2ded194a845101366192e0146d990345f841495bbf7e6e7b14ce3e8891f8
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.