| | """ |
| | Example usage of Online mode with warmup |
| | |
| | This demonstrates: |
| | 1. Warmup phase (generate N sequences to calibrate threshold) |
| | 2. Threshold computation (DeepConf-low or DeepConf-high) |
| | 3. Final generation with calibrated early stopping |
| | """ |
| |
|
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
| |
|
| |
|
| | def extract_answer(text: str) -> Optional[str]: |
| | """ |
| | Extract boxed answer from LaTeX text |
| | |
| | Looks for \\boxed{answer} pattern in generated text. |
| | """ |
| | if "boxed" in text: |
| | ans = text.split("boxed")[-1] |
| | if len(ans) == 0: |
| | return "" |
| | elif ans[0] == "{": |
| | stack = 1 |
| | a = "" |
| | for c in ans[1:]: |
| | if c == "{": |
| | stack += 1 |
| | a += c |
| | elif c == "}": |
| | stack -= 1 |
| | if stack == 0: |
| | break |
| | a += c |
| | else: |
| | a += c |
| | else: |
| | a = ans.split("$")[0].strip() |
| | return a.strip() |
| |
|
| | return None |
| |
|
| |
|
| | def compute_least_grouped(confs: list, group_size: int) -> list: |
| | """ |
| | Compute sliding window mean confidence |
| | |
| | Args: |
| | confs: List of per-token confidence values |
| | group_size: Size of sliding window |
| | |
| | Returns: |
| | List of mean confidences for each window position |
| | """ |
| | if len(confs) < group_size: |
| | return [sum(confs) / len(confs)] if confs else [0] |
| |
|
| | sliding_means = [] |
| | for i in range(len(confs) - group_size + 1): |
| | window = confs[i : i + group_size] |
| | sliding_means.append(round(sum(window) / len(window), 3)) |
| | return sliding_means |
| |
|
| |
|
| | def process_single_output( |
| | sequence, confidences, tokenizer, window_size: int, threshold: Optional[float] = None |
| | ) -> dict: |
| | """ |
| | Process a single generated sequence |
| | |
| | Args: |
| | sequence: Generated token IDs |
| | confidences: Per-token confidence values (list or tensor) |
| | tokenizer: Tokenizer for decoding |
| | window_size: Size of sliding window for confidence |
| | threshold: Optional threshold for early stopping detection |
| | |
| | Returns: |
| | Dictionary with trace data |
| | """ |
| | |
| | if hasattr(confidences, "tolist"): |
| | confs = confidences.tolist() |
| | else: |
| | confs = list(confidences) |
| |
|
| | |
| | text = tokenizer.decode(sequence, skip_special_tokens=True) |
| |
|
| | |
| | sliding_window = compute_least_grouped(confs, window_size) |
| | min_conf = min(sliding_window) if sliding_window else 0 |
| |
|
| | |
| | stopped_early = False |
| | stop_position = None |
| |
|
| | if threshold is not None: |
| | for pos, window_mean in enumerate(sliding_window): |
| | if window_mean < threshold: |
| | stopped_early = True |
| | stop_position = pos + window_size |
| | break |
| |
|
| | |
| | extracted_answer = extract_answer(text) |
| |
|
| | return { |
| | "text": text, |
| | "confs": confs, |
| | "group_confs": sliding_window, |
| | "min_conf": min_conf, |
| | "stopped_early": stopped_early, |
| | "stop_position": stop_position, |
| | "extracted_answer": extracted_answer, |
| | "num_tokens": len(confs), |
| | "token_ids": sequence.tolist() if hasattr(sequence, "tolist") else list(sequence), |
| | } |
| |
|
| |
|
| | def process_batch_results(outputs, tokenizer, window_size: int = 2048, threshold: Optional[float] = None) -> dict: |
| | """ |
| | Process batch generation outputs |
| | |
| | This function provides post-processing capabilities for batch-generated |
| | sequences, allowing analysis of confidence patterns and early stopping |
| | behavior after generation is complete. |
| | |
| | Args: |
| | outputs: GenerateDecoderOnlyOutput from model.generate() |
| | tokenizer: Tokenizer for decoding sequences |
| | window_size: Size of sliding window for confidence computation |
| | threshold: Optional threshold for detecting where early stopping would occur |
| | |
| | Returns: |
| | Dictionary containing: |
| | - traces: List of processed trace dictionaries |
| | - min_confs: List of minimum confidences per trace |
| | - total_tokens: Total tokens across all traces |
| | - num_traces: Number of traces processed |
| | """ |
| | if not hasattr(outputs, "sequences"): |
| | raise ValueError("outputs must have 'sequences' attribute") |
| |
|
| | if not hasattr(outputs, "confidences") or outputs.confidences is None: |
| | raise ValueError("outputs must have 'confidences' attribute. Set output_confidences=True in generation_config") |
| |
|
| | sequences = outputs.sequences |
| | confidences = outputs.confidences |
| |
|
| | |
| | traces = [] |
| | min_confs = [] |
| | total_tokens = 0 |
| |
|
| | for i in range(sequences.shape[0]): |
| | trace_data = process_single_output(sequences[i], confidences[i], tokenizer, window_size, threshold) |
| |
|
| | traces.append(trace_data) |
| | min_confs.append(trace_data["min_conf"]) |
| | total_tokens += trace_data["num_tokens"] |
| |
|
| | return {"traces": traces, "min_confs": min_confs, "total_tokens": total_tokens, "num_traces": len(traces)} |
| |
|
| |
|
| | def compute_warmup_threshold(min_confs: list, variant: str = "low", eta: Optional[float] = None) -> float: |
| | """ |
| | Compute threshold from warmup confidences |
| | |
| | Args: |
| | min_confs: List of minimum confidences from warmup sequences |
| | variant: "low" (aggressive) or "high" (permissive) |
| | eta: Optional manual eta value (overrides variant default) |
| | |
| | Returns: |
| | Computed threshold value |
| | """ |
| | if eta is None: |
| | eta = 0.1 if variant == "low" else 0.9 if variant == "high" else 0.5 |
| |
|
| | confs = np.asarray(min_confs, dtype=np.float32) |
| | pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) |
| | threshold = float(np.percentile(confs, pct)) |
| |
|
| | return threshold |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def prepare_prompt(question: str, tokenizer): |
| | """Prepare prompt using chat template""" |
| | messages = [{"role": "user", "content": question}] |
| |
|
| | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| |
|
| | return prompt |
| |
|
| |
|
| | def run_online_mode_example( |
| | question: str, |
| | ground_truth: Optional[str] = None, |
| | warmup_traces: int = 8, |
| | confidence_variant: str = "low", |
| | window_size: int = 10, |
| | max_tokens: int = 128, |
| | temperature: float = 0.7, |
| | top_p: float = 0.95, |
| | ): |
| | """ |
| | Run DeepConf in online mode |
| | |
| | Args: |
| | question: Question to answer |
| | ground_truth: Optional ground truth answer for evaluation |
| | warmup_traces: Number of warmup sequences (default: 8) |
| | confidence_variant: "low" (aggressive) or "high" (permissive) |
| | window_size: Sliding window size for confidence |
| | max_tokens: Max tokens per generation |
| | temperature: Sampling temperature |
| | top_p: Top-p sampling |
| | """ |
| |
|
| | |
| | model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
| | print(f"Loading model: {model_name}") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | local_files_only=True, |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
| |
|
| | |
| | prompt = prepare_prompt(question, tokenizer) |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
|
| | print("\n" + "=" * 80) |
| | print("DEEPCONF ONLINE MODE - FOLLOWING OFFICIAL PATTERN") |
| | print("=" * 80) |
| | print(f"\nQuestion: {question}") |
| | if ground_truth: |
| | print(f"Ground truth: {ground_truth}") |
| | print("\nConfiguration:") |
| | print(f" - Warmup traces: {warmup_traces}") |
| | print(f" - Variant: DeepConf-{confidence_variant}") |
| | print(f" - Window size: {window_size}") |
| | print(f" - Max tokens: {max_tokens}") |
| | print(f" - Temperature: {temperature}") |
| | print(f" - Top-p: {top_p}") |
| |
|
| | |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print(f"PHASE 1: WARMUP (Generating {warmup_traces} sequences for calibration)") |
| | print("=" * 80) |
| |
|
| | warmup_config = GenerationConfig( |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | max_new_tokens=max_tokens, |
| | enable_conf=True, |
| | enable_early_stopping=False, |
| | output_confidences=True, |
| | return_dict_in_generate=True, |
| | pad_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | |
| | expanded_ids = inputs.input_ids.repeat(warmup_traces, 1) |
| | if "attention_mask" in inputs and inputs.attention_mask is not None: |
| | expanded_mask = inputs.attention_mask.repeat(warmup_traces, 1) |
| | else: |
| | expanded_mask = None |
| |
|
| | print(f"Generating {warmup_traces} warmup sequences...") |
| | warmup_outputs = model.generate( |
| | input_ids=expanded_ids, |
| | attention_mask=expanded_mask, |
| | generation_config=warmup_config, |
| | custom_generate="kashif/DeepConf", |
| | trust_remote_code=True, |
| | ) |
| |
|
| | |
| | warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=window_size) |
| |
|
| | print("\nWarmup complete!") |
| | print(f" - Total tokens: {warmup_results['total_tokens']}") |
| | print(f" - Min confidences: {[round(c, 3) for c in warmup_results['min_confs']]}") |
| |
|
| | |
| | print("\nWarmup Traces:") |
| | print("-" * 80) |
| | for i, trace in enumerate(warmup_results["traces"]): |
| | text = trace["text"][len(prompt) :].strip() |
| | answer = extract_answer(text) |
| | print(f"\nTrace {i + 1}:") |
| | print(f" Tokens: {trace['num_tokens']}, Min conf: {trace['min_conf']:.3f}") |
| | print(f" Text: {text[:80]}..." if len(text) > 80 else f" Text: {text}") |
| | if answer: |
| | print(f" Answer: {answer}") |
| | if ground_truth: |
| | correct = answer.strip() == ground_truth.strip() |
| | print(f" Correct: {'✓' if correct else '✗'}") |
| |
|
| | |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("PHASE 2: THRESHOLD COMPUTATION") |
| | print("=" * 80) |
| |
|
| | threshold = compute_warmup_threshold(warmup_results["min_confs"], variant=confidence_variant) |
| |
|
| | eta = 0.1 if confidence_variant == "low" else 0.9 |
| | percentile = (1.0 - eta) * 100 |
| |
|
| | print("\nComputed threshold from warmup:") |
| | print(f" - Variant: DeepConf-{confidence_variant} (eta={eta})") |
| | print(f" - Percentile: {percentile:.0f}th") |
| | print(f" - Threshold: {threshold:.3f}") |
| | print("\nInterpretation:") |
| | if confidence_variant == "low": |
| | print(" DeepConf-low is AGGRESSIVE - stops early to save tokens") |
| | else: |
| | print(" DeepConf-high is PERMISSIVE - allows longer generation") |
| |
|
| | |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("PHASE 3: FINAL GENERATION (With calibrated early stopping)") |
| | print("=" * 80) |
| |
|
| | final_config = GenerationConfig( |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | max_new_tokens=max_tokens, |
| | enable_conf=True, |
| | enable_early_stopping=True, |
| | threshold=threshold, |
| | window_size=window_size, |
| | output_confidences=True, |
| | return_dict_in_generate=True, |
| | pad_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | print(f"Generating with DeepConf-{confidence_variant} (threshold={threshold:.3f})...") |
| | final_output = model.generate( |
| | **inputs, |
| | generation_config=final_config, |
| | custom_generate="kashif/DeepConf", |
| | trust_remote_code=True, |
| | ) |
| |
|
| | final_text = tokenizer.decode(final_output.sequences[0], skip_special_tokens=True) |
| | final_tokens = final_output.sequences.shape[1] - inputs.input_ids.shape[1] |
| | final_answer = extract_answer(final_text) |
| |
|
| | |
| | if hasattr(final_output, "confidences") and final_output.confidences is not None: |
| | min_conf = final_output.confidences.min().item() |
| | mean_conf = final_output.confidences.mean().item() |
| | else: |
| | min_conf = None |
| | mean_conf = None |
| |
|
| | print("\nFinal generation complete!") |
| | print(f" - Tokens generated: {final_tokens}") |
| | if min_conf is not None: |
| | print(f" - Min confidence: {min_conf:.3f}") |
| | print(f" - Mean confidence: {mean_conf:.3f}") |
| |
|
| | print("\nGenerated text:") |
| | print("-" * 80) |
| | print(final_text) |
| | print("-" * 80) |
| |
|
| | if final_answer: |
| | print(f"\nExtracted answer: {final_answer}") |
| | if ground_truth: |
| | correct = final_answer.strip() == ground_truth.strip() |
| | print(f"Correct: {'✓' if correct else '✗'}") |
| |
|
| | |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("SUMMARY") |
| | print("=" * 80) |
| |
|
| | total_warmup_tokens = warmup_results["total_tokens"] |
| | total_tokens = total_warmup_tokens + final_tokens |
| |
|
| | print(f"Total tokens: {total_tokens}") |
| | print(f" - Warmup: {total_warmup_tokens} ({warmup_traces} sequences)") |
| | print(f" - Final: {final_tokens}") |
| |
|
| | |
| | avg_warmup_tokens = total_warmup_tokens / warmup_traces |
| | potential_savings = avg_warmup_tokens - final_tokens |
| | if potential_savings > 0: |
| | print("\nToken savings from early stopping:") |
| | print(f" - Average warmup length: {avg_warmup_tokens:.1f} tokens") |
| | print(f" - Final length: {final_tokens} tokens") |
| | print(f" - Saved: {potential_savings:.1f} tokens ({potential_savings / avg_warmup_tokens * 100:.1f}%)") |
| |
|
| | print("\n" + "=" * 80) |
| | print("Example complete!") |
| | print("=" * 80) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("\n\n" + "â–ˆ" * 80) |
| | print("EXAMPLE 1: Simple Math Problem") |
| | print("â–ˆ" * 80) |
| |
|
| | run_online_mode_example( |
| | question="What is 15 * 8? Show your work step by step.", |
| | ground_truth="120", |
| | warmup_traces=4, |
| | confidence_variant="low", |
| | window_size=5, |
| | max_tokens=64, |
| | ) |
| |
|
| | |
| | print("\n\n" + "â–ˆ" * 80) |
| | print("EXAMPLE 2: Square Root Problem") |
| | print("â–ˆ" * 80) |
| |
|
| | run_online_mode_example( |
| | question="What is the square root of 144? Express your answer in the form \\boxed{answer}.", |
| | ground_truth="12", |
| | warmup_traces=4, |
| | confidence_variant="high", |
| | window_size=5, |
| | max_tokens=64, |
| | ) |
| |
|
| | |
| | print("\n\n" + "â–ˆ" * 80) |
| | print("EXAMPLE 3: Word Problem") |
| | print("â–ˆ" * 80) |
| |
|
| | run_online_mode_example( |
| | question="If a train travels 60 miles per hour for 2.5 hours, how far does it travel?", |
| | ground_truth="150", |
| | warmup_traces=4, |
| | confidence_variant="low", |
| | window_size=5, |
| | max_tokens=96, |
| | ) |
| |
|