| | import transformers |
| | import torch |
| | import random |
| | from datasets import load_dataset |
| | import requests |
| |
|
| | question = "Mike Barnett negotiated many contracts including which player that went on to become general manager of CSKA Moscow of the Kontinental Hockey League?" |
| |
|
| | |
| | model_id = "PeterJinGo/SearchR1-nq_hotpotqa_train-qwen2.5-7b-em-ppo" |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | question = question.strip() |
| | if question[-1] != '?': |
| | question += '?' |
| | curr_eos = [151645, 151643] |
| | curr_search_template = '\n\n{output_text}<information>{search_results}</information>\n\n' |
| |
|
| | |
| | prompt = f"""Answer the given question. \ |
| | You must conduct reasoning inside <think> and </think> first every time you get new information. \ |
| | After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \ |
| | You can search as many times as your want. \ |
| | If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n""" |
| |
|
| | |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) |
| | model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") |
| |
|
| | |
| | class StopOnSequence(transformers.StoppingCriteria): |
| | def __init__(self, target_sequences, tokenizer): |
| | |
| | self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences] |
| | self.target_lengths = [len(target_id) for target_id in self.target_ids] |
| | self._tokenizer = tokenizer |
| |
|
| | def __call__(self, input_ids, scores, **kwargs): |
| | |
| | targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids] |
| |
|
| | if input_ids.shape[1] < min(self.target_lengths): |
| | return False |
| |
|
| | |
| | for i, target in enumerate(targets): |
| | if torch.equal(input_ids[0, -self.target_lengths[i]:], target): |
| | return True |
| |
|
| | return False |
| |
|
| | def get_query(text): |
| | import re |
| | pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL) |
| | matches = pattern.findall(text) |
| | if matches: |
| | return matches[-1] |
| | else: |
| | return None |
| |
|
| | def search(query: str): |
| | payload = { |
| | "queries": [query], |
| | "topk": 3, |
| | "return_scores": True |
| | } |
| | results = requests.post("http://127.0.0.1:8000/retrieve", json=payload).json()['result'] |
| | |
| | def _passages2string(retrieval_result): |
| | format_reference = '' |
| | for idx, doc_item in enumerate(retrieval_result): |
| | |
| | content = doc_item['document']['contents'] |
| | title = content.split("\n")[0] |
| | text = "\n".join(content.split("\n")[1:]) |
| | format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" |
| | return format_reference |
| |
|
| | return _passages2string(results[0]) |
| |
|
| |
|
| | |
| | target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"] |
| | stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)]) |
| |
|
| | cnt = 0 |
| |
|
| | if tokenizer.chat_template: |
| | prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) |
| |
|
| | print('\n\n################# [Start Reasoning + Searching] ##################\n\n') |
| | print(prompt) |
| | |
| | while True: |
| | input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
| | attention_mask = torch.ones_like(input_ids) |
| | |
| | |
| | outputs = model.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | max_new_tokens=1024, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=tokenizer.eos_token_id, |
| | do_sample=True, |
| | temperature=0.7 |
| | ) |
| |
|
| | if outputs[0][-1].item() in curr_eos: |
| | generated_tokens = outputs[0][input_ids.shape[1]:] |
| | output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | print(output_text) |
| | break |
| |
|
| | generated_tokens = outputs[0][input_ids.shape[1]:] |
| | output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | |
| | tmp_query = get_query(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
| | if tmp_query: |
| | |
| | search_results = search(tmp_query) |
| | else: |
| | search_results = '' |
| |
|
| | search_text = curr_search_template.format(output_text=output_text, search_results=search_results) |
| | prompt += search_text |
| | cnt += 1 |
| | print(search_text) |
| |
|