Alexhf825 commited on
Commit
fd3b506
·
verified ·
1 Parent(s): ae516e5

Upload folder using huggingface_hub

Browse files
automr/config.py CHANGED
@@ -18,6 +18,8 @@ class AutoMRConfig:
18
  batch_size: int = 8
19
  num_samples_per_query: int = 4 # M in paper
20
  gradient_clip: float = 1.0
 
 
21
 
22
  # Validation settings
23
  val_every_n_steps: int = 100 # Alpha in the requirement - validate every N steps
@@ -25,7 +27,7 @@ class AutoMRConfig:
25
  early_stopping_patience: int = 5 # Stop if no improvement for N validations
26
 
27
  # Generation settings
28
- max_new_tokens: int = 1024
29
  temperature: float = 0.01
30
  top_p: float = 0.9
31
 
 
18
  batch_size: int = 8
19
  num_samples_per_query: int = 4 # M in paper
20
  gradient_clip: float = 1.0
21
+ initial_baseline: float = 0.0 # Initial value for REINFORCE baseline
22
+ baseline_momentum: float = 0.9 # Momentum for baseline update
23
 
24
  # Validation settings
25
  val_every_n_steps: int = 100 # Alpha in the requirement - validate every N steps
 
27
  early_stopping_patience: int = 5 # Stop if no improvement for N validations
28
 
29
  # Generation settings
30
+ max_new_tokens: int = 4096
31
  temperature: float = 0.01
32
  top_p: float = 0.9
33
 
automr/dag.py CHANGED
@@ -57,7 +57,7 @@ class MetaReasoningDAG:
57
 
58
  def get_context_up_to(self, idx: int) -> str:
59
  """Get all node contents up to index idx"""
60
- return "\n".join([node.content for node in self.nodes[:idx+1]])
61
 
62
  def total_tokens(self) -> int:
63
  """Total tokens generated (excluding source node)"""
 
57
 
58
  def get_context_up_to(self, idx: int) -> str:
59
  """Get all node contents up to index idx"""
60
+ return "\n".join([f"step: {node.index}: {node.content}" for node in self.nodes[:idx+1]])
61
 
62
  def total_tokens(self) -> int:
63
  """Total tokens generated (excluding source node)"""
automr/evaluator.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import List, Dict, Tuple
2
  from tqdm import tqdm
3
  import os
 
4
 
5
  from .model import AutoMR
6
  from .config import AutoMRConfig
@@ -15,79 +16,71 @@ class AutoMREvaluator:
15
  self.config = config
16
  ensure_dir(config.results_dir)
17
 
18
- def evaluate(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
19
- """
20
- Evaluate model on test data
21
- Returns: (accuracy, detailed_results)
22
- """
23
- print(f"\nEvaluating on {len(test_data)} samples...")
24
-
25
  self.model.strategy_mlp.eval()
26
  self.model.strategy_embeddings.eval()
27
-
 
 
 
 
 
 
28
  correct = 0
29
  total = 0
30
- detailed_results = []
31
- batch_size = self.config.batch_size
32
- pbar = tqdm(
33
- range(0, len(test_data), batch_size), desc="Evaluating"
34
- )
35
-
36
- # for item in tqdm(test_data, desc="Evaluating"):
37
- for i in pbar:
38
- batch = test_data[i:i + batch_size]
39
- queries = [item['query'] for item in batch]
40
- ground_truths = [item['answer'] for item in batch]
41
-
42
- # Run inference
43
- pred_answers, dags = self.model.inference(queries,M=1)
44
-
45
- for query, ground_truth, pred_answer, dag in zip(queries, ground_truths, pred_answers, dags):
46
- # Check correctness
47
- is_correct = check_answer_match(
48
- pred_answer,
49
- ground_truth,
50
- self.config.task_type
51
- )
52
-
53
- if is_correct:
54
- correct += 1
55
- total += 1
56
-
57
- pbar.set_postfix({
58
- 'Acc': f'{correct} / {total}',
59
- })
60
-
61
- # Store detailed result
62
  # result = {
63
- # 'query': query,
64
- # 'ground_truth': ground_truth,
65
- # 'prediction': pred_answer,
66
- # 'correct': is_correct
67
  # }
68
-
69
  # if self.config.save_skeletons:
70
- # result['skeleton'] = dag.to_dict()
71
-
72
  # detailed_results.append(result)
73
-
74
  accuracy = correct / total if total > 0 else 0.0
75
-
76
  print(f"\nEvaluation Results:")
77
  print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
78
-
79
- # Save results
80
  if self.config.save_predictions:
81
  results_path = os.path.join(
82
  self.config.results_dir,
83
- 'evaluation_results.json'
 
 
 
 
 
 
 
 
 
84
  )
85
- save_json({
86
- 'accuracy': accuracy,
87
- 'correct': correct,
88
- 'total': total,
89
- 'detailed_results': detailed_results
90
- }, results_path)
91
  print(f"Results saved to {results_path}")
92
-
93
- return accuracy, detailed_results
 
 
 
 
 
1
  from typing import List, Dict, Tuple
2
  from tqdm import tqdm
3
  import os
4
+ import asyncio
5
 
6
  from .model import AutoMR
7
  from .config import AutoMRConfig
 
16
  self.config = config
17
  ensure_dir(config.results_dir)
18
 
19
+ async def evaluate_async(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
20
+ """Async evaluation: send all queries in a single batch to vLLM."""
21
+ print(f"\nEvaluating on {len(test_data)} samples (async, single batch)...")
22
+
 
 
 
23
  self.model.strategy_mlp.eval()
24
  self.model.strategy_embeddings.eval()
25
+
26
+ queries = [item['query'] for item in test_data]
27
+ ground_truths = [item['answer'] for item in test_data]
28
+
29
+ # One-shot async sampling over all queries, M=1
30
+ pred_answers, _ = await self.model.sample_batch(queries, M=1)
31
+
32
  correct = 0
33
  total = 0
34
+ detailed_results: List[Dict] = []
35
+
36
+ for query, ground_truth, pred_answer in tqdm(
37
+ zip(queries, ground_truths, pred_answers),
38
+ total=len(queries),
39
+ desc="Evaluating",
40
+ ):
41
+ is_correct = check_answer_match(
42
+ pred_answer,
43
+ ground_truth,
44
+ self.config.task_type,
45
+ )
46
+ if is_correct:
47
+ correct += 1
48
+ total += 1
49
+
50
+ # 可选:收集详细结果(目前默认为空,保持文件结构)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # result = {
52
+ # 'query': query,
53
+ # 'ground_truth': ground_truth,
54
+ # 'prediction': pred_answer,
55
+ # 'correct': is_correct,
56
  # }
 
57
  # if self.config.save_skeletons:
58
+ # result['skeleton'] = None
 
59
  # detailed_results.append(result)
60
+
61
  accuracy = correct / total if total > 0 else 0.0
62
+
63
  print(f"\nEvaluation Results:")
64
  print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
65
+
 
66
  if self.config.save_predictions:
67
  results_path = os.path.join(
68
  self.config.results_dir,
69
+ 'evaluation_results.json',
70
+ )
71
+ save_json(
72
+ {
73
+ 'accuracy': accuracy,
74
+ 'correct': correct,
75
+ 'total': total,
76
+ 'detailed_results': detailed_results,
77
+ },
78
+ results_path,
79
  )
 
 
 
 
 
 
80
  print(f"Results saved to {results_path}")
81
+
82
+ return accuracy, detailed_results
83
+
84
+ def evaluate(self, test_data: List[Dict[str, str]]) -> Tuple[float, List[Dict]]:
85
+ """Synchronous wrapper for async evaluation, for use in main.py."""
86
+ return asyncio.run(self.evaluate_async(test_data))
automr/model.py CHANGED
@@ -1,22 +1,18 @@
1
- from concurrent.futures import ThreadPoolExecutor
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from typing import List, Tuple
6
  import random
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- from vllm import LLM
9
- from vllm import SamplingParams
 
10
  from .config import AutoMRConfig
11
  from .strategies import META_STRATEGIES, STRATEGY_LIST
12
  from .dag import MetaReasoningDAG
13
- from .utils import extract_answer
14
- from typing import Dict
15
- from openai import OpenAI
16
 
17
  class StrategyMLP(nn.Module):
18
  """MLP for sampling meta-reasoning strategies"""
19
-
20
  def __init__(self, hidden_size: int, num_strategies: int):
21
  super().__init__()
22
  # Input: [node_repr, strategy_repr, context_repr]
@@ -26,14 +22,6 @@ class StrategyMLP(nn.Module):
26
  self.dropout = nn.Dropout(0.1)
27
 
28
  def forward(self, node_repr, strategy_repr, context_repr):
29
- """
30
- Args:
31
- node_repr: [batch, hidden_size]
32
- strategy_repr: [batch, hidden_size]
33
- context_repr: [batch, hidden_size]
34
- Returns:
35
- logits: [batch, num_strategies]
36
- """
37
  x = torch.cat([node_repr, strategy_repr, context_repr], dim=-1)
38
  x = F.relu(self.fc1(x))
39
  x = self.dropout(x)
@@ -44,361 +32,250 @@ class StrategyMLP(nn.Module):
44
 
45
 
46
  class AutoMR:
47
- """AutoMR Framework for Meta-Reasoning Skeleton Search"""
48
 
49
  def __init__(self, config: AutoMRConfig):
50
  self.config = config
51
  self.device = config.device
52
  self.token_budget = config.token_budget
53
- self.model_name_for_api = config.model_name
54
 
55
- # # Load LLM
56
- # print(f"Loading Tokenizer and Config: {config.model_name}")
57
- # self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
58
-
59
- # print(f"Loading vLLM generator: {config.model_name}")
60
- # self.llm = LLM(
61
- # config.model_name,
62
- # dtype=torch.float16,
63
- # trust_remote_code=True,
64
- # tensor_parallel_size=config.tensor_parallel_size,
65
- # gpu_memory_utilization="0.8"
66
- # )
67
 
68
- # print(f"Loading LLM Embbedder: {config.model_name}")
69
- # self.llm_embedder = AutoModelForCausalLM.from_pretrained(
70
- # config.model_name,
71
- # torch_dtype=torch.float16,
72
- # trust_remote_code=True,
73
- # device_map=None
74
- # ).to(self.device)
75
- # self.llm_embedder.eval()
76
-
77
- # print(f"Loading vLLM Embbedder: {config.model_name}")
78
- # self.llm_embedder = LLM(
79
- # config.model_name,
80
- # dtype=torch.float16,
81
- # trust_remote_code=True,
82
- # tensor_parallel_size=config.tensor_parallel_size,
83
- # gpu_memory_utilization="0.8"
84
- # task="embed",
85
- # )
86
-
87
- print("Connecting to vLLM Generator Server (port 8000)...")
88
- self.generator_client = OpenAI(
89
- api_key="vllm",
90
- base_url="http://localhost:8000/v1"
91
- )
92
 
93
- print("Connecting to Custom Embedder Server (port 8001)...")
94
- self.embed_client = OpenAI(
95
- api_key="vllm",
96
- base_url="http://localhost:8001/v1"
97
- )
98
-
99
 
100
- # Strategy components
101
  self.num_strategies = len(STRATEGY_LIST)
102
  hidden_size = config.hidden_size
103
 
104
  self.strategy_embeddings = nn.Embedding(self.num_strategies, hidden_size).to(self.device)
105
  self.strategy_mlp = StrategyMLP(hidden_size, self.num_strategies).to(self.device)
106
 
107
- # Strategy mappings
108
- self.strategy_to_idx = {s: i for i, s in enumerate(STRATEGY_LIST)}
109
  self.idx_to_strategy = {i: s for i, s in enumerate(STRATEGY_LIST)}
 
110
 
111
- # Optimizer
112
  self.optimizer = torch.optim.Adam(
113
  list(self.strategy_embeddings.parameters()) +
114
  list(self.strategy_mlp.parameters()),
115
  lr=config.learning_rate
116
  )
117
 
118
- print("AutoMR initialized successfully")
119
-
120
- def get_text_representation(self, texts: List[str]) -> Tuple[torch.Tensor]:
121
- """
122
- Get pooled hidden state representations for texts in batch using LLM embedder.
123
- Args:
124
- texts: List of input texts
125
- Returns:
126
- pooled: Tensor of shape [batch_size, hidden_size]
127
- """
128
-
129
- # self.tokenizer.padding_side = "left"
130
- # inputs = self.tokenizer(
131
- # texts,
132
- # return_tensors="pt",
133
- # padding=True,
134
- # truncation=True,
135
- # ).to(self.device)
136
 
137
- # with torch.no_grad():
138
- # outputs = self.llm(**inputs, output_hidden_states=True)
139
- # hidden_states = outputs.hidden_states[-1] # [bsz, len, dim]
140
- # pooled = hidden_states[:, -1, :]
141
-
142
- # batch_outputs = self.llm_embedder.encode(texts)
143
- # pooled = []
144
- # for outputs in batch_outputs:
145
- # last_hidden_state = outputs.outputs.data[-1,:] # [seq_len, hidden_size]
146
- # pooled.append(last_hidden_state)
147
- # pooled = torch.stack(pooled, dim=0).to(self.device) # [batch_size, hidden_size]
148
 
149
- batch_outputs = self.embed_client.embeddings.create(
150
- input=texts,
151
- model=self.model_name_for_api
152
- )
153
-
154
- batch_reprs = [
155
- torch.tensor(data.embedding, device=self.device, dtype=torch.float16)
156
- for data in batch_outputs.data
157
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- pooled = torch.stack(batch_reprs, dim=0) # [batch_size, hidden_size]
160
- return pooled
161
-
162
- def sample_strategy(
163
- self,
164
- batch_node_content_repr: torch.Tensor,
165
- batch_sampled_strategies: Dict[int, List[int]],
166
- batch_context_repr: torch.Tensor
167
- ) -> Tuple[List[int], torch.Tensor]:
168
- """
169
- Sample a strategy for each edge (j, i) in batch
170
  Args:
171
- batch_node_content_repr: Tensor of shape [batch_size, hidden_size]
172
- batch_sampled_strategies: Dict of lists of sampled strategy indices
173
- batch_context_repr: Tensor of shape [batch_size, hidden_size]
174
- Returns:
175
- batch_strategy_idx: List of sampled strategy indices
176
- batch_log_prob: Tensor of log probabilities, shape [batch_size]
177
  """
 
 
 
 
 
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- batch_strategy_repr = []
181
-
182
- for sampled_strategies in batch_sampled_strategies.values():
183
- if sampled_strategies:
184
- sampled_strategies = torch.tensor(sampled_strategies).to(self.device)
185
- strategy_repr = self.strategy_embeddings(sampled_strategies).mean(dim=0, keepdim=True)
186
- else:
187
- strategy_repr = torch.zeros(1, self.config.hidden_size).to(self.device)
188
- batch_strategy_repr.append(strategy_repr)
189
-
190
- batch_strategy_repr = torch.cat(batch_strategy_repr, dim=0) # Combine all batch representations
191
- batch_logits = self.strategy_mlp(batch_node_content_repr, batch_strategy_repr, batch_context_repr)
192
- batch_probs = F.softmax(batch_logits, dim=-1)
193
-
194
- dist = torch.distributions.Categorical(batch_probs)
195
- batch_strategy_idx = dist.sample()
196
- batch_log_prob = dist.log_prob(batch_strategy_idx).to(self.device)
197
-
198
- return batch_strategy_idx.cpu().tolist(), batch_log_prob
199
 
200
- def generate_content(
201
- self,
202
- batch_query: List[str],
203
- batch_context: List[str],
204
- batch_strategies: List[List[str]],
205
- batch_remaining_budget: List[int]
206
- ) -> Tuple[List[str], List[int], torch.Tensor]:
207
- """
208
- Generate reasoning content based on selected strategies
209
- Args:
210
- batch_query: List of query strings
211
- batch_context: List of context strings
212
- batch_strategies: List of lists of strategy names
213
- batch_remaining_budget: List of remaining token budgets
214
  Returns:
215
- batch_generated_texts: List of generated content strings
216
- batch_num_tokens: List of number of tokens generated
217
- batch_content_reprs: Tensor of content representations, shape [batch_size, hidden_size]
218
  """
219
- batch_strategy_prompts = [[] for _ in batch_query]
220
- batch_full_prompt: List[str] = []
221
- for i, strategies in enumerate(batch_strategies):
222
- for s in strategies:
223
- prompt = random.choice(META_STRATEGIES[s])
224
- batch_strategy_prompts[i].append(prompt)
225
-
226
- batch_full_prompt.append(f"{batch_context[i]}\n{' '.join(batch_strategy_prompts[i])}\n")
227
- params_list = []
228
 
229
- for i in range(len(batch_query)):
230
- remaining_budget = batch_remaining_budget[i]
231
- current_max_tokens = min(self.config.max_new_tokens, remaining_budget)
232
- params_list.append({
233
- "prompt": batch_full_prompt[i],
234
- "max_tokens": current_max_tokens,
235
- })
236
 
237
- with ThreadPoolExecutor(max_workers=None) as executor:
238
- batch_outputs = list(executor.map(self.make_api_call, params_list))
239
-
240
- batch_generated_texts = [output.choices[0].text.strip() for output in batch_outputs]
241
- batch_num_tokens = [output.usage.completion_tokens for output in batch_outputs]
242
- batch_content_reprs = self.get_text_representation(batch_generated_texts)
243
-
244
- return batch_generated_texts, batch_num_tokens, batch_content_reprs
245
-
246
- def dynamic_skeleton_sampling(self, queries: List[str], M: int) -> Tuple[List[MetaReasoningDAG], torch.Tensor]:
247
- """
248
- Algorithm 1: Dynamic Skeleton Sampling at inference time
249
- Args:
250
- queries: List of input query strings
251
- M: Number of trajectories per query
252
- Returns:
253
- batch_dags: List of generated MetaReasoningDAGs
254
- total_log_probs: Tensor of total log probabilities for each trajectory
255
- """
256
- # === 1. Initialize M*N DAGs ===
257
- N = len(queries)
258
- batch_size = N * M
259
- batch_dags: List[MetaReasoningDAG] = []
260
- query_reprs = self.get_text_representation(queries)
261
- for i in range(N):
262
- for _ in range(M):
263
- batch_dags.append(
264
- MetaReasoningDAG(queries[i], query_reprs[i], 0) # we don't count query tokens, set 0
265
- )
266
-
267
- total_log_probs = torch.zeros(batch_size).to(self.device)
268
- # the idx of trajectories that are still active
269
- active_indices = list(range(batch_size))
270
- i = 0 # Current topology step (i=1 is the first new node)
271
- while active_indices:
272
- i += 1
273
- sampled_strategies = {dag_idx: [] for dag_idx in active_indices}
274
- incoming_edges = {dag_idx: [] for dag_idx in active_indices}
275
 
276
- # Step 1: Determine incoming edges (traverse in reverse order)
277
- for j in range(i-1, -1, -1):
278
- node_j_content_reprs = torch.stack([batch_dags[idx].get_node_content_repr(j) for idx in active_indices], dim=0)
279
- context_reprs = torch.stack([batch_dags[idx].get_context_repr_up_to(i-1) for idx in active_indices], dim=0)
280
-
281
- strategy_idx, log_prob = self.sample_strategy(
282
- node_j_content_reprs,
283
- sampled_strategies,
284
- context_reprs
285
  )
286
-
287
- for k, dag_idx in enumerate(active_indices):
288
- sampled_strategies[dag_idx].append(strategy_idx[k])
289
-
290
- total_log_probs[active_indices] += log_prob
291
-
292
- for dag_idx in active_indices:
293
- strategy_idx = sampled_strategies[dag_idx][-1]
294
- strategy_name = self.idx_to_strategy[strategy_idx]
295
- if strategy_name != "zero":
296
- incoming_edges[dag_idx].append((j, strategy_name))
297
-
298
- # Step 2: Check which DAGs are still active
299
- for dag_idx in active_indices.copy():
300
- if not incoming_edges[dag_idx]:
301
- active_indices.remove(dag_idx)
302
-
303
- if not active_indices:
304
  break
305
 
306
- # Step 3: Generate base reasoning content
307
- batch_strategies = []
308
- batch_context = []
309
- batch_query = []
310
- batch_remaining_budget = []
311
- for dag_idx in active_indices:
312
- dag = batch_dags[dag_idx]
313
- strategies = [edge[1] for edge in incoming_edges[dag_idx]]
314
- batch_strategies.append(strategies)
315
- context = dag.get_context_up_to(i-1)
316
- batch_context.append(context)
317
- batch_query.append(dag.nodes[0].content)
318
- batch_remaining_budget.append(self.token_budget - dag.total_tokens())
319
-
320
- batch_content, batch_num_tokens, batch_content_repr = self.generate_content(
321
- batch_query,
322
- batch_context,
323
- batch_strategies,
324
- batch_remaining_budget
325
- )
326
 
327
- # Step 4: Update DAGs with new nodes and edges
328
- for k, dag_idx in enumerate(active_indices):
329
- dag = batch_dags[dag_idx]
330
- content = batch_content[k]
331
- num_tokens = batch_num_tokens[k]
332
- content_repr = batch_content_repr[k]
333
- dag.add_node(content, num_tokens, content_repr)
334
-
335
- for dag_idx in incoming_edges:
336
- dag = batch_dags[dag_idx]
337
- for from_j, strategy in incoming_edges[dag_idx]:
338
- dag.add_edge(from_j, i, strategy)
339
- # Step 5: Check stopping criteria
340
- for dag_idx in active_indices.copy():
341
- dag = batch_dags[dag_idx]
342
- content = dag.get_node_content(i)
343
- if not content or "boxed" in content.lower() or dag.total_tokens() >= self.token_budget:
344
- active_indices.remove(dag_idx)
345
-
346
- return batch_dags, total_log_probs
347
-
348
- def extract_answer(self, batch_dags: List[MetaReasoningDAG]) -> List[str]:
349
- """Extract final answer from the reasoning DAG"""
350
- batch_answer_prompts = []
351
- params_list = []
352
- for dag in batch_dags:
353
- full_context = dag.get_context_up_to(len(dag.nodes) - 1)
354
-
355
- answer_prompt = f"{full_context}\n{random.choice(META_STRATEGIES['Answer'])}\n"
356
- params_list.append( {
357
- "prompt": answer_prompt,
358
- "max_tokens": self.config.max_new_tokens,
359
- })
360
-
361
- with ThreadPoolExecutor(max_workers=None) as executor:
362
- batch_outputs = list(executor.map(self.make_api_call, params_list))
363
-
364
- batch_answers = [output.choices[0].text.strip() for output in batch_outputs]
365
-
366
- return batch_answers
367
-
368
- def inference(self, batch_queries: List[str], M: int) -> Tuple[str, MetaReasoningDAG]:
369
- """Run inference on a single query"""
370
- self.strategy_mlp.eval()
371
- self.strategy_embeddings.eval()
372
-
373
- with torch.no_grad():
374
- batch_dags, _ = self.dynamic_skeleton_sampling(batch_queries, M)
375
- batch_answers = self.extract_answer(batch_dags)
376
-
377
- return batch_answers, batch_dags
378
-
379
- def save_checkpoint(self, path: str):
380
- """Save model checkpoint"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  torch.save({
382
  'strategy_embeddings': self.strategy_embeddings.state_dict(),
383
  'strategy_mlp': self.strategy_mlp.state_dict(),
384
  'optimizer': self.optimizer.state_dict()
385
  }, path)
386
- print(f"Checkpoint saved to {path}")
387
-
388
- def load_checkpoint(self, path: str):
389
- """Load model checkpoint"""
390
- checkpoint = torch.load(path, map_location=self.device)
391
- self.strategy_embeddings.load_state_dict(checkpoint['strategy_embeddings'])
392
- self.strategy_mlp.load_state_dict(checkpoint['strategy_mlp'])
393
- self.optimizer.load_state_dict(checkpoint['optimizer'])
394
- print(f"Checkpoint loaded from {path}")
395
 
396
- def make_api_call(self,params):
397
- """Make API call to vLLM server"""
398
- return self.generator_client.completions.create(
399
- model=self.model_name_for_api,
400
- prompt=params["prompt"],
401
- max_tokens=params["max_tokens"],
402
- temperature=self.config.temperature,
403
- top_p=self.config.top_p,
404
- )
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ import asyncio
5
  import random
6
+ from typing import List, Tuple, Dict, Any
7
+ from openai import AsyncOpenAI
8
+ from tqdm.asyncio import tqdm_asyncio
9
+
10
  from .config import AutoMRConfig
11
  from .strategies import META_STRATEGIES, STRATEGY_LIST
12
  from .dag import MetaReasoningDAG
 
 
 
13
 
14
  class StrategyMLP(nn.Module):
15
  """MLP for sampling meta-reasoning strategies"""
 
16
  def __init__(self, hidden_size: int, num_strategies: int):
17
  super().__init__()
18
  # Input: [node_repr, strategy_repr, context_repr]
 
22
  self.dropout = nn.Dropout(0.1)
23
 
24
  def forward(self, node_repr, strategy_repr, context_repr):
 
 
 
 
 
 
 
 
25
  x = torch.cat([node_repr, strategy_repr, context_repr], dim=-1)
26
  x = F.relu(self.fc1(x))
27
  x = self.dropout(x)
 
32
 
33
 
34
  class AutoMR:
35
+ """AutoMR Framework with Async vLLM Support"""
36
 
37
  def __init__(self, config: AutoMRConfig):
38
  self.config = config
39
  self.device = config.device
40
  self.token_budget = config.token_budget
 
41
 
42
+ # Concurrency control: prevent overloading the client/server
43
+ self.semaphore = asyncio.Semaphore(128)
 
 
 
 
 
 
 
 
 
 
44
 
45
+ print(f"Connecting to vLLM Generator (Async)...")
46
+ self.generator_client = AsyncOpenAI(api_key="vllm", base_url="http://localhost:8000/v1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ print(f"Connecting to Embedder (Async)...")
49
+ self.embed_client = AsyncOpenAI(api_key="vllm", base_url="http://localhost:8001/v1")
 
 
 
 
50
 
51
+ # Components for the meta-reasoning strategy network
52
  self.num_strategies = len(STRATEGY_LIST)
53
  hidden_size = config.hidden_size
54
 
55
  self.strategy_embeddings = nn.Embedding(self.num_strategies, hidden_size).to(self.device)
56
  self.strategy_mlp = StrategyMLP(hidden_size, self.num_strategies).to(self.device)
57
 
58
+ # Mapping tables between strategy indices and names
 
59
  self.idx_to_strategy = {i: s for i, s in enumerate(STRATEGY_LIST)}
60
+ self.strategy_to_idx = {s: i for i, s in enumerate(STRATEGY_LIST)}
61
 
62
+ # Optimizer for strategy embeddings and MLP
63
  self.optimizer = torch.optim.Adam(
64
  list(self.strategy_embeddings.parameters()) +
65
  list(self.strategy_mlp.parameters()),
66
  lr=config.learning_rate
67
  )
68
 
69
+ # Pre-allocated zero tensor to avoid repeated allocation in loops
70
+ self.zero_strategy_repr = torch.zeros(1, hidden_size, device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ print("AutoMR initialized successfully (Async Mode)")
 
 
 
 
 
 
 
 
 
 
73
 
74
+ async def get_text_representation(self, text: str) -> torch.Tensor:
75
+ """Get the embedding vector for a single text string."""
76
+ if not text or not text.strip():
77
+ return torch.zeros(self.config.hidden_size, device=self.device, dtype=torch.float16)
78
+
79
+ try:
80
+ # For simplicity we do not add full retry logic here
81
+ resp = await self.embed_client.embeddings.create(
82
+ input=text,
83
+ model=self.config.model_name
84
+ )
85
+ # Extract embedding and move it to GPU/device
86
+ return torch.tensor(resp.data[0].embedding, device=self.device, dtype=torch.float16)
87
+ except Exception as e:
88
+ print(f"Embedding Error: {e}")
89
+ return torch.zeros(self.config.hidden_size, device=self.device, dtype=torch.float16)
90
+
91
+ async def _generate_text(self, prompt: str, max_tokens: int) -> Tuple[str, int]:
92
+ """Atomic text generation call, returning text and used completion tokens."""
93
+ if not prompt:
94
+ return "", 0
95
+ async with self.semaphore:
96
+ try:
97
+ resp = await self.generator_client.completions.create(
98
+ model=self.config.model_name,
99
+ prompt=prompt,
100
+ max_tokens=max_tokens,
101
+ temperature=self.config.temperature
102
+ )
103
+ text = resp.choices[0].text.strip()
104
+ # vLLM/OpenAI-style usage field; fallback to 0 if missing
105
+ used_tokens = resp.usage.completion_tokens
106
+ return text, int(used_tokens)
107
+ except Exception as e:
108
+ print(f"Generation Error: {e}")
109
+ return "", 0
110
+
111
+ def select_strategy(
112
+ self,
113
+ node_j_repr: torch.Tensor,
114
+ existing_strategy_indices: List[int],
115
+ context_repr: torch.Tensor
116
+ ) -> Tuple[int, torch.Tensor]:
117
+ """Decide whether to create an edge j->i according to Algorithm 1.
118
 
 
 
 
 
 
 
 
 
 
 
 
119
  Args:
120
+ node_j_repr: Representation of candidate source node j.
121
+ existing_strategy_indices: Already selected k->i strategies (k > j).
122
+ context_repr: Global context representation.
 
 
 
123
  """
124
+ # 1. Pool existing strategies (k->i)
125
+ if existing_strategy_indices:
126
+ hist_tensor = torch.tensor(existing_strategy_indices, device=self.device, dtype=torch.long)
127
+ strategy_repr = self.strategy_embeddings(hist_tensor).mean(dim=0, keepdim=True)
128
+ else:
129
+ strategy_repr = self.zero_strategy_repr
130
 
131
+ # 2. Forward pass through MLP (add batch dimension [1, dim])
132
+ logits = self.strategy_mlp(
133
+ node_j_repr.unsqueeze(0),
134
+ strategy_repr,
135
+ context_repr.unsqueeze(0)
136
+ )
137
+
138
+ # 3. Sample a strategy index from the categorical distribution
139
+ probs = F.softmax(logits, dim=-1)
140
+ dist = torch.distributions.Categorical(probs)
141
+ idx = dist.sample()
142
+ log_prob = dist.log_prob(idx)
143
+
144
+ return idx.item(), log_prob
145
 
146
+ async def run_single_trajectory(self, query: str) -> Tuple[str, torch.Tensor]:
147
+ """Run a single reasoning trajectory (Algorithm 1).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  Returns:
150
+ Tuple of (final_answer, total_log_prob).
 
 
151
  """
152
+ # 1. Initialization
153
+ q_repr = await self.get_text_representation(query)
154
+ dag = MetaReasoningDAG(query, q_repr, 0)
155
+ trajectory_log_prob = torch.tensor(0.0, device=self.device)
 
 
 
 
 
156
 
157
+ step_idx = 0
158
+ # Stopping condition: token budget exhausted or step limit (30)
159
+ while dag.total_tokens() < self.config.token_budget and step_idx < 30:
160
+ step_idx += 1
 
 
 
161
 
162
+ context_repr = dag.get_context_repr_up_to(step_idx - 1)
163
+
164
+ # === Inner loop: iterate j in reverse order (from i-1 down to 0) ===
165
+ strategies_k_to_i: List[int] = [] # Input to select_strategy (k > j)
166
+ incoming_edges_info: List[Tuple[int, str]] = [] # For prompt construction: (src_node_idx, strategy_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ for j in range(step_idx - 1, -1, -1):
169
+ node_j_repr = dag.get_node_content_repr(j)
170
+
171
+ # Strategy decision
172
+ strat_idx, log_prob = self.select_strategy(
173
+ node_j_repr, strategies_k_to_i, context_repr
 
 
 
174
  )
175
+
176
+ # Accumulate log-probability contribution
177
+ trajectory_log_prob = trajectory_log_prob + log_prob
178
+ strategy_name = self.idx_to_strategy[strat_idx]
179
+
180
+ # If this is a non-zero (effective) strategy
181
+ if strategy_name != "zero":
182
+ strategies_k_to_i.append(strat_idx)
183
+ incoming_edges_info.append((j, strategy_name))
184
+ dag.add_edge(j, step_idx, strategy_name)
185
+
186
+ # Fallback: if not the first step and there is no incoming edge, treat as reasoning interruption
187
+ if not incoming_edges_info and step_idx > 1:
 
 
 
 
 
188
  break
189
 
190
+ # === Prompt construction (Algorithm 1 + Appendix A.2) ===
191
+ # Reverse edges back to chronological order (Step 0, Step 1...) for readability
192
+ incoming_edges_info.reverse()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ prompts_list = []
195
+ has_answer_strategy = False
196
+
197
+ for src_node_idx, s_name in incoming_edges_info:
198
+ if s_name == "Answer":
199
+ has_answer_strategy = True
200
+
201
+ # Get strategy template prompt
202
+ raw_strategy_prompt = random.choice(META_STRATEGIES.get(s_name, [""]))
203
+
204
+ # Apply Appendix A.2 template
205
+ formatted_prompt = f"Let me attend to Step {src_node_idx}, {raw_strategy_prompt}"
206
+ prompts_list.append(formatted_prompt)
207
+
208
+ # Concatenate all incoming-edge prompts
209
+ strategies_text = " ".join(prompts_list)
210
+ full_context = dag.get_context_up_to(step_idx - 1)
211
+
212
+ # === Generate content ===
213
+ if has_answer_strategy:
214
+ # Final answer generation
215
+ full_prompt = f"{full_context}\n{strategies_text}\nAnswer:\n"
216
+ remain_budget = max(1, self.config.token_budget - dag.total_tokens())
217
+ final_answer, used_tokens = await self._generate_text(full_prompt, remain_budget)
218
+ # Strictly accumulate used completion tokens before termination
219
+ if used_tokens > 0:
220
+ dag.add_node(final_answer, used_tokens, await self.get_text_representation(final_answer))
221
+ return final_answer, trajectory_log_prob
222
+ else:
223
+ # Intermediate reasoning generation
224
+ full_prompt = f"{full_context}\n{strategies_text}\n"
225
+
226
+ # Compute token limit for current step
227
+ current_remain = self.config.token_budget - dag.total_tokens()
228
+ step_limit = min(self.config.max_new_tokens, current_remain)
229
+
230
+ if step_limit <= 0:
231
+ break # Budget exhausted
232
+
233
+ content, used_tokens = await self._generate_text(full_prompt, step_limit)
234
+
235
+ # If service does not return usage, fall back to at least one token when content exists
236
+ if used_tokens <= 0 and content:
237
+ used_tokens = 1
238
+
239
+ content_repr = await self.get_text_representation(content)
240
+ dag.add_node(content, used_tokens, content_repr)
241
+
242
+ # Check whether a boxed answer appears earlier than expected
243
+ if "boxed" in content:
244
+ return content, trajectory_log_prob
245
+
246
+ # After loop ends: return content of last node if exists
247
+ if len(dag.nodes) > 0:
248
+ return dag.nodes[-1].content, trajectory_log_prob
249
+ return "", trajectory_log_prob
250
+
251
+ async def sample_batch(self, queries: List[str], M: int) -> Tuple[List[str], torch.Tensor]:
252
+ """Async entry: expand B queries into B*M trajectories and run concurrently. """
253
+ tasks = []
254
+ for q in queries:
255
+ for _ in range(M):
256
+ tasks.append(self.run_single_trajectory(q))
257
+
258
+ # Show generation progress: each finished trajectory updates the bar
259
+ results = await tqdm_asyncio.gather(*tasks)
260
+
261
+ answers = [r[0] for r in results]
262
+ log_probs = torch.stack([r[1] for r in results])
263
+ return answers, log_probs
264
+
265
+ def sample_batch_sync(self, queries: List[str], M: int) -> Tuple[List[str], torch.Tensor]:
266
+ """Synchronous wrapper for sample_batch, for use in trainer."""
267
+ return asyncio.run(self.sample_batch(queries, M))
268
+
269
+ # Compatibility interfaces
270
+ def save_checkpoint(self, path):
271
  torch.save({
272
  'strategy_embeddings': self.strategy_embeddings.state_dict(),
273
  'strategy_mlp': self.strategy_mlp.state_dict(),
274
  'optimizer': self.optimizer.state_dict()
275
  }, path)
 
 
 
 
 
 
 
 
 
276
 
277
+ def load_checkpoint(self, path):
278
+ ckpt = torch.load(path, map_location=self.device)
279
+ self.strategy_embeddings.load_state_dict(ckpt['strategy_embeddings'])
280
+ self.strategy_mlp.load_state_dict(ckpt['strategy_mlp'])
281
+ self.optimizer.load_state_dict(ckpt['optimizer'])
 
 
 
 
automr/strategies.py CHANGED
@@ -1,8 +1,6 @@
1
  # Meta Reasoning Strategy Prompts (from Table 2 in paper)
2
  META_STRATEGIES = {
3
  "Next": [
4
- "Next,",
5
- "Then,",
6
  "Now, let me move on to the next step."
7
  ],
8
  "Reflect": [
 
1
  # Meta Reasoning Strategy Prompts (from Table 2 in paper)
2
  META_STRATEGIES = {
3
  "Next": [
 
 
4
  "Now, let me move on to the next step."
5
  ],
6
  "Reflect": [
automr/trainer.py CHANGED
@@ -2,15 +2,14 @@ import random
2
  import torch
3
  from typing import List, Dict, Tuple
4
  from tqdm import tqdm
 
5
 
6
  from .model import AutoMR
7
  from .config import AutoMRConfig
8
  from .utils import check_answer_match, ensure_dir, save_json
9
- import os
10
-
11
 
12
  class AutoMRTrainer:
13
- """Trainer for AutoMR using REINFORCE (Algorithm 2 from paper)"""
14
 
15
  def __init__(self, model: AutoMR, config: AutoMRConfig):
16
  self.model = model
@@ -21,6 +20,9 @@ class AutoMRTrainer:
21
  self.global_step = 0
22
  self.best_val_reward = -float('inf')
23
  self.patience_counter = 0
 
 
 
24
  self.training_history = {
25
  'train_loss': [],
26
  'train_reward': [],
@@ -31,7 +33,7 @@ class AutoMRTrainer:
31
 
32
  def compute_reward_batch(self, queries: List[str], answers: List[str]) -> Tuple[float, float]:
33
  """
34
- Compute average reward and accuracy on a batch
35
  Returns: (avg_reward, accuracy)
36
  """
37
  total_reward = 0.0
@@ -42,16 +44,16 @@ class AutoMRTrainer:
42
  self.model.strategy_embeddings.eval()
43
 
44
  with torch.no_grad():
45
- dags, _ = self.model.dynamic_skeleton_sampling(queries, M=1)
46
- pred_answers = self.model.extract_answer(dags)
47
-
48
- is_correct = [check_answer_match(
49
- pred_answer, answer, self.config.task_type
50
- ) for pred_answer, answer in zip(pred_answers, answers)]
51
 
52
- rewards = [1.0 if correct else -1.0 for correct in is_correct]
53
- total_reward += sum(rewards)
54
- correct += sum(is_correct)
 
 
 
 
55
 
56
  avg_reward = total_reward / total if total > 0 else 0.0
57
  accuracy = correct / total if total > 0 else 0.0
@@ -76,37 +78,48 @@ class AutoMRTrainer:
76
 
77
  def train_step(self, batch_queries: List[str], batch_answers: List[str]) -> Tuple[float, float]:
78
  """
79
- Single training step using REINFORCE (Equation 4 from paper)
80
  Returns: (loss, avg_reward)
81
  """
82
  self.model.strategy_mlp.train()
83
  self.model.strategy_embeddings.train()
84
 
85
  M = self.config.num_samples_per_query
86
- loss = []
87
  rewards_list = []
88
- # expand answers for M samples per query
89
- batch_answers = [answer for answer in batch_answers for _ in range(M)]
90
- batch_dags, batch_log_probs = self.model.dynamic_skeleton_sampling(batch_queries, M)
91
- # Get prediction
92
- batch_pred_answers = self.model.extract_answer(batch_dags)
93
- # Compute reward
94
- for pred_answer, answer, log_prob in zip(batch_pred_answers, batch_answers, batch_log_probs):
 
 
95
  reward = 1.0 if check_answer_match(
96
  pred_answer, answer, self.config.task_type
97
  ) else -1.0
 
98
  rewards_list.append(reward)
99
-
100
- # Accumulate gradient (REINFORCE)
101
- loss.append(-reward * log_prob)
102
-
103
- # Compute average reward for this batch
104
  avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- # Update parameters
107
  self.model.optimizer.zero_grad()
108
- loss = torch.stack(loss).mean()
109
-
110
  loss.backward()
111
  torch.nn.utils.clip_grad_norm_(
112
  list(self.model.strategy_embeddings.parameters()) +
@@ -114,7 +127,6 @@ class AutoMRTrainer:
114
  max_norm=self.config.gradient_clip
115
  )
116
  self.model.optimizer.step()
117
-
118
  return loss.item(), avg_reward
119
 
120
  def should_stop_early(self) -> bool:
@@ -145,7 +157,7 @@ class AutoMRTrainer:
145
  print(f" 💾 Best checkpoint saved: {checkpoint_path}")
146
 
147
  def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
148
- """Training loop with validation (Algorithm 2 from paper with validation)"""
149
  print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
150
  print(f"Training samples: {len(train_data)}")
151
  print(f"Validation samples: {len(val_data)}")
@@ -160,18 +172,21 @@ class AutoMRTrainer:
160
  epoch_reward = 0.0
161
  num_batches = 0
162
 
 
 
163
  pbar = tqdm(
164
- range(0, len(train_data), self.config.batch_size),
165
  desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
166
  )
167
 
168
  for i in pbar:
169
- batch = train_data[i:i+self.config.batch_size]
170
  batch_queries = [item['query'] for item in batch]
171
  batch_answers = [item['answer'] for item in batch]
172
 
173
- # Training step
174
  loss, avg_reward = self.train_step(batch_queries, batch_answers)
 
175
  epoch_loss += loss
176
  epoch_reward += avg_reward
177
  num_batches += 1
@@ -221,9 +236,9 @@ class AutoMRTrainer:
221
  print(f"Best validation reward: {self.best_val_reward:.4f}")
222
  return
223
 
224
- # End of epoch
225
- avg_epoch_loss = epoch_loss / num_batches
226
- avg_epoch_reward = epoch_reward / num_batches
227
 
228
  self.training_history['train_loss'].append(avg_epoch_loss)
229
  self.training_history['train_reward'].append(avg_epoch_reward)
@@ -236,7 +251,7 @@ class AutoMRTrainer:
236
  print(f"Best Val Reward: {self.best_val_reward:.4f}")
237
  print(f"{'='*80}\n")
238
 
239
- # Save checkpoint at end of epoch (if not save_best_only)
240
  if not self.config.save_best_only:
241
  self.save_checkpoint(epoch + 1)
242
 
 
2
  import torch
3
  from typing import List, Dict, Tuple
4
  from tqdm import tqdm
5
+ import os
6
 
7
  from .model import AutoMR
8
  from .config import AutoMRConfig
9
  from .utils import check_answer_match, ensure_dir, save_json
 
 
10
 
11
  class AutoMRTrainer:
12
+ """Trainer for AutoMR using REINFORCE (sync trainer, async model calls)"""
13
 
14
  def __init__(self, model: AutoMR, config: AutoMRConfig):
15
  self.model = model
 
20
  self.global_step = 0
21
  self.best_val_reward = -float('inf')
22
  self.patience_counter = 0
23
+ # Sliding-window baseline for REINFORCE advantage (variance reduction)
24
+ self.baseline = self.config.initial_baseline
25
+ self.baseline_momentum = self.config.baseline_momentum
26
  self.training_history = {
27
  'train_loss': [],
28
  'train_reward': [],
 
33
 
34
  def compute_reward_batch(self, queries: List[str], answers: List[str]) -> Tuple[float, float]:
35
  """
36
+ Compute average reward and accuracy on a batch (Async)
37
  Returns: (avg_reward, accuracy)
38
  """
39
  total_reward = 0.0
 
44
  self.model.strategy_embeddings.eval()
45
 
46
  with torch.no_grad():
47
+ # M=1 for evaluation/validation; call async model via sync wrapper
48
+ pred_answers, _ = self.model.sample_batch_sync(queries, M=1)
 
 
 
 
49
 
50
+ for pred_answer, answer in zip(pred_answers, answers):
51
+ is_correct = check_answer_match(pred_answer, answer, self.config.task_type)
52
+ if is_correct:
53
+ correct += 1
54
+ total_reward += 1.0
55
+ else:
56
+ total_reward += -1.0
57
 
58
  avg_reward = total_reward / total if total > 0 else 0.0
59
  accuracy = correct / total if total > 0 else 0.0
 
78
 
79
  def train_step(self, batch_queries: List[str], batch_answers: List[str]) -> Tuple[float, float]:
80
  """
81
+ Single training step using REINFORCE
82
  Returns: (loss, avg_reward)
83
  """
84
  self.model.strategy_mlp.train()
85
  self.model.strategy_embeddings.train()
86
 
87
  M = self.config.num_samples_per_query
88
+ loss_list = []
89
  rewards_list = []
90
+
91
+ # pred_answers: [B*M], log_probs: [B*M]; sync wrapper over async model
92
+ pred_answers, log_probs = self.model.sample_batch_sync(batch_queries, M)
93
+
94
+ # 2. Expand answers for comparison
95
+ expanded_answers = [answer for answer in batch_answers for _ in range(M)]
96
+
97
+ # 3. Compute Reward & Loss
98
+ for pred_answer, answer, log_prob in zip(pred_answers, expanded_answers, log_probs):
99
  reward = 1.0 if check_answer_match(
100
  pred_answer, answer, self.config.task_type
101
  ) else -1.0
102
+
103
  rewards_list.append(reward)
104
+
105
+ # Compute batch average reward
 
 
 
106
  avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
107
+
108
+ # Update sliding baseline: exponential moving average
109
+ self.baseline = (
110
+ self.baseline_momentum * self.baseline
111
+ + (1.0 - self.baseline_momentum) * avg_reward
112
+ )
113
+
114
+ # Policy Gradient with advantage: -(reward - baseline) * log_prob
115
+ for reward, log_prob in zip(rewards_list, log_probs):
116
+ advantage = reward - self.baseline
117
+ loss_list.append(-advantage * log_prob)
118
 
119
+ # 4. Update parameters
120
  self.model.optimizer.zero_grad()
121
+
122
+ loss = torch.stack(loss_list).mean()
123
  loss.backward()
124
  torch.nn.utils.clip_grad_norm_(
125
  list(self.model.strategy_embeddings.parameters()) +
 
127
  max_norm=self.config.gradient_clip
128
  )
129
  self.model.optimizer.step()
 
130
  return loss.item(), avg_reward
131
 
132
  def should_stop_early(self) -> bool:
 
157
  print(f" 💾 Best checkpoint saved: {checkpoint_path}")
158
 
159
  def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
160
+ """Training loop with validation"""
161
  print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
162
  print(f"Training samples: {len(train_data)}")
163
  print(f"Validation samples: {len(val_data)}")
 
172
  epoch_reward = 0.0
173
  num_batches = 0
174
 
175
+ batch_indices = list(range(0, len(train_data), self.config.batch_size))
176
+
177
  pbar = tqdm(
178
+ batch_indices,
179
  desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
180
  )
181
 
182
  for i in pbar:
183
+ batch = train_data[i : i + self.config.batch_size]
184
  batch_queries = [item['query'] for item in batch]
185
  batch_answers = [item['answer'] for item in batch]
186
 
187
+ # Training step (sync)
188
  loss, avg_reward = self.train_step(batch_queries, batch_answers)
189
+
190
  epoch_loss += loss
191
  epoch_reward += avg_reward
192
  num_batches += 1
 
236
  print(f"Best validation reward: {self.best_val_reward:.4f}")
237
  return
238
 
239
+ # End of epoch summary
240
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
241
+ avg_epoch_reward = epoch_reward / max(num_batches, 1)
242
 
243
  self.training_history['train_loss'].append(avg_epoch_loss)
244
  self.training_history['train_reward'].append(avg_epoch_reward)
 
251
  print(f"Best Val Reward: {self.best_val_reward:.4f}")
252
  print(f"{'='*80}\n")
253
 
254
+ # Save checkpoint at end of epoch
255
  if not self.config.save_best_only:
256
  self.save_checkpoint(epoch + 1)
257