| |
| """ |
| Improved REINFORCE for Symbolic Regression |
| |
| Improvements over basic REINFORCE: |
| 1. Larger batch size with gradient accumulation |
| 2. Entropy bonus for exploration |
| 3. Better baseline (exponential moving average with warmup) |
| 4. Reward shaping (softer penalty for invalid expressions) |
| 5. Best-of-N sampling to find good expressions faster |
| 6. Learning rate scheduling |
| 7. Gradient clipping |
| 8. Detailed logging per epoch |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| import logging |
| import datetime |
| from pathlib import Path |
| from typing import List, Tuple, Dict |
| from collections import deque |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel, LoraConfig, get_peft_model |
|
|
| from expression import Expression |
| from dataset import RegressionDataset |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ImprovedREINFORCE: |
| """Improved REINFORCE algorithm for symbolic regression.""" |
|
|
| def __init__( |
| self, |
| model_path: str, |
| X: np.ndarray, |
| y: np.ndarray, |
| output_dir: str = "./output/reinforce", |
| learning_rate: float = 5e-5, |
| device: str = None, |
| entropy_coef: float = 0.01, |
| baseline_decay: float = 0.95, |
| ): |
| self.X = X |
| self.y = y |
| self.n_vars = X.shape[1] |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| self.learning_rate = learning_rate |
| self.entropy_coef = entropy_coef |
| self.baseline_decay = baseline_decay |
|
|
| |
| if device: |
| self.device = torch.device(device) |
| else: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {self.device}") |
|
|
| |
| self._load_model(model_path) |
|
|
| |
| self.prompt = self._build_prompt() |
| self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device) |
|
|
| |
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=learning_rate, |
| weight_decay=0.01 |
| ) |
|
|
| |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| self.optimizer, T_0=10, T_mult=2 |
| ) |
|
|
| |
| self.best_r2 = -np.inf |
| self.best_expression = None |
| self.history = [] |
|
|
| |
| self.reward_buffer = deque(maxlen=50) |
| self.baseline = 0.0 |
|
|
| |
| self.discovered_expressions: Dict[str, float] = {} |
|
|
| def _load_model(self, model_path: str): |
| """Load model and tokenizer.""" |
| logger.info(f"Loading model from {model_path}") |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| try: |
| logger.info("Attempting to load as LoRA adapter...") |
| base_model = AutoModelForCausalLM.from_pretrained("gpt2") |
| if len(self.tokenizer) != base_model.config.vocab_size: |
| base_model.resize_token_embeddings(len(self.tokenizer)) |
| logger.info(f"Resized embeddings to {len(self.tokenizer)}") |
|
|
| model_with_lora = PeftModel.from_pretrained(base_model, model_path) |
| self.model = model_with_lora.merge_and_unload() |
| logger.info("LoRA adapter loaded and merged successfully") |
| except Exception as e: |
| logger.info(f"LoRA load failed ({e}), loading as standalone model...") |
| self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
| |
| lora_config = LoraConfig( |
| r=8, |
| lora_alpha=16, |
| target_modules=["c_attn"], |
| lora_dropout=0.05, |
| bias="none", |
| ) |
| self.model = get_peft_model(self.model, lora_config) |
| self.model = self.model.to(self.device) |
| self.model.train() |
|
|
| trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
| logger.info(f"Model loaded with {trainable} trainable params") |
|
|
| def _build_prompt(self, ops: list = None) -> str: |
| """Build JSON format prompt.""" |
| vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
|
|
| |
| if ops is None: |
| ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"] |
| else: |
| ops_list = ops |
|
|
| prompt = json.dumps({ |
| "vars": vars_list, |
| "ops": ops_list, |
| "cons": "C", |
| "expr": "" |
| }) |
| prompt = prompt[:-2] |
| return prompt |
|
|
| def extract_expression(self, text: str) -> str: |
| """Extract expression from generated text.""" |
| try: |
| eos_token = "<|endoftext|>" |
| if eos_token in text: |
| text = text[:text.index(eos_token)] |
|
|
| if '"expr": "' in text: |
| start = text.index('"expr": "') + len('"expr": "') |
| remaining = text[start:] |
| for terminator in ['"}', '"']: |
| if terminator in remaining: |
| return remaining[:remaining.index(terminator)].strip() |
| return remaining.strip() |
|
|
| if '"expr": ' in text: |
| start = text.index('"expr": ') + len('"expr": ') |
| remaining = text[start:] |
| if '"}' in remaining: |
| return remaining[:remaining.index('"}')].strip() |
| return remaining.strip(' "') |
|
|
| except (ValueError, IndexError): |
| pass |
|
|
| if '"expr"' in text: |
| return text.split('"expr"')[-1].strip(' ":{}') |
| return text.strip() |
|
|
| def compute_r2(self, expression_str: str) -> Tuple[float, bool]: |
| """Compute R^2 score. Returns (score, is_valid).""" |
| if not expression_str or expression_str.isspace(): |
| return -1.0, False |
|
|
| if 'C' in expression_str: |
| expression_str = expression_str.replace('C', '1') |
|
|
| try: |
| expr = Expression(expression_str, is_prefix=False) |
| if not expr.is_valid_on_dataset(self.X): |
| return -1.0, False |
|
|
| y_pred = expr.evaluate(self.X) |
| if not np.all(np.isfinite(y_pred)): |
| return -1.0, False |
|
|
| ss_res = np.sum((self.y - y_pred) ** 2) |
| ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
| if ss_tot == 0: |
| return 0.0, True |
|
|
| r2 = 1 - (ss_res / ss_tot) |
| return float(np.clip(r2, -1.0, 1.0)), True |
| except Exception: |
| return -1.0, False |
|
|
| def shape_reward(self, r2: float, is_valid: bool) -> float: |
| """Shape reward to encourage exploration.""" |
| if not is_valid: |
| return -0.1 |
|
|
| |
| if r2 < 0: |
| return r2 * 0.5 |
| elif r2 < 0.5: |
| return r2 |
| elif r2 < 0.9: |
| return r2 * 1.5 |
| else: |
| return r2 * 2.0 |
|
|
| def generate_batch( |
| self, |
| batch_size: int, |
| temperature: float = 0.7, |
| max_new_tokens: int = 50 |
| ) -> List[Dict]: |
| """Generate a batch of expressions with log probabilities.""" |
| results = [] |
|
|
| for _ in range(batch_size): |
| generated_ids = self.prompt_ids.clone() |
| generated_tokens = [] |
|
|
| |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| outputs = self.model(generated_ids) |
| logits = outputs.logits[:, -1, :] / temperature |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated_tokens.append(next_token.item()) |
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
| if next_token.item() == self.tokenizer.eos_token_id: |
| break |
|
|
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| if '"}' in text[len(self.prompt):]: |
| break |
|
|
| |
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| expr_str = self.extract_expression(text) |
| r2, is_valid = self.compute_r2(expr_str) |
| reward = self.shape_reward(r2, is_valid) |
|
|
| |
| if len(generated_tokens) > 0: |
| |
| full_ids = torch.cat([ |
| self.prompt_ids, |
| torch.tensor([generated_tokens], device=self.device) |
| ], dim=1) |
|
|
| |
| outputs = self.model(full_ids[:, :-1]) |
| logits = outputs.logits / temperature |
|
|
| |
| prompt_len = self.prompt_ids.shape[1] |
| gen_logits = logits[:, prompt_len-1:, :] |
|
|
| log_probs_all = F.log_softmax(gen_logits, dim=-1) |
| probs_all = F.softmax(gen_logits, dim=-1) |
|
|
| |
| target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0) |
| selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1) |
| total_log_prob = selected_log_probs.sum() |
|
|
| |
| entropy_per_pos = -(probs_all * log_probs_all).sum(dim=-1) |
| total_entropy = entropy_per_pos.mean() |
| else: |
| total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True) |
| total_entropy = torch.tensor(0.0, device=self.device) |
|
|
| results.append({ |
| "text": text, |
| "expression": expr_str, |
| "r2": r2, |
| "is_valid": is_valid, |
| "reward": reward, |
| "log_prob": total_log_prob, |
| "entropy": total_entropy, |
| }) |
|
|
| |
| if is_valid: |
| self.discovered_expressions[expr_str] = max( |
| self.discovered_expressions.get(expr_str, -np.inf), r2 |
| ) |
|
|
| if r2 > self.best_r2: |
| self.best_r2 = r2 |
| self.best_expression = expr_str |
|
|
| |
| if self.device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| return results |
|
|
| def update_baseline(self, rewards: List[float]): |
| """Update baseline using reward buffer.""" |
| valid_rewards = [r for r in rewards if r > -0.5] |
| self.reward_buffer.extend(valid_rewards) |
|
|
| if len(self.reward_buffer) > 0: |
| |
| self.baseline = self.baseline_decay * self.baseline + \ |
| (1 - self.baseline_decay) * np.median(list(self.reward_buffer)) |
|
|
| def train_step(self, batch_size: int = 8, grad_accum_steps: int = 4) -> dict: |
| """Perform one training step with gradient accumulation.""" |
| self.model.train() |
|
|
| all_results = [] |
| total_policy_loss = 0.0 |
| total_entropy_loss = 0.0 |
|
|
| self.optimizer.zero_grad() |
|
|
| effective_batch = batch_size * grad_accum_steps |
|
|
| for accum_step in range(grad_accum_steps): |
| |
| if self.device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| results = self.generate_batch(batch_size) |
| all_results.extend(results) |
|
|
| |
| policy_loss = torch.tensor(0.0, device=self.device) |
| entropy_loss = torch.tensor(0.0, device=self.device) |
| valid_count = 0 |
|
|
| for r in results: |
| if r["is_valid"]: |
| advantage = r["reward"] - self.baseline |
| policy_loss = policy_loss - r["log_prob"] * advantage |
| entropy_loss = entropy_loss - r["entropy"] |
| valid_count += 1 |
|
|
| if valid_count > 0: |
| policy_loss = policy_loss / valid_count |
| entropy_loss = entropy_loss / valid_count |
|
|
| |
| loss = policy_loss + self.entropy_coef * entropy_loss |
| loss = loss / grad_accum_steps |
|
|
| loss.backward() |
|
|
| total_policy_loss += policy_loss.item() |
| total_entropy_loss += entropy_loss.item() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
| |
| self.optimizer.step() |
| self.scheduler.step() |
|
|
| |
| rewards = [r["reward"] for r in all_results] |
| self.update_baseline(rewards) |
|
|
| |
| r2_values = [r["r2"] for r in all_results] |
| valid_mask = [r["is_valid"] for r in all_results] |
| valid_r2 = [r2 for r2, v in zip(r2_values, valid_mask) if v] |
|
|
| return { |
| "valid_count": sum(valid_mask), |
| "total_count": len(all_results), |
| "valid_rate": sum(valid_mask) / len(all_results), |
| "mean_r2": np.mean(valid_r2) if valid_r2 else 0.0, |
| "max_r2": max(r2_values), |
| "baseline": self.baseline, |
| "policy_loss": total_policy_loss / grad_accum_steps, |
| "entropy_loss": total_entropy_loss / grad_accum_steps, |
| "lr": self.scheduler.get_last_lr()[0], |
| } |
|
|
| def run( |
| self, |
| n_epochs: int = 50, |
| batch_size: int = 16, |
| grad_accum_steps: int = 2, |
| target_r2: float = 0.99, |
| patience: int = 20, |
| ): |
| """Run training with early stopping.""" |
| logger.info("=" * 60) |
| logger.info("IMPROVED REINFORCE SYMBOLIC REGRESSION") |
| logger.info("=" * 60) |
| logger.info(f"Epochs: {n_epochs}") |
| logger.info(f"Batch size: {batch_size} x {grad_accum_steps} = {batch_size * grad_accum_steps}") |
| logger.info(f"Entropy coef: {self.entropy_coef}") |
| logger.info(f"Target R^2: {target_r2}") |
| logger.info("=" * 60) |
|
|
| no_improvement = 0 |
| prev_best = -np.inf |
|
|
| for epoch in range(n_epochs): |
| stats = self.train_step(batch_size, grad_accum_steps) |
|
|
| self.history.append({ |
| "epoch": epoch + 1, |
| **stats, |
| "best_r2": self.best_r2, |
| }) |
|
|
| |
| if self.best_r2 > prev_best + 0.001: |
| no_improvement = 0 |
| prev_best = self.best_r2 |
| else: |
| no_improvement += 1 |
|
|
| |
| logger.info( |
| f"Epoch {epoch+1:3d} | " |
| f"Valid: {stats['valid_count']}/{stats['total_count']} | " |
| f"Mean R²: {stats['mean_r2']:.4f} | " |
| f"Best: {self.best_r2:.4f} | " |
| f"Baseline: {self.baseline:.4f} | " |
| f"LR: {stats['lr']:.2e}" |
| ) |
|
|
| |
| if self.best_r2 >= target_r2: |
| logger.info(f"Target R^2 {target_r2} reached at epoch {epoch+1}!") |
| break |
|
|
| if no_improvement >= patience: |
| logger.info(f"No improvement for {patience} epochs. Early stopping.") |
| break |
|
|
| |
| logger.info("\n" + "=" * 60) |
| logger.info("FINAL RESULTS") |
| logger.info("=" * 60) |
| logger.info(f"Best R^2: {self.best_r2:.4f}") |
| logger.info(f"Best expression: {self.best_expression}") |
| logger.info(f"Unique expressions discovered: {len(self.discovered_expressions)}") |
|
|
| |
| top_exprs = sorted(self.discovered_expressions.items(), key=lambda x: -x[1])[:5] |
| logger.info("Top 5 expressions:") |
| for expr, r2 in top_exprs: |
| logger.info(f" R²={r2:.4f}: {expr}") |
|
|
| return { |
| "best_r2": self.best_r2, |
| "best_expression": self.best_expression, |
| "history": self.history, |
| "discovered_expressions": self.discovered_expressions, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Improved REINFORCE Symbolic Regression") |
| parser.add_argument("--model_path", type=str, default="gpt2") |
| parser.add_argument("--dataset", type=str, default="./data/ppo_test/sin_x1.csv") |
| parser.add_argument("--output_dir", type=str, default="./output/reinforce") |
| parser.add_argument("--epochs", type=int, default=50) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--grad_accum", type=int, default=2) |
| parser.add_argument("--lr", type=float, default=5e-5) |
| parser.add_argument("--entropy_coef", type=float, default=0.01) |
| parser.add_argument("--patience", type=int, default=20) |
| parser.add_argument("--cpu", action="store_true") |
|
|
| args = parser.parse_args() |
|
|
| |
| dataset_path = Path(args.dataset) |
| if not dataset_path.exists(): |
| logger.error(f"Dataset not found: {dataset_path}") |
| return |
|
|
| reg = RegressionDataset(str(dataset_path.parent), dataset_path.name) |
| X, y = reg.get_numpy() |
|
|
| |
| experiment = ImprovedREINFORCE( |
| model_path=args.model_path, |
| X=X, |
| y=y, |
| output_dir=args.output_dir, |
| learning_rate=args.lr, |
| device="cpu" if args.cpu else None, |
| entropy_coef=args.entropy_coef, |
| ) |
|
|
| results = experiment.run( |
| n_epochs=args.epochs, |
| batch_size=args.batch_size, |
| grad_accum_steps=args.grad_accum, |
| patience=args.patience, |
| ) |
|
|
| |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| results_file = Path(args.output_dir) / f"results_improved_{timestamp}.json" |
|
|
| |
| results_json = { |
| "best_r2": float(results["best_r2"]), |
| "best_expression": results["best_expression"], |
| "history": results["history"], |
| "discovered_expressions": {k: float(v) for k, v in results["discovered_expressions"].items()}, |
| } |
|
|
| with open(results_file, 'w') as f: |
| json.dump(results_json, f, indent=2) |
|
|
| logger.info(f"Results saved to: {results_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|