RSA / app.py
Маликов Дмитрий Романович
Update Application
d08c34e
import re
import torch
import random
import gradio as gr
from collections import Counter
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM
DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
COT_PROMPT_TEMPLATE = """You are a highly precise reasoning assistant. Solve the problem step by step.
Problem: {question}
INSTRUCTIONS (READ CAREFULLY):
1. Show your step-by-step reasoning in **numbered steps**, e.g.:
Step 1: ...
Step 2: ...
Final answer: ...
3. THE FINAL ANSWER MUST BE STRICTLY AN INTEGER NUMBER. NO UNITS, NO WORDS, NO SYMBOLS, NO LATEX.
4. THE FINAL ANSWER MUST BE OUTPUT IN THIS EXACT FORMAT:
Final answer: #### <final_number>
Replace <final_number> with only the INTEGER NUMERIC ANSWER WITHOUT ANITHING ELSE.
5. DO NOT INCLUDE ANYTHING ELSE AFTER '####'.
6. IF YOU CANNOT CALCULATE, OUTPUT ONLY '#### NO ANSWER' AS PLACEHOLDER.
BEGIN STEP-BY-STEP REASONING:
"""
AGGREGATION_PROMPT_TEMPLATE = """You are a highly precise reasoning assistant.
You are given a math problem and several candidate solutions. Some candidates may be incorrect, incomplete, or contain reasoning errors.
Your task is to aggregate the useful ideas from the candidates and produce a single, high-quality solution.
Problem:
{question}
Candidate solutions:
{candidates}
INSTRUCTIONS (READ CAREFULLY):
1. Analyze all candidate solutions step by step.
2. Identify correct, useful, or partially correct reasoning steps.
3. If candidates disagree, **select the logically correct path** and discard incorrect reasoning.
4. Combine the selected reasoning into **one coherent, numbered step-by-step solution**, e.g.:
Step 1: ...
Step 2: ...
Final answer: ...
5. If **all candidate solutions are incorrect or inconsistent**, abandon them and solve the problem using a correct alternative strategy.
6. THE FINAL ANSWER MUST BE STRICTLY AN INTEGER NUMBER.
- NO UNITS
- NO WORDS
- NO SYMBOLS
- NO LATEX
7. THE FINAL ANSWER MUST BE OUTPUT IN THIS EXACT FORMAT:
Final answer: #### <final_number>
8. Replace <final_number> with ONLY the integer numeric answer.
9. DO NOT INCLUDE ANYTHING ELSE AFTER '####'.
BEGIN AGGREGATED STEP-BY-STEP REASONING:
"""
def load_model(model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
return tokenizer, model
def generate_responses(model, tokenizer, prompts: List[str], max_new_tokens: int = 512):
messages = [{"role": "user", "content": prompt} for prompt in prompts]
texts = []
if hasattr(tokenizer, "apply_chat_template"):
for m in messages:
texts.append(tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True))
else:
texts = [m["content"] for m in messages]
model_inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
device = next(model.parameters()).device
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
input_len = model_inputs["input_ids"].shape[1]
gen_tokens = generated_ids[:, input_len:]
texts_decoded = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
return texts_decoded
def extract_answer_from_text(text: str) -> str:
m = re.search(r"####\s*(NO ANSWER|-?\d+)", text, flags=re.IGNORECASE)
if m:
return m.group(1).strip().upper()
m2 = re.search(r"(-?\d+)(?!.*-?\d+)", text)
if m2:
return m2.group(1)
return ""
def majority_vote(population: List[str]) -> str:
answers = []
for content in population:
ans = extract_answer_from_text(content)
if ans:
answers.append(ans)
if not answers:
return ""
counter = Counter(answers)
return counter.most_common(1)[0][0]
def rsa_for_question(model, tokenizer, question: str, N=4, K=2, T=2):
prompts = [COT_PROMPT_TEMPLATE.format(question=question) for _ in range(N)]
population = generate_responses(model, tokenizer, prompts)
for t in range(T):
agg_prompts = []
for _ in range(N):
if len(population) >= K:
chosens = random.sample(population, K)
else:
chosens = random.choices(population, k=K)
candidates_text = "\n\n".join([f"CANDIDATE #{i+1}:\n{c}" for i, c in enumerate(chosens)])
agg_prompt = AGGREGATION_PROMPT_TEMPLATE.format(
question=question,
candidates=candidates_text
)
agg_prompts.append(agg_prompt)
population = generate_responses(model, tokenizer, agg_prompts)
final_answer = majority_vote(population)
return "\n\n---\n\n".join(population), final_answer
def run_and_format(model_name, question, N, K, T, compare_baseline):
tokenizer, model = load_model(model_name)
N = int(N)
K = int(K)
T = int(T)
final_population, final_answer = rsa_for_question(model, tokenizer, question, N, K, T)
if bool(compare_baseline):
baseline_output, baseline_answer = rsa_for_question(model, tokenizer, question, 1, 1, 0)
else:
baseline_output = baseline_answer = "(skipped)"
return final_population, final_answer, baseline_output, baseline_answer
with gr.Blocks(title="RSA") as demo:
gr.Markdown("### Recursive Self-Aggregation for Mathematical Reasoning\nEnter a problem whose answer is a single integer.")
model_name = gr.Textbox(label="Model name", value=DEFAULT_MODEL, lines=1)
with gr.Row():
N = gr.Slider(1, 16, value=4, step=1, label="N")
K = gr.Slider(1, 16, value=2, step=1, label="K")
T = gr.Slider(0, 10, value=2, step=1, label="T")
compare_baseline = gr.Checkbox(label="Compare with baseline", value=True)
question = gr.Textbox(label="Question", lines=2, value="How many positive divisors does number 120 have?")
run_btn = gr.Button("Run")
final_population = gr.Textbox(label="Final Population", lines=10)
final_answer = gr.Textbox(label="Final answer")
baseline_output = gr.Textbox(label="Baseline output", lines=5)
baseline_answer = gr.Textbox(label="Baseline answer")
run_btn.click(run_and_format, inputs=[model_name, question, N, K, T, compare_baseline],
outputs=[final_population, final_answer, baseline_output, baseline_answer])
if __name__ == "__main__":
demo.launch()