haifei commited on
Commit
1482463
·
1 Parent(s): 5168b2e

code and checkpoint

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/MATH/pangu/best_checkpoint.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ processed_data/
2
+ checkpoints/
3
+ results/
4
+ automr/__pycache__/
AMC.SH ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ASCEND_RT_VISIBLE_DEVICES=4 python main.py --mode eval \
2
+ --device npu \
3
+ --model_name FreedomIntelligence/openPangu-Embedded-7B \
4
+ --test_data processed_data/AMC/test.jsonl \
5
+ --load_checkpoint checkpoints/MATH/pangu/best_checkpoint.pt \
6
+ --task_type math \
7
+ --results_dir results/AMC/ \
8
+ --token_budget 4096 \
automr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .model import AutoMR
2
+ from .trainer import AutoMRTrainer
3
+ from .evaluator import AutoMREvaluator
4
+ from .config import AutoMRConfig
5
+
6
+ __version__ = "1.0.0"
7
+ __all__ = ["AutoMR", "AutoMRTrainer", "AutoMREvaluator", "AutoMRConfig"]
automr/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class AutoMRConfig:
7
+ """Configuration for AutoMR framework"""
8
+
9
+ # Model settings
10
+ model_name: str = "FreedomIntelligence/openPangu-Embedded-7B"
11
+ device: str = "npu"
12
+ token_budget: int = 4096
13
+ hidden_size: int = 4096 # To be set according to the model used
14
+
15
+ # Training settings
16
+ learning_rate: float = 5e-4
17
+ num_epochs: int = 5
18
+ batch_size: int = 8
19
+ num_samples_per_query: int = 4 # M in paper
20
+ gradient_clip: float = 1.0
21
+
22
+ # Validation settings
23
+ val_every_n_steps: int = 100 # Alpha in the requirement - validate every N steps
24
+ val_batch_size: int = 10 # Number of validation samples to evaluate
25
+ early_stopping_patience: int = 5 # Stop if no improvement for N validations
26
+
27
+ # Generation settings
28
+ max_new_tokens: int = 1024
29
+ temperature: float = 0.01
30
+ top_p: float = 0.9
31
+
32
+ # Paths
33
+ train_data_path: str = "data/train.json"
34
+ val_data_path: str = "data/val.json"
35
+ test_data_path: str = "data/test.json"
36
+ checkpoint_dir: str = "checkpoints"
37
+ results_dir: str = "results"
38
+
39
+ # Evaluation settings
40
+ save_predictions: bool = True
41
+ save_skeletons: bool = True
42
+ save_best_only: bool = True
43
+
44
+ # Task type
45
+ task_type: str = "math" # "math" or "multiple_choice"
automr/dag.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class ReasoningNode:
8
+ """Represents a node in the meta-reasoning DAG"""
9
+ index: int
10
+ content: str
11
+ num_tokens: int
12
+ content_repr: torch.Tensor
13
+
14
+
15
+ @dataclass
16
+ class ReasoningEdge:
17
+ """Represents an edge in the meta-reasoning DAG"""
18
+ from_node: int
19
+ to_node: int
20
+ strategy: str
21
+
22
+
23
+ class MetaReasoningDAG:
24
+ """Represents the meta-reasoning skeleton as a DAG"""
25
+
26
+ def __init__(self, query: str, query_repr: torch.Tensor, query_num_tokens: int):
27
+ self.nodes: List[ReasoningNode] = [ReasoningNode(0, query, query_num_tokens, query_repr)]
28
+ self.edges: List[ReasoningEdge] = []
29
+ self.current_index = 0
30
+
31
+ def add_node(self, content: str, num_tokens: int, content_repr: torch.Tensor) -> int:
32
+ """Add a new node to the DAG"""
33
+ self.current_index += 1
34
+ node = ReasoningNode(self.current_index, content, num_tokens, content_repr)
35
+ self.nodes.append(node)
36
+ return self.current_index
37
+
38
+ def add_edge(self, from_idx: int, to_idx: int, strategy: str):
39
+ """Add an edge between two nodes"""
40
+ self.edges.append(ReasoningEdge(from_idx, to_idx, strategy))
41
+
42
+ def get_node_content_repr(self, idx: int) -> torch.Tensor:
43
+ """Get content representation of a specific node"""
44
+ return self.nodes[idx].content_repr
45
+
46
+ def get_node_content(self, idx: int) -> str:
47
+ """Get content of a specific node"""
48
+ return self.nodes[idx].content
49
+
50
+ def get_context_repr_up_to(self, idx: int) -> torch.Tensor:
51
+ """Get context representation up to index idx"""
52
+ context_repr = self.nodes[0].content_repr
53
+ for node in self.nodes[1: idx+1]:
54
+ context_repr += node.content_repr
55
+ context_repr /= (idx + 1)
56
+ return context_repr
57
+
58
+ def get_context_up_to(self, idx: int) -> str:
59
+ """Get all node contents up to index idx"""
60
+ return "\n".join([node.content for node in self.nodes[:idx+1]])
61
+
62
+ def total_tokens(self) -> int:
63
+ """Total tokens generated (excluding source node)"""
64
+ return sum(node.num_tokens for node in self.nodes[1:])
65
+
66
+ def to_dict(self) -> dict:
67
+ """Convert DAG to dictionary for serialization"""
68
+ return {
69
+ "nodes": [
70
+ {"index": n.index, "content": n.content, "num_tokens": n.num_tokens}
71
+ for n in self.nodes
72
+ ],
73
+ "edges": [
74
+ {"from": e.from_node, "to": e.to_node, "strategy": e.strategy}
75
+ for e in self.edges
76
+ ]
77
+ }
automr/data_loader.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Dict
3
+
4
+
5
+ class DataLoader:
6
+ """Load and process datasets"""
7
+
8
+ @staticmethod
9
+ def load_data(file_path: str) -> List[Dict[str, str]]:
10
+ """
11
+ Load data from JSON file
12
+ Expected format: [{"query": "...", "answer": "..."}, ...]
13
+ """
14
+ with open(file_path, 'r') as f:
15
+ if file_path.endswith('.jsonl'):
16
+ data = [json.loads(line) for line in f]
17
+ else:
18
+ data = json.load(f)
19
+
20
+ # Validate data format
21
+ for item in data:
22
+ if 'query' not in item or 'answer' not in item:
23
+ raise ValueError("Each data item must have 'query' and 'answer' fields")
24
+
25
+ return data
26
+
27
+ @staticmethod
28
+ def load_math_dataset(file_path: str) -> List[Dict[str, str]]:
29
+ """Load MATH or GSM8K format dataset"""
30
+ return DataLoader.load_data(file_path)
31
+
32
+ @staticmethod
33
+ def load_mmlu_dataset(file_path: str) -> List[Dict[str, str]]:
34
+ """Load MMLU-Pro format dataset"""
35
+ return DataLoader.load_data(file_path)
automr/evaluator.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple
2
+ from tqdm import tqdm
3
+ import os
4
+
5
+ from .model import AutoMR
6
+ from .config import AutoMRConfig
7
+ from .utils import check_answer_match, save_json, ensure_dir
8
+
9
+
10
+ class AutoMREvaluator:
11
+ """Evaluator for AutoMR"""
12
+
13
+ def __init__(self, model: AutoMR, config: AutoMRConfig):
14
+ self.model = model
15
+ self.config = config
16
+ ensure_dir(config.results_dir)
17
+
18
+ def evaluate(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
19
+ """
20
+ Evaluate model on test data
21
+ Returns: (accuracy, detailed_results)
22
+ """
23
+ print(f"\nEvaluating on {len(test_data)} samples...")
24
+
25
+ self.model.strategy_mlp.eval()
26
+ self.model.strategy_embeddings.eval()
27
+
28
+ correct = 0
29
+ total = 0
30
+ detailed_results = []
31
+ batch_size = self.config.batch_size
32
+ pbar = tqdm(
33
+ range(0, len(test_data), batch_size), desc="Evaluating"
34
+ )
35
+
36
+ # for item in tqdm(test_data, desc="Evaluating"):
37
+ for i in pbar:
38
+ batch = test_data[i:i + batch_size]
39
+ queries = [item['query'] for item in batch]
40
+ ground_truths = [item['answer'] for item in batch]
41
+
42
+ # Run inference
43
+ pred_answers, dags = self.model.inference(queries,M=1)
44
+
45
+ for query, ground_truth, pred_answer, dag in zip(queries, ground_truths, pred_answers, dags):
46
+ # Check correctness
47
+ is_correct = check_answer_match(
48
+ pred_answer,
49
+ ground_truth,
50
+ self.config.task_type
51
+ )
52
+
53
+ if is_correct:
54
+ correct += 1
55
+ total += 1
56
+
57
+ pbar.set_postfix({
58
+ 'Acc': f'{correct} / {total}',
59
+ })
60
+
61
+ # Store detailed result
62
+ # result = {
63
+ # 'query': query,
64
+ # 'ground_truth': ground_truth,
65
+ # 'prediction': pred_answer,
66
+ # 'correct': is_correct
67
+ # }
68
+
69
+ # if self.config.save_skeletons:
70
+ # result['skeleton'] = dag.to_dict()
71
+
72
+ # detailed_results.append(result)
73
+
74
+ accuracy = correct / total if total > 0 else 0.0
75
+
76
+ print(f"\nEvaluation Results:")
77
+ print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
78
+
79
+ # Save results
80
+ if self.config.save_predictions:
81
+ results_path = os.path.join(
82
+ self.config.results_dir,
83
+ 'evaluation_results.json'
84
+ )
85
+ save_json({
86
+ 'accuracy': accuracy,
87
+ 'correct': correct,
88
+ 'total': total,
89
+ 'detailed_results': detailed_results
90
+ }, results_path)
91
+ print(f"Results saved to {results_path}")
92
+
93
+ return accuracy, detailed_results
automr/model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import List, Tuple
6
+ import random
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from vllm import LLM
9
+ from vllm import SamplingParams
10
+ from .config import AutoMRConfig
11
+ from .strategies import META_STRATEGIES, STRATEGY_LIST
12
+ from .dag import MetaReasoningDAG
13
+ from .utils import extract_answer
14
+ from typing import Dict
15
+ from openai import OpenAI
16
+
17
+ class StrategyMLP(nn.Module):
18
+ """MLP for sampling meta-reasoning strategies"""
19
+
20
+ def __init__(self, hidden_size: int, num_strategies: int):
21
+ super().__init__()
22
+ # Input: [node_repr, strategy_repr, context_repr]
23
+ self.fc1 = nn.Linear(hidden_size * 3, hidden_size * 2)
24
+ self.fc2 = nn.Linear(hidden_size * 2, hidden_size)
25
+ self.fc3 = nn.Linear(hidden_size, num_strategies)
26
+ self.dropout = nn.Dropout(0.1)
27
+
28
+ def forward(self, node_repr, strategy_repr, context_repr):
29
+ """
30
+ Args:
31
+ node_repr: [batch, hidden_size]
32
+ strategy_repr: [batch, hidden_size]
33
+ context_repr: [batch, hidden_size]
34
+ Returns:
35
+ logits: [batch, num_strategies]
36
+ """
37
+ x = torch.cat([node_repr, strategy_repr, context_repr], dim=-1)
38
+ x = F.relu(self.fc1(x))
39
+ x = self.dropout(x)
40
+ x = F.relu(self.fc2(x))
41
+ x = self.dropout(x)
42
+ logits = self.fc3(x)
43
+ return logits
44
+
45
+
46
+ class AutoMR:
47
+ """AutoMR Framework for Meta-Reasoning Skeleton Search"""
48
+
49
+ def __init__(self, config: AutoMRConfig):
50
+ self.config = config
51
+ self.device = config.device
52
+ self.token_budget = config.token_budget
53
+ self.model_name_for_api = config.model_name
54
+
55
+ # # Load LLM
56
+ # print(f"Loading Tokenizer and Config: {config.model_name}")
57
+ # self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
58
+
59
+ # print(f"Loading vLLM generator: {config.model_name}")
60
+ # self.llm = LLM(
61
+ # config.model_name,
62
+ # dtype=torch.float16,
63
+ # trust_remote_code=True,
64
+ # tensor_parallel_size=config.tensor_parallel_size,
65
+ # gpu_memory_utilization="0.8"
66
+ # )
67
+
68
+ # print(f"Loading LLM Embbedder: {config.model_name}")
69
+ # self.llm_embedder = AutoModelForCausalLM.from_pretrained(
70
+ # config.model_name,
71
+ # torch_dtype=torch.float16,
72
+ # trust_remote_code=True,
73
+ # device_map=None
74
+ # ).to(self.device)
75
+ # self.llm_embedder.eval()
76
+
77
+ # print(f"Loading vLLM Embbedder: {config.model_name}")
78
+ # self.llm_embedder = LLM(
79
+ # config.model_name,
80
+ # dtype=torch.float16,
81
+ # trust_remote_code=True,
82
+ # tensor_parallel_size=config.tensor_parallel_size,
83
+ # gpu_memory_utilization="0.8"
84
+ # task="embed",
85
+ # )
86
+
87
+ print("Connecting to vLLM Generator Server (port 8000)...")
88
+ self.generator_client = OpenAI(
89
+ api_key="vllm",
90
+ base_url="http://localhost:8000/v1"
91
+ )
92
+
93
+ print("Connecting to Custom Embedder Server (port 8001)...")
94
+ self.embed_client = OpenAI(
95
+ api_key="vllm",
96
+ base_url="http://localhost:8001/v1"
97
+ )
98
+
99
+
100
+ # Strategy components
101
+ self.num_strategies = len(STRATEGY_LIST)
102
+ hidden_size = config.hidden_size
103
+
104
+ self.strategy_embeddings = nn.Embedding(self.num_strategies, hidden_size).to(self.device)
105
+ self.strategy_mlp = StrategyMLP(hidden_size, self.num_strategies).to(self.device)
106
+
107
+ # Strategy mappings
108
+ self.strategy_to_idx = {s: i for i, s in enumerate(STRATEGY_LIST)}
109
+ self.idx_to_strategy = {i: s for i, s in enumerate(STRATEGY_LIST)}
110
+
111
+ # Optimizer
112
+ self.optimizer = torch.optim.Adam(
113
+ list(self.strategy_embeddings.parameters()) +
114
+ list(self.strategy_mlp.parameters()),
115
+ lr=config.learning_rate
116
+ )
117
+
118
+ print("AutoMR initialized successfully")
119
+
120
+ def get_text_representation(self, texts: List[str]) -> Tuple[torch.Tensor]:
121
+ """
122
+ Get pooled hidden state representations for texts in batch using LLM embedder.
123
+ Args:
124
+ texts: List of input texts
125
+ Returns:
126
+ pooled: Tensor of shape [batch_size, hidden_size]
127
+ """
128
+
129
+ # self.tokenizer.padding_side = "left"
130
+ # inputs = self.tokenizer(
131
+ # texts,
132
+ # return_tensors="pt",
133
+ # padding=True,
134
+ # truncation=True,
135
+ # ).to(self.device)
136
+
137
+ # with torch.no_grad():
138
+ # outputs = self.llm(**inputs, output_hidden_states=True)
139
+ # hidden_states = outputs.hidden_states[-1] # [bsz, len, dim]
140
+ # pooled = hidden_states[:, -1, :]
141
+
142
+ # batch_outputs = self.llm_embedder.encode(texts)
143
+ # pooled = []
144
+ # for outputs in batch_outputs:
145
+ # last_hidden_state = outputs.outputs.data[-1,:] # [seq_len, hidden_size]
146
+ # pooled.append(last_hidden_state)
147
+ # pooled = torch.stack(pooled, dim=0).to(self.device) # [batch_size, hidden_size]
148
+
149
+ batch_outputs = self.embed_client.embeddings.create(
150
+ input=texts,
151
+ model=self.model_name_for_api
152
+ )
153
+
154
+ batch_reprs = [
155
+ torch.tensor(data.embedding, device=self.device, dtype=torch.float16)
156
+ for data in batch_outputs.data
157
+ ]
158
+
159
+ pooled = torch.stack(batch_reprs, dim=0) # [batch_size, hidden_size]
160
+ return pooled
161
+
162
+ def sample_strategy(
163
+ self,
164
+ batch_node_content_repr: torch.Tensor,
165
+ batch_sampled_strategies: Dict[int, List[int]],
166
+ batch_context_repr: torch.Tensor
167
+ ) -> Tuple[List[int], torch.Tensor]:
168
+ """
169
+ Sample a strategy for each edge (j, i) in batch
170
+ Args:
171
+ batch_node_content_repr: Tensor of shape [batch_size, hidden_size]
172
+ batch_sampled_strategies: Dict of lists of sampled strategy indices
173
+ batch_context_repr: Tensor of shape [batch_size, hidden_size]
174
+ Returns:
175
+ batch_strategy_idx: List of sampled strategy indices
176
+ batch_log_prob: Tensor of log probabilities, shape [batch_size]
177
+ """
178
+
179
+
180
+ batch_strategy_repr = []
181
+
182
+ for sampled_strategies in batch_sampled_strategies.values():
183
+ if sampled_strategies:
184
+ sampled_strategies = torch.tensor(sampled_strategies).to(self.device)
185
+ strategy_repr = self.strategy_embeddings(sampled_strategies).mean(dim=0, keepdim=True)
186
+ else:
187
+ strategy_repr = torch.zeros(1, self.config.hidden_size).to(self.device)
188
+ batch_strategy_repr.append(strategy_repr)
189
+
190
+ batch_strategy_repr = torch.cat(batch_strategy_repr, dim=0) # Combine all batch representations
191
+ batch_logits = self.strategy_mlp(batch_node_content_repr, batch_strategy_repr, batch_context_repr)
192
+ batch_probs = F.softmax(batch_logits, dim=-1)
193
+
194
+ dist = torch.distributions.Categorical(batch_probs)
195
+ batch_strategy_idx = dist.sample()
196
+ batch_log_prob = dist.log_prob(batch_strategy_idx).to(self.device)
197
+
198
+ return batch_strategy_idx.cpu().tolist(), batch_log_prob
199
+
200
+ def generate_content(
201
+ self,
202
+ batch_query: List[str],
203
+ batch_context: List[str],
204
+ batch_strategies: List[List[str]],
205
+ batch_remaining_budget: List[int]
206
+ ) -> Tuple[List[str], List[int], torch.Tensor]:
207
+ """
208
+ Generate reasoning content based on selected strategies
209
+ Args:
210
+ batch_query: List of query strings
211
+ batch_context: List of context strings
212
+ batch_strategies: List of lists of strategy names
213
+ batch_remaining_budget: List of remaining token budgets
214
+ Returns:
215
+ batch_generated_texts: List of generated content strings
216
+ batch_num_tokens: List of number of tokens generated
217
+ batch_content_reprs: Tensor of content representations, shape [batch_size, hidden_size]
218
+ """
219
+ batch_strategy_prompts = [[] for _ in batch_query]
220
+ batch_full_prompt: List[str] = []
221
+ for i, strategies in enumerate(batch_strategies):
222
+ for s in strategies:
223
+ prompt = random.choice(META_STRATEGIES[s])
224
+ batch_strategy_prompts[i].append(prompt)
225
+
226
+ batch_full_prompt.append(f"{batch_context[i]}\n{' '.join(batch_strategy_prompts[i])}\n")
227
+ params_list = []
228
+
229
+ for i in range(len(batch_query)):
230
+ remaining_budget = batch_remaining_budget[i]
231
+ current_max_tokens = min(self.config.max_new_tokens, remaining_budget)
232
+ params_list.append({
233
+ "prompt": batch_full_prompt[i],
234
+ "max_tokens": current_max_tokens,
235
+ })
236
+
237
+ with ThreadPoolExecutor(max_workers=None) as executor:
238
+ batch_outputs = list(executor.map(self.make_api_call, params_list))
239
+
240
+ batch_generated_texts = [output.choices[0].text.strip() for output in batch_outputs]
241
+ batch_num_tokens = [output.usage.completion_tokens for output in batch_outputs]
242
+ batch_content_reprs = self.get_text_representation(batch_generated_texts)
243
+
244
+ return batch_generated_texts, batch_num_tokens, batch_content_reprs
245
+
246
+ def dynamic_skeleton_sampling(self, queries: List[str], M: int) -> Tuple[List[MetaReasoningDAG], torch.Tensor]:
247
+ """
248
+ Algorithm 1: Dynamic Skeleton Sampling at inference time
249
+ Args:
250
+ queries: List of input query strings
251
+ M: Number of trajectories per query
252
+ Returns:
253
+ batch_dags: List of generated MetaReasoningDAGs
254
+ total_log_probs: Tensor of total log probabilities for each trajectory
255
+ """
256
+ # === 1. Initialize M*N DAGs ===
257
+ N = len(queries)
258
+ batch_size = N * M
259
+ batch_dags: List[MetaReasoningDAG] = []
260
+ query_reprs = self.get_text_representation(queries)
261
+ for i in range(N):
262
+ for _ in range(M):
263
+ batch_dags.append(
264
+ MetaReasoningDAG(queries[i], query_reprs[i], 0) # we don't count query tokens, set 0
265
+ )
266
+
267
+ total_log_probs = torch.zeros(batch_size).to(self.device)
268
+ # the idx of trajectories that are still active
269
+ active_indices = list(range(batch_size))
270
+ i = 0 # Current topology step (i=1 is the first new node)
271
+ while active_indices:
272
+ i += 1
273
+ sampled_strategies = {dag_idx: [] for dag_idx in active_indices}
274
+ incoming_edges = {dag_idx: [] for dag_idx in active_indices}
275
+
276
+ # Step 1: Determine incoming edges (traverse in reverse order)
277
+ for j in range(i-1, -1, -1):
278
+ node_j_content_reprs = torch.stack([batch_dags[idx].get_node_content_repr(j) for idx in active_indices], dim=0)
279
+ context_reprs = torch.stack([batch_dags[idx].get_context_repr_up_to(i-1) for idx in active_indices], dim=0)
280
+
281
+ strategy_idx, log_prob = self.sample_strategy(
282
+ node_j_content_reprs,
283
+ sampled_strategies,
284
+ context_reprs
285
+ )
286
+
287
+ for k, dag_idx in enumerate(active_indices):
288
+ sampled_strategies[dag_idx].append(strategy_idx[k])
289
+
290
+ total_log_probs[active_indices] += log_prob
291
+
292
+ for dag_idx in active_indices:
293
+ strategy_idx = sampled_strategies[dag_idx][-1]
294
+ strategy_name = self.idx_to_strategy[strategy_idx]
295
+ if strategy_name != "zero":
296
+ incoming_edges[dag_idx].append((j, strategy_name))
297
+
298
+ # Step 2: Check which DAGs are still active
299
+ for dag_idx in active_indices.copy():
300
+ if not incoming_edges[dag_idx]:
301
+ active_indices.remove(dag_idx)
302
+
303
+ if not active_indices:
304
+ break
305
+
306
+ # Step 3: Generate base reasoning content
307
+ batch_strategies = []
308
+ batch_context = []
309
+ batch_query = []
310
+ batch_remaining_budget = []
311
+ for dag_idx in active_indices:
312
+ dag = batch_dags[dag_idx]
313
+ strategies = [edge[1] for edge in incoming_edges[dag_idx]]
314
+ batch_strategies.append(strategies)
315
+ context = dag.get_context_up_to(i-1)
316
+ batch_context.append(context)
317
+ batch_query.append(dag.nodes[0].content)
318
+ batch_remaining_budget.append(self.token_budget - dag.total_tokens())
319
+
320
+ batch_content, batch_num_tokens, batch_content_repr = self.generate_content(
321
+ batch_query,
322
+ batch_context,
323
+ batch_strategies,
324
+ batch_remaining_budget
325
+ )
326
+
327
+ # Step 4: Update DAGs with new nodes and edges
328
+ for k, dag_idx in enumerate(active_indices):
329
+ dag = batch_dags[dag_idx]
330
+ content = batch_content[k]
331
+ num_tokens = batch_num_tokens[k]
332
+ content_repr = batch_content_repr[k]
333
+ dag.add_node(content, num_tokens, content_repr)
334
+
335
+ for dag_idx in incoming_edges:
336
+ dag = batch_dags[dag_idx]
337
+ for from_j, strategy in incoming_edges[dag_idx]:
338
+ dag.add_edge(from_j, i, strategy)
339
+ # Step 5: Check stopping criteria
340
+ for dag_idx in active_indices.copy():
341
+ dag = batch_dags[dag_idx]
342
+ content = dag.get_node_content(i)
343
+ if not content or "boxed" in content.lower() or dag.total_tokens() >= self.token_budget:
344
+ active_indices.remove(dag_idx)
345
+
346
+ return batch_dags, total_log_probs
347
+
348
+ def extract_answer(self, batch_dags: List[MetaReasoningDAG]) -> List[str]:
349
+ """Extract final answer from the reasoning DAG"""
350
+ batch_answer_prompts = []
351
+ params_list = []
352
+ for dag in batch_dags:
353
+ full_context = dag.get_context_up_to(len(dag.nodes) - 1)
354
+
355
+ answer_prompt = f"{full_context}\n{random.choice(META_STRATEGIES['Answer'])}\n"
356
+ params_list.append( {
357
+ "prompt": answer_prompt,
358
+ "max_tokens": self.config.max_new_tokens,
359
+ })
360
+
361
+ with ThreadPoolExecutor(max_workers=None) as executor:
362
+ batch_outputs = list(executor.map(self.make_api_call, params_list))
363
+
364
+ batch_answers = [output.choices[0].text.strip() for output in batch_outputs]
365
+
366
+ return batch_answers
367
+
368
+ def inference(self, batch_queries: List[str], M: int) -> Tuple[str, MetaReasoningDAG]:
369
+ """Run inference on a single query"""
370
+ self.strategy_mlp.eval()
371
+ self.strategy_embeddings.eval()
372
+
373
+ with torch.no_grad():
374
+ batch_dags, _ = self.dynamic_skeleton_sampling(batch_queries, M)
375
+ batch_answers = self.extract_answer(batch_dags)
376
+
377
+ return batch_answers, batch_dags
378
+
379
+ def save_checkpoint(self, path: str):
380
+ """Save model checkpoint"""
381
+ torch.save({
382
+ 'strategy_embeddings': self.strategy_embeddings.state_dict(),
383
+ 'strategy_mlp': self.strategy_mlp.state_dict(),
384
+ 'optimizer': self.optimizer.state_dict()
385
+ }, path)
386
+ print(f"Checkpoint saved to {path}")
387
+
388
+ def load_checkpoint(self, path: str):
389
+ """Load model checkpoint"""
390
+ checkpoint = torch.load(path, map_location=self.device)
391
+ self.strategy_embeddings.load_state_dict(checkpoint['strategy_embeddings'])
392
+ self.strategy_mlp.load_state_dict(checkpoint['strategy_mlp'])
393
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
394
+ print(f"Checkpoint loaded from {path}")
395
+
396
+ def make_api_call(self,params):
397
+ """Make API call to vLLM server"""
398
+ return self.generator_client.completions.create(
399
+ model=self.model_name_for_api,
400
+ prompt=params["prompt"],
401
+ max_tokens=params["max_tokens"],
402
+ temperature=self.config.temperature,
403
+ top_p=self.config.top_p,
404
+ )
automr/strategies.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Meta Reasoning Strategy Prompts (from Table 2 in paper)
2
+ META_STRATEGIES = {
3
+ "Next": [
4
+ "Next,",
5
+ "Then,",
6
+ "Now, let me move on to the next step."
7
+ ],
8
+ "Reflect": [
9
+ "Let me consider what part of the reasoning feels least certain, and how can it be examined.",
10
+ "Wait, let me think if there anything missing in the current reasoning.",
11
+ "Let me think does the current line of thought have any error."
12
+ ],
13
+ "Explore": [
14
+ "Let me consider which direction of thinking I should explore.",
15
+ "Let me think what potential strategy has not yet been considered that could be the next solution path.",
16
+ "Let me think what possible solution could be tried next."
17
+ ],
18
+ "Decompose": [
19
+ "This question is a bit complex, let me think how to decompose it into sub-questions that I can solve.",
20
+ "The question feels too broad, let me think what smaller version could I tackle first.",
21
+ "Let me think if I can express the problem in terms of simpler components or modules.",
22
+ "Let me consider the options one by one." # For multiple choice
23
+ ],
24
+ "Summarize": [
25
+ "Let me summarize what have I established so far.",
26
+ "Let me summarize the current state of reasoning process, what's known, unknown, and assumed?",
27
+ "Let me consider if I can captures the essence of the reasoning so far with single sentence."
28
+ ],
29
+ "Recall": [
30
+ "Let me think if I have encountered similar problems or if learned knowledge and previous intermediate step can be used here.",
31
+ "Let me think what prior reasoning steps are directly relevant here or this question connect to earlier results.",
32
+ "Let me recall which theorems, rules, or principles from earlier knowledge is related to this question."
33
+ ],
34
+ "Answer": [
35
+ "Let me give the answer according to current reasoning context."
36
+ ]
37
+ }
38
+
39
+ # All strategies including the special "zero" edge type
40
+ STRATEGY_LIST = ["Next", "Reflect", "Explore", "Decompose", "Summarize", "Recall", "Answer", "zero"]
automr/trainer.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from typing import List, Dict, Tuple
4
+ from tqdm import tqdm
5
+
6
+ from .model import AutoMR
7
+ from .config import AutoMRConfig
8
+ from .utils import check_answer_match, ensure_dir, save_json
9
+ import os
10
+
11
+
12
+ class AutoMRTrainer:
13
+ """Trainer for AutoMR using REINFORCE (Algorithm 2 from paper)"""
14
+
15
+ def __init__(self, model: AutoMR, config: AutoMRConfig):
16
+ self.model = model
17
+ self.config = config
18
+ ensure_dir(config.checkpoint_dir)
19
+
20
+ # Track training progress
21
+ self.global_step = 0
22
+ self.best_val_reward = -float('inf')
23
+ self.patience_counter = 0
24
+ self.training_history = {
25
+ 'train_loss': [],
26
+ 'train_reward': [],
27
+ 'val_reward': [],
28
+ 'val_accuracy': [],
29
+ 'steps': []
30
+ }
31
+
32
+ def compute_reward_batch(self, queries: List[str], answers: List[str]) -> Tuple[float, float]:
33
+ """
34
+ Compute average reward and accuracy on a batch
35
+ Returns: (avg_reward, accuracy)
36
+ """
37
+ total_reward = 0.0
38
+ correct = 0
39
+ total = len(queries)
40
+
41
+ self.model.strategy_mlp.eval()
42
+ self.model.strategy_embeddings.eval()
43
+
44
+ with torch.no_grad():
45
+ dags, _ = self.model.dynamic_skeleton_sampling(queries, M=1)
46
+ pred_answers = self.model.extract_answer(dags)
47
+
48
+ is_correct = [check_answer_match(
49
+ pred_answer, answer, self.config.task_type
50
+ ) for pred_answer, answer in zip(pred_answers, answers)]
51
+
52
+ rewards = [1.0 if correct else -1.0 for correct in is_correct]
53
+ total_reward += sum(rewards)
54
+ correct += sum(is_correct)
55
+
56
+ avg_reward = total_reward / total if total > 0 else 0.0
57
+ accuracy = correct / total if total > 0 else 0.0
58
+
59
+ return avg_reward, accuracy
60
+
61
+ def validate(self, val_data: List[Dict[str, str]]) -> Tuple[float, float]:
62
+ """
63
+ Run validation on validation set
64
+ Returns: (avg_reward, accuracy)
65
+ """
66
+ # Sample validation batch
67
+ val_batch_size = min(self.config.val_batch_size, len(val_data))
68
+ val_batch = random.sample(val_data, val_batch_size)
69
+
70
+ val_queries = [item['query'] for item in val_batch]
71
+ val_answers = [item['answer'] for item in val_batch]
72
+
73
+ avg_reward, accuracy = self.compute_reward_batch(val_queries, val_answers)
74
+
75
+ return avg_reward, accuracy
76
+
77
+ def train_step(self, batch_queries: List[str], batch_answers: List[str]) -> Tuple[float, float]:
78
+ """
79
+ Single training step using REINFORCE (Equation 4 from paper)
80
+ Returns: (loss, avg_reward)
81
+ """
82
+ self.model.strategy_mlp.train()
83
+ self.model.strategy_embeddings.train()
84
+
85
+ M = self.config.num_samples_per_query
86
+ loss = []
87
+ rewards_list = []
88
+ # expand answers for M samples per query
89
+ batch_answers = [answer for answer in batch_answers for _ in range(M)]
90
+ batch_dags, batch_log_probs = self.model.dynamic_skeleton_sampling(batch_queries, M)
91
+ # Get prediction
92
+ batch_pred_answers = self.model.extract_answer(batch_dags)
93
+ # Compute reward
94
+ for pred_answer, answer, log_prob in zip(batch_pred_answers, batch_answers, batch_log_probs):
95
+ reward = 1.0 if check_answer_match(
96
+ pred_answer, answer, self.config.task_type
97
+ ) else -1.0
98
+ rewards_list.append(reward)
99
+
100
+ # Accumulate gradient (REINFORCE)
101
+ loss.append(-reward * log_prob)
102
+
103
+ # Compute average reward for this batch
104
+ avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
105
+
106
+ # Update parameters
107
+ self.model.optimizer.zero_grad()
108
+ loss = torch.stack(loss).mean()
109
+
110
+ loss.backward()
111
+ torch.nn.utils.clip_grad_norm_(
112
+ list(self.model.strategy_embeddings.parameters()) +
113
+ list(self.model.strategy_mlp.parameters()),
114
+ max_norm=self.config.gradient_clip
115
+ )
116
+ self.model.optimizer.step()
117
+
118
+ return loss.item(), avg_reward
119
+
120
+ def should_stop_early(self) -> bool:
121
+ """Check if training should stop early"""
122
+ return self.patience_counter >= self.config.early_stopping_patience
123
+
124
+ def save_history(self):
125
+ history_path = os.path.join(self.config.checkpoint_dir, "training_history.json")
126
+ save_json(self.training_history, history_path)
127
+
128
+ def save_checkpoint(self, epoch: int, is_best: bool = False):
129
+ """Save checkpoint"""
130
+ if self.config.save_best_only and not is_best:
131
+ return
132
+
133
+ checkpoint_name = f"checkpoint_epoch_{epoch}_step_{self.global_step}.pt"
134
+ if is_best:
135
+ checkpoint_name = "best_checkpoint.pt"
136
+
137
+ checkpoint_path = os.path.join(self.config.checkpoint_dir, checkpoint_name)
138
+
139
+ self.model.save_checkpoint(checkpoint_path)
140
+
141
+ # Also save training history
142
+ self.save_history()
143
+
144
+ if is_best:
145
+ print(f" 💾 Best checkpoint saved: {checkpoint_path}")
146
+
147
+ def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
148
+ """Training loop with validation (Algorithm 2 from paper with validation)"""
149
+ print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
150
+ print(f"Training samples: {len(train_data)}")
151
+ print(f"Validation samples: {len(val_data)}")
152
+ print(f"Batch size: {self.config.batch_size}")
153
+ print(f"Samples per query: {self.config.num_samples_per_query}")
154
+ print(f"Validation every {self.config.val_every_n_steps} steps")
155
+ print(f"Early stopping patience: {self.config.early_stopping_patience}\n")
156
+
157
+ for epoch in range(self.config.num_epochs):
158
+ random.shuffle(train_data)
159
+ epoch_loss = 0.0
160
+ epoch_reward = 0.0
161
+ num_batches = 0
162
+
163
+ pbar = tqdm(
164
+ range(0, len(train_data), self.config.batch_size),
165
+ desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
166
+ )
167
+
168
+ for i in pbar:
169
+ batch = train_data[i:i+self.config.batch_size]
170
+ batch_queries = [item['query'] for item in batch]
171
+ batch_answers = [item['answer'] for item in batch]
172
+
173
+ # Training step
174
+ loss, avg_reward = self.train_step(batch_queries, batch_answers)
175
+ epoch_loss += loss
176
+ epoch_reward += avg_reward
177
+ num_batches += 1
178
+ self.global_step += 1
179
+
180
+ self.training_history['train_reward'].append(avg_reward)
181
+ self.save_history()
182
+
183
+ pbar.set_postfix({
184
+ 'loss': f'{loss:.4f}',
185
+ 'reward': f'{avg_reward:.3f}',
186
+ 'step': self.global_step
187
+ })
188
+
189
+ # Validation
190
+ if self.global_step % self.config.val_every_n_steps == 0:
191
+ print(f"\n{'='*80}")
192
+ print(f"Validation at Step {self.global_step}")
193
+ print(f"{'='*80}")
194
+
195
+ val_reward, val_accuracy = self.validate(val_data)
196
+
197
+ print(f"Validation Reward: {val_reward:.4f}")
198
+ print(f"Validation Accuracy: {val_accuracy:.2%}")
199
+
200
+ # Record history
201
+ self.training_history['val_reward'].append(val_reward)
202
+ self.training_history['val_accuracy'].append(val_accuracy)
203
+ self.training_history['steps'].append(self.global_step)
204
+
205
+ # Check if this is the best model
206
+ is_best = val_reward > self.best_val_reward
207
+ if is_best:
208
+ print(f"✨ New best validation reward: {val_reward:.4f} (previous: {self.best_val_reward:.4f})")
209
+ self.best_val_reward = val_reward
210
+ self.patience_counter = 0
211
+ self.save_checkpoint(epoch + 1, is_best=True)
212
+ else:
213
+ self.patience_counter += 1
214
+ print(f"No improvement. Patience: {self.patience_counter}/{self.config.early_stopping_patience}")
215
+
216
+ print(f"{'='*80}\n")
217
+
218
+ # Check early stopping
219
+ if self.should_stop_early():
220
+ print(f"\n Early stopping triggered after {self.global_step} steps")
221
+ print(f"Best validation reward: {self.best_val_reward:.4f}")
222
+ return
223
+
224
+ # End of epoch
225
+ avg_epoch_loss = epoch_loss / num_batches
226
+ avg_epoch_reward = epoch_reward / num_batches
227
+
228
+ self.training_history['train_loss'].append(avg_epoch_loss)
229
+ self.training_history['train_reward'].append(avg_epoch_reward)
230
+
231
+ print(f"\n{'='*80}")
232
+ print(f"Epoch {epoch+1} Summary")
233
+ print(f"{'='*80}")
234
+ print(f"Average Loss: {avg_epoch_loss:.4f}")
235
+ print(f"Average Reward: {avg_epoch_reward:.4f}")
236
+ print(f"Best Val Reward: {self.best_val_reward:.4f}")
237
+ print(f"{'='*80}\n")
238
+
239
+ # Save checkpoint at end of epoch (if not save_best_only)
240
+ if not self.config.save_best_only:
241
+ self.save_checkpoint(epoch + 1)
242
+
243
+ print("Training completed!")
244
+ print(f"Best validation reward achieved: {self.best_val_reward:.4f}")
automr/utils.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import json
4
+ from typing import Any
5
+ import re
6
+ import regex
7
+ from latex2sympy2 import latex2sympy
8
+ from word2number import w2n
9
+
10
+
11
+ def extract_math_answer(text: str) -> str:
12
+ """Extract answer from math problem (boxed format)"""
13
+ # Try to find \boxed{...} format
14
+ boxed_pattern = r'\\boxed\{([^}]*)\}'
15
+ matches = re.findall(boxed_pattern, text)
16
+ if matches:
17
+ return matches[-1].strip()
18
+
19
+ # Try to find answer after "answer is" or similar phrases
20
+ answer_patterns = [
21
+ r'answer is[:\s]+([^\n.]+)',
22
+ r'final answer[:\s]+([^\n.]+)',
23
+ r'therefore[,:\s]+([^\n.]+)'
24
+ ]
25
+
26
+ for pattern in answer_patterns:
27
+ matches = re.findall(pattern, text.lower())
28
+ if matches:
29
+ return matches[-1].strip()
30
+
31
+ return text.strip()
32
+
33
+
34
+ def extract_multiple_choice_answer(text: str) -> str:
35
+ """Extract answer from multiple choice (A, B, C, D format)"""
36
+ # Look for single letter answers
37
+ pattern = r'\b([A-D])\b'
38
+ matches = re.findall(pattern, text.upper())
39
+ if matches:
40
+ return matches[-1]
41
+
42
+ return text.strip()
43
+
44
+
45
+ def normalize_answer(answer: str) -> str:
46
+ """Normalize answer for comparison"""
47
+ answer = answer.strip().lower()
48
+ # Remove common mathematical notation
49
+ answer = answer.replace('$', '').replace('\\', '')
50
+ # Remove extra whitespace
51
+ answer = ' '.join(answer.split())
52
+ return answer
53
+
54
+
55
+ def check_answer_match(pred: str, ground_truth: str, task_type: str = "math") -> bool:
56
+ """Check if predicted answer matches ground truth"""
57
+ # if task_type == "math":
58
+ # pred = extract_math_answer(pred)
59
+ # ground_truth = extract_math_answer(ground_truth)
60
+ # elif task_type == "multiple_choice":
61
+ # pred = extract_multiple_choice_answer(pred)
62
+ # ground_truth = extract_multiple_choice_answer(ground_truth)
63
+
64
+ # pred_norm = normalize_answer(pred)
65
+ # gt_norm = normalize_answer(ground_truth)
66
+
67
+ # return pred_norm == gt_norm or pred_norm in gt_norm or gt_norm in pred_norm
68
+ pred = str(pred)
69
+ ground_truth = str(ground_truth)
70
+ return pred == ground_truth or pred in ground_truth or ground_truth in pred
71
+
72
+
73
+ def ensure_dir(directory: str):
74
+ """Create directory if it doesn't exist"""
75
+ if not os.path.exists(directory):
76
+ os.makedirs(directory)
77
+
78
+
79
+ def save_json(data: Any, path: str):
80
+ """Save data to JSON file"""
81
+ ensure_dir(os.path.dirname(path))
82
+ with open(path, 'w') as f:
83
+ json.dump(data, f, indent=2)
84
+
85
+
86
+ def load_json(path: str) -> Any:
87
+ """Load data from JSON file"""
88
+ with open(path, 'r') as f:
89
+ return json.load(f)
90
+
91
+ def _fix_fracs(string):
92
+ substrs = string.split("\\frac")
93
+ new_str = substrs[0]
94
+ if len(substrs) > 1:
95
+ substrs = substrs[1:]
96
+ for substr in substrs:
97
+ new_str += "\\frac"
98
+ if len(substr) > 0 and substr[0] == "{":
99
+ new_str += substr
100
+ else:
101
+ try:
102
+ assert len(substr) >= 2
103
+ except:
104
+ return string
105
+ a = substr[0]
106
+ b = substr[1]
107
+ if b != "{":
108
+ if len(substr) > 2:
109
+ post_substr = substr[2:]
110
+ new_str += "{" + a + "}{" + b + "}" + post_substr
111
+ else:
112
+ new_str += "{" + a + "}{" + b + "}"
113
+ else:
114
+ if len(substr) > 2:
115
+ post_substr = substr[2:]
116
+ new_str += "{" + a + "}" + b + post_substr
117
+ else:
118
+ new_str += "{" + a + "}" + b
119
+ string = new_str
120
+ return string
121
+
122
+
123
+ def _fix_a_slash_b(string):
124
+ if len(string.split("/")) != 2:
125
+ return string
126
+ a = string.split("/")[0]
127
+ b = string.split("/")[1]
128
+ try:
129
+ if "sqrt" not in a:
130
+ a = int(a)
131
+ if "sqrt" not in b:
132
+ b = int(b)
133
+ assert string == "{}/{}".format(a, b)
134
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
135
+ return new_string
136
+ except:
137
+ return string
138
+
139
+
140
+ def _fix_sqrt(string):
141
+ _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
142
+ return _string
143
+
144
+
145
+ def convert_word_number(text: str) -> str:
146
+ try:
147
+ text = str(w2n.word_to_num(text))
148
+ except:
149
+ pass
150
+ return text
151
+
152
+
153
+ # units mainly from MathQA
154
+ unit_texts = [
155
+ "east",
156
+ "degree",
157
+ "mph",
158
+ "kmph",
159
+ "ft",
160
+ "m sqaure",
161
+ " m east",
162
+ "sq m",
163
+ "deg",
164
+ "mile",
165
+ "q .",
166
+ "monkey",
167
+ "prime",
168
+ "ratio",
169
+ "profit of rs",
170
+ "rd",
171
+ "o",
172
+ "gm",
173
+ "p . m",
174
+ "lb",
175
+ "tile",
176
+ "per",
177
+ "dm",
178
+ "lt",
179
+ "gain",
180
+ "ab",
181
+ "way",
182
+ "west",
183
+ "a .",
184
+ "b .",
185
+ "c .",
186
+ "d .",
187
+ "e .",
188
+ "f .",
189
+ "g .",
190
+ "h .",
191
+ "t",
192
+ "a",
193
+ "h",
194
+ "no change",
195
+ "men",
196
+ "soldier",
197
+ "pie",
198
+ "bc",
199
+ "excess",
200
+ "st",
201
+ "inches",
202
+ "noon",
203
+ "percent",
204
+ "by",
205
+ "gal",
206
+ "kmh",
207
+ "c",
208
+ "acre",
209
+ "rise",
210
+ "a . m",
211
+ "th",
212
+ "π r 2",
213
+ "sq",
214
+ "mark",
215
+ "l",
216
+ "toy",
217
+ "coin",
218
+ "sq . m",
219
+ "gallon",
220
+ "° f",
221
+ "profit",
222
+ "minw",
223
+ "yr",
224
+ "women",
225
+ "feet",
226
+ "am",
227
+ "pm",
228
+ "hr",
229
+ "cu cm",
230
+ "square",
231
+ "v â € ™",
232
+ "are",
233
+ "rupee",
234
+ "rounds",
235
+ "cubic",
236
+ "cc",
237
+ "mtr",
238
+ "s",
239
+ "ohm",
240
+ "number",
241
+ "kmph",
242
+ "day",
243
+ "hour",
244
+ "minute",
245
+ "min",
246
+ "second",
247
+ "man",
248
+ "woman",
249
+ "sec",
250
+ "cube",
251
+ "mt",
252
+ "sq inch",
253
+ "mp",
254
+ "∏ cm ³",
255
+ "hectare",
256
+ "more",
257
+ "sec",
258
+ "unit",
259
+ "cu . m",
260
+ "cm 2",
261
+ "rs .",
262
+ "rs",
263
+ "kg",
264
+ "g",
265
+ "month",
266
+ "km",
267
+ "m",
268
+ "cm",
269
+ "mm",
270
+ "apple",
271
+ "liter",
272
+ "loss",
273
+ "yard",
274
+ "pure",
275
+ "year",
276
+ "increase",
277
+ "decrease",
278
+ "d",
279
+ "less",
280
+ "Surface",
281
+ "litre",
282
+ "pi sq m",
283
+ "s .",
284
+ "metre",
285
+ "meter",
286
+ "inch",
287
+ ]
288
+
289
+ unit_texts.extend([t + "s" for t in unit_texts])
290
+
291
+
292
+ def strip_string(string, skip_unit=False):
293
+ string = str(string).strip()
294
+ # linebreaks
295
+ string = string.replace("\n", "")
296
+
297
+ # right "."
298
+ string = string.rstrip(".")
299
+
300
+ # remove inverse spaces
301
+ # replace \\ with \
302
+ string = string.replace("\\!", "")
303
+ # string = string.replace("\\ ", "")
304
+ # string = string.replace("\\\\", "\\")
305
+
306
+ # matrix
307
+ string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
308
+ string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
309
+ string = string.replace("bmatrix", "pmatrix")
310
+
311
+ # replace tfrac and dfrac with frac
312
+ string = string.replace("tfrac", "frac")
313
+ string = string.replace("dfrac", "frac")
314
+ string = (string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge"))
315
+
316
+ # remove \left and \right
317
+ string = string.replace("\\left", "")
318
+ string = string.replace("\\right", "")
319
+ string = string.replace("\\{", "{")
320
+ string = string.replace("\\}", "}")
321
+
322
+ # Remove unit: miles, dollars if after is not none
323
+ _string = re.sub(r"\\text{.*?}$", "", string).strip()
324
+ if _string != "" and _string != string:
325
+ # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
326
+ string = _string
327
+
328
+ if not skip_unit:
329
+ # Remove unit: texts
330
+ for _ in range(2):
331
+ for unit_text in unit_texts:
332
+ # use regex, the prefix should be either the start of the string or a non-alphanumeric character
333
+ # the suffix should be either the end of the string or a non-alphanumeric character
334
+ _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
335
+ if _string != "":
336
+ string = _string
337
+
338
+ # Remove circ (degrees)
339
+ string = string.replace("^{\\circ}", "")
340
+ string = string.replace("^\\circ", "")
341
+
342
+ # remove dollar signs
343
+ string = string.replace("\\$", "")
344
+ string = string.replace("$", "")
345
+ string = string.replace("\\(", "").replace("\\)", "")
346
+
347
+ # convert word number to digit
348
+ string = convert_word_number(string)
349
+
350
+ # replace "\\text{...}" to "..."
351
+ string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
352
+ for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
353
+ string = string.replace(key, "")
354
+ string = string.replace("\\emptyset", r"{}")
355
+ string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
356
+
357
+ # remove percentage
358
+ string = string.replace("\\%", "")
359
+ string = string.replace("\%", "")
360
+ string = string.replace("%", "")
361
+
362
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
363
+ string = string.replace(" .", " 0.")
364
+ string = string.replace("{.", "{0.")
365
+
366
+ # cdot
367
+ # string = string.replace("\\cdot", "")
368
+ if (string.startswith("{") and string.endswith("}") and string.isalnum() or
369
+ string.startswith("(") and string.endswith(")") and string.isalnum() or
370
+ string.startswith("[") and string.endswith("]") and string.isalnum()):
371
+ string = string[1:-1]
372
+
373
+ # inf
374
+ string = string.replace("infinity", "\\infty")
375
+ if "\\infty" not in string:
376
+ string = string.replace("inf", "\\infty")
377
+ string = string.replace("+\\inity", "\\infty")
378
+
379
+ # and
380
+ string = string.replace("and", "")
381
+ string = string.replace("\\mathbf", "")
382
+
383
+ # use regex to remove \mbox{...}
384
+ string = re.sub(r"\\mbox{.*?}", "", string)
385
+
386
+ # quote
387
+ string.replace("'", "")
388
+ string.replace('"', "")
389
+
390
+ # i, j
391
+ if "j" in string and "i" not in string:
392
+ string = string.replace("j", "i")
393
+
394
+ # replace a.000b where b is not number or b is end, with ab, use regex
395
+ string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
396
+ string = re.sub(r"(\d+)\.0*$", r"\1", string)
397
+
398
+ # if empty, return empty string
399
+ if len(string) == 0:
400
+ return string
401
+ if string[0] == ".":
402
+ string = "0" + string
403
+
404
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
405
+ if len(string.split("=")) == 2:
406
+ if len(string.split("=")[0]) <= 2:
407
+ string = string.split("=")[1]
408
+
409
+ string = _fix_sqrt(string)
410
+ string = string.replace(" ", "")
411
+
412
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
413
+ string = _fix_fracs(string)
414
+
415
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
416
+ string = _fix_a_slash_b(string)
417
+
418
+ return string
419
+
420
+
421
+ direct_answer_trigger_for_fewshot = ("choice is", "answer is")
422
+
423
+
424
+ def choice_answer_clean(pred: str):
425
+ pred = pred.strip("\n")
426
+
427
+ # Determine if this is ICL, if so, use \n\n to split the first chunk.
428
+ ICL = False
429
+ for trigger in direct_answer_trigger_for_fewshot:
430
+ if pred.count(trigger) > 1:
431
+ ICL = True
432
+ if ICL:
433
+ pred = pred.split("\n\n")[0]
434
+
435
+ # Split the trigger to find the answer.
436
+ preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
437
+ if len(preds) > 1:
438
+ answer_flag = True
439
+ pred = preds[-1]
440
+ else:
441
+ answer_flag = False
442
+
443
+ pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
444
+
445
+ # Clean the answer based on the dataset
446
+ tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
447
+ if tmp:
448
+ pred = tmp
449
+ else:
450
+ pred = [pred.strip().strip(".")]
451
+
452
+ if len(pred) == 0:
453
+ pred = ""
454
+ else:
455
+ if answer_flag:
456
+ # choose the first element in list ...
457
+ pred = pred[0]
458
+ else:
459
+ # choose the last e
460
+ pred = pred[-1]
461
+
462
+ # Remove the period at the end, again!
463
+ pred = pred.rstrip(".").rstrip("/")
464
+
465
+ return pred
466
+
467
+
468
+ def find_box(pred_str: str):
469
+ ans = pred_str.split("boxed")[-1]
470
+ if not ans:
471
+ return ""
472
+ if ans[0] == "{":
473
+ stack = 1
474
+ a = ""
475
+ for c in ans[1:]:
476
+ if c == "{":
477
+ stack += 1
478
+ a += c
479
+ elif c == "}":
480
+ stack -= 1
481
+ if stack == 0:
482
+ break
483
+ a += c
484
+ else:
485
+ a += c
486
+ else:
487
+ a = ans.split("$")[0].strip()
488
+ return a
489
+
490
+
491
+ def clean_units(pred_str: str):
492
+ """Clean the units in the number."""
493
+
494
+ def convert_pi_to_number(code_string):
495
+ code_string = code_string.replace("\\pi", "π")
496
+ # Replace \pi or π not preceded by a digit or } with 3.14
497
+ code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string)
498
+ # Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14"
499
+ code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
500
+ # Handle cases where π is within braces or followed by a multiplication symbol
501
+ # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
502
+ code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
503
+ code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
504
+ return code_string
505
+
506
+ pred_str = convert_pi_to_number(pred_str)
507
+ pred_str = pred_str.replace("%", "/100")
508
+ pred_str = pred_str.replace("$", "")
509
+ pred_str = pred_str.replace("¥", "")
510
+ pred_str = pred_str.replace("°C", "")
511
+ pred_str = pred_str.replace(" C", "")
512
+ pred_str = pred_str.replace("°", "")
513
+ return pred_str
514
+
515
+
516
+ def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
517
+ if any([option in pred.lower() for option in ["yes", "true"]]):
518
+ pred = "True"
519
+ elif any([option in pred.lower() for option in ["no", "false"]]):
520
+ pred = "False"
521
+ elif any([option in pred.lower() for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]]):
522
+ pass
523
+ else:
524
+ # Some of the models somehow get used to boxed output from pre-training
525
+ if "boxed" in pred:
526
+ pred = find_box(pred)
527
+
528
+ if answer_flag:
529
+ # Extract the numbers out of the string
530
+ pred = pred.split("=")[-1].strip()
531
+ pred = clean_units(pred)
532
+ try:
533
+ tmp = str(latex2sympy(pred))
534
+ pred = str(eval(tmp))
535
+ except Exception:
536
+ if re.match(r"-?[\d\.]+\s\D+$", pred):
537
+ pred = pred.split(" ")[0]
538
+ elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
539
+ pred = pred.split(" ")[0]
540
+ else:
541
+ # desparate search over the last number
542
+ preds = re.findall(r"-?\d*\.?\d+", pred)
543
+ if len(preds) >= 1:
544
+ pred = preds[-1]
545
+ else:
546
+ pred = ""
547
+
548
+ return pred
549
+
550
+
551
+ def extract_answer(pred_str, data_name, use_last_number=True):
552
+ if data_name.lower() == "humaneval":
553
+ pattern = r"### Function Body:\s*\n```python\n(.*?)\n```"
554
+ matches = re.findall(pattern, pred_str, re.DOTALL)
555
+ try:
556
+ return matches[0]
557
+ except IndexError:
558
+ return ""
559
+ elif data_name.lower() == "mmlu":
560
+ if len(pred_str) >= 3 and pred_str[0] == '(' and pred_str[2] == ')':
561
+ return pred_str[1]
562
+ pred_str = pred_str.replace("\u043a\u0438", "")
563
+
564
+ if "final answer is $" in pred_str and "$. I hope" in pred_str:
565
+ # minerva_math
566
+ tmp = pred_str.split("final answer is $", 1)[1]
567
+ pred = tmp.split("$. I hope", 1)[0].strip()
568
+ elif "boxed" in pred_str:
569
+ ans = pred_str.split("boxed")[-1]
570
+ if len(ans) == 0:
571
+ return ""
572
+ elif ans[0] == "{":
573
+ stack = 1
574
+ a = ""
575
+ for c in ans[1:]:
576
+ if c == "{":
577
+ stack += 1
578
+ a += c
579
+ elif c == "}":
580
+ stack -= 1
581
+ if stack == 0:
582
+ break
583
+ a += c
584
+ else:
585
+ a += c
586
+ else:
587
+ a = ans.split("$")[0].strip()
588
+ pred = a
589
+ elif "he answer is" in pred_str:
590
+ pred = pred_str.split("he answer is")[-1].strip()
591
+ elif "final answer is" in pred_str:
592
+ pred = pred_str.split("final answer is")[-1].strip()
593
+ elif "答案是" in pred_str:
594
+ # Handle Chinese few-shot multiple choice problem answer extraction
595
+ pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
596
+ else: # use the last number
597
+ if use_last_number:
598
+ pattern = "-?\d*\.?\d+"
599
+ pred = re.findall(pattern, pred_str.replace(",", ""))
600
+ if len(pred) >= 1:
601
+ pred = pred[-1]
602
+ else:
603
+ pred = ""
604
+ else:
605
+ pred = ""
606
+
607
+ # multiple line
608
+ # pred = pred.split("\n")[0]
609
+ pred = re.sub(r"\n\s*", "", pred)
610
+ if pred != "" and pred[0] == ":":
611
+ pred = pred[1:]
612
+ if pred != "" and pred[-1] == ".":
613
+ pred = pred[:-1]
614
+ if pred != "" and pred[-1] == "/":
615
+ pred = pred[:-1]
616
+ pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva"])
617
+
618
+ if data_name == 'GPQA' or data_name == 'MMLU':
619
+ if len(pred) >= 3 and pred[0] == '(' and pred[2] == ')':
620
+ pred = pred[1]
621
+ return pred
622
+
623
+
624
+ STRIP_EXCEPTIONS = ["carp_en", "minerva"]
625
+
626
+
627
+ def parse_ground_truth(groudtruth_solution: str, data_name):
628
+ gt_ans = extract_answer(groudtruth_solution, data_name)
629
+ return gt_ans
630
+
631
+
632
+ def parse_question(example, data_name):
633
+ question = ""
634
+ if data_name == "asdiv":
635
+ question = f"{example['body'].strip()} {example['question'].strip()}"
636
+ elif data_name == "svamp":
637
+ body = example["Body"].strip()
638
+ if not body.endswith("."):
639
+ body = body + "."
640
+ question = f'{body} {example["Question"].strip()}'
641
+ elif data_name == "tabmwp":
642
+ title_str = (f'regarding "{example["table_title"]}" ' if example["table_title"] else "")
643
+ question = f"Read the following table {title_str}and answer a question:\n"
644
+ question += f'{example["table"]}\n{example["question"]}'
645
+ if example["choices"]:
646
+ question += (f' Please select from the following options: {example["choices"]}')
647
+ elif data_name == "carp_en":
648
+ question = example["content"]
649
+ elif data_name == "mmlu_stem":
650
+ options = example["choices"]
651
+ assert len(options) == 4
652
+ for i, (label, option) in enumerate(zip("ABCD", options)):
653
+ options[i] = f"({label}) {str(option).strip()}"
654
+ options = " ".join(options)
655
+ # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
656
+ question = f"{example['question'].strip()}\nAnswer Choices: {options}"
657
+ elif data_name == "sat_math":
658
+ options = example["options"].strip()
659
+ assert "A" == options[0]
660
+ options = "(" + options
661
+ for ch in "BCD":
662
+ if f" {ch}) " in options:
663
+ options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
664
+ # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
665
+ question = f"{example['question'].strip()}\nAnswer Choices: {options}"
666
+ elif "aqua" in data_name:
667
+ options = example["options"]
668
+ choice = "(" + "(".join(options)
669
+ choice = choice.replace("(", " (").replace(")", ") ").strip()
670
+ choice = "\nAnswer Choices: " + choice
671
+ question = example["question"].strip() + choice
672
+ elif data_name == "gaokao_math_qa":
673
+ options_dict = example["options"]
674
+ options = []
675
+ for key in options_dict:
676
+ options.append(f"({key}) {options_dict[key]}")
677
+ options = " ".join(options)
678
+ question = f"{example['question'].strip()}\n选项: {options}"
679
+ else:
680
+ for key in ["question", "problem", "Question", "input"]:
681
+ if key in example:
682
+ question = example[key]
683
+ break
684
+ # assert question != ""
685
+ # Yes or No question
686
+ _, gt_ans = parse_ground_truth(example, data_name)
687
+ if isinstance(gt_ans, str):
688
+ gt_lower = gt_ans.lower()
689
+ if gt_lower in ["true", "false"]:
690
+ question += " (True or False)"
691
+ if gt_lower in ["yes", "no"]:
692
+ question += " (Yes or No)"
693
+ return question.strip()
embedder_server.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export ASCEND_RT_VISIBLE_DEVICES=3
2
+ export VLLM_USE_V1=1
3
+ python -m vllm.entrypoints.openai.api_server \
4
+ --model "FreedomIntelligence/openPangu-Embedded-7B" \
5
+ --tensor-parallel-size 1 \
6
+ --port 8001 \
7
+ --host localhost \
8
+ --gpu-memory-utilization 0.4 \
9
+ --trust-remote-code \
10
+ --task embed \
11
+ --dtype bfloat16 \
generator_server.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ export ASCEND_RT_VISIBLE_DEVICES=1
2
+ export VLLM_USE_V1=1
3
+ python -m vllm.entrypoints.openai.api_server \
4
+ --model "FreedomIntelligence/openPangu-Embedded-7B" \
5
+ --tensor-parallel-size 1 \
6
+ --port 8000 \
7
+ --host localhost \
8
+ --trust-remote-code \
9
+ --dtype bfloat16 \
10
+ --gpu-memory-utilization 0.90 \
main.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch_npu
4
+ from automr import AutoMR, AutoMRTrainer, AutoMREvaluator, AutoMRConfig
5
+ from automr.data_loader import DataLoader
6
+
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser()
10
+
11
+ # Mode
12
+ parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval'], help='Mode: train, eval, or train_eval')
13
+
14
+ # Model settings
15
+ parser.add_argument('--model_name', type=str, default='Qwen/Qwen2.5-3B-Instruct', help='Pretrained LLM model name')
16
+ parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu', 'npu'], help='Device to use')
17
+ parser.add_argument('--token_budget', type=int, default=256, help='Token budget for reasoning')
18
+ parser.add_argument('--hidden_size', type=int, default=4096, help='Hidden size of the model')
19
+
20
+ # Training settings
21
+ parser.add_argument('--learning_rate', type=float, default=5e-4, help='Learning rate')
22
+ parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
23
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
24
+ parser.add_argument('--num_samples', type=int, default=1, help='Number of skeletons to sample per query (M)')
25
+
26
+ # Data paths
27
+ parser.add_argument('--train_data', type=str, default='data/train.json', help='Path to training data')
28
+ parser.add_argument('--val_data', type=str, default='data/val.json', help='Path to validation data')
29
+ parser.add_argument('--test_data', type=str, default='data/test.json', help='Path to test data')
30
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save checkpoints')
31
+ parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
32
+
33
+ # Checkpoint
34
+ parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint to load')
35
+
36
+ # Task type
37
+ parser.add_argument('--task_type', type=str, default='math', choices=['math', 'multiple_choice'], help='Task type')
38
+
39
+ return parser.parse_args()
40
+
41
+
42
+ def main():
43
+ args = parse_args()
44
+
45
+ # Create configuration
46
+ config = AutoMRConfig(
47
+ model_name=args.model_name,
48
+ device=args.device,
49
+ token_budget=args.token_budget,
50
+ learning_rate=args.learning_rate,
51
+ num_epochs=args.num_epochs,
52
+ batch_size=args.batch_size,
53
+ num_samples_per_query=args.num_samples,
54
+ train_data_path=args.train_data,
55
+ val_data_path=args.val_data,
56
+ test_data_path=args.test_data,
57
+ checkpoint_dir=args.checkpoint_dir,
58
+ results_dir=args.results_dir,
59
+ task_type=args.task_type,
60
+ hidden_size=args.hidden_size,
61
+ )
62
+
63
+ print("="*80)
64
+ print("AutoMR: Automatic Meta-Reasoning Skeleton Search")
65
+ print("="*80)
66
+ print(f"\nConfiguration:")
67
+ print(f" Model: {config.model_name}")
68
+ print(f" Device: {config.device}")
69
+ print(f" Token Budget: {config.token_budget}")
70
+ print(f" Task Type: {config.task_type}")
71
+ print(f" Mode: {args.mode}")
72
+ print("="*80)
73
+
74
+ # Initialize model
75
+ model = AutoMR(config)
76
+
77
+ # Load checkpoint if specified
78
+ if args.load_checkpoint and os.path.exists(args.load_checkpoint):
79
+ model.load_checkpoint(args.load_checkpoint)
80
+
81
+ # Training mode
82
+ if args.mode == 'train':
83
+ print(f"\n{'='*80}")
84
+ print("TRAINING")
85
+ print("="*80)
86
+
87
+ # Load training data
88
+ train_data = DataLoader.load_data(config.train_data_path)
89
+ val_data = DataLoader.load_data(config.val_data_path)
90
+ print(f"Loaded {len(train_data)} training samples from {config.train_data_path}")
91
+
92
+ # Train
93
+ trainer = AutoMRTrainer(model, config)
94
+ trainer.train(train_data, val_data)
95
+
96
+ # Evaluation mode
97
+ elif args.mode == 'eval':
98
+ print(f"\n{'='*80}")
99
+ print("EVALUATION")
100
+ print("="*80)
101
+
102
+ # Load test data
103
+ test_data = DataLoader.load_data(config.test_data_path)
104
+ print(f"Loaded {len(test_data)} test samples from {config.test_data_path}")
105
+
106
+ # Evaluate
107
+ evaluator = AutoMREvaluator(model, config)
108
+ accuracy, results = evaluator.evaluate(test_data)
109
+
110
+ print(f"\n{'='*80}")
111
+ print(f"Final Accuracy: {accuracy:.2%}")
112
+ print("="*80)
113
+
114
+ else:
115
+ raise NotImplementedError
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
math_train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VLLM_USE_V1=1 ASCEND_RT_VISIBLE_DEVICES=0 python main.py --mode train \
2
+ --device npu \
3
+ --model_name "FreedomIntelligence/openPangu-Embedded-7B" \
4
+ --train_data processed_data/MATH/train.jsonl \
5
+ --val_data processed_data/MATH/val.jsonl \
6
+ --num_epochs 5 \
7
+ --batch_size 8 \
8
+ --num_samples 4 \
9
+ --token_budget 4096 \
10
+ --checkpoint_dir checkpoints/MATH/pangu \
11
+ --task_type math \
12
+