haifei
commited on
Commit
·
1482463
1
Parent(s):
5168b2e
code and checkpoint
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- AMC.SH +8 -0
- automr/__init__.py +7 -0
- automr/config.py +45 -0
- automr/dag.py +77 -0
- automr/data_loader.py +35 -0
- automr/evaluator.py +93 -0
- automr/model.py +404 -0
- automr/strategies.py +40 -0
- automr/trainer.py +244 -0
- automr/utils.py +693 -0
- embedder_server.sh +11 -0
- generator_server.sh +10 -0
- main.py +119 -0
- math_train.sh +12 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
checkpoints/MATH/pangu/best_checkpoint.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
processed_data/
|
| 2 |
+
checkpoints/
|
| 3 |
+
results/
|
| 4 |
+
automr/__pycache__/
|
AMC.SH
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ASCEND_RT_VISIBLE_DEVICES=4 python main.py --mode eval \
|
| 2 |
+
--device npu \
|
| 3 |
+
--model_name FreedomIntelligence/openPangu-Embedded-7B \
|
| 4 |
+
--test_data processed_data/AMC/test.jsonl \
|
| 5 |
+
--load_checkpoint checkpoints/MATH/pangu/best_checkpoint.pt \
|
| 6 |
+
--task_type math \
|
| 7 |
+
--results_dir results/AMC/ \
|
| 8 |
+
--token_budget 4096 \
|
automr/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import AutoMR
|
| 2 |
+
from .trainer import AutoMRTrainer
|
| 3 |
+
from .evaluator import AutoMREvaluator
|
| 4 |
+
from .config import AutoMRConfig
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
| 7 |
+
__all__ = ["AutoMR", "AutoMRTrainer", "AutoMREvaluator", "AutoMRConfig"]
|
automr/config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class AutoMRConfig:
|
| 7 |
+
"""Configuration for AutoMR framework"""
|
| 8 |
+
|
| 9 |
+
# Model settings
|
| 10 |
+
model_name: str = "FreedomIntelligence/openPangu-Embedded-7B"
|
| 11 |
+
device: str = "npu"
|
| 12 |
+
token_budget: int = 4096
|
| 13 |
+
hidden_size: int = 4096 # To be set according to the model used
|
| 14 |
+
|
| 15 |
+
# Training settings
|
| 16 |
+
learning_rate: float = 5e-4
|
| 17 |
+
num_epochs: int = 5
|
| 18 |
+
batch_size: int = 8
|
| 19 |
+
num_samples_per_query: int = 4 # M in paper
|
| 20 |
+
gradient_clip: float = 1.0
|
| 21 |
+
|
| 22 |
+
# Validation settings
|
| 23 |
+
val_every_n_steps: int = 100 # Alpha in the requirement - validate every N steps
|
| 24 |
+
val_batch_size: int = 10 # Number of validation samples to evaluate
|
| 25 |
+
early_stopping_patience: int = 5 # Stop if no improvement for N validations
|
| 26 |
+
|
| 27 |
+
# Generation settings
|
| 28 |
+
max_new_tokens: int = 1024
|
| 29 |
+
temperature: float = 0.01
|
| 30 |
+
top_p: float = 0.9
|
| 31 |
+
|
| 32 |
+
# Paths
|
| 33 |
+
train_data_path: str = "data/train.json"
|
| 34 |
+
val_data_path: str = "data/val.json"
|
| 35 |
+
test_data_path: str = "data/test.json"
|
| 36 |
+
checkpoint_dir: str = "checkpoints"
|
| 37 |
+
results_dir: str = "results"
|
| 38 |
+
|
| 39 |
+
# Evaluation settings
|
| 40 |
+
save_predictions: bool = True
|
| 41 |
+
save_skeletons: bool = True
|
| 42 |
+
save_best_only: bool = True
|
| 43 |
+
|
| 44 |
+
# Task type
|
| 45 |
+
task_type: str = "math" # "math" or "multiple_choice"
|
automr/dag.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class ReasoningNode:
|
| 8 |
+
"""Represents a node in the meta-reasoning DAG"""
|
| 9 |
+
index: int
|
| 10 |
+
content: str
|
| 11 |
+
num_tokens: int
|
| 12 |
+
content_repr: torch.Tensor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ReasoningEdge:
|
| 17 |
+
"""Represents an edge in the meta-reasoning DAG"""
|
| 18 |
+
from_node: int
|
| 19 |
+
to_node: int
|
| 20 |
+
strategy: str
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MetaReasoningDAG:
|
| 24 |
+
"""Represents the meta-reasoning skeleton as a DAG"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, query: str, query_repr: torch.Tensor, query_num_tokens: int):
|
| 27 |
+
self.nodes: List[ReasoningNode] = [ReasoningNode(0, query, query_num_tokens, query_repr)]
|
| 28 |
+
self.edges: List[ReasoningEdge] = []
|
| 29 |
+
self.current_index = 0
|
| 30 |
+
|
| 31 |
+
def add_node(self, content: str, num_tokens: int, content_repr: torch.Tensor) -> int:
|
| 32 |
+
"""Add a new node to the DAG"""
|
| 33 |
+
self.current_index += 1
|
| 34 |
+
node = ReasoningNode(self.current_index, content, num_tokens, content_repr)
|
| 35 |
+
self.nodes.append(node)
|
| 36 |
+
return self.current_index
|
| 37 |
+
|
| 38 |
+
def add_edge(self, from_idx: int, to_idx: int, strategy: str):
|
| 39 |
+
"""Add an edge between two nodes"""
|
| 40 |
+
self.edges.append(ReasoningEdge(from_idx, to_idx, strategy))
|
| 41 |
+
|
| 42 |
+
def get_node_content_repr(self, idx: int) -> torch.Tensor:
|
| 43 |
+
"""Get content representation of a specific node"""
|
| 44 |
+
return self.nodes[idx].content_repr
|
| 45 |
+
|
| 46 |
+
def get_node_content(self, idx: int) -> str:
|
| 47 |
+
"""Get content of a specific node"""
|
| 48 |
+
return self.nodes[idx].content
|
| 49 |
+
|
| 50 |
+
def get_context_repr_up_to(self, idx: int) -> torch.Tensor:
|
| 51 |
+
"""Get context representation up to index idx"""
|
| 52 |
+
context_repr = self.nodes[0].content_repr
|
| 53 |
+
for node in self.nodes[1: idx+1]:
|
| 54 |
+
context_repr += node.content_repr
|
| 55 |
+
context_repr /= (idx + 1)
|
| 56 |
+
return context_repr
|
| 57 |
+
|
| 58 |
+
def get_context_up_to(self, idx: int) -> str:
|
| 59 |
+
"""Get all node contents up to index idx"""
|
| 60 |
+
return "\n".join([node.content for node in self.nodes[:idx+1]])
|
| 61 |
+
|
| 62 |
+
def total_tokens(self) -> int:
|
| 63 |
+
"""Total tokens generated (excluding source node)"""
|
| 64 |
+
return sum(node.num_tokens for node in self.nodes[1:])
|
| 65 |
+
|
| 66 |
+
def to_dict(self) -> dict:
|
| 67 |
+
"""Convert DAG to dictionary for serialization"""
|
| 68 |
+
return {
|
| 69 |
+
"nodes": [
|
| 70 |
+
{"index": n.index, "content": n.content, "num_tokens": n.num_tokens}
|
| 71 |
+
for n in self.nodes
|
| 72 |
+
],
|
| 73 |
+
"edges": [
|
| 74 |
+
{"from": e.from_node, "to": e.to_node, "strategy": e.strategy}
|
| 75 |
+
for e in self.edges
|
| 76 |
+
]
|
| 77 |
+
}
|
automr/data_loader.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DataLoader:
|
| 6 |
+
"""Load and process datasets"""
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def load_data(file_path: str) -> List[Dict[str, str]]:
|
| 10 |
+
"""
|
| 11 |
+
Load data from JSON file
|
| 12 |
+
Expected format: [{"query": "...", "answer": "..."}, ...]
|
| 13 |
+
"""
|
| 14 |
+
with open(file_path, 'r') as f:
|
| 15 |
+
if file_path.endswith('.jsonl'):
|
| 16 |
+
data = [json.loads(line) for line in f]
|
| 17 |
+
else:
|
| 18 |
+
data = json.load(f)
|
| 19 |
+
|
| 20 |
+
# Validate data format
|
| 21 |
+
for item in data:
|
| 22 |
+
if 'query' not in item or 'answer' not in item:
|
| 23 |
+
raise ValueError("Each data item must have 'query' and 'answer' fields")
|
| 24 |
+
|
| 25 |
+
return data
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def load_math_dataset(file_path: str) -> List[Dict[str, str]]:
|
| 29 |
+
"""Load MATH or GSM8K format dataset"""
|
| 30 |
+
return DataLoader.load_data(file_path)
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def load_mmlu_dataset(file_path: str) -> List[Dict[str, str]]:
|
| 34 |
+
"""Load MMLU-Pro format dataset"""
|
| 35 |
+
return DataLoader.load_data(file_path)
|
automr/evaluator.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Tuple
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from .model import AutoMR
|
| 6 |
+
from .config import AutoMRConfig
|
| 7 |
+
from .utils import check_answer_match, save_json, ensure_dir
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AutoMREvaluator:
|
| 11 |
+
"""Evaluator for AutoMR"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, model: AutoMR, config: AutoMRConfig):
|
| 14 |
+
self.model = model
|
| 15 |
+
self.config = config
|
| 16 |
+
ensure_dir(config.results_dir)
|
| 17 |
+
|
| 18 |
+
def evaluate(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
|
| 19 |
+
"""
|
| 20 |
+
Evaluate model on test data
|
| 21 |
+
Returns: (accuracy, detailed_results)
|
| 22 |
+
"""
|
| 23 |
+
print(f"\nEvaluating on {len(test_data)} samples...")
|
| 24 |
+
|
| 25 |
+
self.model.strategy_mlp.eval()
|
| 26 |
+
self.model.strategy_embeddings.eval()
|
| 27 |
+
|
| 28 |
+
correct = 0
|
| 29 |
+
total = 0
|
| 30 |
+
detailed_results = []
|
| 31 |
+
batch_size = self.config.batch_size
|
| 32 |
+
pbar = tqdm(
|
| 33 |
+
range(0, len(test_data), batch_size), desc="Evaluating"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# for item in tqdm(test_data, desc="Evaluating"):
|
| 37 |
+
for i in pbar:
|
| 38 |
+
batch = test_data[i:i + batch_size]
|
| 39 |
+
queries = [item['query'] for item in batch]
|
| 40 |
+
ground_truths = [item['answer'] for item in batch]
|
| 41 |
+
|
| 42 |
+
# Run inference
|
| 43 |
+
pred_answers, dags = self.model.inference(queries,M=1)
|
| 44 |
+
|
| 45 |
+
for query, ground_truth, pred_answer, dag in zip(queries, ground_truths, pred_answers, dags):
|
| 46 |
+
# Check correctness
|
| 47 |
+
is_correct = check_answer_match(
|
| 48 |
+
pred_answer,
|
| 49 |
+
ground_truth,
|
| 50 |
+
self.config.task_type
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if is_correct:
|
| 54 |
+
correct += 1
|
| 55 |
+
total += 1
|
| 56 |
+
|
| 57 |
+
pbar.set_postfix({
|
| 58 |
+
'Acc': f'{correct} / {total}',
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
# Store detailed result
|
| 62 |
+
# result = {
|
| 63 |
+
# 'query': query,
|
| 64 |
+
# 'ground_truth': ground_truth,
|
| 65 |
+
# 'prediction': pred_answer,
|
| 66 |
+
# 'correct': is_correct
|
| 67 |
+
# }
|
| 68 |
+
|
| 69 |
+
# if self.config.save_skeletons:
|
| 70 |
+
# result['skeleton'] = dag.to_dict()
|
| 71 |
+
|
| 72 |
+
# detailed_results.append(result)
|
| 73 |
+
|
| 74 |
+
accuracy = correct / total if total > 0 else 0.0
|
| 75 |
+
|
| 76 |
+
print(f"\nEvaluation Results:")
|
| 77 |
+
print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
|
| 78 |
+
|
| 79 |
+
# Save results
|
| 80 |
+
if self.config.save_predictions:
|
| 81 |
+
results_path = os.path.join(
|
| 82 |
+
self.config.results_dir,
|
| 83 |
+
'evaluation_results.json'
|
| 84 |
+
)
|
| 85 |
+
save_json({
|
| 86 |
+
'accuracy': accuracy,
|
| 87 |
+
'correct': correct,
|
| 88 |
+
'total': total,
|
| 89 |
+
'detailed_results': detailed_results
|
| 90 |
+
}, results_path)
|
| 91 |
+
print(f"Results saved to {results_path}")
|
| 92 |
+
|
| 93 |
+
return accuracy, detailed_results
|
automr/model.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
import random
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
+
from vllm import LLM
|
| 9 |
+
from vllm import SamplingParams
|
| 10 |
+
from .config import AutoMRConfig
|
| 11 |
+
from .strategies import META_STRATEGIES, STRATEGY_LIST
|
| 12 |
+
from .dag import MetaReasoningDAG
|
| 13 |
+
from .utils import extract_answer
|
| 14 |
+
from typing import Dict
|
| 15 |
+
from openai import OpenAI
|
| 16 |
+
|
| 17 |
+
class StrategyMLP(nn.Module):
|
| 18 |
+
"""MLP for sampling meta-reasoning strategies"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, hidden_size: int, num_strategies: int):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# Input: [node_repr, strategy_repr, context_repr]
|
| 23 |
+
self.fc1 = nn.Linear(hidden_size * 3, hidden_size * 2)
|
| 24 |
+
self.fc2 = nn.Linear(hidden_size * 2, hidden_size)
|
| 25 |
+
self.fc3 = nn.Linear(hidden_size, num_strategies)
|
| 26 |
+
self.dropout = nn.Dropout(0.1)
|
| 27 |
+
|
| 28 |
+
def forward(self, node_repr, strategy_repr, context_repr):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
node_repr: [batch, hidden_size]
|
| 32 |
+
strategy_repr: [batch, hidden_size]
|
| 33 |
+
context_repr: [batch, hidden_size]
|
| 34 |
+
Returns:
|
| 35 |
+
logits: [batch, num_strategies]
|
| 36 |
+
"""
|
| 37 |
+
x = torch.cat([node_repr, strategy_repr, context_repr], dim=-1)
|
| 38 |
+
x = F.relu(self.fc1(x))
|
| 39 |
+
x = self.dropout(x)
|
| 40 |
+
x = F.relu(self.fc2(x))
|
| 41 |
+
x = self.dropout(x)
|
| 42 |
+
logits = self.fc3(x)
|
| 43 |
+
return logits
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AutoMR:
|
| 47 |
+
"""AutoMR Framework for Meta-Reasoning Skeleton Search"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: AutoMRConfig):
|
| 50 |
+
self.config = config
|
| 51 |
+
self.device = config.device
|
| 52 |
+
self.token_budget = config.token_budget
|
| 53 |
+
self.model_name_for_api = config.model_name
|
| 54 |
+
|
| 55 |
+
# # Load LLM
|
| 56 |
+
# print(f"Loading Tokenizer and Config: {config.model_name}")
|
| 57 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 58 |
+
|
| 59 |
+
# print(f"Loading vLLM generator: {config.model_name}")
|
| 60 |
+
# self.llm = LLM(
|
| 61 |
+
# config.model_name,
|
| 62 |
+
# dtype=torch.float16,
|
| 63 |
+
# trust_remote_code=True,
|
| 64 |
+
# tensor_parallel_size=config.tensor_parallel_size,
|
| 65 |
+
# gpu_memory_utilization="0.8"
|
| 66 |
+
# )
|
| 67 |
+
|
| 68 |
+
# print(f"Loading LLM Embbedder: {config.model_name}")
|
| 69 |
+
# self.llm_embedder = AutoModelForCausalLM.from_pretrained(
|
| 70 |
+
# config.model_name,
|
| 71 |
+
# torch_dtype=torch.float16,
|
| 72 |
+
# trust_remote_code=True,
|
| 73 |
+
# device_map=None
|
| 74 |
+
# ).to(self.device)
|
| 75 |
+
# self.llm_embedder.eval()
|
| 76 |
+
|
| 77 |
+
# print(f"Loading vLLM Embbedder: {config.model_name}")
|
| 78 |
+
# self.llm_embedder = LLM(
|
| 79 |
+
# config.model_name,
|
| 80 |
+
# dtype=torch.float16,
|
| 81 |
+
# trust_remote_code=True,
|
| 82 |
+
# tensor_parallel_size=config.tensor_parallel_size,
|
| 83 |
+
# gpu_memory_utilization="0.8"
|
| 84 |
+
# task="embed",
|
| 85 |
+
# )
|
| 86 |
+
|
| 87 |
+
print("Connecting to vLLM Generator Server (port 8000)...")
|
| 88 |
+
self.generator_client = OpenAI(
|
| 89 |
+
api_key="vllm",
|
| 90 |
+
base_url="http://localhost:8000/v1"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
print("Connecting to Custom Embedder Server (port 8001)...")
|
| 94 |
+
self.embed_client = OpenAI(
|
| 95 |
+
api_key="vllm",
|
| 96 |
+
base_url="http://localhost:8001/v1"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Strategy components
|
| 101 |
+
self.num_strategies = len(STRATEGY_LIST)
|
| 102 |
+
hidden_size = config.hidden_size
|
| 103 |
+
|
| 104 |
+
self.strategy_embeddings = nn.Embedding(self.num_strategies, hidden_size).to(self.device)
|
| 105 |
+
self.strategy_mlp = StrategyMLP(hidden_size, self.num_strategies).to(self.device)
|
| 106 |
+
|
| 107 |
+
# Strategy mappings
|
| 108 |
+
self.strategy_to_idx = {s: i for i, s in enumerate(STRATEGY_LIST)}
|
| 109 |
+
self.idx_to_strategy = {i: s for i, s in enumerate(STRATEGY_LIST)}
|
| 110 |
+
|
| 111 |
+
# Optimizer
|
| 112 |
+
self.optimizer = torch.optim.Adam(
|
| 113 |
+
list(self.strategy_embeddings.parameters()) +
|
| 114 |
+
list(self.strategy_mlp.parameters()),
|
| 115 |
+
lr=config.learning_rate
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
print("AutoMR initialized successfully")
|
| 119 |
+
|
| 120 |
+
def get_text_representation(self, texts: List[str]) -> Tuple[torch.Tensor]:
|
| 121 |
+
"""
|
| 122 |
+
Get pooled hidden state representations for texts in batch using LLM embedder.
|
| 123 |
+
Args:
|
| 124 |
+
texts: List of input texts
|
| 125 |
+
Returns:
|
| 126 |
+
pooled: Tensor of shape [batch_size, hidden_size]
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# self.tokenizer.padding_side = "left"
|
| 130 |
+
# inputs = self.tokenizer(
|
| 131 |
+
# texts,
|
| 132 |
+
# return_tensors="pt",
|
| 133 |
+
# padding=True,
|
| 134 |
+
# truncation=True,
|
| 135 |
+
# ).to(self.device)
|
| 136 |
+
|
| 137 |
+
# with torch.no_grad():
|
| 138 |
+
# outputs = self.llm(**inputs, output_hidden_states=True)
|
| 139 |
+
# hidden_states = outputs.hidden_states[-1] # [bsz, len, dim]
|
| 140 |
+
# pooled = hidden_states[:, -1, :]
|
| 141 |
+
|
| 142 |
+
# batch_outputs = self.llm_embedder.encode(texts)
|
| 143 |
+
# pooled = []
|
| 144 |
+
# for outputs in batch_outputs:
|
| 145 |
+
# last_hidden_state = outputs.outputs.data[-1,:] # [seq_len, hidden_size]
|
| 146 |
+
# pooled.append(last_hidden_state)
|
| 147 |
+
# pooled = torch.stack(pooled, dim=0).to(self.device) # [batch_size, hidden_size]
|
| 148 |
+
|
| 149 |
+
batch_outputs = self.embed_client.embeddings.create(
|
| 150 |
+
input=texts,
|
| 151 |
+
model=self.model_name_for_api
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
batch_reprs = [
|
| 155 |
+
torch.tensor(data.embedding, device=self.device, dtype=torch.float16)
|
| 156 |
+
for data in batch_outputs.data
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
pooled = torch.stack(batch_reprs, dim=0) # [batch_size, hidden_size]
|
| 160 |
+
return pooled
|
| 161 |
+
|
| 162 |
+
def sample_strategy(
|
| 163 |
+
self,
|
| 164 |
+
batch_node_content_repr: torch.Tensor,
|
| 165 |
+
batch_sampled_strategies: Dict[int, List[int]],
|
| 166 |
+
batch_context_repr: torch.Tensor
|
| 167 |
+
) -> Tuple[List[int], torch.Tensor]:
|
| 168 |
+
"""
|
| 169 |
+
Sample a strategy for each edge (j, i) in batch
|
| 170 |
+
Args:
|
| 171 |
+
batch_node_content_repr: Tensor of shape [batch_size, hidden_size]
|
| 172 |
+
batch_sampled_strategies: Dict of lists of sampled strategy indices
|
| 173 |
+
batch_context_repr: Tensor of shape [batch_size, hidden_size]
|
| 174 |
+
Returns:
|
| 175 |
+
batch_strategy_idx: List of sampled strategy indices
|
| 176 |
+
batch_log_prob: Tensor of log probabilities, shape [batch_size]
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
batch_strategy_repr = []
|
| 181 |
+
|
| 182 |
+
for sampled_strategies in batch_sampled_strategies.values():
|
| 183 |
+
if sampled_strategies:
|
| 184 |
+
sampled_strategies = torch.tensor(sampled_strategies).to(self.device)
|
| 185 |
+
strategy_repr = self.strategy_embeddings(sampled_strategies).mean(dim=0, keepdim=True)
|
| 186 |
+
else:
|
| 187 |
+
strategy_repr = torch.zeros(1, self.config.hidden_size).to(self.device)
|
| 188 |
+
batch_strategy_repr.append(strategy_repr)
|
| 189 |
+
|
| 190 |
+
batch_strategy_repr = torch.cat(batch_strategy_repr, dim=0) # Combine all batch representations
|
| 191 |
+
batch_logits = self.strategy_mlp(batch_node_content_repr, batch_strategy_repr, batch_context_repr)
|
| 192 |
+
batch_probs = F.softmax(batch_logits, dim=-1)
|
| 193 |
+
|
| 194 |
+
dist = torch.distributions.Categorical(batch_probs)
|
| 195 |
+
batch_strategy_idx = dist.sample()
|
| 196 |
+
batch_log_prob = dist.log_prob(batch_strategy_idx).to(self.device)
|
| 197 |
+
|
| 198 |
+
return batch_strategy_idx.cpu().tolist(), batch_log_prob
|
| 199 |
+
|
| 200 |
+
def generate_content(
|
| 201 |
+
self,
|
| 202 |
+
batch_query: List[str],
|
| 203 |
+
batch_context: List[str],
|
| 204 |
+
batch_strategies: List[List[str]],
|
| 205 |
+
batch_remaining_budget: List[int]
|
| 206 |
+
) -> Tuple[List[str], List[int], torch.Tensor]:
|
| 207 |
+
"""
|
| 208 |
+
Generate reasoning content based on selected strategies
|
| 209 |
+
Args:
|
| 210 |
+
batch_query: List of query strings
|
| 211 |
+
batch_context: List of context strings
|
| 212 |
+
batch_strategies: List of lists of strategy names
|
| 213 |
+
batch_remaining_budget: List of remaining token budgets
|
| 214 |
+
Returns:
|
| 215 |
+
batch_generated_texts: List of generated content strings
|
| 216 |
+
batch_num_tokens: List of number of tokens generated
|
| 217 |
+
batch_content_reprs: Tensor of content representations, shape [batch_size, hidden_size]
|
| 218 |
+
"""
|
| 219 |
+
batch_strategy_prompts = [[] for _ in batch_query]
|
| 220 |
+
batch_full_prompt: List[str] = []
|
| 221 |
+
for i, strategies in enumerate(batch_strategies):
|
| 222 |
+
for s in strategies:
|
| 223 |
+
prompt = random.choice(META_STRATEGIES[s])
|
| 224 |
+
batch_strategy_prompts[i].append(prompt)
|
| 225 |
+
|
| 226 |
+
batch_full_prompt.append(f"{batch_context[i]}\n{' '.join(batch_strategy_prompts[i])}\n")
|
| 227 |
+
params_list = []
|
| 228 |
+
|
| 229 |
+
for i in range(len(batch_query)):
|
| 230 |
+
remaining_budget = batch_remaining_budget[i]
|
| 231 |
+
current_max_tokens = min(self.config.max_new_tokens, remaining_budget)
|
| 232 |
+
params_list.append({
|
| 233 |
+
"prompt": batch_full_prompt[i],
|
| 234 |
+
"max_tokens": current_max_tokens,
|
| 235 |
+
})
|
| 236 |
+
|
| 237 |
+
with ThreadPoolExecutor(max_workers=None) as executor:
|
| 238 |
+
batch_outputs = list(executor.map(self.make_api_call, params_list))
|
| 239 |
+
|
| 240 |
+
batch_generated_texts = [output.choices[0].text.strip() for output in batch_outputs]
|
| 241 |
+
batch_num_tokens = [output.usage.completion_tokens for output in batch_outputs]
|
| 242 |
+
batch_content_reprs = self.get_text_representation(batch_generated_texts)
|
| 243 |
+
|
| 244 |
+
return batch_generated_texts, batch_num_tokens, batch_content_reprs
|
| 245 |
+
|
| 246 |
+
def dynamic_skeleton_sampling(self, queries: List[str], M: int) -> Tuple[List[MetaReasoningDAG], torch.Tensor]:
|
| 247 |
+
"""
|
| 248 |
+
Algorithm 1: Dynamic Skeleton Sampling at inference time
|
| 249 |
+
Args:
|
| 250 |
+
queries: List of input query strings
|
| 251 |
+
M: Number of trajectories per query
|
| 252 |
+
Returns:
|
| 253 |
+
batch_dags: List of generated MetaReasoningDAGs
|
| 254 |
+
total_log_probs: Tensor of total log probabilities for each trajectory
|
| 255 |
+
"""
|
| 256 |
+
# === 1. Initialize M*N DAGs ===
|
| 257 |
+
N = len(queries)
|
| 258 |
+
batch_size = N * M
|
| 259 |
+
batch_dags: List[MetaReasoningDAG] = []
|
| 260 |
+
query_reprs = self.get_text_representation(queries)
|
| 261 |
+
for i in range(N):
|
| 262 |
+
for _ in range(M):
|
| 263 |
+
batch_dags.append(
|
| 264 |
+
MetaReasoningDAG(queries[i], query_reprs[i], 0) # we don't count query tokens, set 0
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
total_log_probs = torch.zeros(batch_size).to(self.device)
|
| 268 |
+
# the idx of trajectories that are still active
|
| 269 |
+
active_indices = list(range(batch_size))
|
| 270 |
+
i = 0 # Current topology step (i=1 is the first new node)
|
| 271 |
+
while active_indices:
|
| 272 |
+
i += 1
|
| 273 |
+
sampled_strategies = {dag_idx: [] for dag_idx in active_indices}
|
| 274 |
+
incoming_edges = {dag_idx: [] for dag_idx in active_indices}
|
| 275 |
+
|
| 276 |
+
# Step 1: Determine incoming edges (traverse in reverse order)
|
| 277 |
+
for j in range(i-1, -1, -1):
|
| 278 |
+
node_j_content_reprs = torch.stack([batch_dags[idx].get_node_content_repr(j) for idx in active_indices], dim=0)
|
| 279 |
+
context_reprs = torch.stack([batch_dags[idx].get_context_repr_up_to(i-1) for idx in active_indices], dim=0)
|
| 280 |
+
|
| 281 |
+
strategy_idx, log_prob = self.sample_strategy(
|
| 282 |
+
node_j_content_reprs,
|
| 283 |
+
sampled_strategies,
|
| 284 |
+
context_reprs
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
for k, dag_idx in enumerate(active_indices):
|
| 288 |
+
sampled_strategies[dag_idx].append(strategy_idx[k])
|
| 289 |
+
|
| 290 |
+
total_log_probs[active_indices] += log_prob
|
| 291 |
+
|
| 292 |
+
for dag_idx in active_indices:
|
| 293 |
+
strategy_idx = sampled_strategies[dag_idx][-1]
|
| 294 |
+
strategy_name = self.idx_to_strategy[strategy_idx]
|
| 295 |
+
if strategy_name != "zero":
|
| 296 |
+
incoming_edges[dag_idx].append((j, strategy_name))
|
| 297 |
+
|
| 298 |
+
# Step 2: Check which DAGs are still active
|
| 299 |
+
for dag_idx in active_indices.copy():
|
| 300 |
+
if not incoming_edges[dag_idx]:
|
| 301 |
+
active_indices.remove(dag_idx)
|
| 302 |
+
|
| 303 |
+
if not active_indices:
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
# Step 3: Generate base reasoning content
|
| 307 |
+
batch_strategies = []
|
| 308 |
+
batch_context = []
|
| 309 |
+
batch_query = []
|
| 310 |
+
batch_remaining_budget = []
|
| 311 |
+
for dag_idx in active_indices:
|
| 312 |
+
dag = batch_dags[dag_idx]
|
| 313 |
+
strategies = [edge[1] for edge in incoming_edges[dag_idx]]
|
| 314 |
+
batch_strategies.append(strategies)
|
| 315 |
+
context = dag.get_context_up_to(i-1)
|
| 316 |
+
batch_context.append(context)
|
| 317 |
+
batch_query.append(dag.nodes[0].content)
|
| 318 |
+
batch_remaining_budget.append(self.token_budget - dag.total_tokens())
|
| 319 |
+
|
| 320 |
+
batch_content, batch_num_tokens, batch_content_repr = self.generate_content(
|
| 321 |
+
batch_query,
|
| 322 |
+
batch_context,
|
| 323 |
+
batch_strategies,
|
| 324 |
+
batch_remaining_budget
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Step 4: Update DAGs with new nodes and edges
|
| 328 |
+
for k, dag_idx in enumerate(active_indices):
|
| 329 |
+
dag = batch_dags[dag_idx]
|
| 330 |
+
content = batch_content[k]
|
| 331 |
+
num_tokens = batch_num_tokens[k]
|
| 332 |
+
content_repr = batch_content_repr[k]
|
| 333 |
+
dag.add_node(content, num_tokens, content_repr)
|
| 334 |
+
|
| 335 |
+
for dag_idx in incoming_edges:
|
| 336 |
+
dag = batch_dags[dag_idx]
|
| 337 |
+
for from_j, strategy in incoming_edges[dag_idx]:
|
| 338 |
+
dag.add_edge(from_j, i, strategy)
|
| 339 |
+
# Step 5: Check stopping criteria
|
| 340 |
+
for dag_idx in active_indices.copy():
|
| 341 |
+
dag = batch_dags[dag_idx]
|
| 342 |
+
content = dag.get_node_content(i)
|
| 343 |
+
if not content or "boxed" in content.lower() or dag.total_tokens() >= self.token_budget:
|
| 344 |
+
active_indices.remove(dag_idx)
|
| 345 |
+
|
| 346 |
+
return batch_dags, total_log_probs
|
| 347 |
+
|
| 348 |
+
def extract_answer(self, batch_dags: List[MetaReasoningDAG]) -> List[str]:
|
| 349 |
+
"""Extract final answer from the reasoning DAG"""
|
| 350 |
+
batch_answer_prompts = []
|
| 351 |
+
params_list = []
|
| 352 |
+
for dag in batch_dags:
|
| 353 |
+
full_context = dag.get_context_up_to(len(dag.nodes) - 1)
|
| 354 |
+
|
| 355 |
+
answer_prompt = f"{full_context}\n{random.choice(META_STRATEGIES['Answer'])}\n"
|
| 356 |
+
params_list.append( {
|
| 357 |
+
"prompt": answer_prompt,
|
| 358 |
+
"max_tokens": self.config.max_new_tokens,
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
with ThreadPoolExecutor(max_workers=None) as executor:
|
| 362 |
+
batch_outputs = list(executor.map(self.make_api_call, params_list))
|
| 363 |
+
|
| 364 |
+
batch_answers = [output.choices[0].text.strip() for output in batch_outputs]
|
| 365 |
+
|
| 366 |
+
return batch_answers
|
| 367 |
+
|
| 368 |
+
def inference(self, batch_queries: List[str], M: int) -> Tuple[str, MetaReasoningDAG]:
|
| 369 |
+
"""Run inference on a single query"""
|
| 370 |
+
self.strategy_mlp.eval()
|
| 371 |
+
self.strategy_embeddings.eval()
|
| 372 |
+
|
| 373 |
+
with torch.no_grad():
|
| 374 |
+
batch_dags, _ = self.dynamic_skeleton_sampling(batch_queries, M)
|
| 375 |
+
batch_answers = self.extract_answer(batch_dags)
|
| 376 |
+
|
| 377 |
+
return batch_answers, batch_dags
|
| 378 |
+
|
| 379 |
+
def save_checkpoint(self, path: str):
|
| 380 |
+
"""Save model checkpoint"""
|
| 381 |
+
torch.save({
|
| 382 |
+
'strategy_embeddings': self.strategy_embeddings.state_dict(),
|
| 383 |
+
'strategy_mlp': self.strategy_mlp.state_dict(),
|
| 384 |
+
'optimizer': self.optimizer.state_dict()
|
| 385 |
+
}, path)
|
| 386 |
+
print(f"Checkpoint saved to {path}")
|
| 387 |
+
|
| 388 |
+
def load_checkpoint(self, path: str):
|
| 389 |
+
"""Load model checkpoint"""
|
| 390 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 391 |
+
self.strategy_embeddings.load_state_dict(checkpoint['strategy_embeddings'])
|
| 392 |
+
self.strategy_mlp.load_state_dict(checkpoint['strategy_mlp'])
|
| 393 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
| 394 |
+
print(f"Checkpoint loaded from {path}")
|
| 395 |
+
|
| 396 |
+
def make_api_call(self,params):
|
| 397 |
+
"""Make API call to vLLM server"""
|
| 398 |
+
return self.generator_client.completions.create(
|
| 399 |
+
model=self.model_name_for_api,
|
| 400 |
+
prompt=params["prompt"],
|
| 401 |
+
max_tokens=params["max_tokens"],
|
| 402 |
+
temperature=self.config.temperature,
|
| 403 |
+
top_p=self.config.top_p,
|
| 404 |
+
)
|
automr/strategies.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Meta Reasoning Strategy Prompts (from Table 2 in paper)
|
| 2 |
+
META_STRATEGIES = {
|
| 3 |
+
"Next": [
|
| 4 |
+
"Next,",
|
| 5 |
+
"Then,",
|
| 6 |
+
"Now, let me move on to the next step."
|
| 7 |
+
],
|
| 8 |
+
"Reflect": [
|
| 9 |
+
"Let me consider what part of the reasoning feels least certain, and how can it be examined.",
|
| 10 |
+
"Wait, let me think if there anything missing in the current reasoning.",
|
| 11 |
+
"Let me think does the current line of thought have any error."
|
| 12 |
+
],
|
| 13 |
+
"Explore": [
|
| 14 |
+
"Let me consider which direction of thinking I should explore.",
|
| 15 |
+
"Let me think what potential strategy has not yet been considered that could be the next solution path.",
|
| 16 |
+
"Let me think what possible solution could be tried next."
|
| 17 |
+
],
|
| 18 |
+
"Decompose": [
|
| 19 |
+
"This question is a bit complex, let me think how to decompose it into sub-questions that I can solve.",
|
| 20 |
+
"The question feels too broad, let me think what smaller version could I tackle first.",
|
| 21 |
+
"Let me think if I can express the problem in terms of simpler components or modules.",
|
| 22 |
+
"Let me consider the options one by one." # For multiple choice
|
| 23 |
+
],
|
| 24 |
+
"Summarize": [
|
| 25 |
+
"Let me summarize what have I established so far.",
|
| 26 |
+
"Let me summarize the current state of reasoning process, what's known, unknown, and assumed?",
|
| 27 |
+
"Let me consider if I can captures the essence of the reasoning so far with single sentence."
|
| 28 |
+
],
|
| 29 |
+
"Recall": [
|
| 30 |
+
"Let me think if I have encountered similar problems or if learned knowledge and previous intermediate step can be used here.",
|
| 31 |
+
"Let me think what prior reasoning steps are directly relevant here or this question connect to earlier results.",
|
| 32 |
+
"Let me recall which theorems, rules, or principles from earlier knowledge is related to this question."
|
| 33 |
+
],
|
| 34 |
+
"Answer": [
|
| 35 |
+
"Let me give the answer according to current reasoning context."
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# All strategies including the special "zero" edge type
|
| 40 |
+
STRATEGY_LIST = ["Next", "Reflect", "Explore", "Decompose", "Summarize", "Recall", "Answer", "zero"]
|
automr/trainer.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from .model import AutoMR
|
| 7 |
+
from .config import AutoMRConfig
|
| 8 |
+
from .utils import check_answer_match, ensure_dir, save_json
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AutoMRTrainer:
|
| 13 |
+
"""Trainer for AutoMR using REINFORCE (Algorithm 2 from paper)"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, model: AutoMR, config: AutoMRConfig):
|
| 16 |
+
self.model = model
|
| 17 |
+
self.config = config
|
| 18 |
+
ensure_dir(config.checkpoint_dir)
|
| 19 |
+
|
| 20 |
+
# Track training progress
|
| 21 |
+
self.global_step = 0
|
| 22 |
+
self.best_val_reward = -float('inf')
|
| 23 |
+
self.patience_counter = 0
|
| 24 |
+
self.training_history = {
|
| 25 |
+
'train_loss': [],
|
| 26 |
+
'train_reward': [],
|
| 27 |
+
'val_reward': [],
|
| 28 |
+
'val_accuracy': [],
|
| 29 |
+
'steps': []
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def compute_reward_batch(self, queries: List[str], answers: List[str]) -> Tuple[float, float]:
|
| 33 |
+
"""
|
| 34 |
+
Compute average reward and accuracy on a batch
|
| 35 |
+
Returns: (avg_reward, accuracy)
|
| 36 |
+
"""
|
| 37 |
+
total_reward = 0.0
|
| 38 |
+
correct = 0
|
| 39 |
+
total = len(queries)
|
| 40 |
+
|
| 41 |
+
self.model.strategy_mlp.eval()
|
| 42 |
+
self.model.strategy_embeddings.eval()
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
dags, _ = self.model.dynamic_skeleton_sampling(queries, M=1)
|
| 46 |
+
pred_answers = self.model.extract_answer(dags)
|
| 47 |
+
|
| 48 |
+
is_correct = [check_answer_match(
|
| 49 |
+
pred_answer, answer, self.config.task_type
|
| 50 |
+
) for pred_answer, answer in zip(pred_answers, answers)]
|
| 51 |
+
|
| 52 |
+
rewards = [1.0 if correct else -1.0 for correct in is_correct]
|
| 53 |
+
total_reward += sum(rewards)
|
| 54 |
+
correct += sum(is_correct)
|
| 55 |
+
|
| 56 |
+
avg_reward = total_reward / total if total > 0 else 0.0
|
| 57 |
+
accuracy = correct / total if total > 0 else 0.0
|
| 58 |
+
|
| 59 |
+
return avg_reward, accuracy
|
| 60 |
+
|
| 61 |
+
def validate(self, val_data: List[Dict[str, str]]) -> Tuple[float, float]:
|
| 62 |
+
"""
|
| 63 |
+
Run validation on validation set
|
| 64 |
+
Returns: (avg_reward, accuracy)
|
| 65 |
+
"""
|
| 66 |
+
# Sample validation batch
|
| 67 |
+
val_batch_size = min(self.config.val_batch_size, len(val_data))
|
| 68 |
+
val_batch = random.sample(val_data, val_batch_size)
|
| 69 |
+
|
| 70 |
+
val_queries = [item['query'] for item in val_batch]
|
| 71 |
+
val_answers = [item['answer'] for item in val_batch]
|
| 72 |
+
|
| 73 |
+
avg_reward, accuracy = self.compute_reward_batch(val_queries, val_answers)
|
| 74 |
+
|
| 75 |
+
return avg_reward, accuracy
|
| 76 |
+
|
| 77 |
+
def train_step(self, batch_queries: List[str], batch_answers: List[str]) -> Tuple[float, float]:
|
| 78 |
+
"""
|
| 79 |
+
Single training step using REINFORCE (Equation 4 from paper)
|
| 80 |
+
Returns: (loss, avg_reward)
|
| 81 |
+
"""
|
| 82 |
+
self.model.strategy_mlp.train()
|
| 83 |
+
self.model.strategy_embeddings.train()
|
| 84 |
+
|
| 85 |
+
M = self.config.num_samples_per_query
|
| 86 |
+
loss = []
|
| 87 |
+
rewards_list = []
|
| 88 |
+
# expand answers for M samples per query
|
| 89 |
+
batch_answers = [answer for answer in batch_answers for _ in range(M)]
|
| 90 |
+
batch_dags, batch_log_probs = self.model.dynamic_skeleton_sampling(batch_queries, M)
|
| 91 |
+
# Get prediction
|
| 92 |
+
batch_pred_answers = self.model.extract_answer(batch_dags)
|
| 93 |
+
# Compute reward
|
| 94 |
+
for pred_answer, answer, log_prob in zip(batch_pred_answers, batch_answers, batch_log_probs):
|
| 95 |
+
reward = 1.0 if check_answer_match(
|
| 96 |
+
pred_answer, answer, self.config.task_type
|
| 97 |
+
) else -1.0
|
| 98 |
+
rewards_list.append(reward)
|
| 99 |
+
|
| 100 |
+
# Accumulate gradient (REINFORCE)
|
| 101 |
+
loss.append(-reward * log_prob)
|
| 102 |
+
|
| 103 |
+
# Compute average reward for this batch
|
| 104 |
+
avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
|
| 105 |
+
|
| 106 |
+
# Update parameters
|
| 107 |
+
self.model.optimizer.zero_grad()
|
| 108 |
+
loss = torch.stack(loss).mean()
|
| 109 |
+
|
| 110 |
+
loss.backward()
|
| 111 |
+
torch.nn.utils.clip_grad_norm_(
|
| 112 |
+
list(self.model.strategy_embeddings.parameters()) +
|
| 113 |
+
list(self.model.strategy_mlp.parameters()),
|
| 114 |
+
max_norm=self.config.gradient_clip
|
| 115 |
+
)
|
| 116 |
+
self.model.optimizer.step()
|
| 117 |
+
|
| 118 |
+
return loss.item(), avg_reward
|
| 119 |
+
|
| 120 |
+
def should_stop_early(self) -> bool:
|
| 121 |
+
"""Check if training should stop early"""
|
| 122 |
+
return self.patience_counter >= self.config.early_stopping_patience
|
| 123 |
+
|
| 124 |
+
def save_history(self):
|
| 125 |
+
history_path = os.path.join(self.config.checkpoint_dir, "training_history.json")
|
| 126 |
+
save_json(self.training_history, history_path)
|
| 127 |
+
|
| 128 |
+
def save_checkpoint(self, epoch: int, is_best: bool = False):
|
| 129 |
+
"""Save checkpoint"""
|
| 130 |
+
if self.config.save_best_only and not is_best:
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
checkpoint_name = f"checkpoint_epoch_{epoch}_step_{self.global_step}.pt"
|
| 134 |
+
if is_best:
|
| 135 |
+
checkpoint_name = "best_checkpoint.pt"
|
| 136 |
+
|
| 137 |
+
checkpoint_path = os.path.join(self.config.checkpoint_dir, checkpoint_name)
|
| 138 |
+
|
| 139 |
+
self.model.save_checkpoint(checkpoint_path)
|
| 140 |
+
|
| 141 |
+
# Also save training history
|
| 142 |
+
self.save_history()
|
| 143 |
+
|
| 144 |
+
if is_best:
|
| 145 |
+
print(f" 💾 Best checkpoint saved: {checkpoint_path}")
|
| 146 |
+
|
| 147 |
+
def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
|
| 148 |
+
"""Training loop with validation (Algorithm 2 from paper with validation)"""
|
| 149 |
+
print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
|
| 150 |
+
print(f"Training samples: {len(train_data)}")
|
| 151 |
+
print(f"Validation samples: {len(val_data)}")
|
| 152 |
+
print(f"Batch size: {self.config.batch_size}")
|
| 153 |
+
print(f"Samples per query: {self.config.num_samples_per_query}")
|
| 154 |
+
print(f"Validation every {self.config.val_every_n_steps} steps")
|
| 155 |
+
print(f"Early stopping patience: {self.config.early_stopping_patience}\n")
|
| 156 |
+
|
| 157 |
+
for epoch in range(self.config.num_epochs):
|
| 158 |
+
random.shuffle(train_data)
|
| 159 |
+
epoch_loss = 0.0
|
| 160 |
+
epoch_reward = 0.0
|
| 161 |
+
num_batches = 0
|
| 162 |
+
|
| 163 |
+
pbar = tqdm(
|
| 164 |
+
range(0, len(train_data), self.config.batch_size),
|
| 165 |
+
desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
for i in pbar:
|
| 169 |
+
batch = train_data[i:i+self.config.batch_size]
|
| 170 |
+
batch_queries = [item['query'] for item in batch]
|
| 171 |
+
batch_answers = [item['answer'] for item in batch]
|
| 172 |
+
|
| 173 |
+
# Training step
|
| 174 |
+
loss, avg_reward = self.train_step(batch_queries, batch_answers)
|
| 175 |
+
epoch_loss += loss
|
| 176 |
+
epoch_reward += avg_reward
|
| 177 |
+
num_batches += 1
|
| 178 |
+
self.global_step += 1
|
| 179 |
+
|
| 180 |
+
self.training_history['train_reward'].append(avg_reward)
|
| 181 |
+
self.save_history()
|
| 182 |
+
|
| 183 |
+
pbar.set_postfix({
|
| 184 |
+
'loss': f'{loss:.4f}',
|
| 185 |
+
'reward': f'{avg_reward:.3f}',
|
| 186 |
+
'step': self.global_step
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
# Validation
|
| 190 |
+
if self.global_step % self.config.val_every_n_steps == 0:
|
| 191 |
+
print(f"\n{'='*80}")
|
| 192 |
+
print(f"Validation at Step {self.global_step}")
|
| 193 |
+
print(f"{'='*80}")
|
| 194 |
+
|
| 195 |
+
val_reward, val_accuracy = self.validate(val_data)
|
| 196 |
+
|
| 197 |
+
print(f"Validation Reward: {val_reward:.4f}")
|
| 198 |
+
print(f"Validation Accuracy: {val_accuracy:.2%}")
|
| 199 |
+
|
| 200 |
+
# Record history
|
| 201 |
+
self.training_history['val_reward'].append(val_reward)
|
| 202 |
+
self.training_history['val_accuracy'].append(val_accuracy)
|
| 203 |
+
self.training_history['steps'].append(self.global_step)
|
| 204 |
+
|
| 205 |
+
# Check if this is the best model
|
| 206 |
+
is_best = val_reward > self.best_val_reward
|
| 207 |
+
if is_best:
|
| 208 |
+
print(f"✨ New best validation reward: {val_reward:.4f} (previous: {self.best_val_reward:.4f})")
|
| 209 |
+
self.best_val_reward = val_reward
|
| 210 |
+
self.patience_counter = 0
|
| 211 |
+
self.save_checkpoint(epoch + 1, is_best=True)
|
| 212 |
+
else:
|
| 213 |
+
self.patience_counter += 1
|
| 214 |
+
print(f"No improvement. Patience: {self.patience_counter}/{self.config.early_stopping_patience}")
|
| 215 |
+
|
| 216 |
+
print(f"{'='*80}\n")
|
| 217 |
+
|
| 218 |
+
# Check early stopping
|
| 219 |
+
if self.should_stop_early():
|
| 220 |
+
print(f"\n Early stopping triggered after {self.global_step} steps")
|
| 221 |
+
print(f"Best validation reward: {self.best_val_reward:.4f}")
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
# End of epoch
|
| 225 |
+
avg_epoch_loss = epoch_loss / num_batches
|
| 226 |
+
avg_epoch_reward = epoch_reward / num_batches
|
| 227 |
+
|
| 228 |
+
self.training_history['train_loss'].append(avg_epoch_loss)
|
| 229 |
+
self.training_history['train_reward'].append(avg_epoch_reward)
|
| 230 |
+
|
| 231 |
+
print(f"\n{'='*80}")
|
| 232 |
+
print(f"Epoch {epoch+1} Summary")
|
| 233 |
+
print(f"{'='*80}")
|
| 234 |
+
print(f"Average Loss: {avg_epoch_loss:.4f}")
|
| 235 |
+
print(f"Average Reward: {avg_epoch_reward:.4f}")
|
| 236 |
+
print(f"Best Val Reward: {self.best_val_reward:.4f}")
|
| 237 |
+
print(f"{'='*80}\n")
|
| 238 |
+
|
| 239 |
+
# Save checkpoint at end of epoch (if not save_best_only)
|
| 240 |
+
if not self.config.save_best_only:
|
| 241 |
+
self.save_checkpoint(epoch + 1)
|
| 242 |
+
|
| 243 |
+
print("Training completed!")
|
| 244 |
+
print(f"Best validation reward achieved: {self.best_val_reward:.4f}")
|
automr/utils.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from typing import Any
|
| 5 |
+
import re
|
| 6 |
+
import regex
|
| 7 |
+
from latex2sympy2 import latex2sympy
|
| 8 |
+
from word2number import w2n
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def extract_math_answer(text: str) -> str:
|
| 12 |
+
"""Extract answer from math problem (boxed format)"""
|
| 13 |
+
# Try to find \boxed{...} format
|
| 14 |
+
boxed_pattern = r'\\boxed\{([^}]*)\}'
|
| 15 |
+
matches = re.findall(boxed_pattern, text)
|
| 16 |
+
if matches:
|
| 17 |
+
return matches[-1].strip()
|
| 18 |
+
|
| 19 |
+
# Try to find answer after "answer is" or similar phrases
|
| 20 |
+
answer_patterns = [
|
| 21 |
+
r'answer is[:\s]+([^\n.]+)',
|
| 22 |
+
r'final answer[:\s]+([^\n.]+)',
|
| 23 |
+
r'therefore[,:\s]+([^\n.]+)'
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
for pattern in answer_patterns:
|
| 27 |
+
matches = re.findall(pattern, text.lower())
|
| 28 |
+
if matches:
|
| 29 |
+
return matches[-1].strip()
|
| 30 |
+
|
| 31 |
+
return text.strip()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def extract_multiple_choice_answer(text: str) -> str:
|
| 35 |
+
"""Extract answer from multiple choice (A, B, C, D format)"""
|
| 36 |
+
# Look for single letter answers
|
| 37 |
+
pattern = r'\b([A-D])\b'
|
| 38 |
+
matches = re.findall(pattern, text.upper())
|
| 39 |
+
if matches:
|
| 40 |
+
return matches[-1]
|
| 41 |
+
|
| 42 |
+
return text.strip()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def normalize_answer(answer: str) -> str:
|
| 46 |
+
"""Normalize answer for comparison"""
|
| 47 |
+
answer = answer.strip().lower()
|
| 48 |
+
# Remove common mathematical notation
|
| 49 |
+
answer = answer.replace('$', '').replace('\\', '')
|
| 50 |
+
# Remove extra whitespace
|
| 51 |
+
answer = ' '.join(answer.split())
|
| 52 |
+
return answer
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def check_answer_match(pred: str, ground_truth: str, task_type: str = "math") -> bool:
|
| 56 |
+
"""Check if predicted answer matches ground truth"""
|
| 57 |
+
# if task_type == "math":
|
| 58 |
+
# pred = extract_math_answer(pred)
|
| 59 |
+
# ground_truth = extract_math_answer(ground_truth)
|
| 60 |
+
# elif task_type == "multiple_choice":
|
| 61 |
+
# pred = extract_multiple_choice_answer(pred)
|
| 62 |
+
# ground_truth = extract_multiple_choice_answer(ground_truth)
|
| 63 |
+
|
| 64 |
+
# pred_norm = normalize_answer(pred)
|
| 65 |
+
# gt_norm = normalize_answer(ground_truth)
|
| 66 |
+
|
| 67 |
+
# return pred_norm == gt_norm or pred_norm in gt_norm or gt_norm in pred_norm
|
| 68 |
+
pred = str(pred)
|
| 69 |
+
ground_truth = str(ground_truth)
|
| 70 |
+
return pred == ground_truth or pred in ground_truth or ground_truth in pred
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def ensure_dir(directory: str):
|
| 74 |
+
"""Create directory if it doesn't exist"""
|
| 75 |
+
if not os.path.exists(directory):
|
| 76 |
+
os.makedirs(directory)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_json(data: Any, path: str):
|
| 80 |
+
"""Save data to JSON file"""
|
| 81 |
+
ensure_dir(os.path.dirname(path))
|
| 82 |
+
with open(path, 'w') as f:
|
| 83 |
+
json.dump(data, f, indent=2)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_json(path: str) -> Any:
|
| 87 |
+
"""Load data from JSON file"""
|
| 88 |
+
with open(path, 'r') as f:
|
| 89 |
+
return json.load(f)
|
| 90 |
+
|
| 91 |
+
def _fix_fracs(string):
|
| 92 |
+
substrs = string.split("\\frac")
|
| 93 |
+
new_str = substrs[0]
|
| 94 |
+
if len(substrs) > 1:
|
| 95 |
+
substrs = substrs[1:]
|
| 96 |
+
for substr in substrs:
|
| 97 |
+
new_str += "\\frac"
|
| 98 |
+
if len(substr) > 0 and substr[0] == "{":
|
| 99 |
+
new_str += substr
|
| 100 |
+
else:
|
| 101 |
+
try:
|
| 102 |
+
assert len(substr) >= 2
|
| 103 |
+
except:
|
| 104 |
+
return string
|
| 105 |
+
a = substr[0]
|
| 106 |
+
b = substr[1]
|
| 107 |
+
if b != "{":
|
| 108 |
+
if len(substr) > 2:
|
| 109 |
+
post_substr = substr[2:]
|
| 110 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
| 111 |
+
else:
|
| 112 |
+
new_str += "{" + a + "}{" + b + "}"
|
| 113 |
+
else:
|
| 114 |
+
if len(substr) > 2:
|
| 115 |
+
post_substr = substr[2:]
|
| 116 |
+
new_str += "{" + a + "}" + b + post_substr
|
| 117 |
+
else:
|
| 118 |
+
new_str += "{" + a + "}" + b
|
| 119 |
+
string = new_str
|
| 120 |
+
return string
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _fix_a_slash_b(string):
|
| 124 |
+
if len(string.split("/")) != 2:
|
| 125 |
+
return string
|
| 126 |
+
a = string.split("/")[0]
|
| 127 |
+
b = string.split("/")[1]
|
| 128 |
+
try:
|
| 129 |
+
if "sqrt" not in a:
|
| 130 |
+
a = int(a)
|
| 131 |
+
if "sqrt" not in b:
|
| 132 |
+
b = int(b)
|
| 133 |
+
assert string == "{}/{}".format(a, b)
|
| 134 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
| 135 |
+
return new_string
|
| 136 |
+
except:
|
| 137 |
+
return string
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _fix_sqrt(string):
|
| 141 |
+
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
| 142 |
+
return _string
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def convert_word_number(text: str) -> str:
|
| 146 |
+
try:
|
| 147 |
+
text = str(w2n.word_to_num(text))
|
| 148 |
+
except:
|
| 149 |
+
pass
|
| 150 |
+
return text
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# units mainly from MathQA
|
| 154 |
+
unit_texts = [
|
| 155 |
+
"east",
|
| 156 |
+
"degree",
|
| 157 |
+
"mph",
|
| 158 |
+
"kmph",
|
| 159 |
+
"ft",
|
| 160 |
+
"m sqaure",
|
| 161 |
+
" m east",
|
| 162 |
+
"sq m",
|
| 163 |
+
"deg",
|
| 164 |
+
"mile",
|
| 165 |
+
"q .",
|
| 166 |
+
"monkey",
|
| 167 |
+
"prime",
|
| 168 |
+
"ratio",
|
| 169 |
+
"profit of rs",
|
| 170 |
+
"rd",
|
| 171 |
+
"o",
|
| 172 |
+
"gm",
|
| 173 |
+
"p . m",
|
| 174 |
+
"lb",
|
| 175 |
+
"tile",
|
| 176 |
+
"per",
|
| 177 |
+
"dm",
|
| 178 |
+
"lt",
|
| 179 |
+
"gain",
|
| 180 |
+
"ab",
|
| 181 |
+
"way",
|
| 182 |
+
"west",
|
| 183 |
+
"a .",
|
| 184 |
+
"b .",
|
| 185 |
+
"c .",
|
| 186 |
+
"d .",
|
| 187 |
+
"e .",
|
| 188 |
+
"f .",
|
| 189 |
+
"g .",
|
| 190 |
+
"h .",
|
| 191 |
+
"t",
|
| 192 |
+
"a",
|
| 193 |
+
"h",
|
| 194 |
+
"no change",
|
| 195 |
+
"men",
|
| 196 |
+
"soldier",
|
| 197 |
+
"pie",
|
| 198 |
+
"bc",
|
| 199 |
+
"excess",
|
| 200 |
+
"st",
|
| 201 |
+
"inches",
|
| 202 |
+
"noon",
|
| 203 |
+
"percent",
|
| 204 |
+
"by",
|
| 205 |
+
"gal",
|
| 206 |
+
"kmh",
|
| 207 |
+
"c",
|
| 208 |
+
"acre",
|
| 209 |
+
"rise",
|
| 210 |
+
"a . m",
|
| 211 |
+
"th",
|
| 212 |
+
"π r 2",
|
| 213 |
+
"sq",
|
| 214 |
+
"mark",
|
| 215 |
+
"l",
|
| 216 |
+
"toy",
|
| 217 |
+
"coin",
|
| 218 |
+
"sq . m",
|
| 219 |
+
"gallon",
|
| 220 |
+
"° f",
|
| 221 |
+
"profit",
|
| 222 |
+
"minw",
|
| 223 |
+
"yr",
|
| 224 |
+
"women",
|
| 225 |
+
"feet",
|
| 226 |
+
"am",
|
| 227 |
+
"pm",
|
| 228 |
+
"hr",
|
| 229 |
+
"cu cm",
|
| 230 |
+
"square",
|
| 231 |
+
"v â € ™",
|
| 232 |
+
"are",
|
| 233 |
+
"rupee",
|
| 234 |
+
"rounds",
|
| 235 |
+
"cubic",
|
| 236 |
+
"cc",
|
| 237 |
+
"mtr",
|
| 238 |
+
"s",
|
| 239 |
+
"ohm",
|
| 240 |
+
"number",
|
| 241 |
+
"kmph",
|
| 242 |
+
"day",
|
| 243 |
+
"hour",
|
| 244 |
+
"minute",
|
| 245 |
+
"min",
|
| 246 |
+
"second",
|
| 247 |
+
"man",
|
| 248 |
+
"woman",
|
| 249 |
+
"sec",
|
| 250 |
+
"cube",
|
| 251 |
+
"mt",
|
| 252 |
+
"sq inch",
|
| 253 |
+
"mp",
|
| 254 |
+
"∏ cm ³",
|
| 255 |
+
"hectare",
|
| 256 |
+
"more",
|
| 257 |
+
"sec",
|
| 258 |
+
"unit",
|
| 259 |
+
"cu . m",
|
| 260 |
+
"cm 2",
|
| 261 |
+
"rs .",
|
| 262 |
+
"rs",
|
| 263 |
+
"kg",
|
| 264 |
+
"g",
|
| 265 |
+
"month",
|
| 266 |
+
"km",
|
| 267 |
+
"m",
|
| 268 |
+
"cm",
|
| 269 |
+
"mm",
|
| 270 |
+
"apple",
|
| 271 |
+
"liter",
|
| 272 |
+
"loss",
|
| 273 |
+
"yard",
|
| 274 |
+
"pure",
|
| 275 |
+
"year",
|
| 276 |
+
"increase",
|
| 277 |
+
"decrease",
|
| 278 |
+
"d",
|
| 279 |
+
"less",
|
| 280 |
+
"Surface",
|
| 281 |
+
"litre",
|
| 282 |
+
"pi sq m",
|
| 283 |
+
"s .",
|
| 284 |
+
"metre",
|
| 285 |
+
"meter",
|
| 286 |
+
"inch",
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
unit_texts.extend([t + "s" for t in unit_texts])
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def strip_string(string, skip_unit=False):
|
| 293 |
+
string = str(string).strip()
|
| 294 |
+
# linebreaks
|
| 295 |
+
string = string.replace("\n", "")
|
| 296 |
+
|
| 297 |
+
# right "."
|
| 298 |
+
string = string.rstrip(".")
|
| 299 |
+
|
| 300 |
+
# remove inverse spaces
|
| 301 |
+
# replace \\ with \
|
| 302 |
+
string = string.replace("\\!", "")
|
| 303 |
+
# string = string.replace("\\ ", "")
|
| 304 |
+
# string = string.replace("\\\\", "\\")
|
| 305 |
+
|
| 306 |
+
# matrix
|
| 307 |
+
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
|
| 308 |
+
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
|
| 309 |
+
string = string.replace("bmatrix", "pmatrix")
|
| 310 |
+
|
| 311 |
+
# replace tfrac and dfrac with frac
|
| 312 |
+
string = string.replace("tfrac", "frac")
|
| 313 |
+
string = string.replace("dfrac", "frac")
|
| 314 |
+
string = (string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge"))
|
| 315 |
+
|
| 316 |
+
# remove \left and \right
|
| 317 |
+
string = string.replace("\\left", "")
|
| 318 |
+
string = string.replace("\\right", "")
|
| 319 |
+
string = string.replace("\\{", "{")
|
| 320 |
+
string = string.replace("\\}", "}")
|
| 321 |
+
|
| 322 |
+
# Remove unit: miles, dollars if after is not none
|
| 323 |
+
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
| 324 |
+
if _string != "" and _string != string:
|
| 325 |
+
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
| 326 |
+
string = _string
|
| 327 |
+
|
| 328 |
+
if not skip_unit:
|
| 329 |
+
# Remove unit: texts
|
| 330 |
+
for _ in range(2):
|
| 331 |
+
for unit_text in unit_texts:
|
| 332 |
+
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
| 333 |
+
# the suffix should be either the end of the string or a non-alphanumeric character
|
| 334 |
+
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
| 335 |
+
if _string != "":
|
| 336 |
+
string = _string
|
| 337 |
+
|
| 338 |
+
# Remove circ (degrees)
|
| 339 |
+
string = string.replace("^{\\circ}", "")
|
| 340 |
+
string = string.replace("^\\circ", "")
|
| 341 |
+
|
| 342 |
+
# remove dollar signs
|
| 343 |
+
string = string.replace("\\$", "")
|
| 344 |
+
string = string.replace("$", "")
|
| 345 |
+
string = string.replace("\\(", "").replace("\\)", "")
|
| 346 |
+
|
| 347 |
+
# convert word number to digit
|
| 348 |
+
string = convert_word_number(string)
|
| 349 |
+
|
| 350 |
+
# replace "\\text{...}" to "..."
|
| 351 |
+
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
|
| 352 |
+
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
|
| 353 |
+
string = string.replace(key, "")
|
| 354 |
+
string = string.replace("\\emptyset", r"{}")
|
| 355 |
+
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
|
| 356 |
+
|
| 357 |
+
# remove percentage
|
| 358 |
+
string = string.replace("\\%", "")
|
| 359 |
+
string = string.replace("\%", "")
|
| 360 |
+
string = string.replace("%", "")
|
| 361 |
+
|
| 362 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
| 363 |
+
string = string.replace(" .", " 0.")
|
| 364 |
+
string = string.replace("{.", "{0.")
|
| 365 |
+
|
| 366 |
+
# cdot
|
| 367 |
+
# string = string.replace("\\cdot", "")
|
| 368 |
+
if (string.startswith("{") and string.endswith("}") and string.isalnum() or
|
| 369 |
+
string.startswith("(") and string.endswith(")") and string.isalnum() or
|
| 370 |
+
string.startswith("[") and string.endswith("]") and string.isalnum()):
|
| 371 |
+
string = string[1:-1]
|
| 372 |
+
|
| 373 |
+
# inf
|
| 374 |
+
string = string.replace("infinity", "\\infty")
|
| 375 |
+
if "\\infty" not in string:
|
| 376 |
+
string = string.replace("inf", "\\infty")
|
| 377 |
+
string = string.replace("+\\inity", "\\infty")
|
| 378 |
+
|
| 379 |
+
# and
|
| 380 |
+
string = string.replace("and", "")
|
| 381 |
+
string = string.replace("\\mathbf", "")
|
| 382 |
+
|
| 383 |
+
# use regex to remove \mbox{...}
|
| 384 |
+
string = re.sub(r"\\mbox{.*?}", "", string)
|
| 385 |
+
|
| 386 |
+
# quote
|
| 387 |
+
string.replace("'", "")
|
| 388 |
+
string.replace('"', "")
|
| 389 |
+
|
| 390 |
+
# i, j
|
| 391 |
+
if "j" in string and "i" not in string:
|
| 392 |
+
string = string.replace("j", "i")
|
| 393 |
+
|
| 394 |
+
# replace a.000b where b is not number or b is end, with ab, use regex
|
| 395 |
+
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
|
| 396 |
+
string = re.sub(r"(\d+)\.0*$", r"\1", string)
|
| 397 |
+
|
| 398 |
+
# if empty, return empty string
|
| 399 |
+
if len(string) == 0:
|
| 400 |
+
return string
|
| 401 |
+
if string[0] == ".":
|
| 402 |
+
string = "0" + string
|
| 403 |
+
|
| 404 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
| 405 |
+
if len(string.split("=")) == 2:
|
| 406 |
+
if len(string.split("=")[0]) <= 2:
|
| 407 |
+
string = string.split("=")[1]
|
| 408 |
+
|
| 409 |
+
string = _fix_sqrt(string)
|
| 410 |
+
string = string.replace(" ", "")
|
| 411 |
+
|
| 412 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
| 413 |
+
string = _fix_fracs(string)
|
| 414 |
+
|
| 415 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
| 416 |
+
string = _fix_a_slash_b(string)
|
| 417 |
+
|
| 418 |
+
return string
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
direct_answer_trigger_for_fewshot = ("choice is", "answer is")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def choice_answer_clean(pred: str):
|
| 425 |
+
pred = pred.strip("\n")
|
| 426 |
+
|
| 427 |
+
# Determine if this is ICL, if so, use \n\n to split the first chunk.
|
| 428 |
+
ICL = False
|
| 429 |
+
for trigger in direct_answer_trigger_for_fewshot:
|
| 430 |
+
if pred.count(trigger) > 1:
|
| 431 |
+
ICL = True
|
| 432 |
+
if ICL:
|
| 433 |
+
pred = pred.split("\n\n")[0]
|
| 434 |
+
|
| 435 |
+
# Split the trigger to find the answer.
|
| 436 |
+
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
|
| 437 |
+
if len(preds) > 1:
|
| 438 |
+
answer_flag = True
|
| 439 |
+
pred = preds[-1]
|
| 440 |
+
else:
|
| 441 |
+
answer_flag = False
|
| 442 |
+
|
| 443 |
+
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
| 444 |
+
|
| 445 |
+
# Clean the answer based on the dataset
|
| 446 |
+
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
| 447 |
+
if tmp:
|
| 448 |
+
pred = tmp
|
| 449 |
+
else:
|
| 450 |
+
pred = [pred.strip().strip(".")]
|
| 451 |
+
|
| 452 |
+
if len(pred) == 0:
|
| 453 |
+
pred = ""
|
| 454 |
+
else:
|
| 455 |
+
if answer_flag:
|
| 456 |
+
# choose the first element in list ...
|
| 457 |
+
pred = pred[0]
|
| 458 |
+
else:
|
| 459 |
+
# choose the last e
|
| 460 |
+
pred = pred[-1]
|
| 461 |
+
|
| 462 |
+
# Remove the period at the end, again!
|
| 463 |
+
pred = pred.rstrip(".").rstrip("/")
|
| 464 |
+
|
| 465 |
+
return pred
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def find_box(pred_str: str):
|
| 469 |
+
ans = pred_str.split("boxed")[-1]
|
| 470 |
+
if not ans:
|
| 471 |
+
return ""
|
| 472 |
+
if ans[0] == "{":
|
| 473 |
+
stack = 1
|
| 474 |
+
a = ""
|
| 475 |
+
for c in ans[1:]:
|
| 476 |
+
if c == "{":
|
| 477 |
+
stack += 1
|
| 478 |
+
a += c
|
| 479 |
+
elif c == "}":
|
| 480 |
+
stack -= 1
|
| 481 |
+
if stack == 0:
|
| 482 |
+
break
|
| 483 |
+
a += c
|
| 484 |
+
else:
|
| 485 |
+
a += c
|
| 486 |
+
else:
|
| 487 |
+
a = ans.split("$")[0].strip()
|
| 488 |
+
return a
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def clean_units(pred_str: str):
|
| 492 |
+
"""Clean the units in the number."""
|
| 493 |
+
|
| 494 |
+
def convert_pi_to_number(code_string):
|
| 495 |
+
code_string = code_string.replace("\\pi", "π")
|
| 496 |
+
# Replace \pi or π not preceded by a digit or } with 3.14
|
| 497 |
+
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string)
|
| 498 |
+
# Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14"
|
| 499 |
+
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
|
| 500 |
+
# Handle cases where π is within braces or followed by a multiplication symbol
|
| 501 |
+
# This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
|
| 502 |
+
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
|
| 503 |
+
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
|
| 504 |
+
return code_string
|
| 505 |
+
|
| 506 |
+
pred_str = convert_pi_to_number(pred_str)
|
| 507 |
+
pred_str = pred_str.replace("%", "/100")
|
| 508 |
+
pred_str = pred_str.replace("$", "")
|
| 509 |
+
pred_str = pred_str.replace("¥", "")
|
| 510 |
+
pred_str = pred_str.replace("°C", "")
|
| 511 |
+
pred_str = pred_str.replace(" C", "")
|
| 512 |
+
pred_str = pred_str.replace("°", "")
|
| 513 |
+
return pred_str
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
|
| 517 |
+
if any([option in pred.lower() for option in ["yes", "true"]]):
|
| 518 |
+
pred = "True"
|
| 519 |
+
elif any([option in pred.lower() for option in ["no", "false"]]):
|
| 520 |
+
pred = "False"
|
| 521 |
+
elif any([option in pred.lower() for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]]):
|
| 522 |
+
pass
|
| 523 |
+
else:
|
| 524 |
+
# Some of the models somehow get used to boxed output from pre-training
|
| 525 |
+
if "boxed" in pred:
|
| 526 |
+
pred = find_box(pred)
|
| 527 |
+
|
| 528 |
+
if answer_flag:
|
| 529 |
+
# Extract the numbers out of the string
|
| 530 |
+
pred = pred.split("=")[-1].strip()
|
| 531 |
+
pred = clean_units(pred)
|
| 532 |
+
try:
|
| 533 |
+
tmp = str(latex2sympy(pred))
|
| 534 |
+
pred = str(eval(tmp))
|
| 535 |
+
except Exception:
|
| 536 |
+
if re.match(r"-?[\d\.]+\s\D+$", pred):
|
| 537 |
+
pred = pred.split(" ")[0]
|
| 538 |
+
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
|
| 539 |
+
pred = pred.split(" ")[0]
|
| 540 |
+
else:
|
| 541 |
+
# desparate search over the last number
|
| 542 |
+
preds = re.findall(r"-?\d*\.?\d+", pred)
|
| 543 |
+
if len(preds) >= 1:
|
| 544 |
+
pred = preds[-1]
|
| 545 |
+
else:
|
| 546 |
+
pred = ""
|
| 547 |
+
|
| 548 |
+
return pred
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def extract_answer(pred_str, data_name, use_last_number=True):
|
| 552 |
+
if data_name.lower() == "humaneval":
|
| 553 |
+
pattern = r"### Function Body:\s*\n```python\n(.*?)\n```"
|
| 554 |
+
matches = re.findall(pattern, pred_str, re.DOTALL)
|
| 555 |
+
try:
|
| 556 |
+
return matches[0]
|
| 557 |
+
except IndexError:
|
| 558 |
+
return ""
|
| 559 |
+
elif data_name.lower() == "mmlu":
|
| 560 |
+
if len(pred_str) >= 3 and pred_str[0] == '(' and pred_str[2] == ')':
|
| 561 |
+
return pred_str[1]
|
| 562 |
+
pred_str = pred_str.replace("\u043a\u0438", "")
|
| 563 |
+
|
| 564 |
+
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
| 565 |
+
# minerva_math
|
| 566 |
+
tmp = pred_str.split("final answer is $", 1)[1]
|
| 567 |
+
pred = tmp.split("$. I hope", 1)[0].strip()
|
| 568 |
+
elif "boxed" in pred_str:
|
| 569 |
+
ans = pred_str.split("boxed")[-1]
|
| 570 |
+
if len(ans) == 0:
|
| 571 |
+
return ""
|
| 572 |
+
elif ans[0] == "{":
|
| 573 |
+
stack = 1
|
| 574 |
+
a = ""
|
| 575 |
+
for c in ans[1:]:
|
| 576 |
+
if c == "{":
|
| 577 |
+
stack += 1
|
| 578 |
+
a += c
|
| 579 |
+
elif c == "}":
|
| 580 |
+
stack -= 1
|
| 581 |
+
if stack == 0:
|
| 582 |
+
break
|
| 583 |
+
a += c
|
| 584 |
+
else:
|
| 585 |
+
a += c
|
| 586 |
+
else:
|
| 587 |
+
a = ans.split("$")[0].strip()
|
| 588 |
+
pred = a
|
| 589 |
+
elif "he answer is" in pred_str:
|
| 590 |
+
pred = pred_str.split("he answer is")[-1].strip()
|
| 591 |
+
elif "final answer is" in pred_str:
|
| 592 |
+
pred = pred_str.split("final answer is")[-1].strip()
|
| 593 |
+
elif "答案是" in pred_str:
|
| 594 |
+
# Handle Chinese few-shot multiple choice problem answer extraction
|
| 595 |
+
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
| 596 |
+
else: # use the last number
|
| 597 |
+
if use_last_number:
|
| 598 |
+
pattern = "-?\d*\.?\d+"
|
| 599 |
+
pred = re.findall(pattern, pred_str.replace(",", ""))
|
| 600 |
+
if len(pred) >= 1:
|
| 601 |
+
pred = pred[-1]
|
| 602 |
+
else:
|
| 603 |
+
pred = ""
|
| 604 |
+
else:
|
| 605 |
+
pred = ""
|
| 606 |
+
|
| 607 |
+
# multiple line
|
| 608 |
+
# pred = pred.split("\n")[0]
|
| 609 |
+
pred = re.sub(r"\n\s*", "", pred)
|
| 610 |
+
if pred != "" and pred[0] == ":":
|
| 611 |
+
pred = pred[1:]
|
| 612 |
+
if pred != "" and pred[-1] == ".":
|
| 613 |
+
pred = pred[:-1]
|
| 614 |
+
if pred != "" and pred[-1] == "/":
|
| 615 |
+
pred = pred[:-1]
|
| 616 |
+
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva"])
|
| 617 |
+
|
| 618 |
+
if data_name == 'GPQA' or data_name == 'MMLU':
|
| 619 |
+
if len(pred) >= 3 and pred[0] == '(' and pred[2] == ')':
|
| 620 |
+
pred = pred[1]
|
| 621 |
+
return pred
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
STRIP_EXCEPTIONS = ["carp_en", "minerva"]
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def parse_ground_truth(groudtruth_solution: str, data_name):
|
| 628 |
+
gt_ans = extract_answer(groudtruth_solution, data_name)
|
| 629 |
+
return gt_ans
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def parse_question(example, data_name):
|
| 633 |
+
question = ""
|
| 634 |
+
if data_name == "asdiv":
|
| 635 |
+
question = f"{example['body'].strip()} {example['question'].strip()}"
|
| 636 |
+
elif data_name == "svamp":
|
| 637 |
+
body = example["Body"].strip()
|
| 638 |
+
if not body.endswith("."):
|
| 639 |
+
body = body + "."
|
| 640 |
+
question = f'{body} {example["Question"].strip()}'
|
| 641 |
+
elif data_name == "tabmwp":
|
| 642 |
+
title_str = (f'regarding "{example["table_title"]}" ' if example["table_title"] else "")
|
| 643 |
+
question = f"Read the following table {title_str}and answer a question:\n"
|
| 644 |
+
question += f'{example["table"]}\n{example["question"]}'
|
| 645 |
+
if example["choices"]:
|
| 646 |
+
question += (f' Please select from the following options: {example["choices"]}')
|
| 647 |
+
elif data_name == "carp_en":
|
| 648 |
+
question = example["content"]
|
| 649 |
+
elif data_name == "mmlu_stem":
|
| 650 |
+
options = example["choices"]
|
| 651 |
+
assert len(options) == 4
|
| 652 |
+
for i, (label, option) in enumerate(zip("ABCD", options)):
|
| 653 |
+
options[i] = f"({label}) {str(option).strip()}"
|
| 654 |
+
options = " ".join(options)
|
| 655 |
+
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
|
| 656 |
+
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
|
| 657 |
+
elif data_name == "sat_math":
|
| 658 |
+
options = example["options"].strip()
|
| 659 |
+
assert "A" == options[0]
|
| 660 |
+
options = "(" + options
|
| 661 |
+
for ch in "BCD":
|
| 662 |
+
if f" {ch}) " in options:
|
| 663 |
+
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
|
| 664 |
+
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
|
| 665 |
+
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
|
| 666 |
+
elif "aqua" in data_name:
|
| 667 |
+
options = example["options"]
|
| 668 |
+
choice = "(" + "(".join(options)
|
| 669 |
+
choice = choice.replace("(", " (").replace(")", ") ").strip()
|
| 670 |
+
choice = "\nAnswer Choices: " + choice
|
| 671 |
+
question = example["question"].strip() + choice
|
| 672 |
+
elif data_name == "gaokao_math_qa":
|
| 673 |
+
options_dict = example["options"]
|
| 674 |
+
options = []
|
| 675 |
+
for key in options_dict:
|
| 676 |
+
options.append(f"({key}) {options_dict[key]}")
|
| 677 |
+
options = " ".join(options)
|
| 678 |
+
question = f"{example['question'].strip()}\n选项: {options}"
|
| 679 |
+
else:
|
| 680 |
+
for key in ["question", "problem", "Question", "input"]:
|
| 681 |
+
if key in example:
|
| 682 |
+
question = example[key]
|
| 683 |
+
break
|
| 684 |
+
# assert question != ""
|
| 685 |
+
# Yes or No question
|
| 686 |
+
_, gt_ans = parse_ground_truth(example, data_name)
|
| 687 |
+
if isinstance(gt_ans, str):
|
| 688 |
+
gt_lower = gt_ans.lower()
|
| 689 |
+
if gt_lower in ["true", "false"]:
|
| 690 |
+
question += " (True or False)"
|
| 691 |
+
if gt_lower in ["yes", "no"]:
|
| 692 |
+
question += " (Yes or No)"
|
| 693 |
+
return question.strip()
|
embedder_server.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export ASCEND_RT_VISIBLE_DEVICES=3
|
| 2 |
+
export VLLM_USE_V1=1
|
| 3 |
+
python -m vllm.entrypoints.openai.api_server \
|
| 4 |
+
--model "FreedomIntelligence/openPangu-Embedded-7B" \
|
| 5 |
+
--tensor-parallel-size 1 \
|
| 6 |
+
--port 8001 \
|
| 7 |
+
--host localhost \
|
| 8 |
+
--gpu-memory-utilization 0.4 \
|
| 9 |
+
--trust-remote-code \
|
| 10 |
+
--task embed \
|
| 11 |
+
--dtype bfloat16 \
|
generator_server.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export ASCEND_RT_VISIBLE_DEVICES=1
|
| 2 |
+
export VLLM_USE_V1=1
|
| 3 |
+
python -m vllm.entrypoints.openai.api_server \
|
| 4 |
+
--model "FreedomIntelligence/openPangu-Embedded-7B" \
|
| 5 |
+
--tensor-parallel-size 1 \
|
| 6 |
+
--port 8000 \
|
| 7 |
+
--host localhost \
|
| 8 |
+
--trust-remote-code \
|
| 9 |
+
--dtype bfloat16 \
|
| 10 |
+
--gpu-memory-utilization 0.90 \
|
main.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch_npu
|
| 4 |
+
from automr import AutoMR, AutoMRTrainer, AutoMREvaluator, AutoMRConfig
|
| 5 |
+
from automr.data_loader import DataLoader
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def parse_args():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
|
| 11 |
+
# Mode
|
| 12 |
+
parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval'], help='Mode: train, eval, or train_eval')
|
| 13 |
+
|
| 14 |
+
# Model settings
|
| 15 |
+
parser.add_argument('--model_name', type=str, default='Qwen/Qwen2.5-3B-Instruct', help='Pretrained LLM model name')
|
| 16 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu', 'npu'], help='Device to use')
|
| 17 |
+
parser.add_argument('--token_budget', type=int, default=256, help='Token budget for reasoning')
|
| 18 |
+
parser.add_argument('--hidden_size', type=int, default=4096, help='Hidden size of the model')
|
| 19 |
+
|
| 20 |
+
# Training settings
|
| 21 |
+
parser.add_argument('--learning_rate', type=float, default=5e-4, help='Learning rate')
|
| 22 |
+
parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
|
| 23 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
| 24 |
+
parser.add_argument('--num_samples', type=int, default=1, help='Number of skeletons to sample per query (M)')
|
| 25 |
+
|
| 26 |
+
# Data paths
|
| 27 |
+
parser.add_argument('--train_data', type=str, default='data/train.json', help='Path to training data')
|
| 28 |
+
parser.add_argument('--val_data', type=str, default='data/val.json', help='Path to validation data')
|
| 29 |
+
parser.add_argument('--test_data', type=str, default='data/test.json', help='Path to test data')
|
| 30 |
+
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save checkpoints')
|
| 31 |
+
parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
|
| 32 |
+
|
| 33 |
+
# Checkpoint
|
| 34 |
+
parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint to load')
|
| 35 |
+
|
| 36 |
+
# Task type
|
| 37 |
+
parser.add_argument('--task_type', type=str, default='math', choices=['math', 'multiple_choice'], help='Task type')
|
| 38 |
+
|
| 39 |
+
return parser.parse_args()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
args = parse_args()
|
| 44 |
+
|
| 45 |
+
# Create configuration
|
| 46 |
+
config = AutoMRConfig(
|
| 47 |
+
model_name=args.model_name,
|
| 48 |
+
device=args.device,
|
| 49 |
+
token_budget=args.token_budget,
|
| 50 |
+
learning_rate=args.learning_rate,
|
| 51 |
+
num_epochs=args.num_epochs,
|
| 52 |
+
batch_size=args.batch_size,
|
| 53 |
+
num_samples_per_query=args.num_samples,
|
| 54 |
+
train_data_path=args.train_data,
|
| 55 |
+
val_data_path=args.val_data,
|
| 56 |
+
test_data_path=args.test_data,
|
| 57 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 58 |
+
results_dir=args.results_dir,
|
| 59 |
+
task_type=args.task_type,
|
| 60 |
+
hidden_size=args.hidden_size,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print("="*80)
|
| 64 |
+
print("AutoMR: Automatic Meta-Reasoning Skeleton Search")
|
| 65 |
+
print("="*80)
|
| 66 |
+
print(f"\nConfiguration:")
|
| 67 |
+
print(f" Model: {config.model_name}")
|
| 68 |
+
print(f" Device: {config.device}")
|
| 69 |
+
print(f" Token Budget: {config.token_budget}")
|
| 70 |
+
print(f" Task Type: {config.task_type}")
|
| 71 |
+
print(f" Mode: {args.mode}")
|
| 72 |
+
print("="*80)
|
| 73 |
+
|
| 74 |
+
# Initialize model
|
| 75 |
+
model = AutoMR(config)
|
| 76 |
+
|
| 77 |
+
# Load checkpoint if specified
|
| 78 |
+
if args.load_checkpoint and os.path.exists(args.load_checkpoint):
|
| 79 |
+
model.load_checkpoint(args.load_checkpoint)
|
| 80 |
+
|
| 81 |
+
# Training mode
|
| 82 |
+
if args.mode == 'train':
|
| 83 |
+
print(f"\n{'='*80}")
|
| 84 |
+
print("TRAINING")
|
| 85 |
+
print("="*80)
|
| 86 |
+
|
| 87 |
+
# Load training data
|
| 88 |
+
train_data = DataLoader.load_data(config.train_data_path)
|
| 89 |
+
val_data = DataLoader.load_data(config.val_data_path)
|
| 90 |
+
print(f"Loaded {len(train_data)} training samples from {config.train_data_path}")
|
| 91 |
+
|
| 92 |
+
# Train
|
| 93 |
+
trainer = AutoMRTrainer(model, config)
|
| 94 |
+
trainer.train(train_data, val_data)
|
| 95 |
+
|
| 96 |
+
# Evaluation mode
|
| 97 |
+
elif args.mode == 'eval':
|
| 98 |
+
print(f"\n{'='*80}")
|
| 99 |
+
print("EVALUATION")
|
| 100 |
+
print("="*80)
|
| 101 |
+
|
| 102 |
+
# Load test data
|
| 103 |
+
test_data = DataLoader.load_data(config.test_data_path)
|
| 104 |
+
print(f"Loaded {len(test_data)} test samples from {config.test_data_path}")
|
| 105 |
+
|
| 106 |
+
# Evaluate
|
| 107 |
+
evaluator = AutoMREvaluator(model, config)
|
| 108 |
+
accuracy, results = evaluator.evaluate(test_data)
|
| 109 |
+
|
| 110 |
+
print(f"\n{'='*80}")
|
| 111 |
+
print(f"Final Accuracy: {accuracy:.2%}")
|
| 112 |
+
print("="*80)
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
main()
|
math_train.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VLLM_USE_V1=1 ASCEND_RT_VISIBLE_DEVICES=0 python main.py --mode train \
|
| 2 |
+
--device npu \
|
| 3 |
+
--model_name "FreedomIntelligence/openPangu-Embedded-7B" \
|
| 4 |
+
--train_data processed_data/MATH/train.jsonl \
|
| 5 |
+
--val_data processed_data/MATH/val.jsonl \
|
| 6 |
+
--num_epochs 5 \
|
| 7 |
+
--batch_size 8 \
|
| 8 |
+
--num_samples 4 \
|
| 9 |
+
--token_budget 4096 \
|
| 10 |
+
--checkpoint_dir checkpoints/MATH/pangu \
|
| 11 |
+
--task_type math \
|
| 12 |
+
|