Spaces:
Sleeping
Sleeping
| import re | |
| import torch | |
| import random | |
| import gradio as gr | |
| from collections import Counter | |
| from typing import List | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" | |
| COT_PROMPT_TEMPLATE = """You are a highly precise reasoning assistant. Solve the problem step by step. | |
| Problem: {question} | |
| INSTRUCTIONS (READ CAREFULLY): | |
| 1. Show your step-by-step reasoning in **numbered steps**, e.g.: | |
| Step 1: ... | |
| Step 2: ... | |
| Final answer: ... | |
| 3. THE FINAL ANSWER MUST BE STRICTLY AN INTEGER NUMBER. NO UNITS, NO WORDS, NO SYMBOLS, NO LATEX. | |
| 4. THE FINAL ANSWER MUST BE OUTPUT IN THIS EXACT FORMAT: | |
| Final answer: #### <final_number> | |
| Replace <final_number> with only the INTEGER NUMERIC ANSWER WITHOUT ANITHING ELSE. | |
| 5. DO NOT INCLUDE ANYTHING ELSE AFTER '####'. | |
| 6. IF YOU CANNOT CALCULATE, OUTPUT ONLY '#### NO ANSWER' AS PLACEHOLDER. | |
| BEGIN STEP-BY-STEP REASONING: | |
| """ | |
| AGGREGATION_PROMPT_TEMPLATE = """You are a highly precise reasoning assistant. | |
| You are given a math problem and several candidate solutions. Some candidates may be incorrect, incomplete, or contain reasoning errors. | |
| Your task is to aggregate the useful ideas from the candidates and produce a single, high-quality solution. | |
| Problem: | |
| {question} | |
| Candidate solutions: | |
| {candidates} | |
| INSTRUCTIONS (READ CAREFULLY): | |
| 1. Analyze all candidate solutions step by step. | |
| 2. Identify correct, useful, or partially correct reasoning steps. | |
| 3. If candidates disagree, **select the logically correct path** and discard incorrect reasoning. | |
| 4. Combine the selected reasoning into **one coherent, numbered step-by-step solution**, e.g.: | |
| Step 1: ... | |
| Step 2: ... | |
| Final answer: ... | |
| 5. If **all candidate solutions are incorrect or inconsistent**, abandon them and solve the problem using a correct alternative strategy. | |
| 6. THE FINAL ANSWER MUST BE STRICTLY AN INTEGER NUMBER. | |
| - NO UNITS | |
| - NO WORDS | |
| - NO SYMBOLS | |
| - NO LATEX | |
| 7. THE FINAL ANSWER MUST BE OUTPUT IN THIS EXACT FORMAT: | |
| Final answer: #### <final_number> | |
| 8. Replace <final_number> with ONLY the integer numeric answer. | |
| 9. DO NOT INCLUDE ANYTHING ELSE AFTER '####'. | |
| BEGIN AGGREGATED STEP-BY-STEP REASONING: | |
| """ | |
| def load_model(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left') | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True | |
| ) | |
| return tokenizer, model | |
| def generate_responses(model, tokenizer, prompts: List[str], max_new_tokens: int = 512): | |
| messages = [{"role": "user", "content": prompt} for prompt in prompts] | |
| texts = [] | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| for m in messages: | |
| texts.append(tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True)) | |
| else: | |
| texts = [m["content"] for m in messages] | |
| model_inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) | |
| device = next(model.parameters()).device | |
| model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
| generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens) | |
| input_len = model_inputs["input_ids"].shape[1] | |
| gen_tokens = generated_ids[:, input_len:] | |
| texts_decoded = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) | |
| return texts_decoded | |
| def extract_answer_from_text(text: str) -> str: | |
| m = re.search(r"####\s*(NO ANSWER|-?\d+)", text, flags=re.IGNORECASE) | |
| if m: | |
| return m.group(1).strip().upper() | |
| m2 = re.search(r"(-?\d+)(?!.*-?\d+)", text) | |
| if m2: | |
| return m2.group(1) | |
| return "" | |
| def majority_vote(population: List[str]) -> str: | |
| answers = [] | |
| for content in population: | |
| ans = extract_answer_from_text(content) | |
| if ans: | |
| answers.append(ans) | |
| if not answers: | |
| return "" | |
| counter = Counter(answers) | |
| return counter.most_common(1)[0][0] | |
| def rsa_for_question(model, tokenizer, question: str, N=4, K=2, T=2): | |
| prompts = [COT_PROMPT_TEMPLATE.format(question=question) for _ in range(N)] | |
| population = generate_responses(model, tokenizer, prompts) | |
| for t in range(T): | |
| agg_prompts = [] | |
| for _ in range(N): | |
| if len(population) >= K: | |
| chosens = random.sample(population, K) | |
| else: | |
| chosens = random.choices(population, k=K) | |
| candidates_text = "\n\n".join([f"CANDIDATE #{i+1}:\n{c}" for i, c in enumerate(chosens)]) | |
| agg_prompt = AGGREGATION_PROMPT_TEMPLATE.format( | |
| question=question, | |
| candidates=candidates_text | |
| ) | |
| agg_prompts.append(agg_prompt) | |
| population = generate_responses(model, tokenizer, agg_prompts) | |
| final_answer = majority_vote(population) | |
| return "\n\n---\n\n".join(population), final_answer | |
| def run_and_format(model_name, question, N, K, T, compare_baseline): | |
| tokenizer, model = load_model(model_name) | |
| N = int(N) | |
| K = int(K) | |
| T = int(T) | |
| final_population, final_answer = rsa_for_question(model, tokenizer, question, N, K, T) | |
| if bool(compare_baseline): | |
| baseline_output, baseline_answer = rsa_for_question(model, tokenizer, question, 1, 1, 0) | |
| else: | |
| baseline_output = baseline_answer = "(skipped)" | |
| return final_population, final_answer, baseline_output, baseline_answer | |
| with gr.Blocks(title="RSA") as demo: | |
| gr.Markdown("### Recursive Self-Aggregation for Mathematical Reasoning\nEnter a problem whose answer is a single integer.") | |
| model_name = gr.Textbox(label="Model name", value=DEFAULT_MODEL, lines=1) | |
| with gr.Row(): | |
| N = gr.Slider(1, 16, value=4, step=1, label="N") | |
| K = gr.Slider(1, 16, value=2, step=1, label="K") | |
| T = gr.Slider(0, 10, value=2, step=1, label="T") | |
| compare_baseline = gr.Checkbox(label="Compare with baseline", value=True) | |
| question = gr.Textbox(label="Question", lines=2, value="How many positive divisors does number 120 have?") | |
| run_btn = gr.Button("Run") | |
| final_population = gr.Textbox(label="Final Population", lines=10) | |
| final_answer = gr.Textbox(label="Final answer") | |
| baseline_output = gr.Textbox(label="Baseline output", lines=5) | |
| baseline_answer = gr.Textbox(label="Baseline answer") | |
| run_btn.click(run_and_format, inputs=[model_name, question, N, K, T, compare_baseline], | |
| outputs=[final_population, final_answer, baseline_output, baseline_answer]) | |
| if __name__ == "__main__": | |
| demo.launch() | |