camdog920 commited on
Commit
3238d40
·
verified ·
1 Parent(s): c779308

Upload train_aether_job.py

Browse files
Files changed (1) hide show
  1. train_aether_job.py +209 -0
train_aether_job.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AETHER Training with TRL GRPO - Production Job Script.
4
+ Self-evolving neuro-symbolic agent training.
5
+ """
6
+
7
+ # /// script
8
+ # dependencies = [
9
+ # "trl>=0.15.0", "transformers>=4.45.0", "datasets>=3.0.0",
10
+ # "accelerate>=1.0.0", "peft>=0.13.0", "trackio>=0.1.0",
11
+ # "torch>=2.0.0", "networkx>=3.0", "numpy>=1.24.0",
12
+ # "huggingface-hub>=0.26.0", "sentencepiece>=0.2.0"
13
+ # ]
14
+ # ///
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import logging
20
+ from typing import List
21
+
22
+ import torch
23
+
24
+ from datasets import load_dataset
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+ from trl import GRPOTrainer, GRPOConfig
27
+ from trl.rewards import accuracy_reward, think_format_reward
28
+
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
32
+ )
33
+ logger = logging.getLogger("AETHER.Train")
34
+
35
+
36
+ def aether_reward(completions: List[str], **kwargs) -> List[float]:
37
+ """AETHER neuro-symbolic reward combining reasoning structure and knowledge coherence."""
38
+ rewards = []
39
+ for completion in completions:
40
+ score = 0.0
41
+ text = completion if isinstance(completion, str) else str(completion)
42
+
43
+ # 1. Reasoning structure: <think> tags
44
+ if "<think>" in text and "</think>" in text:
45
+ score += 0.3
46
+
47
+ # 2. Step enumeration
48
+ steps = sum(1 for s in text.split("\n") if any(s.strip().startswith(p) for p in ["1.", "2.", "3.", "4.", "5.", "Step", "Phase"]))
49
+ score += min(steps * 0.05, 0.25)
50
+
51
+ # 3. Knowledge references (causal reasoning)
52
+ if any(kw in text.lower() for kw in ["therefore", "because", "implies", "consequently"]):
53
+ score += 0.2
54
+
55
+ # 4. Sub-goal / blueprint structure (HiMAC-style)
56
+ if any(kw in text.lower() for kw in ["sub-goal", "blueprint", "plan", "phase"]):
57
+ score += 0.15
58
+
59
+ # 5. Self-reflection / meta-cognition
60
+ if any(kw in text.lower() for kw in ["reflect", "evaluate", "improve", "evolve"]):
61
+ score += 0.1
62
+
63
+ rewards.append(min(score, 1.0))
64
+ return rewards
65
+
66
+
67
+ def main():
68
+ MODEL_NAME = os.environ.get("AETHER_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
69
+ OUTPUT_DIR = os.environ.get("AETHER_OUTPUT", "./aether-output")
70
+
71
+ trackio_space_id = os.environ.get("TRACKIO_SPACE_ID")
72
+ trackio_project = os.environ.get("TRACKIO_PROJECT", "aether-evolution")
73
+
74
+ logger.info("=" * 60)
75
+ logger.info("AETHER TRAINING - GRPO with Neuro-Symbolic Rewards")
76
+ logger.info("=" * 60)
77
+ logger.info(f"Model: {MODEL_NAME}")
78
+ logger.info(f"Output: {OUTPUT_DIR}")
79
+
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ logger.info(f"Device: {device}")
82
+
83
+ logger.info("Loading model...")
84
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ MODEL_NAME,
87
+ torch_dtype=dtype,
88
+ device_map="auto" if torch.cuda.is_available() else None,
89
+ trust_remote_code=True,
90
+ )
91
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
92
+ if tokenizer.pad_token is None:
93
+ tokenizer.pad_token = tokenizer.eos_token
94
+
95
+ logger.info("Loading dataset...")
96
+ try:
97
+ dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
98
+ logger.info(f"Loaded DeepMath-103K: {len(dataset)} examples")
99
+ except Exception as e:
100
+ logger.warning(f"DeepMath failed: {e}")
101
+ try:
102
+ dataset = load_dataset("trl-lib/Capybara", split="train")
103
+ logger.info(f"Loaded Capybara: {len(dataset)} examples")
104
+ except Exception as e2:
105
+ logger.warning(f"Capybara failed: {e2}")
106
+ from datasets import Dataset
107
+ prompts = [
108
+ {"prompt": "Think step by step and solve: If a train travels 240 km in 3 hours, what is its average speed?"},
109
+ {"prompt": "Plan and reason: You have 5 shelves and need to store 150 books evenly. How many per shelf?"},
110
+ {"prompt": "Analyze and explain: Why does recursive self-improvement require safety constraints?"},
111
+ {"prompt": "Break down into phases: How would you build a self-evolving AI system?"},
112
+ {"prompt": "Reflect and improve: A previous solution had an error in step 3. How would you fix it?"},
113
+ {"prompt": "Think about this: What are the trade-offs between symbolic and neural reasoning?"},
114
+ {"prompt": "Plan a hierarchy: Design a multi-agent system with a manager and workers."},
115
+ {"prompt": "Evolve this solution: Start with a simple sorting algorithm and improve it iteratively."},
116
+ {"prompt": "Knowledge reasoning: Given that all birds can fly and penguins are birds, what can you conclude?"},
117
+ {"prompt": "Meta-cognitive analysis: Evaluate your own reasoning process and identify biases."},
118
+ ] * 100
119
+ dataset = Dataset.from_list(prompts)
120
+ logger.info(f"Created fallback dataset: {len(dataset)} examples")
121
+
122
+ if "prompt" not in dataset.column_names:
123
+ if "text" in dataset.column_names:
124
+ dataset = dataset.rename_column("text", "prompt")
125
+ elif "messages" in dataset.column_names:
126
+ def extract_prompt(examples):
127
+ prompts = []
128
+ for msgs in examples["messages"]:
129
+ for msg in msgs:
130
+ if msg.get("role") == "user":
131
+ prompts.append(msg.get("content", ""))
132
+ break
133
+ else:
134
+ prompts.append(str(msgs))
135
+ return {"prompt": prompts}
136
+ dataset = dataset.map(extract_prompt, batched=True, remove_columns=dataset.column_names)
137
+ elif "question" in dataset.column_names:
138
+ dataset = dataset.rename_column("question", "prompt")
139
+
140
+ dataset = dataset.train_test_split(test_size=0.1)
141
+ train_ds = dataset["train"]
142
+ eval_ds = dataset["test"]
143
+ logger.info(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
144
+
145
+ training_args = GRPOConfig(
146
+ output_dir=OUTPUT_DIR,
147
+ num_train_epochs=1,
148
+ per_device_train_batch_size=1,
149
+ per_device_eval_batch_size=1,
150
+ gradient_accumulation_steps=8,
151
+ learning_rate=2e-5,
152
+ logging_steps=10,
153
+ save_steps=100,
154
+ eval_strategy="steps",
155
+ eval_steps=50,
156
+ bf16=torch.cuda.is_available(),
157
+ max_completion_length=512,
158
+ num_generations=4,
159
+ report_to="trackio" if trackio_space_id else [],
160
+ run_name=f"aether-grpo-{MODEL_NAME.split('/')[-1]}",
161
+ project=trackio_project,
162
+ trackio_space_id=trackio_space_id,
163
+ disable_tqdm=True,
164
+ logging_first_step=True,
165
+ push_to_hub=True,
166
+ hub_model_id=f"camdog920/aether-{MODEL_NAME.split('/')[-1]}-grpo",
167
+ )
168
+
169
+ reward_funcs = [
170
+ aether_reward,
171
+ accuracy_reward,
172
+ think_format_reward,
173
+ ]
174
+
175
+ logger.info("Initializing GRPO Trainer...")
176
+ trainer = GRPOTrainer(
177
+ model=model,
178
+ reward_funcs=reward_funcs,
179
+ args=training_args,
180
+ train_dataset=train_ds,
181
+ eval_dataset=eval_ds,
182
+ )
183
+
184
+ logger.info("Starting training...")
185
+ trainer.train()
186
+
187
+ logger.info("Saving model...")
188
+ trainer.save_model(OUTPUT_DIR)
189
+ tokenizer.save_pretrained(OUTPUT_DIR)
190
+
191
+ metadata = {
192
+ "aether_version": "0.1.0",
193
+ "training_method": "GRPO",
194
+ "model_name": MODEL_NAME,
195
+ "reward_functions": ["aether_reward", "accuracy_reward", "think_format_reward"],
196
+ }
197
+ with open(os.path.join(OUTPUT_DIR, "aether_metadata.json"), "w") as f:
198
+ json.dump(metadata, f, indent=2)
199
+
200
+ logger.info("=" * 60)
201
+ logger.info("Training complete!")
202
+ logger.info(f"Model: https://huggingface.co/{training_args.hub_model_id}")
203
+ if trackio_space_id:
204
+ logger.info(f"Dashboard: https://huggingface.co/spaces/{trackio_space_id}")
205
+ logger.info("=" * 60)
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()