| | |
| | """ |
| | REINFORCE for Symbolic Regression |
| | |
| | Simple policy gradient (REINFORCE) with R^2 as reward. |
| | This is a more natural fit for symbolic regression than GRPO/PPO |
| | since we have absolute rewards, not relative preferences. |
| | |
| | Algorithm: |
| | 1. Generate expressions using current policy (model) |
| | 2. Compute R^2 reward for each expression |
| | 3. Compute policy gradient: grad = reward * grad(log_prob) |
| | 4. Update model parameters |
| | |
| | With baseline subtraction for variance reduction: |
| | grad = (reward - baseline) * grad(log_prob) |
| | """ |
| |
|
| | import os |
| | import sys |
| | import json |
| | import argparse |
| | import logging |
| | import datetime |
| | from pathlib import Path |
| | from typing import List, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).parent.parent |
| | sys.path.insert(0, str(PROJECT_ROOT)) |
| | sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
| |
|
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from peft import PeftModel, LoraConfig, get_peft_model |
| |
|
| | from expression import Expression |
| | from dataset import RegressionDataset |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s', |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class REINFORCESymbolicRegression: |
| | """REINFORCE algorithm for symbolic regression.""" |
| |
|
| | def __init__( |
| | self, |
| | model_path: str, |
| | X: np.ndarray, |
| | y: np.ndarray, |
| | output_dir: str = "./output/reinforce", |
| | learning_rate: float = 1e-5, |
| | device: str = None, |
| | ): |
| | self.X = X |
| | self.y = y |
| | self.n_vars = X.shape[1] |
| | self.output_dir = Path(output_dir) |
| | self.output_dir.mkdir(parents=True, exist_ok=True) |
| | self.learning_rate = learning_rate |
| |
|
| | |
| | if device: |
| | self.device = torch.device(device) |
| | else: |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | logger.info(f"Using device: {self.device}") |
| |
|
| | |
| | self._load_model(model_path) |
| |
|
| | |
| | self.prompt = self._build_prompt() |
| | self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device) |
| |
|
| | |
| | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) |
| |
|
| | |
| | self.best_r2 = -np.inf |
| | self.best_expression = None |
| | self.history = [] |
| | self.baseline = 0.0 |
| |
|
| | def _load_model(self, model_path: str): |
| | """Load model and tokenizer.""" |
| | logger.info(f"Loading model from {model_path}") |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
|
| | |
| | try: |
| | logger.info("Attempting to load as LoRA adapter...") |
| | base_model = AutoModelForCausalLM.from_pretrained("gpt2") |
| | if len(self.tokenizer) != base_model.config.vocab_size: |
| | base_model.resize_token_embeddings(len(self.tokenizer)) |
| | logger.info(f"Resized embeddings to {len(self.tokenizer)}") |
| |
|
| | model_with_lora = PeftModel.from_pretrained(base_model, model_path) |
| | self.model = model_with_lora.merge_and_unload() |
| | logger.info("LoRA adapter loaded and merged successfully") |
| | except Exception as e: |
| | logger.info(f"LoRA load failed ({e}), loading as standalone model...") |
| | self.model = AutoModelForCausalLM.from_pretrained(model_path) |
| |
|
| | |
| | lora_config = LoraConfig( |
| | r=8, |
| | lora_alpha=32, |
| | target_modules=["c_attn"], |
| | lora_dropout=0.05, |
| | bias="none", |
| | ) |
| | self.model = get_peft_model(self.model, lora_config) |
| | self.model = self.model.to(self.device) |
| | self.model.train() |
| |
|
| | logger.info(f"Model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)} trainable params") |
| |
|
| | def _build_prompt(self) -> str: |
| | """Build JSON format prompt compatible with trained model.""" |
| | vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
| | ops_list = ["+", "-", "*", "sin", "cos"] |
| |
|
| | |
| | |
| | prompt = json.dumps({ |
| | "vars": vars_list, |
| | "ops": ops_list, |
| | "cons": "C", |
| | "expr": "" |
| | }) |
| | |
| | prompt = prompt[:-2] |
| |
|
| | return prompt |
| |
|
| | def extract_expression(self, text: str) -> str: |
| | """Extract expression from generated text. |
| | |
| | The trained model generates: {"vars": [...], "expr": "sin(x_1) + C<|endoftext|> |
| | So we need to extract text after "expr": " and before EOS or end of string. |
| | """ |
| | try: |
| | |
| | eos_token = "<|endoftext|>" |
| | if eos_token in text: |
| | text = text[:text.index(eos_token)] |
| |
|
| | |
| | if '"expr": "' in text: |
| | start = text.index('"expr": "') + len('"expr": "') |
| | remaining = text[start:] |
| | |
| | for terminator in ['"}', '"']: |
| | if terminator in remaining: |
| | return remaining[:remaining.index(terminator)].strip() |
| | |
| | return remaining.strip() |
| |
|
| | if '"expr": ' in text: |
| | start = text.index('"expr": ') + len('"expr": ') |
| | remaining = text[start:] |
| | if '"}' in remaining: |
| | return remaining[:remaining.index('"}')].strip() |
| | return remaining.strip(' "') |
| |
|
| | except (ValueError, IndexError): |
| | pass |
| |
|
| | |
| | if '"expr"' in text: |
| | return text.split('"expr"')[-1].strip(' ":{}') |
| | return text.strip() |
| |
|
| | def compute_r2(self, expression_str: str) -> float: |
| | """Compute R^2 score.""" |
| | if not expression_str or expression_str.isspace(): |
| | return -1.0 |
| |
|
| | if 'C' in expression_str: |
| | expression_str = expression_str.replace('C', '1') |
| |
|
| | try: |
| | expr = Expression(expression_str, is_prefix=False) |
| | if not expr.is_valid_on_dataset(self.X): |
| | return -1.0 |
| |
|
| | y_pred = expr.evaluate(self.X) |
| | if not np.all(np.isfinite(y_pred)): |
| | return -1.0 |
| |
|
| | ss_res = np.sum((self.y - y_pred) ** 2) |
| | ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
| |
|
| | if ss_tot == 0: |
| | return 0.0 |
| |
|
| | r2 = 1 - (ss_res / ss_tot) |
| | return float(np.clip(r2, -1.0, 1.0)) |
| | except Exception: |
| | return -1.0 |
| |
|
| | def generate_with_logprobs(self, max_new_tokens: int = 50, temperature: float = 0.7) -> Tuple[str, torch.Tensor]: |
| | """Generate a sequence and return log probabilities with gradients.""" |
| | generated_ids = self.prompt_ids.clone() |
| | generated_tokens = [] |
| |
|
| | |
| | with torch.no_grad(): |
| | for _ in range(max_new_tokens): |
| | outputs = self.model(generated_ids) |
| | logits = outputs.logits[:, -1, :] / temperature |
| |
|
| | |
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | generated_tokens.append(next_token.item()) |
| |
|
| | |
| | generated_ids = torch.cat([generated_ids, next_token], dim=1) |
| |
|
| | |
| | if next_token.item() == self.tokenizer.eos_token_id: |
| | break |
| |
|
| | |
| | text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| | if '"}' in text[len(self.prompt):]: |
| | break |
| |
|
| | |
| | |
| | if len(generated_tokens) == 0: |
| | |
| | text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| | return text, torch.tensor(0.0, device=self.device, requires_grad=True) |
| |
|
| | |
| | log_probs = [] |
| | current_ids = self.prompt_ids.clone() |
| |
|
| | for token_id in generated_tokens: |
| | outputs = self.model(current_ids) |
| | logits = outputs.logits[:, -1, :] / temperature |
| | log_prob = F.log_softmax(logits, dim=-1) |
| |
|
| | |
| | token_tensor = torch.tensor([[token_id]], device=self.device) |
| | selected_log_prob = log_prob.gather(1, token_tensor) |
| | log_probs.append(selected_log_prob) |
| |
|
| | |
| | current_ids = torch.cat([current_ids, token_tensor], dim=1) |
| |
|
| | total_log_prob = torch.cat(log_probs, dim=1).sum() |
| | text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| |
|
| | return text, total_log_prob |
| |
|
| | def train_step(self, batch_size: int = 8) -> dict: |
| | """Perform one REINFORCE training step.""" |
| | self.model.train() |
| |
|
| | texts = [] |
| | expressions = [] |
| | rewards = [] |
| | log_probs = [] |
| |
|
| | |
| | for _ in range(batch_size): |
| | text, log_prob = self.generate_with_logprobs() |
| | expr_str = self.extract_expression(text) |
| | r2 = self.compute_r2(expr_str) |
| |
|
| | texts.append(text) |
| | expressions.append(expr_str) |
| | rewards.append(r2) |
| | log_probs.append(log_prob) |
| |
|
| | |
| | if r2 > self.best_r2: |
| | self.best_r2 = r2 |
| | self.best_expression = expr_str |
| |
|
| | |
| | rewards_tensor = torch.tensor(rewards, dtype=torch.float32, device=self.device) |
| |
|
| | |
| | valid_rewards = rewards_tensor[rewards_tensor > -1.0] |
| | if len(valid_rewards) > 0: |
| | self.baseline = 0.9 * self.baseline + 0.1 * valid_rewards.mean().item() |
| |
|
| | |
| | |
| | policy_loss = 0.0 |
| | valid_count = 0 |
| |
|
| | for log_prob, reward in zip(log_probs, rewards): |
| | if reward > -1.0: |
| | advantage = reward - self.baseline |
| | policy_loss -= log_prob * advantage |
| | valid_count += 1 |
| |
|
| | if valid_count > 0: |
| | policy_loss = policy_loss / valid_count |
| |
|
| | |
| | self.optimizer.zero_grad() |
| | policy_loss.backward() |
| |
|
| | |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| |
|
| | self.optimizer.step() |
| |
|
| | |
| | rewards_array = np.array(rewards) |
| | valid_mask = rewards_array > -1.0 |
| |
|
| | return { |
| | "valid_count": int(valid_mask.sum()), |
| | "valid_rate": float(valid_mask.mean()), |
| | "mean_reward": float(rewards_array[valid_mask].mean()) if valid_mask.any() else 0.0, |
| | "max_reward": float(rewards_array.max()), |
| | "baseline": float(self.baseline), |
| | "policy_loss": policy_loss.item() if isinstance(policy_loss, torch.Tensor) else 0.0, |
| | } |
| |
|
| | def run( |
| | self, |
| | n_epochs: int = 50, |
| | batch_size: int = 8, |
| | target_r2: float = 0.99, |
| | ): |
| | """Run REINFORCE training.""" |
| | logger.info("=" * 60) |
| | logger.info("REINFORCE SYMBOLIC REGRESSION") |
| | logger.info("=" * 60) |
| | logger.info(f"Epochs: {n_epochs}") |
| | logger.info(f"Batch size: {batch_size}") |
| | logger.info(f"Target R^2: {target_r2}") |
| | logger.info("=" * 60) |
| |
|
| | for epoch in range(n_epochs): |
| | stats = self.train_step(batch_size) |
| |
|
| | self.history.append({ |
| | "epoch": epoch + 1, |
| | **stats, |
| | "best_r2": self.best_r2, |
| | }) |
| |
|
| | |
| | if (epoch + 1) % 5 == 0 or epoch == 0: |
| | logger.info( |
| | f"Epoch {epoch+1:3d} | " |
| | f"Valid: {stats['valid_count']}/{batch_size} | " |
| | f"Mean R^2: {stats['mean_reward']:.4f} | " |
| | f"Best: {self.best_r2:.4f} | " |
| | f"Baseline: {self.baseline:.4f}" |
| | ) |
| |
|
| | |
| | if self.best_r2 >= target_r2: |
| | logger.info(f"Target R^2 {target_r2} reached at epoch {epoch+1}!") |
| | break |
| |
|
| | |
| | logger.info("\n" + "=" * 60) |
| | logger.info("FINAL RESULTS") |
| | logger.info("=" * 60) |
| | logger.info(f"Best R^2: {self.best_r2:.4f}") |
| | logger.info(f"Best expression: {self.best_expression}") |
| |
|
| | return { |
| | "best_r2": self.best_r2, |
| | "best_expression": self.best_expression, |
| | "history": self.history, |
| | } |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="REINFORCE Symbolic Regression") |
| | parser.add_argument("--model_path", type=str, default="gpt2") |
| | parser.add_argument("--dataset", type=str, default="./data/ppo_test/sin_x1.csv") |
| | parser.add_argument("--output_dir", type=str, default="./output/reinforce") |
| | parser.add_argument("--epochs", type=int, default=50) |
| | parser.add_argument("--batch_size", type=int, default=8) |
| | parser.add_argument("--lr", type=float, default=1e-5) |
| | parser.add_argument("--cpu", action="store_true") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | dataset_path = Path(args.dataset) |
| | if not dataset_path.exists(): |
| | logger.error(f"Dataset not found: {dataset_path}") |
| | return |
| |
|
| | reg = RegressionDataset(str(dataset_path.parent), dataset_path.name) |
| | X, y = reg.get_numpy() |
| |
|
| | |
| | experiment = REINFORCESymbolicRegression( |
| | model_path=args.model_path, |
| | X=X, |
| | y=y, |
| | output_dir=args.output_dir, |
| | learning_rate=args.lr, |
| | device="cpu" if args.cpu else None, |
| | ) |
| |
|
| | results = experiment.run( |
| | n_epochs=args.epochs, |
| | batch_size=args.batch_size, |
| | ) |
| |
|
| | |
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| | results_file = Path(args.output_dir) / f"results_{timestamp}.json" |
| | with open(results_file, 'w') as f: |
| | json.dump(results, f, indent=2) |
| |
|
| | logger.info(f"Results saved to: {results_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|