Upload folder using huggingface_hub
Browse files- automr/config.py +3 -1
- automr/dag.py +1 -1
- automr/evaluator.py +53 -60
- automr/model.py +206 -329
- automr/strategies.py +0 -2
- automr/trainer.py +54 -39
automr/config.py
CHANGED
|
@@ -18,6 +18,8 @@ class AutoMRConfig:
|
|
| 18 |
batch_size: int = 8
|
| 19 |
num_samples_per_query: int = 4 # M in paper
|
| 20 |
gradient_clip: float = 1.0
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Validation settings
|
| 23 |
val_every_n_steps: int = 100 # Alpha in the requirement - validate every N steps
|
|
@@ -25,7 +27,7 @@ class AutoMRConfig:
|
|
| 25 |
early_stopping_patience: int = 5 # Stop if no improvement for N validations
|
| 26 |
|
| 27 |
# Generation settings
|
| 28 |
-
max_new_tokens: int =
|
| 29 |
temperature: float = 0.01
|
| 30 |
top_p: float = 0.9
|
| 31 |
|
|
|
|
| 18 |
batch_size: int = 8
|
| 19 |
num_samples_per_query: int = 4 # M in paper
|
| 20 |
gradient_clip: float = 1.0
|
| 21 |
+
initial_baseline: float = 0.0 # Initial value for REINFORCE baseline
|
| 22 |
+
baseline_momentum: float = 0.9 # Momentum for baseline update
|
| 23 |
|
| 24 |
# Validation settings
|
| 25 |
val_every_n_steps: int = 100 # Alpha in the requirement - validate every N steps
|
|
|
|
| 27 |
early_stopping_patience: int = 5 # Stop if no improvement for N validations
|
| 28 |
|
| 29 |
# Generation settings
|
| 30 |
+
max_new_tokens: int = 4096
|
| 31 |
temperature: float = 0.01
|
| 32 |
top_p: float = 0.9
|
| 33 |
|
automr/dag.py
CHANGED
|
@@ -57,7 +57,7 @@ class MetaReasoningDAG:
|
|
| 57 |
|
| 58 |
def get_context_up_to(self, idx: int) -> str:
|
| 59 |
"""Get all node contents up to index idx"""
|
| 60 |
-
return "\n".join([node.content for node in self.nodes[:idx+1]])
|
| 61 |
|
| 62 |
def total_tokens(self) -> int:
|
| 63 |
"""Total tokens generated (excluding source node)"""
|
|
|
|
| 57 |
|
| 58 |
def get_context_up_to(self, idx: int) -> str:
|
| 59 |
"""Get all node contents up to index idx"""
|
| 60 |
+
return "\n".join([f"step: {node.index}: {node.content}" for node in self.nodes[:idx+1]])
|
| 61 |
|
| 62 |
def total_tokens(self) -> int:
|
| 63 |
"""Total tokens generated (excluding source node)"""
|
automr/evaluator.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from typing import List, Dict, Tuple
|
| 2 |
from tqdm import tqdm
|
| 3 |
import os
|
|
|
|
| 4 |
|
| 5 |
from .model import AutoMR
|
| 6 |
from .config import AutoMRConfig
|
|
@@ -15,79 +16,71 @@ class AutoMREvaluator:
|
|
| 15 |
self.config = config
|
| 16 |
ensure_dir(config.results_dir)
|
| 17 |
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
"""
|
| 23 |
-
print(f"\nEvaluating on {len(test_data)} samples...")
|
| 24 |
-
|
| 25 |
self.model.strategy_mlp.eval()
|
| 26 |
self.model.strategy_embeddings.eval()
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
correct = 0
|
| 29 |
total = 0
|
| 30 |
-
detailed_results = []
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
is_correct = check_answer_match(
|
| 48 |
-
pred_answer,
|
| 49 |
-
ground_truth,
|
| 50 |
-
self.config.task_type
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
if is_correct:
|
| 54 |
-
correct += 1
|
| 55 |
-
total += 1
|
| 56 |
-
|
| 57 |
-
pbar.set_postfix({
|
| 58 |
-
'Acc': f'{correct} / {total}',
|
| 59 |
-
})
|
| 60 |
-
|
| 61 |
-
# Store detailed result
|
| 62 |
# result = {
|
| 63 |
-
#
|
| 64 |
-
#
|
| 65 |
-
#
|
| 66 |
-
#
|
| 67 |
# }
|
| 68 |
-
|
| 69 |
# if self.config.save_skeletons:
|
| 70 |
-
#
|
| 71 |
-
|
| 72 |
# detailed_results.append(result)
|
| 73 |
-
|
| 74 |
accuracy = correct / total if total > 0 else 0.0
|
| 75 |
-
|
| 76 |
print(f"\nEvaluation Results:")
|
| 77 |
print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
|
| 78 |
-
|
| 79 |
-
# Save results
|
| 80 |
if self.config.save_predictions:
|
| 81 |
results_path = os.path.join(
|
| 82 |
self.config.results_dir,
|
| 83 |
-
'evaluation_results.json'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
)
|
| 85 |
-
save_json({
|
| 86 |
-
'accuracy': accuracy,
|
| 87 |
-
'correct': correct,
|
| 88 |
-
'total': total,
|
| 89 |
-
'detailed_results': detailed_results
|
| 90 |
-
}, results_path)
|
| 91 |
print(f"Results saved to {results_path}")
|
| 92 |
-
|
| 93 |
-
return accuracy, detailed_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List, Dict, Tuple
|
| 2 |
from tqdm import tqdm
|
| 3 |
import os
|
| 4 |
+
import asyncio
|
| 5 |
|
| 6 |
from .model import AutoMR
|
| 7 |
from .config import AutoMRConfig
|
|
|
|
| 16 |
self.config = config
|
| 17 |
ensure_dir(config.results_dir)
|
| 18 |
|
| 19 |
+
async def evaluate_async(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
|
| 20 |
+
"""Async evaluation: send all queries in a single batch to vLLM."""
|
| 21 |
+
print(f"\nEvaluating on {len(test_data)} samples (async, single batch)...")
|
| 22 |
+
|
|
|
|
|
|
|
|
|
|
| 23 |
self.model.strategy_mlp.eval()
|
| 24 |
self.model.strategy_embeddings.eval()
|
| 25 |
+
|
| 26 |
+
queries = [item['query'] for item in test_data]
|
| 27 |
+
ground_truths = [item['answer'] for item in test_data]
|
| 28 |
+
|
| 29 |
+
# One-shot async sampling over all queries, M=1
|
| 30 |
+
pred_answers, _ = await self.model.sample_batch(queries, M=1)
|
| 31 |
+
|
| 32 |
correct = 0
|
| 33 |
total = 0
|
| 34 |
+
detailed_results: List[Dict] = []
|
| 35 |
+
|
| 36 |
+
for query, ground_truth, pred_answer in tqdm(
|
| 37 |
+
zip(queries, ground_truths, pred_answers),
|
| 38 |
+
total=len(queries),
|
| 39 |
+
desc="Evaluating",
|
| 40 |
+
):
|
| 41 |
+
is_correct = check_answer_match(
|
| 42 |
+
pred_answer,
|
| 43 |
+
ground_truth,
|
| 44 |
+
self.config.task_type,
|
| 45 |
+
)
|
| 46 |
+
if is_correct:
|
| 47 |
+
correct += 1
|
| 48 |
+
total += 1
|
| 49 |
+
|
| 50 |
+
# 可选:收集详细结果(目前默认为空,保持文件结构)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# result = {
|
| 52 |
+
# 'query': query,
|
| 53 |
+
# 'ground_truth': ground_truth,
|
| 54 |
+
# 'prediction': pred_answer,
|
| 55 |
+
# 'correct': is_correct,
|
| 56 |
# }
|
|
|
|
| 57 |
# if self.config.save_skeletons:
|
| 58 |
+
# result['skeleton'] = None
|
|
|
|
| 59 |
# detailed_results.append(result)
|
| 60 |
+
|
| 61 |
accuracy = correct / total if total > 0 else 0.0
|
| 62 |
+
|
| 63 |
print(f"\nEvaluation Results:")
|
| 64 |
print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
|
| 65 |
+
|
|
|
|
| 66 |
if self.config.save_predictions:
|
| 67 |
results_path = os.path.join(
|
| 68 |
self.config.results_dir,
|
| 69 |
+
'evaluation_results.json',
|
| 70 |
+
)
|
| 71 |
+
save_json(
|
| 72 |
+
{
|
| 73 |
+
'accuracy': accuracy,
|
| 74 |
+
'correct': correct,
|
| 75 |
+
'total': total,
|
| 76 |
+
'detailed_results': detailed_results,
|
| 77 |
+
},
|
| 78 |
+
results_path,
|
| 79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
print(f"Results saved to {results_path}")
|
| 81 |
+
|
| 82 |
+
return accuracy, detailed_results
|
| 83 |
+
|
| 84 |
+
def evaluate(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
|
| 85 |
+
"""Synchronous wrapper for async evaluation, for use in main.py."""
|
| 86 |
+
return asyncio.run(self.evaluate_async(test_data))
|
automr/model.py
CHANGED
|
@@ -1,22 +1,18 @@
|
|
| 1 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
-
|
| 6 |
import random
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from
|
|
|
|
| 10 |
from .config import AutoMRConfig
|
| 11 |
from .strategies import META_STRATEGIES, STRATEGY_LIST
|
| 12 |
from .dag import MetaReasoningDAG
|
| 13 |
-
from .utils import extract_answer
|
| 14 |
-
from typing import Dict
|
| 15 |
-
from openai import OpenAI
|
| 16 |
|
| 17 |
class StrategyMLP(nn.Module):
|
| 18 |
"""MLP for sampling meta-reasoning strategies"""
|
| 19 |
-
|
| 20 |
def __init__(self, hidden_size: int, num_strategies: int):
|
| 21 |
super().__init__()
|
| 22 |
# Input: [node_repr, strategy_repr, context_repr]
|
|
@@ -26,14 +22,6 @@ class StrategyMLP(nn.Module):
|
|
| 26 |
self.dropout = nn.Dropout(0.1)
|
| 27 |
|
| 28 |
def forward(self, node_repr, strategy_repr, context_repr):
|
| 29 |
-
"""
|
| 30 |
-
Args:
|
| 31 |
-
node_repr: [batch, hidden_size]
|
| 32 |
-
strategy_repr: [batch, hidden_size]
|
| 33 |
-
context_repr: [batch, hidden_size]
|
| 34 |
-
Returns:
|
| 35 |
-
logits: [batch, num_strategies]
|
| 36 |
-
"""
|
| 37 |
x = torch.cat([node_repr, strategy_repr, context_repr], dim=-1)
|
| 38 |
x = F.relu(self.fc1(x))
|
| 39 |
x = self.dropout(x)
|
|
@@ -44,361 +32,250 @@ class StrategyMLP(nn.Module):
|
|
| 44 |
|
| 45 |
|
| 46 |
class AutoMR:
|
| 47 |
-
"""AutoMR Framework
|
| 48 |
|
| 49 |
def __init__(self, config: AutoMRConfig):
|
| 50 |
self.config = config
|
| 51 |
self.device = config.device
|
| 52 |
self.token_budget = config.token_budget
|
| 53 |
-
self.model_name_for_api = config.model_name
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 58 |
-
|
| 59 |
-
# print(f"Loading vLLM generator: {config.model_name}")
|
| 60 |
-
# self.llm = LLM(
|
| 61 |
-
# config.model_name,
|
| 62 |
-
# dtype=torch.float16,
|
| 63 |
-
# trust_remote_code=True,
|
| 64 |
-
# tensor_parallel_size=config.tensor_parallel_size,
|
| 65 |
-
# gpu_memory_utilization="0.8"
|
| 66 |
-
# )
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# config.model_name,
|
| 71 |
-
# torch_dtype=torch.float16,
|
| 72 |
-
# trust_remote_code=True,
|
| 73 |
-
# device_map=None
|
| 74 |
-
# ).to(self.device)
|
| 75 |
-
# self.llm_embedder.eval()
|
| 76 |
-
|
| 77 |
-
# print(f"Loading vLLM Embbedder: {config.model_name}")
|
| 78 |
-
# self.llm_embedder = LLM(
|
| 79 |
-
# config.model_name,
|
| 80 |
-
# dtype=torch.float16,
|
| 81 |
-
# trust_remote_code=True,
|
| 82 |
-
# tensor_parallel_size=config.tensor_parallel_size,
|
| 83 |
-
# gpu_memory_utilization="0.8"
|
| 84 |
-
# task="embed",
|
| 85 |
-
# )
|
| 86 |
-
|
| 87 |
-
print("Connecting to vLLM Generator Server (port 8000)...")
|
| 88 |
-
self.generator_client = OpenAI(
|
| 89 |
-
api_key="vllm",
|
| 90 |
-
base_url="http://localhost:8000/v1"
|
| 91 |
-
)
|
| 92 |
|
| 93 |
-
print("Connecting to
|
| 94 |
-
self.embed_client =
|
| 95 |
-
api_key="vllm",
|
| 96 |
-
base_url="http://localhost:8001/v1"
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
|
| 100 |
-
#
|
| 101 |
self.num_strategies = len(STRATEGY_LIST)
|
| 102 |
hidden_size = config.hidden_size
|
| 103 |
|
| 104 |
self.strategy_embeddings = nn.Embedding(self.num_strategies, hidden_size).to(self.device)
|
| 105 |
self.strategy_mlp = StrategyMLP(hidden_size, self.num_strategies).to(self.device)
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
self.strategy_to_idx = {s: i for i, s in enumerate(STRATEGY_LIST)}
|
| 109 |
self.idx_to_strategy = {i: s for i, s in enumerate(STRATEGY_LIST)}
|
|
|
|
| 110 |
|
| 111 |
-
# Optimizer
|
| 112 |
self.optimizer = torch.optim.Adam(
|
| 113 |
list(self.strategy_embeddings.parameters()) +
|
| 114 |
list(self.strategy_mlp.parameters()),
|
| 115 |
lr=config.learning_rate
|
| 116 |
)
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def get_text_representation(self, texts: List[str]) -> Tuple[torch.Tensor]:
|
| 121 |
-
"""
|
| 122 |
-
Get pooled hidden state representations for texts in batch using LLM embedder.
|
| 123 |
-
Args:
|
| 124 |
-
texts: List of input texts
|
| 125 |
-
Returns:
|
| 126 |
-
pooled: Tensor of shape [batch_size, hidden_size]
|
| 127 |
-
"""
|
| 128 |
-
|
| 129 |
-
# self.tokenizer.padding_side = "left"
|
| 130 |
-
# inputs = self.tokenizer(
|
| 131 |
-
# texts,
|
| 132 |
-
# return_tensors="pt",
|
| 133 |
-
# padding=True,
|
| 134 |
-
# truncation=True,
|
| 135 |
-
# ).to(self.device)
|
| 136 |
|
| 137 |
-
|
| 138 |
-
# outputs = self.llm(**inputs, output_hidden_states=True)
|
| 139 |
-
# hidden_states = outputs.hidden_states[-1] # [bsz, len, dim]
|
| 140 |
-
# pooled = hidden_states[:, -1, :]
|
| 141 |
-
|
| 142 |
-
# batch_outputs = self.llm_embedder.encode(texts)
|
| 143 |
-
# pooled = []
|
| 144 |
-
# for outputs in batch_outputs:
|
| 145 |
-
# last_hidden_state = outputs.outputs.data[-1,:] # [seq_len, hidden_size]
|
| 146 |
-
# pooled.append(last_hidden_state)
|
| 147 |
-
# pooled = torch.stack(pooled, dim=0).to(self.device) # [batch_size, hidden_size]
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
pooled = torch.stack(batch_reprs, dim=0) # [batch_size, hidden_size]
|
| 160 |
-
return pooled
|
| 161 |
-
|
| 162 |
-
def sample_strategy(
|
| 163 |
-
self,
|
| 164 |
-
batch_node_content_repr: torch.Tensor,
|
| 165 |
-
batch_sampled_strategies: Dict[int, List[int]],
|
| 166 |
-
batch_context_repr: torch.Tensor
|
| 167 |
-
) -> Tuple[List[int], torch.Tensor]:
|
| 168 |
-
"""
|
| 169 |
-
Sample a strategy for each edge (j, i) in batch
|
| 170 |
Args:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
Returns:
|
| 175 |
-
batch_strategy_idx: List of sampled strategy indices
|
| 176 |
-
batch_log_prob: Tensor of log probabilities, shape [batch_size]
|
| 177 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
for sampled_strategies in batch_sampled_strategies.values():
|
| 183 |
-
if sampled_strategies:
|
| 184 |
-
sampled_strategies = torch.tensor(sampled_strategies).to(self.device)
|
| 185 |
-
strategy_repr = self.strategy_embeddings(sampled_strategies).mean(dim=0, keepdim=True)
|
| 186 |
-
else:
|
| 187 |
-
strategy_repr = torch.zeros(1, self.config.hidden_size).to(self.device)
|
| 188 |
-
batch_strategy_repr.append(strategy_repr)
|
| 189 |
-
|
| 190 |
-
batch_strategy_repr = torch.cat(batch_strategy_repr, dim=0) # Combine all batch representations
|
| 191 |
-
batch_logits = self.strategy_mlp(batch_node_content_repr, batch_strategy_repr, batch_context_repr)
|
| 192 |
-
batch_probs = F.softmax(batch_logits, dim=-1)
|
| 193 |
-
|
| 194 |
-
dist = torch.distributions.Categorical(batch_probs)
|
| 195 |
-
batch_strategy_idx = dist.sample()
|
| 196 |
-
batch_log_prob = dist.log_prob(batch_strategy_idx).to(self.device)
|
| 197 |
-
|
| 198 |
-
return batch_strategy_idx.cpu().tolist(), batch_log_prob
|
| 199 |
|
| 200 |
-
def generate_content(
|
| 201 |
-
self,
|
| 202 |
-
batch_query: List[str],
|
| 203 |
-
batch_context: List[str],
|
| 204 |
-
batch_strategies: List[List[str]],
|
| 205 |
-
batch_remaining_budget: List[int]
|
| 206 |
-
) -> Tuple[List[str], List[int], torch.Tensor]:
|
| 207 |
-
"""
|
| 208 |
-
Generate reasoning content based on selected strategies
|
| 209 |
-
Args:
|
| 210 |
-
batch_query: List of query strings
|
| 211 |
-
batch_context: List of context strings
|
| 212 |
-
batch_strategies: List of lists of strategy names
|
| 213 |
-
batch_remaining_budget: List of remaining token budgets
|
| 214 |
Returns:
|
| 215 |
-
|
| 216 |
-
batch_num_tokens: List of number of tokens generated
|
| 217 |
-
batch_content_reprs: Tensor of content representations, shape [batch_size, hidden_size]
|
| 218 |
"""
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
prompt = random.choice(META_STRATEGIES[s])
|
| 224 |
-
batch_strategy_prompts[i].append(prompt)
|
| 225 |
-
|
| 226 |
-
batch_full_prompt.append(f"{batch_context[i]}\n{' '.join(batch_strategy_prompts[i])}\n")
|
| 227 |
-
params_list = []
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
"prompt": batch_full_prompt[i],
|
| 234 |
-
"max_tokens": current_max_tokens,
|
| 235 |
-
})
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
batch_content_reprs = self.get_text_representation(batch_generated_texts)
|
| 243 |
-
|
| 244 |
-
return batch_generated_texts, batch_num_tokens, batch_content_reprs
|
| 245 |
-
|
| 246 |
-
def dynamic_skeleton_sampling(self, queries: List[str], M: int) -> Tuple[List[MetaReasoningDAG], torch.Tensor]:
|
| 247 |
-
"""
|
| 248 |
-
Algorithm 1: Dynamic Skeleton Sampling at inference time
|
| 249 |
-
Args:
|
| 250 |
-
queries: List of input query strings
|
| 251 |
-
M: Number of trajectories per query
|
| 252 |
-
Returns:
|
| 253 |
-
batch_dags: List of generated MetaReasoningDAGs
|
| 254 |
-
total_log_probs: Tensor of total log probabilities for each trajectory
|
| 255 |
-
"""
|
| 256 |
-
# === 1. Initialize M*N DAGs ===
|
| 257 |
-
N = len(queries)
|
| 258 |
-
batch_size = N * M
|
| 259 |
-
batch_dags: List[MetaReasoningDAG] = []
|
| 260 |
-
query_reprs = self.get_text_representation(queries)
|
| 261 |
-
for i in range(N):
|
| 262 |
-
for _ in range(M):
|
| 263 |
-
batch_dags.append(
|
| 264 |
-
MetaReasoningDAG(queries[i], query_reprs[i], 0) # we don't count query tokens, set 0
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
total_log_probs = torch.zeros(batch_size).to(self.device)
|
| 268 |
-
# the idx of trajectories that are still active
|
| 269 |
-
active_indices = list(range(batch_size))
|
| 270 |
-
i = 0 # Current topology step (i=1 is the first new node)
|
| 271 |
-
while active_indices:
|
| 272 |
-
i += 1
|
| 273 |
-
sampled_strategies = {dag_idx: [] for dag_idx in active_indices}
|
| 274 |
-
incoming_edges = {dag_idx: [] for dag_idx in active_indices}
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
node_j_content_reprs,
|
| 283 |
-
sampled_strategies,
|
| 284 |
-
context_reprs
|
| 285 |
)
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
strategy_name
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
for dag_idx in active_indices.copy():
|
| 300 |
-
if not incoming_edges[dag_idx]:
|
| 301 |
-
active_indices.remove(dag_idx)
|
| 302 |
-
|
| 303 |
-
if not active_indices:
|
| 304 |
break
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
batch_query = []
|
| 310 |
-
batch_remaining_budget = []
|
| 311 |
-
for dag_idx in active_indices:
|
| 312 |
-
dag = batch_dags[dag_idx]
|
| 313 |
-
strategies = [edge[1] for edge in incoming_edges[dag_idx]]
|
| 314 |
-
batch_strategies.append(strategies)
|
| 315 |
-
context = dag.get_context_up_to(i-1)
|
| 316 |
-
batch_context.append(context)
|
| 317 |
-
batch_query.append(dag.nodes[0].content)
|
| 318 |
-
batch_remaining_budget.append(self.token_budget - dag.total_tokens())
|
| 319 |
-
|
| 320 |
-
batch_content, batch_num_tokens, batch_content_repr = self.generate_content(
|
| 321 |
-
batch_query,
|
| 322 |
-
batch_context,
|
| 323 |
-
batch_strategies,
|
| 324 |
-
batch_remaining_budget
|
| 325 |
-
)
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
torch.save({
|
| 382 |
'strategy_embeddings': self.strategy_embeddings.state_dict(),
|
| 383 |
'strategy_mlp': self.strategy_mlp.state_dict(),
|
| 384 |
'optimizer': self.optimizer.state_dict()
|
| 385 |
}, path)
|
| 386 |
-
print(f"Checkpoint saved to {path}")
|
| 387 |
-
|
| 388 |
-
def load_checkpoint(self, path: str):
|
| 389 |
-
"""Load model checkpoint"""
|
| 390 |
-
checkpoint = torch.load(path, map_location=self.device)
|
| 391 |
-
self.strategy_embeddings.load_state_dict(checkpoint['strategy_embeddings'])
|
| 392 |
-
self.strategy_mlp.load_state_dict(checkpoint['strategy_mlp'])
|
| 393 |
-
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
| 394 |
-
print(f"Checkpoint loaded from {path}")
|
| 395 |
|
| 396 |
-
def
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
max_tokens=params["max_tokens"],
|
| 402 |
-
temperature=self.config.temperature,
|
| 403 |
-
top_p=self.config.top_p,
|
| 404 |
-
)
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
import asyncio
|
| 5 |
import random
|
| 6 |
+
from typing import List, Tuple, Dict, Any
|
| 7 |
+
from openai import AsyncOpenAI
|
| 8 |
+
from tqdm.asyncio import tqdm_asyncio
|
| 9 |
+
|
| 10 |
from .config import AutoMRConfig
|
| 11 |
from .strategies import META_STRATEGIES, STRATEGY_LIST
|
| 12 |
from .dag import MetaReasoningDAG
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class StrategyMLP(nn.Module):
|
| 15 |
"""MLP for sampling meta-reasoning strategies"""
|
|
|
|
| 16 |
def __init__(self, hidden_size: int, num_strategies: int):
|
| 17 |
super().__init__()
|
| 18 |
# Input: [node_repr, strategy_repr, context_repr]
|
|
|
|
| 22 |
self.dropout = nn.Dropout(0.1)
|
| 23 |
|
| 24 |
def forward(self, node_repr, strategy_repr, context_repr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
x = torch.cat([node_repr, strategy_repr, context_repr], dim=-1)
|
| 26 |
x = F.relu(self.fc1(x))
|
| 27 |
x = self.dropout(x)
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
class AutoMR:
|
| 35 |
+
"""AutoMR Framework with Async vLLM Support"""
|
| 36 |
|
| 37 |
def __init__(self, config: AutoMRConfig):
|
| 38 |
self.config = config
|
| 39 |
self.device = config.device
|
| 40 |
self.token_budget = config.token_budget
|
|
|
|
| 41 |
|
| 42 |
+
# Concurrency control: prevent overloading the client/server
|
| 43 |
+
self.semaphore = asyncio.Semaphore(128)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
print(f"Connecting to vLLM Generator (Async)...")
|
| 46 |
+
self.generator_client = AsyncOpenAI(api_key="vllm", base_url="http://localhost:8000/v1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
print(f"Connecting to Embedder (Async)...")
|
| 49 |
+
self.embed_client = AsyncOpenAI(api_key="vllm", base_url="http://localhost:8001/v1")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
# Components for the meta-reasoning strategy network
|
| 52 |
self.num_strategies = len(STRATEGY_LIST)
|
| 53 |
hidden_size = config.hidden_size
|
| 54 |
|
| 55 |
self.strategy_embeddings = nn.Embedding(self.num_strategies, hidden_size).to(self.device)
|
| 56 |
self.strategy_mlp = StrategyMLP(hidden_size, self.num_strategies).to(self.device)
|
| 57 |
|
| 58 |
+
# Mapping tables between strategy indices and names
|
|
|
|
| 59 |
self.idx_to_strategy = {i: s for i, s in enumerate(STRATEGY_LIST)}
|
| 60 |
+
self.strategy_to_idx = {s: i for i, s in enumerate(STRATEGY_LIST)}
|
| 61 |
|
| 62 |
+
# Optimizer for strategy embeddings and MLP
|
| 63 |
self.optimizer = torch.optim.Adam(
|
| 64 |
list(self.strategy_embeddings.parameters()) +
|
| 65 |
list(self.strategy_mlp.parameters()),
|
| 66 |
lr=config.learning_rate
|
| 67 |
)
|
| 68 |
|
| 69 |
+
# Pre-allocated zero tensor to avoid repeated allocation in loops
|
| 70 |
+
self.zero_strategy_repr = torch.zeros(1, hidden_size, device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
print("AutoMR initialized successfully (Async Mode)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
async def get_text_representation(self, text: str) -> torch.Tensor:
|
| 75 |
+
"""Get the embedding vector for a single text string."""
|
| 76 |
+
if not text or not text.strip():
|
| 77 |
+
return torch.zeros(self.config.hidden_size, device=self.device, dtype=torch.float16)
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
# For simplicity we do not add full retry logic here
|
| 81 |
+
resp = await self.embed_client.embeddings.create(
|
| 82 |
+
input=text,
|
| 83 |
+
model=self.config.model_name
|
| 84 |
+
)
|
| 85 |
+
# Extract embedding and move it to GPU/device
|
| 86 |
+
return torch.tensor(resp.data[0].embedding, device=self.device, dtype=torch.float16)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Embedding Error: {e}")
|
| 89 |
+
return torch.zeros(self.config.hidden_size, device=self.device, dtype=torch.float16)
|
| 90 |
+
|
| 91 |
+
async def _generate_text(self, prompt: str, max_tokens: int) -> Tuple[str, int]:
|
| 92 |
+
"""Atomic text generation call, returning text and used completion tokens."""
|
| 93 |
+
if not prompt:
|
| 94 |
+
return "", 0
|
| 95 |
+
async with self.semaphore:
|
| 96 |
+
try:
|
| 97 |
+
resp = await self.generator_client.completions.create(
|
| 98 |
+
model=self.config.model_name,
|
| 99 |
+
prompt=prompt,
|
| 100 |
+
max_tokens=max_tokens,
|
| 101 |
+
temperature=self.config.temperature
|
| 102 |
+
)
|
| 103 |
+
text = resp.choices[0].text.strip()
|
| 104 |
+
# vLLM/OpenAI-style usage field; fallback to 0 if missing
|
| 105 |
+
used_tokens = resp.usage.completion_tokens
|
| 106 |
+
return text, int(used_tokens)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Generation Error: {e}")
|
| 109 |
+
return "", 0
|
| 110 |
+
|
| 111 |
+
def select_strategy(
|
| 112 |
+
self,
|
| 113 |
+
node_j_repr: torch.Tensor,
|
| 114 |
+
existing_strategy_indices: List[int],
|
| 115 |
+
context_repr: torch.Tensor
|
| 116 |
+
) -> Tuple[int, torch.Tensor]:
|
| 117 |
+
"""Decide whether to create an edge j->i according to Algorithm 1.
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
Args:
|
| 120 |
+
node_j_repr: Representation of candidate source node j.
|
| 121 |
+
existing_strategy_indices: Already selected k->i strategies (k > j).
|
| 122 |
+
context_repr: Global context representation.
|
|
|
|
|
|
|
|
|
|
| 123 |
"""
|
| 124 |
+
# 1. Pool existing strategies (k->i)
|
| 125 |
+
if existing_strategy_indices:
|
| 126 |
+
hist_tensor = torch.tensor(existing_strategy_indices, device=self.device, dtype=torch.long)
|
| 127 |
+
strategy_repr = self.strategy_embeddings(hist_tensor).mean(dim=0, keepdim=True)
|
| 128 |
+
else:
|
| 129 |
+
strategy_repr = self.zero_strategy_repr
|
| 130 |
|
| 131 |
+
# 2. Forward pass through MLP (add batch dimension [1, dim])
|
| 132 |
+
logits = self.strategy_mlp(
|
| 133 |
+
node_j_repr.unsqueeze(0),
|
| 134 |
+
strategy_repr,
|
| 135 |
+
context_repr.unsqueeze(0)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# 3. Sample a strategy index from the categorical distribution
|
| 139 |
+
probs = F.softmax(logits, dim=-1)
|
| 140 |
+
dist = torch.distributions.Categorical(probs)
|
| 141 |
+
idx = dist.sample()
|
| 142 |
+
log_prob = dist.log_prob(idx)
|
| 143 |
+
|
| 144 |
+
return idx.item(), log_prob
|
| 145 |
|
| 146 |
+
async def run_single_trajectory(self, query: str) -> Tuple[str, torch.Tensor]:
|
| 147 |
+
"""Run a single reasoning trajectory (Algorithm 1).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
Returns:
|
| 150 |
+
Tuple of (final_answer, total_log_prob).
|
|
|
|
|
|
|
| 151 |
"""
|
| 152 |
+
# 1. Initialization
|
| 153 |
+
q_repr = await self.get_text_representation(query)
|
| 154 |
+
dag = MetaReasoningDAG(query, q_repr, 0)
|
| 155 |
+
trajectory_log_prob = torch.tensor(0.0, device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
step_idx = 0
|
| 158 |
+
# Stopping condition: token budget exhausted or step limit (30)
|
| 159 |
+
while dag.total_tokens() < self.config.token_budget and step_idx < 30:
|
| 160 |
+
step_idx += 1
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
context_repr = dag.get_context_repr_up_to(step_idx - 1)
|
| 163 |
+
|
| 164 |
+
# === Inner loop: iterate j in reverse order (from i-1 down to 0) ===
|
| 165 |
+
strategies_k_to_i: List[int] = [] # Input to select_strategy (k > j)
|
| 166 |
+
incoming_edges_info: List[Tuple[int, str]] = [] # For prompt construction: (src_node_idx, strategy_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
for j in range(step_idx - 1, -1, -1):
|
| 169 |
+
node_j_repr = dag.get_node_content_repr(j)
|
| 170 |
+
|
| 171 |
+
# Strategy decision
|
| 172 |
+
strat_idx, log_prob = self.select_strategy(
|
| 173 |
+
node_j_repr, strategies_k_to_i, context_repr
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
+
|
| 176 |
+
# Accumulate log-probability contribution
|
| 177 |
+
trajectory_log_prob = trajectory_log_prob + log_prob
|
| 178 |
+
strategy_name = self.idx_to_strategy[strat_idx]
|
| 179 |
+
|
| 180 |
+
# If this is a non-zero (effective) strategy
|
| 181 |
+
if strategy_name != "zero":
|
| 182 |
+
strategies_k_to_i.append(strat_idx)
|
| 183 |
+
incoming_edges_info.append((j, strategy_name))
|
| 184 |
+
dag.add_edge(j, step_idx, strategy_name)
|
| 185 |
+
|
| 186 |
+
# Fallback: if not the first step and there is no incoming edge, treat as reasoning interruption
|
| 187 |
+
if not incoming_edges_info and step_idx > 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
break
|
| 189 |
|
| 190 |
+
# === Prompt construction (Algorithm 1 + Appendix A.2) ===
|
| 191 |
+
# Reverse edges back to chronological order (Step 0, Step 1...) for readability
|
| 192 |
+
incoming_edges_info.reverse()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
prompts_list = []
|
| 195 |
+
has_answer_strategy = False
|
| 196 |
+
|
| 197 |
+
for src_node_idx, s_name in incoming_edges_info:
|
| 198 |
+
if s_name == "Answer":
|
| 199 |
+
has_answer_strategy = True
|
| 200 |
+
|
| 201 |
+
# Get strategy template prompt
|
| 202 |
+
raw_strategy_prompt = random.choice(META_STRATEGIES.get(s_name, [""]))
|
| 203 |
+
|
| 204 |
+
# Apply Appendix A.2 template
|
| 205 |
+
formatted_prompt = f"Let me attend to Step {src_node_idx}, {raw_strategy_prompt}"
|
| 206 |
+
prompts_list.append(formatted_prompt)
|
| 207 |
+
|
| 208 |
+
# Concatenate all incoming-edge prompts
|
| 209 |
+
strategies_text = " ".join(prompts_list)
|
| 210 |
+
full_context = dag.get_context_up_to(step_idx - 1)
|
| 211 |
+
|
| 212 |
+
# === Generate content ===
|
| 213 |
+
if has_answer_strategy:
|
| 214 |
+
# Final answer generation
|
| 215 |
+
full_prompt = f"{full_context}\n{strategies_text}\nAnswer:\n"
|
| 216 |
+
remain_budget = max(1, self.config.token_budget - dag.total_tokens())
|
| 217 |
+
final_answer, used_tokens = await self._generate_text(full_prompt, remain_budget)
|
| 218 |
+
# Strictly accumulate used completion tokens before termination
|
| 219 |
+
if used_tokens > 0:
|
| 220 |
+
dag.add_node(final_answer, used_tokens, await self.get_text_representation(final_answer))
|
| 221 |
+
return final_answer, trajectory_log_prob
|
| 222 |
+
else:
|
| 223 |
+
# Intermediate reasoning generation
|
| 224 |
+
full_prompt = f"{full_context}\n{strategies_text}\n"
|
| 225 |
+
|
| 226 |
+
# Compute token limit for current step
|
| 227 |
+
current_remain = self.config.token_budget - dag.total_tokens()
|
| 228 |
+
step_limit = min(self.config.max_new_tokens, current_remain)
|
| 229 |
+
|
| 230 |
+
if step_limit <= 0:
|
| 231 |
+
break # Budget exhausted
|
| 232 |
+
|
| 233 |
+
content, used_tokens = await self._generate_text(full_prompt, step_limit)
|
| 234 |
+
|
| 235 |
+
# If service does not return usage, fall back to at least one token when content exists
|
| 236 |
+
if used_tokens <= 0 and content:
|
| 237 |
+
used_tokens = 1
|
| 238 |
+
|
| 239 |
+
content_repr = await self.get_text_representation(content)
|
| 240 |
+
dag.add_node(content, used_tokens, content_repr)
|
| 241 |
+
|
| 242 |
+
# Check whether a boxed answer appears earlier than expected
|
| 243 |
+
if "boxed" in content:
|
| 244 |
+
return content, trajectory_log_prob
|
| 245 |
+
|
| 246 |
+
# After loop ends: return content of last node if exists
|
| 247 |
+
if len(dag.nodes) > 0:
|
| 248 |
+
return dag.nodes[-1].content, trajectory_log_prob
|
| 249 |
+
return "", trajectory_log_prob
|
| 250 |
+
|
| 251 |
+
async def sample_batch(self, queries: List[str], M: int) -> Tuple[List[str], torch.Tensor]:
|
| 252 |
+
"""Async entry: expand B queries into B*M trajectories and run concurrently. """
|
| 253 |
+
tasks = []
|
| 254 |
+
for q in queries:
|
| 255 |
+
for _ in range(M):
|
| 256 |
+
tasks.append(self.run_single_trajectory(q))
|
| 257 |
+
|
| 258 |
+
# Show generation progress: each finished trajectory updates the bar
|
| 259 |
+
results = await tqdm_asyncio.gather(*tasks)
|
| 260 |
+
|
| 261 |
+
answers = [r[0] for r in results]
|
| 262 |
+
log_probs = torch.stack([r[1] for r in results])
|
| 263 |
+
return answers, log_probs
|
| 264 |
+
|
| 265 |
+
def sample_batch_sync(self, queries: List[str], M: int) -> Tuple[List[str], torch.Tensor]:
|
| 266 |
+
"""Synchronous wrapper for sample_batch, for use in trainer."""
|
| 267 |
+
return asyncio.run(self.sample_batch(queries, M))
|
| 268 |
+
|
| 269 |
+
# Compatibility interfaces
|
| 270 |
+
def save_checkpoint(self, path):
|
| 271 |
torch.save({
|
| 272 |
'strategy_embeddings': self.strategy_embeddings.state_dict(),
|
| 273 |
'strategy_mlp': self.strategy_mlp.state_dict(),
|
| 274 |
'optimizer': self.optimizer.state_dict()
|
| 275 |
}, path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
+
def load_checkpoint(self, path):
|
| 278 |
+
ckpt = torch.load(path, map_location=self.device)
|
| 279 |
+
self.strategy_embeddings.load_state_dict(ckpt['strategy_embeddings'])
|
| 280 |
+
self.strategy_mlp.load_state_dict(ckpt['strategy_mlp'])
|
| 281 |
+
self.optimizer.load_state_dict(ckpt['optimizer'])
|
|
|
|
|
|
|
|
|
|
|
|
automr/strategies.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
# Meta Reasoning Strategy Prompts (from Table 2 in paper)
|
| 2 |
META_STRATEGIES = {
|
| 3 |
"Next": [
|
| 4 |
-
"Next,",
|
| 5 |
-
"Then,",
|
| 6 |
"Now, let me move on to the next step."
|
| 7 |
],
|
| 8 |
"Reflect": [
|
|
|
|
| 1 |
# Meta Reasoning Strategy Prompts (from Table 2 in paper)
|
| 2 |
META_STRATEGIES = {
|
| 3 |
"Next": [
|
|
|
|
|
|
|
| 4 |
"Now, let me move on to the next step."
|
| 5 |
],
|
| 6 |
"Reflect": [
|
automr/trainer.py
CHANGED
|
@@ -2,15 +2,14 @@ import random
|
|
| 2 |
import torch
|
| 3 |
from typing import List, Dict, Tuple
|
| 4 |
from tqdm import tqdm
|
|
|
|
| 5 |
|
| 6 |
from .model import AutoMR
|
| 7 |
from .config import AutoMRConfig
|
| 8 |
from .utils import check_answer_match, ensure_dir, save_json
|
| 9 |
-
import os
|
| 10 |
-
|
| 11 |
|
| 12 |
class AutoMRTrainer:
|
| 13 |
-
"""Trainer for AutoMR using REINFORCE (
|
| 14 |
|
| 15 |
def __init__(self, model: AutoMR, config: AutoMRConfig):
|
| 16 |
self.model = model
|
|
@@ -21,6 +20,9 @@ class AutoMRTrainer:
|
|
| 21 |
self.global_step = 0
|
| 22 |
self.best_val_reward = -float('inf')
|
| 23 |
self.patience_counter = 0
|
|
|
|
|
|
|
|
|
|
| 24 |
self.training_history = {
|
| 25 |
'train_loss': [],
|
| 26 |
'train_reward': [],
|
|
@@ -31,7 +33,7 @@ class AutoMRTrainer:
|
|
| 31 |
|
| 32 |
def compute_reward_batch(self, queries: List[str], answers: List[str]) -> Tuple[float, float]:
|
| 33 |
"""
|
| 34 |
-
Compute average reward and accuracy on a batch
|
| 35 |
Returns: (avg_reward, accuracy)
|
| 36 |
"""
|
| 37 |
total_reward = 0.0
|
|
@@ -42,16 +44,16 @@ class AutoMRTrainer:
|
|
| 42 |
self.model.strategy_embeddings.eval()
|
| 43 |
|
| 44 |
with torch.no_grad():
|
| 45 |
-
|
| 46 |
-
pred_answers = self.model.
|
| 47 |
-
|
| 48 |
-
is_correct = [check_answer_match(
|
| 49 |
-
pred_answer, answer, self.config.task_type
|
| 50 |
-
) for pred_answer, answer in zip(pred_answers, answers)]
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
avg_reward = total_reward / total if total > 0 else 0.0
|
| 57 |
accuracy = correct / total if total > 0 else 0.0
|
|
@@ -76,37 +78,48 @@ class AutoMRTrainer:
|
|
| 76 |
|
| 77 |
def train_step(self, batch_queries: List[str], batch_answers: List[str]) -> Tuple[float, float]:
|
| 78 |
"""
|
| 79 |
-
Single training step using REINFORCE
|
| 80 |
Returns: (loss, avg_reward)
|
| 81 |
"""
|
| 82 |
self.model.strategy_mlp.train()
|
| 83 |
self.model.strategy_embeddings.train()
|
| 84 |
|
| 85 |
M = self.config.num_samples_per_query
|
| 86 |
-
|
| 87 |
rewards_list = []
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
reward = 1.0 if check_answer_match(
|
| 96 |
pred_answer, answer, self.config.task_type
|
| 97 |
) else -1.0
|
|
|
|
| 98 |
rewards_list.append(reward)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
loss.append(-reward * log_prob)
|
| 102 |
-
|
| 103 |
-
# Compute average reward for this batch
|
| 104 |
avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
# Update parameters
|
| 107 |
self.model.optimizer.zero_grad()
|
| 108 |
-
|
| 109 |
-
|
| 110 |
loss.backward()
|
| 111 |
torch.nn.utils.clip_grad_norm_(
|
| 112 |
list(self.model.strategy_embeddings.parameters()) +
|
|
@@ -114,7 +127,6 @@ class AutoMRTrainer:
|
|
| 114 |
max_norm=self.config.gradient_clip
|
| 115 |
)
|
| 116 |
self.model.optimizer.step()
|
| 117 |
-
|
| 118 |
return loss.item(), avg_reward
|
| 119 |
|
| 120 |
def should_stop_early(self) -> bool:
|
|
@@ -145,7 +157,7 @@ class AutoMRTrainer:
|
|
| 145 |
print(f" 💾 Best checkpoint saved: {checkpoint_path}")
|
| 146 |
|
| 147 |
def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
|
| 148 |
-
"""Training loop with validation
|
| 149 |
print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
|
| 150 |
print(f"Training samples: {len(train_data)}")
|
| 151 |
print(f"Validation samples: {len(val_data)}")
|
|
@@ -160,18 +172,21 @@ class AutoMRTrainer:
|
|
| 160 |
epoch_reward = 0.0
|
| 161 |
num_batches = 0
|
| 162 |
|
|
|
|
|
|
|
| 163 |
pbar = tqdm(
|
| 164 |
-
|
| 165 |
desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
|
| 166 |
)
|
| 167 |
|
| 168 |
for i in pbar:
|
| 169 |
-
batch = train_data[i:i+self.config.batch_size]
|
| 170 |
batch_queries = [item['query'] for item in batch]
|
| 171 |
batch_answers = [item['answer'] for item in batch]
|
| 172 |
|
| 173 |
-
# Training step
|
| 174 |
loss, avg_reward = self.train_step(batch_queries, batch_answers)
|
|
|
|
| 175 |
epoch_loss += loss
|
| 176 |
epoch_reward += avg_reward
|
| 177 |
num_batches += 1
|
|
@@ -221,9 +236,9 @@ class AutoMRTrainer:
|
|
| 221 |
print(f"Best validation reward: {self.best_val_reward:.4f}")
|
| 222 |
return
|
| 223 |
|
| 224 |
-
# End of epoch
|
| 225 |
-
avg_epoch_loss = epoch_loss / num_batches
|
| 226 |
-
avg_epoch_reward = epoch_reward / num_batches
|
| 227 |
|
| 228 |
self.training_history['train_loss'].append(avg_epoch_loss)
|
| 229 |
self.training_history['train_reward'].append(avg_epoch_reward)
|
|
@@ -236,7 +251,7 @@ class AutoMRTrainer:
|
|
| 236 |
print(f"Best Val Reward: {self.best_val_reward:.4f}")
|
| 237 |
print(f"{'='*80}\n")
|
| 238 |
|
| 239 |
-
# Save checkpoint at end of epoch
|
| 240 |
if not self.config.save_best_only:
|
| 241 |
self.save_checkpoint(epoch + 1)
|
| 242 |
|
|
|
|
| 2 |
import torch
|
| 3 |
from typing import List, Dict, Tuple
|
| 4 |
from tqdm import tqdm
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
from .model import AutoMR
|
| 8 |
from .config import AutoMRConfig
|
| 9 |
from .utils import check_answer_match, ensure_dir, save_json
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class AutoMRTrainer:
|
| 12 |
+
"""Trainer for AutoMR using REINFORCE (sync trainer, async model calls)"""
|
| 13 |
|
| 14 |
def __init__(self, model: AutoMR, config: AutoMRConfig):
|
| 15 |
self.model = model
|
|
|
|
| 20 |
self.global_step = 0
|
| 21 |
self.best_val_reward = -float('inf')
|
| 22 |
self.patience_counter = 0
|
| 23 |
+
# Sliding-window baseline for REINFORCE advantage (variance reduction)
|
| 24 |
+
self.baseline = self.config.initial_baseline
|
| 25 |
+
self.baseline_momentum = self.config.baseline_momentum
|
| 26 |
self.training_history = {
|
| 27 |
'train_loss': [],
|
| 28 |
'train_reward': [],
|
|
|
|
| 33 |
|
| 34 |
def compute_reward_batch(self, queries: List[str], answers: List[str]) -> Tuple[float, float]:
|
| 35 |
"""
|
| 36 |
+
Compute average reward and accuracy on a batch (Async)
|
| 37 |
Returns: (avg_reward, accuracy)
|
| 38 |
"""
|
| 39 |
total_reward = 0.0
|
|
|
|
| 44 |
self.model.strategy_embeddings.eval()
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
+
# M=1 for evaluation/validation; call async model via sync wrapper
|
| 48 |
+
pred_answers, _ = self.model.sample_batch_sync(queries, M=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
for pred_answer, answer in zip(pred_answers, answers):
|
| 51 |
+
is_correct = check_answer_match(pred_answer, answer, self.config.task_type)
|
| 52 |
+
if is_correct:
|
| 53 |
+
correct += 1
|
| 54 |
+
total_reward += 1.0
|
| 55 |
+
else:
|
| 56 |
+
total_reward += -1.0
|
| 57 |
|
| 58 |
avg_reward = total_reward / total if total > 0 else 0.0
|
| 59 |
accuracy = correct / total if total > 0 else 0.0
|
|
|
|
| 78 |
|
| 79 |
def train_step(self, batch_queries: List[str], batch_answers: List[str]) -> Tuple[float, float]:
|
| 80 |
"""
|
| 81 |
+
Single training step using REINFORCE
|
| 82 |
Returns: (loss, avg_reward)
|
| 83 |
"""
|
| 84 |
self.model.strategy_mlp.train()
|
| 85 |
self.model.strategy_embeddings.train()
|
| 86 |
|
| 87 |
M = self.config.num_samples_per_query
|
| 88 |
+
loss_list = []
|
| 89 |
rewards_list = []
|
| 90 |
+
|
| 91 |
+
# pred_answers: [B*M], log_probs: [B*M]; sync wrapper over async model
|
| 92 |
+
pred_answers, log_probs = self.model.sample_batch_sync(batch_queries, M)
|
| 93 |
+
|
| 94 |
+
# 2. Expand answers for comparison
|
| 95 |
+
expanded_answers = [answer for answer in batch_answers for _ in range(M)]
|
| 96 |
+
|
| 97 |
+
# 3. Compute Reward & Loss
|
| 98 |
+
for pred_answer, answer, log_prob in zip(pred_answers, expanded_answers, log_probs):
|
| 99 |
reward = 1.0 if check_answer_match(
|
| 100 |
pred_answer, answer, self.config.task_type
|
| 101 |
) else -1.0
|
| 102 |
+
|
| 103 |
rewards_list.append(reward)
|
| 104 |
+
|
| 105 |
+
# Compute batch average reward
|
|
|
|
|
|
|
|
|
|
| 106 |
avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
|
| 107 |
+
|
| 108 |
+
# Update sliding baseline: exponential moving average
|
| 109 |
+
self.baseline = (
|
| 110 |
+
self.baseline_momentum * self.baseline
|
| 111 |
+
+ (1.0 - self.baseline_momentum) * avg_reward
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Policy Gradient with advantage: -(reward - baseline) * log_prob
|
| 115 |
+
for reward, log_prob in zip(rewards_list, log_probs):
|
| 116 |
+
advantage = reward - self.baseline
|
| 117 |
+
loss_list.append(-advantage * log_prob)
|
| 118 |
|
| 119 |
+
# 4. Update parameters
|
| 120 |
self.model.optimizer.zero_grad()
|
| 121 |
+
|
| 122 |
+
loss = torch.stack(loss_list).mean()
|
| 123 |
loss.backward()
|
| 124 |
torch.nn.utils.clip_grad_norm_(
|
| 125 |
list(self.model.strategy_embeddings.parameters()) +
|
|
|
|
| 127 |
max_norm=self.config.gradient_clip
|
| 128 |
)
|
| 129 |
self.model.optimizer.step()
|
|
|
|
| 130 |
return loss.item(), avg_reward
|
| 131 |
|
| 132 |
def should_stop_early(self) -> bool:
|
|
|
|
| 157 |
print(f" 💾 Best checkpoint saved: {checkpoint_path}")
|
| 158 |
|
| 159 |
def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
|
| 160 |
+
"""Training loop with validation"""
|
| 161 |
print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
|
| 162 |
print(f"Training samples: {len(train_data)}")
|
| 163 |
print(f"Validation samples: {len(val_data)}")
|
|
|
|
| 172 |
epoch_reward = 0.0
|
| 173 |
num_batches = 0
|
| 174 |
|
| 175 |
+
batch_indices = list(range(0, len(train_data), self.config.batch_size))
|
| 176 |
+
|
| 177 |
pbar = tqdm(
|
| 178 |
+
batch_indices,
|
| 179 |
desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
|
| 180 |
)
|
| 181 |
|
| 182 |
for i in pbar:
|
| 183 |
+
batch = train_data[i : i + self.config.batch_size]
|
| 184 |
batch_queries = [item['query'] for item in batch]
|
| 185 |
batch_answers = [item['answer'] for item in batch]
|
| 186 |
|
| 187 |
+
# Training step (sync)
|
| 188 |
loss, avg_reward = self.train_step(batch_queries, batch_answers)
|
| 189 |
+
|
| 190 |
epoch_loss += loss
|
| 191 |
epoch_reward += avg_reward
|
| 192 |
num_batches += 1
|
|
|
|
| 236 |
print(f"Best validation reward: {self.best_val_reward:.4f}")
|
| 237 |
return
|
| 238 |
|
| 239 |
+
# End of epoch summary
|
| 240 |
+
avg_epoch_loss = epoch_loss / max(num_batches, 1)
|
| 241 |
+
avg_epoch_reward = epoch_reward / max(num_batches, 1)
|
| 242 |
|
| 243 |
self.training_history['train_loss'].append(avg_epoch_loss)
|
| 244 |
self.training_history['train_reward'].append(avg_epoch_reward)
|
|
|
|
| 251 |
print(f"Best Val Reward: {self.best_val_reward:.4f}")
|
| 252 |
print(f"{'='*80}\n")
|
| 253 |
|
| 254 |
+
# Save checkpoint at end of epoch
|
| 255 |
if not self.config.save_best_only:
|
| 256 |
self.save_checkpoint(epoch + 1)
|
| 257 |
|