Spaces:
Sleeping
Sleeping
Mustafa Tag Eldeen
swap: Qwen2.5-0.5B-Instruct (faster, reliable #### format) + fix dp2 detection bug
80dbe0c | """Gradio app: LLM Reasoning Trajectories - Interactive Demo for HF Spaces""" | |
| import sys | |
| import os | |
| import gc | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from src.models.generation_output import CompleteGenerationOutput | |
| from src.models.greedy_generate_twopass import greedy_generate_with_artifacts_twopass | |
| from src.inference.complete_pipeline import process_complete_generation | |
| from src.features.windows import ( | |
| WindowConfig, | |
| compute_all_windows, | |
| classify_trajectory, | |
| compute_step_boundaries, | |
| ) | |
| from src.features.span_detection import extract_answer_after_hash | |
| from src.utils import answers_match | |
| MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct") | |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "128")) | |
| model = None | |
| tokenizer = None | |
| step_token_id = None | |
| step_token_ids = None | |
| SAMPLE_PROBLEMS = [ | |
| { | |
| "name": "Duck Eggs", | |
| "question": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", | |
| "answer": "18", | |
| }, | |
| { | |
| "name": "Robe Fabric", | |
| "question": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", | |
| "answer": "3", | |
| }, | |
| { | |
| "name": "House Flipping", | |
| "question": "Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?", | |
| "answer": "70000", | |
| }, | |
| { | |
| "name": "Train Travel", | |
| "question": "A train leaves Station A traveling at 60 mph. Another train leaves Station B (120 miles away) traveling at 40 mph toward Station A. If they leave at the same time, how long until they meet?", | |
| "answer": "1.2", | |
| }, | |
| ] | |
| def load_model(): | |
| global model, tokenizer, step_token_id, step_token_ids | |
| if model is not None: | |
| return True | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| # Qwen's tokenizer has a stable "Step" token ID (8304) across contexts | |
| # Use the encoded token that results from "Step" at the start of generation | |
| suffix = "\n<|im_start|>assistant\n" | |
| suffix_ids = tokenizer.encode(suffix, add_special_tokens=False) | |
| step_token_id = tokenizer.encode(suffix + "Step", add_special_tokens=False)[len(suffix_ids)] | |
| step_token_ids = {step_token_id} | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| ).eval() | |
| return True | |
| except Exception as e: | |
| print(f"Model loading error: {e}") | |
| return False | |
| def generate_reasoning(question: str, gold_answer: str, max_new_tokens: int): | |
| global model, tokenizer, step_token_id, step_token_ids | |
| if model is None: | |
| loaded = load_model() | |
| if not loaded: | |
| return None, "Failed to load model." | |
| # Build chat messages with the paper's "cot" template content | |
| system_msg = ( | |
| 'You are a helpful assistant that solves problems step by step with each step signified by "Step [step_number]: ". ' | |
| 'Always provide your final answer after #### at the end.' | |
| ) | |
| user_msg = ( | |
| f'Please solve this step by step, putting each step after "Step [step_number]: " ' | |
| f'and always provide your final answer after ####.\n\nQuestion: {question}' | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=False) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| try: | |
| output = greedy_generate_with_artifacts_twopass( | |
| model=model, | |
| tokenizer=tokenizer, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| gold_answer_token_id=None, | |
| capture_hidden_states=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| win_config = WindowConfig(step_token_id=step_token_id, step_token_ids=step_token_ids) | |
| output = process_complete_generation( | |
| output, | |
| gold_answer=gold_answer, | |
| tokenizer=tokenizer, | |
| window_config=win_config, | |
| ) | |
| return output, None | |
| except Exception as e: | |
| return None, f"Generation error: {str(e)}" | |
| def plot_step_confidence(windows, output): | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) | |
| step_keys = sorted( | |
| [k for k in windows.keys() if k.startswith("step_")], | |
| key=lambda x: int(x.split("_")[1]), | |
| ) | |
| if not step_keys: | |
| fig.suptitle("No steps detected", fontsize=14) | |
| plt.tight_layout() | |
| return fig | |
| steps = list(range(len(step_keys))) | |
| gold_ranks = [] | |
| prod_ranks = [] | |
| for k in step_keys: | |
| w = windows[k] | |
| gr_val = w.get("gold", {}).get("min_rank") | |
| pr_val = w.get("prod", {}).get("min_rank") | |
| gold_ranks.append(gr_val if gr_val is not None else 0) | |
| prod_ranks.append(pr_val if pr_val is not None else 0) | |
| ax = axes[0] | |
| ax.plot(steps, gold_ranks, "o-", label="Gold answer rank", color="green", markersize=8) | |
| ax.plot(steps, prod_ranks, "s--", label="Produced answer rank", color="orange", markersize=8) | |
| ax.axhline(y=10, color="gray", linestyle=":", alpha=0.5, label="High conf threshold") | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Rank (lower = more confident)") | |
| ax.set_title("Confidence Trajectory per Step") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| gold_probs = [] | |
| prod_probs = [] | |
| for k in step_keys: | |
| w = windows[k] | |
| gp = w.get("gold", {}).get("max_p") | |
| pp = w.get("prod", {}).get("max_p") | |
| gold_probs.append(gp if gp is not None else 0) | |
| prod_probs.append(pp if pp is not None else 0) | |
| ax = axes[1] | |
| ax.plot(steps, gold_probs, "o-", label="Gold answer prob", color="green", markersize=8) | |
| ax.plot(steps, prod_probs, "s--", label="Produced answer prob", color="orange", markersize=8) | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Probability") | |
| ax.set_title("Confidence Probability per Step") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| fig.suptitle( | |
| f"Trajectory: {output.num_steps} steps | " | |
| f"Correct: {output.produced_answer == output.gold_answer} | " | |
| f"Trajectory type: {output.metadata.get('trajectory_type', 'N/A')}", | |
| fontsize=12, | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| def analyze_question(question, gold_answer, max_new_tokens, progress=gr.Progress()): | |
| progress(0, desc="Starting generation...") | |
| output, error = generate_reasoning(question, gold_answer, max_new_tokens) | |
| if error: | |
| return error, None, None, None, None, None | |
| progress(0.7, desc="Analyzing trajectory...") | |
| windows = output.windows | |
| step_keys = sorted( | |
| [k for k in windows.keys() if k.startswith("step_")], | |
| key=lambda x: int(x.split("_")[1]), | |
| ) | |
| trajectory_type = output.metadata.get("trajectory_type", "N/A") | |
| is_correct = output.produced_answer == output.gold_answer if output.produced_answer else False | |
| summary = ( | |
| f"**Generated:** {output.produced_text}\n\n" | |
| f"**Produced answer:** {output.produced_answer}\n" | |
| f"**Gold answer:** {output.gold_answer}\n" | |
| f"**Correct:** {is_correct}\n" | |
| f"**Steps detected:** {output.num_steps}\n" | |
| f"**Reasoning length:** {output.reasoning_length} tokens\n" | |
| f"**Trajectory type:** {trajectory_type}" | |
| ) | |
| steps_detail = "" | |
| for k in step_keys: | |
| w = windows[k] | |
| gold_rank = w.get("gold", {}).get("min_rank") | |
| gold_rank_str = f"{gold_rank}" if gold_rank is not None else "N/A" | |
| gold_prob = w.get("gold", {}).get("max_p") | |
| gold_prob_str = f"{gold_prob:.4f}" if gold_prob is not None else "N/A" | |
| prod_rank = w.get("prod", {}).get("min_rank") | |
| prod_rank_str = f"{prod_rank}" if prod_rank is not None else "N/A" | |
| prod_prob = w.get("prod", {}).get("max_p") | |
| prod_prob_str = f"{prod_prob:.4f}" if prod_prob is not None else "N/A" | |
| steps_detail += ( | |
| f"**{k}:** gold_rank={gold_rank_str}, gold_prob={gold_prob_str}, " | |
| f"prod_rank={prod_rank_str}, prod_prob={prod_prob_str}\n" | |
| ) | |
| fig = plot_step_confidence(windows, output) | |
| progress(1.0, desc="Done!") | |
| return ( | |
| summary, | |
| steps_detail if steps_detail else "No steps detected.", | |
| fig, | |
| output.num_steps, | |
| is_correct, | |
| trajectory_type, | |
| ) | |
| def load_sample(sample_name): | |
| for s in SAMPLE_PROBLEMS: | |
| if s["name"] == sample_name: | |
| return s["question"], s["answer"] | |
| return "", "" | |
| with gr.Blocks( | |
| title="LLM Reasoning Trajectories", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| footer {display: none !important} | |
| .contain {max-width: 1200px; margin: auto} | |
| """, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🧠 LLM Reasoning as Trajectories: Step-Specific Representation Geometry and Correctness Signals | |
| *Interactive demo for ACL 2026 paper by Microsoft Research.* | |
| This demo runs **TinyLlama-1.1B-Chat** on CPU and performs chain-of-thought reasoning | |
| with **trajectory analysis** — detecting step boundaries, tracking confidence trajectories, | |
| and classifying reasoning patterns. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| sample_dd = gr.Dropdown( | |
| choices=[s["name"] for s in SAMPLE_PROBLEMS], | |
| label="Sample Problems", | |
| value="Duck Eggs", | |
| ) | |
| question_input = gr.Textbox( | |
| label="Math Problem", | |
| lines=3, | |
| placeholder="Enter a math problem...", | |
| value=SAMPLE_PROBLEMS[0]["question"], | |
| ) | |
| answer_input = gr.Textbox( | |
| label="Gold Answer", | |
| lines=1, | |
| placeholder="Expected answer", | |
| value=SAMPLE_PROBLEMS[0]["answer"], | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=64, | |
| maximum=256, | |
| value=128, | |
| step=32, | |
| label="Max new tokens", | |
| ) | |
| submit_btn = gr.Button("🚀 Analyze Trajectory", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| steps_output = gr.Number(label="Steps Detected", value=0) | |
| correct_output = gr.Checkbox(label="Correct", value=False) | |
| traj_type_output = gr.Textbox(label="Trajectory Type", value="") | |
| plot_output = gr.Plot(label="Step Confidence Trajectory") | |
| with gr.Tabs(): | |
| with gr.TabItem("Generated Reasoning"): | |
| summary_output = gr.Markdown(label="Results") | |
| with gr.TabItem("Step Details"): | |
| steps_detail_output = gr.Textbox(label="Per-Step Features", lines=15) | |
| sample_dd.change( | |
| fn=load_sample, | |
| inputs=[sample_dd], | |
| outputs=[question_input, answer_input], | |
| ) | |
| submit_btn.click( | |
| fn=analyze_question, | |
| inputs=[question_input, answer_input, max_tokens], | |
| outputs=[ | |
| summary_output, | |
| steps_detail_output, | |
| plot_output, | |
| steps_output, | |
| correct_output, | |
| traj_type_output, | |
| ], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 📚 About this demo | |
| This interactive demo accompanies the paper *"LLM Reasoning as Trajectories: Step-Specific | |
| Representation Geometry and Correctness Signals"* (ACL 2026). | |
| **Key findings demonstrated:** | |
| - Each reasoning step occupies a distinct, linearly separable region in representation space | |
| - Late-step trajectories diverge between correct and incorrect solutions | |
| - Trajectory features predict correctness with ROC-AUC 0.87 | |
| - Error-targeted inference-time interventions improve accuracy without retraining | |
| **Paper:** [arxiv.org/abs/2604.05655](https://arxiv.org/abs/2604.05655) | |
| | **Code:** [github.com/microsoft/reasoning-trajectory](https://github.com/microsoft/reasoning-trajectory) | |
| | **Model:** TinyLlama-1.1B-Chat (CPU) | |
| """ | |
| ) | |
| demo.queue(max_size=5) | |
| print("Pre-loading model at startup...") | |
| load_model() | |
| print(f"Model ready. Step token IDs: {step_token_ids}") | |
| demo.launch(show_api=False) | |