import re import torch import random import gradio as gr from collections import Counter from typing import List from transformers import AutoTokenizer, AutoModelForCausalLM DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" COT_PROMPT_TEMPLATE = """You are a highly precise reasoning assistant. Solve the problem step by step. Problem: {question} INSTRUCTIONS (READ CAREFULLY): 1. Show your step-by-step reasoning in **numbered steps**, e.g.: Step 1: ... Step 2: ... Final answer: ... 3. THE FINAL ANSWER MUST BE STRICTLY AN INTEGER NUMBER. NO UNITS, NO WORDS, NO SYMBOLS, NO LATEX. 4. THE FINAL ANSWER MUST BE OUTPUT IN THIS EXACT FORMAT: Final answer: #### Replace with only the INTEGER NUMERIC ANSWER WITHOUT ANITHING ELSE. 5. DO NOT INCLUDE ANYTHING ELSE AFTER '####'. 6. IF YOU CANNOT CALCULATE, OUTPUT ONLY '#### NO ANSWER' AS PLACEHOLDER. BEGIN STEP-BY-STEP REASONING: """ AGGREGATION_PROMPT_TEMPLATE = """You are a highly precise reasoning assistant. You are given a math problem and several candidate solutions. Some candidates may be incorrect, incomplete, or contain reasoning errors. Your task is to aggregate the useful ideas from the candidates and produce a single, high-quality solution. Problem: {question} Candidate solutions: {candidates} INSTRUCTIONS (READ CAREFULLY): 1. Analyze all candidate solutions step by step. 2. Identify correct, useful, or partially correct reasoning steps. 3. If candidates disagree, **select the logically correct path** and discard incorrect reasoning. 4. Combine the selected reasoning into **one coherent, numbered step-by-step solution**, e.g.: Step 1: ... Step 2: ... Final answer: ... 5. If **all candidate solutions are incorrect or inconsistent**, abandon them and solve the problem using a correct alternative strategy. 6. THE FINAL ANSWER MUST BE STRICTLY AN INTEGER NUMBER. - NO UNITS - NO WORDS - NO SYMBOLS - NO LATEX 7. THE FINAL ANSWER MUST BE OUTPUT IN THIS EXACT FORMAT: Final answer: #### 8. Replace with ONLY the integer numeric answer. 9. DO NOT INCLUDE ANYTHING ELSE AFTER '####'. BEGIN AGGREGATED STEP-BY-STEP REASONING: """ def load_model(model_name: str): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left') model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True ) return tokenizer, model def generate_responses(model, tokenizer, prompts: List[str], max_new_tokens: int = 512): messages = [{"role": "user", "content": prompt} for prompt in prompts] texts = [] if hasattr(tokenizer, "apply_chat_template"): for m in messages: texts.append(tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True)) else: texts = [m["content"] for m in messages] model_inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) device = next(model.parameters()).device model_inputs = {k: v.to(device) for k, v in model_inputs.items()} generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens) input_len = model_inputs["input_ids"].shape[1] gen_tokens = generated_ids[:, input_len:] texts_decoded = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) return texts_decoded def extract_answer_from_text(text: str) -> str: m = re.search(r"####\s*(NO ANSWER|-?\d+)", text, flags=re.IGNORECASE) if m: return m.group(1).strip().upper() m2 = re.search(r"(-?\d+)(?!.*-?\d+)", text) if m2: return m2.group(1) return "" def majority_vote(population: List[str]) -> str: answers = [] for content in population: ans = extract_answer_from_text(content) if ans: answers.append(ans) if not answers: return "" counter = Counter(answers) return counter.most_common(1)[0][0] def rsa_for_question(model, tokenizer, question: str, N=4, K=2, T=2): prompts = [COT_PROMPT_TEMPLATE.format(question=question) for _ in range(N)] population = generate_responses(model, tokenizer, prompts) for t in range(T): agg_prompts = [] for _ in range(N): if len(population) >= K: chosens = random.sample(population, K) else: chosens = random.choices(population, k=K) candidates_text = "\n\n".join([f"CANDIDATE #{i+1}:\n{c}" for i, c in enumerate(chosens)]) agg_prompt = AGGREGATION_PROMPT_TEMPLATE.format( question=question, candidates=candidates_text ) agg_prompts.append(agg_prompt) population = generate_responses(model, tokenizer, agg_prompts) final_answer = majority_vote(population) return "\n\n---\n\n".join(population), final_answer def run_and_format(model_name, question, N, K, T, compare_baseline): tokenizer, model = load_model(model_name) N = int(N) K = int(K) T = int(T) final_population, final_answer = rsa_for_question(model, tokenizer, question, N, K, T) if bool(compare_baseline): baseline_output, baseline_answer = rsa_for_question(model, tokenizer, question, 1, 1, 0) else: baseline_output = baseline_answer = "(skipped)" return final_population, final_answer, baseline_output, baseline_answer with gr.Blocks(title="RSA") as demo: gr.Markdown("### Recursive Self-Aggregation for Mathematical Reasoning\nEnter a problem whose answer is a single integer.") model_name = gr.Textbox(label="Model name", value=DEFAULT_MODEL, lines=1) with gr.Row(): N = gr.Slider(1, 16, value=4, step=1, label="N") K = gr.Slider(1, 16, value=2, step=1, label="K") T = gr.Slider(0, 10, value=2, step=1, label="T") compare_baseline = gr.Checkbox(label="Compare with baseline", value=True) question = gr.Textbox(label="Question", lines=2, value="How many positive divisors does number 120 have?") run_btn = gr.Button("Run") final_population = gr.Textbox(label="Final Population", lines=10) final_answer = gr.Textbox(label="Final answer") baseline_output = gr.Textbox(label="Baseline output", lines=5) baseline_answer = gr.Textbox(label="Baseline answer") run_btn.click(run_and_format, inputs=[model_name, question, N, K, T, compare_baseline], outputs=[final_population, final_answer, baseline_output, baseline_answer]) if __name__ == "__main__": demo.launch()