File size: 10,649 Bytes
d28330f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import asyncio
import re
import time
import traceback
from copy import deepcopy
from typing import List

from slime.rollout.rm_hub import batched_async_rm
from slime.utils.http_utils import post
from slime.utils.types import Sample

from .prompts import SOLVER_PROMPT_TEMPLATE, generate_rewriter_template, generate_select_template


async def generate_response(args, prompt, key):
    try:
        sampling_params = args.sampling_params
        tokenizer = args.tokenizer
        max_context_length = args.rollout_max_context_len
        sample = deepcopy(args.sample)

        url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

        prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
        sample.tokens = prompt_token_ids
        sample.prompt = prompt
        input_token_ids = prompt_token_ids
        prompt_length = len(input_token_ids)
        current_sampling_params = deepcopy(sampling_params)
        current_sampling_params["max_new_tokens"] = min(
            sampling_params["max_new_tokens"], max_context_length - prompt_length
        )

        if current_sampling_params["max_new_tokens"] <= 0:
            return None

        payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}

        output = await post(url, payload)

        # Extract new response tokens
        if "output_token_logprobs" in output["meta_info"]:
            new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
        else:
            # abort
            new_response_tokens = []

        # Update sample with tokens directly - avoiding re-tokenization
        sample.tokens = sample.tokens + new_response_tokens
        sample.response_length += len(new_response_tokens)
        sample.response = output["text"]

        match output["meta_info"]["finish_reason"]["type"]:
            case "length":
                sample.status = Sample.Status.TRUNCATED
            # case "abort":
            #     sample.status = Sample.Status.ABORTED
            case "stop":
                sample.status = Sample.Status.COMPLETED

        args.results_dict[key].append(sample)

        final = output["text"].replace("<|user|>", "")
        if "</think>" in final:
            contents = final.split("</think>")
            if len(contents) == 2 and contents[1] != "":
                reason_content = contents[0].strip()
                response_content = contents[1].strip()
                sample.reason_content = reason_content
                sample.response_content = response_content
                return response_content
        sample.reason_content = None
        sample.response_content = None
        return None
    except Exception as e:
        print(f"Error generating response: {e}")
        return None


class Agent:
    """A base class for our AI agents."""

    def __init__(self):
        pass

    async def run(self, args, prompt, max_retries: int = 1, key: str = None) -> str:
        """Runs the agent by sending a prompt to the LLM."""
        for i in range(max_retries):
            try:
                response = await generate_response(args, prompt, key=key)
                return response
            except Exception as e:
                print(f"Error querying LLM: {e}")
                time.sleep(1)
        print(f"Failed to query LLM after {max_retries} retries")
        return None


class SolverAgent(Agent):
    """The agent responsible for generating and improving solutions."""

    def __init__(self):
        super().__init__()

    async def generate_initial_solution(self, args, problem_statement) -> str:
        """Generates the first solution attempt."""
        prompt = SOLVER_PROMPT_TEMPLATE.format(problem_statement=problem_statement)
        return await self.run(args, prompt, max_retries=3, key="solver")


class RewriterAgent(Agent):
    """The agent responsible for rewriting solutions."""

    def __init__(self):
        super().__init__()

    async def rewrite(self, args, problem_statement, previous_solutions: List[str]) -> str:
        """Generates the rewrited solution."""

        # Dynamically generate template
        template = generate_rewriter_template(len(previous_solutions))

        # Build parameter dictionary
        format_params = {"problem_statement": problem_statement}
        for i, solution in enumerate(previous_solutions):
            format_params[f"solution{i+1}"] = solution

        prompt = template.format(**format_params)
        return await self.run(args, prompt, max_retries=1, key="rewriter")


class SelectorAgent(Agent):
    """The agent responsible for selecting solutions."""

    def __init__(self):
        super().__init__()

    async def select(self, args, problem_statement, candidate_solutions: List[str]) -> str:
        """Generates the rewrited solution."""

        # Dynamically generate template
        template = generate_select_template(len(candidate_solutions))

        # Build parameter dictionary
        format_params = {"problem_statement": problem_statement}
        for i, solution in enumerate(candidate_solutions):
            format_params[f"solution{i+1}"] = solution

        prompt = template.format(**format_params)
        return await self.run(args, prompt, max_retries=10, key="selector")

    def extract_selected_solution_idx(self, response: str, candidate_solutions: List[str]) -> int:
        """Extracts the selected solution ID from the response."""
        PATTERN = re.compile("Judgment:\s*(\d+)")
        matched = PATTERN.findall(response)
        try:
            selected_id = int(matched[0]) - 1
            if selected_id < len(candidate_solutions) and selected_id >= 0:
                return selected_id
            else:
                return None
        except Exception as e:
            print(f"extract_selected_solution_idx error: {e}")
            return None


async def rewrite_worker(args, previous_solutions, problem_statement, worker_id):
    rewriter = RewriterAgent()
    new_solution = await rewriter.rewrite(args, problem_statement, previous_solutions)
    return new_solution


async def solver_worker(args, problem_statement, worker_id):
    """
    Single solver workflow.
    """

    try:
        solver = SolverAgent()
        current_solution = await solver.generate_initial_solution(args, problem_statement)
        return current_solution

    except Exception as e:
        print(f"[Worker-{worker_id}] exception: {e}")
        print(f"[Worker-{worker_id}] traceback: {traceback.format_exc()}")
        return None


async def run_agent_system(args, sample):
    """
    Concurrently run num_parallel pipeline instances.
    """

    args = deepcopy(args)  # Deep copy args since it may be modified in rollout_with_multi_agents
    args.sample = sample
    args.results_dict = {"solver": [], "rewriter": [], "selector": []}

    problem_statement = sample.prompt
    tasks = [solver_worker(args, problem_statement, worker_id) for worker_id in range(args.num_parallel)]
    results = await asyncio.gather(*tasks, return_exceptions=True)

    rewards = await batched_async_rm(args, args.results_dict["solver"])
    for sample, reward in zip(args.results_dict["solver"], rewards):
        sample.reward = reward

    previous_solutions = [item for item in results if isinstance(item, str)]

    def reward_adjustment(samples, reward_weight):
        for sample in samples:
            sample.reward = sample.reward * reward_weight
        return samples

    if len(previous_solutions) == 0:
        reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
        return args.results_dict["solver"]

    # Rewriting
    tasks = [
        rewrite_worker(args, previous_solutions, problem_statement, worker_id)
        for worker_id in range(args.num_parallel)
    ]
    rewrited_solutions_raw = await asyncio.gather(*tasks, return_exceptions=True)

    # Handle exception results
    rewrited_solutions = []
    for i, result in enumerate(rewrited_solutions_raw):
        if isinstance(result, str):
            rewrited_solutions.append(result)

    rewards = await batched_async_rm(args, args.results_dict["rewriter"])
    for sample, reward in zip(args.results_dict["rewriter"], rewards):
        sample.reward = reward

    if len(rewrited_solutions) == 0:
        reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
        reward_adjustment(args.results_dict["rewriter"], args.incorrect_reward_weight)
        return args.results_dict["solver"] + args.results_dict["rewriter"]

    # Selection
    selector = SelectorAgent()
    response = await selector.select(args, problem_statement, rewrited_solutions)
    if len(args.results_dict["selector"]) == 0:
        reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
        reward_adjustment(args.results_dict["rewriter"], args.incorrect_reward_weight)
        return args.results_dict["solver"] + args.results_dict["rewriter"]

    assert (
        len(args.results_dict["selector"]) == 1
    ), f"selector should only return one solution, but got {len(args.results_dict['selector'])}"
    if response is None:
        args.results_dict["selector"][0].reward = 0
    else:
        selected_solution_idx = selector.extract_selected_solution_idx(response, rewrited_solutions)
        if selected_solution_idx is None:
            args.results_dict["selector"][0].reward = 0
        else:
            selected_solution = rewrited_solutions[selected_solution_idx]
            for sample in args.results_dict["rewriter"]:
                if sample.response_content is not None and selected_solution in sample.response_content:
                    args.results_dict["selector"][0].reward = sample.reward
                    break

    ## If final answer is correct, add positive reward to all; if incorrect, add negative penalty to all.
    if args.results_dict["selector"][0].reward == 1:
        reward_adjustment(args.results_dict["solver"], args.correct_reward_weight)
        reward_adjustment(args.results_dict["rewriter"], args.correct_reward_weight)
        reward_adjustment(args.results_dict["selector"], args.correct_reward_weight)
    else:
        reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
        reward_adjustment(args.results_dict["rewriter"], args.incorrect_reward_weight)
        reward_adjustment(args.results_dict["selector"], args.incorrect_reward_weight)

    return args.results_dict["solver"] + args.results_dict["rewriter"] + args.results_dict["selector"]