| |
| """ |
| Enhanced GRPO for Symbolic Regression with Epoch Tracking |
| Saves all expressions and metrics per epoch for analysis |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| import logging |
| import datetime |
| from pathlib import Path |
| from typing import List, Dict, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn.functional as F |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
|
|
| from expression import Expression |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EnhancedGRPO: |
| """Enhanced GRPO with comprehensive tracking.""" |
|
|
| def __init__( |
| self, |
| model_path: str, |
| X: np.ndarray, |
| y: np.ndarray, |
| output_dir: str = "./output/grpo_enhanced", |
| learning_rate: float = 5e-5, |
| device: str = None, |
| group_size: int = 8, |
| kl_coef: float = 0.01, |
| clip_range: float = 0.2, |
| |
| save_all_expressions: bool = True, |
| is_prefix: bool = True, |
| custom_prompt: str = None, |
| ): |
| self.X = X |
| self.y = y |
| self.n_vars = X.shape[1] |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| self.learning_rate = learning_rate |
| self.group_size = group_size |
| self.kl_coef = kl_coef |
| self.clip_range = clip_range |
|
|
| |
| self.save_all_expressions = save_all_expressions |
| self.is_prefix = is_prefix |
| self.custom_prompt = custom_prompt |
| self.epoch_history = [] |
| self.best_expression = None |
| self.best_r2 = -float('inf') |
| self.best_epoch = -1 |
|
|
| |
| if device: |
| self.device = torch.device(device) |
| else: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {self.device}") |
|
|
| |
| self._load_model(model_path) |
|
|
| |
| if custom_prompt: |
| self.prompt = custom_prompt |
| else: |
| self.prompt = self._build_prompt() |
| self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device) |
|
|
| |
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=learning_rate, |
| weight_decay=0.01 |
| ) |
|
|
| |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| self.optimizer, T_0=10, T_mult=2 |
| ) |
|
|
| |
| self.temperature = 0.8 |
|
|
| def _load_model(self, model_path: str): |
| """Load model and tokenizer.""" |
| logger.info(f"Loading model from {model_path}") |
|
|
| |
| if "/" in model_path and not os.path.exists(model_path): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.float32, |
| trust_remote_code=True |
| ).to(self.device) |
| else: |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
| |
| adapter_config_path = Path(model_path) / "adapter_config.json" |
| if adapter_config_path.exists(): |
| |
| with open(adapter_config_path, 'r') as f: |
| import json |
| adapter_config = json.load(f) |
| base_model_name = adapter_config.get("base_model_name_or_path", "gpt2") |
|
|
| logger.info(f"Loading base model: {base_model_name}") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_name, torch_dtype=torch.float32 |
| ).to(self.device) |
|
|
| |
| base_model.enable_input_require_grads() |
|
|
| self.model = PeftModel.from_pretrained(base_model, model_path).to(self.device) |
| else: |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_path, torch_dtype=torch.float32 |
| ).to(self.device) |
|
|
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| self.model.train() |
|
|
| def _build_prompt(self) -> str: |
| """Build unified prompt with all necessary operations.""" |
| |
| all_ops = ["*", "+", "-", "/", "sin", "cos", "tan", "exp", "log", "sqrt", "abs"] |
|
|
| |
| var_list = [f"x_{i+1}" for i in range(self.n_vars)] |
|
|
| if self.is_prefix: |
| |
| prompt = f"vars: {', '.join(var_list)}\noper: {', '.join(all_ops)}\ncons: C\nexpr: " |
| else: |
| |
| prompt = json.dumps({ |
| "vars": var_list, |
| "ops": all_ops, |
| "cons": "C", |
| "expr": "" |
| })[:-2] |
|
|
| return prompt |
|
|
| def extract_expression(self, text: str) -> str: |
| """Extract expression from generated text.""" |
| if self.is_prefix: |
| |
| if "expr:" in text: |
| text = text.split("expr:")[-1].strip() |
| |
| if "\n" in text: |
| text = text.split("\n")[0].strip() |
| |
| for marker in ['"}"', '"}', '"cons"', '"vars"', '"ops"']: |
| if marker in text: |
| text = text.split(marker)[0].strip() |
| return text.strip() |
| else: |
| |
| if '"expr":' in text: |
| start = text.find('"expr":') + len('"expr":') |
| text = text[start:].strip() |
| if text.startswith('"'): |
| text = text[1:] |
| end = text.find('"') |
| if end > 0: |
| text = text[:end] |
| return text.strip() |
|
|
| def compute_r2(self, expression_str: str) -> Tuple[float, bool, str]: |
| """Compute R^2 score with error tracking.""" |
| error_msg = "" |
|
|
| if not expression_str or expression_str.isspace(): |
| return -1.0, False, "Empty expression" |
|
|
| if 'C' in expression_str: |
| expression_str = expression_str.replace('C', '1') |
|
|
| try: |
| expr = Expression(expression_str, is_prefix=self.is_prefix) |
| if not expr.is_valid_on_dataset(self.X): |
| return -1.0, False, "Invalid on dataset" |
|
|
| y_pred = expr.evaluate(self.X) |
| if not np.all(np.isfinite(y_pred)): |
| return -1.0, False, "Non-finite predictions" |
|
|
| ss_res = np.sum((self.y - y_pred) ** 2) |
| ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
| if ss_tot == 0: |
| return 0.0, True, "" |
|
|
| r2 = 1 - (ss_res / ss_tot) |
| return float(np.clip(r2, -1.0, 1.0)), True, "" |
| except Exception as e: |
| return -1.0, False, str(e) |
|
|
| def generate_sample(self, max_new_tokens: int = 50) -> Dict: |
| """Generate a single sample.""" |
| generated_ids = self.prompt_ids.clone() |
| generated_tokens = [] |
| log_probs_list = [] |
|
|
| with torch.no_grad(): |
| for step in range(max_new_tokens): |
| outputs = self.model(generated_ids) |
| logits = outputs.logits[:, -1, :] / self.temperature |
|
|
| probs = F.softmax(logits, dim=-1) |
| log_probs = F.log_softmax(logits, dim=-1) |
|
|
| next_token = torch.multinomial(probs, num_samples=1) |
| token_log_prob = log_probs[0, next_token.item()].item() |
|
|
| generated_tokens.append(next_token.item()) |
| log_probs_list.append(token_log_prob) |
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
| if next_token.item() == self.tokenizer.eos_token_id: |
| break |
|
|
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| generated_text = text[len(self.prompt):] |
|
|
| |
| if self.is_prefix and ("\n" in generated_text or "vars:" in generated_text): |
| break |
| if not self.is_prefix and '"}' in generated_text: |
| break |
|
|
| |
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| expr_str = self.extract_expression(text) |
| r2, is_valid, error_msg = self.compute_r2(expr_str) |
|
|
| return { |
| "text": text, |
| "expression": expr_str, |
| "r2": r2, |
| "is_valid": is_valid, |
| "error": error_msg, |
| "tokens": generated_tokens, |
| "log_probs": log_probs_list, |
| "total_log_prob": sum(log_probs_list), |
| } |
|
|
| def train(self, epochs: int = 20, samples_per_group: int = 8, groups_per_epoch: int = 4): |
| """Training loop with comprehensive tracking.""" |
| logger.info("Starting GRPO training") |
| logger.info(f"Epochs: {epochs}, Groups per epoch: {groups_per_epoch}, Samples per group: {samples_per_group}") |
|
|
| for epoch in range(epochs): |
| epoch_data = { |
| "epoch": epoch, |
| "timestamp": datetime.datetime.now().isoformat(), |
| "expressions": [], |
| "metrics": {} |
| } |
|
|
| all_samples = [] |
| epoch_loss = 0 |
|
|
| for group_idx in range(groups_per_epoch): |
| |
| group = [] |
| self.model.eval() |
|
|
| for _ in range(samples_per_group): |
| sample = self.generate_sample() |
| group.append(sample) |
| all_samples.append(sample) |
|
|
| |
| if self.save_all_expressions: |
| epoch_data["expressions"].append({ |
| "expression": sample["expression"], |
| "r2": sample["r2"], |
| "is_valid": sample["is_valid"], |
| "error": sample.get("error", ""), |
| "group": group_idx |
| }) |
|
|
| |
| group_rewards = [s["r2"] if s["is_valid"] else -1.0 for s in group] |
| group_mean = np.mean(group_rewards) |
| group_std = np.std(group_rewards) + 1e-8 |
|
|
| advantages = [(r - group_mean) / group_std for r in group_rewards] |
|
|
| |
| self.model.train() |
| group_loss = 0 |
|
|
| for sample, advantage in zip(group, advantages): |
| if not sample["is_valid"]: |
| continue |
|
|
| |
| prompt_len = len(self.prompt_ids[0]) |
| token_ids = torch.cat([ |
| self.prompt_ids[0], |
| torch.tensor(sample["tokens"], device=self.device) |
| ]).unsqueeze(0) |
|
|
| |
| outputs = self.model(token_ids) |
| logits = outputs.logits[:, prompt_len-1:-1, :] |
|
|
| |
| log_probs = F.log_softmax(logits / self.temperature, dim=-1) |
| token_tensor = torch.tensor(sample["tokens"], device=self.device).unsqueeze(0) |
| new_log_probs = log_probs.gather(2, token_tensor.unsqueeze(-1)).squeeze(-1) |
|
|
| |
| loss = -new_log_probs.sum() * advantage |
|
|
| |
| entropy = -(torch.exp(log_probs) * log_probs).sum() |
| loss = loss - 0.01 * entropy |
|
|
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.optimizer.step() |
|
|
| group_loss += loss.item() |
|
|
| epoch_loss += group_loss |
|
|
| |
| self.scheduler.step() |
|
|
| |
| valid_samples = [s for s in all_samples if s["is_valid"]] |
| valid_r2s = [s["r2"] for s in valid_samples] |
|
|
| |
| if valid_r2s: |
| best_idx = np.argmax([s["r2"] if s["is_valid"] else -2 for s in all_samples]) |
| epoch_best = all_samples[best_idx] |
|
|
| |
| if epoch_best["r2"] > self.best_r2: |
| self.best_r2 = epoch_best["r2"] |
| self.best_expression = epoch_best["expression"] |
| self.best_epoch = epoch |
| else: |
| epoch_best = {"r2": -1.0, "expression": ""} |
|
|
| |
| epoch_data["metrics"] = { |
| "valid_rate": len(valid_samples) / len(all_samples) if all_samples else 0, |
| "mean_r2": np.mean(valid_r2s) if valid_r2s else -1.0, |
| "max_r2": max(valid_r2s) if valid_r2s else -1.0, |
| "min_r2": min(valid_r2s) if valid_r2s else -1.0, |
| "std_r2": np.std(valid_r2s) if valid_r2s else 0.0, |
| "best_expression": epoch_best["expression"], |
| "best_r2": epoch_best["r2"], |
| "unique_expressions": len(set(s["expression"] for s in all_samples if s["expression"])), |
| "total_samples": len(all_samples), |
| "loss": epoch_loss / max(len(valid_samples), 1) |
| } |
|
|
| |
| self.epoch_history.append(epoch_data) |
|
|
| |
| logger.info(f"Epoch {epoch}: Valid={len(valid_samples)}/{len(all_samples)} " |
| f"Best R²={epoch_best['r2']:.4f} " |
| f"Mean R²={epoch_data['metrics']['mean_r2']:.4f}") |
|
|
| |
| if (epoch + 1) % 5 == 0: |
| self.save_checkpoint(epoch) |
|
|
| |
| self.save_results() |
|
|
| return self.epoch_history |
|
|
| def save_checkpoint(self, epoch: int): |
| """Save model checkpoint.""" |
| checkpoint_dir = self.output_dir / f"checkpoint-{epoch}" |
| checkpoint_dir.mkdir(exist_ok=True) |
|
|
| |
| if hasattr(self.model, 'save_pretrained'): |
| self.model.save_pretrained(checkpoint_dir) |
|
|
| |
| self.tokenizer.save_pretrained(checkpoint_dir) |
|
|
| |
| with open(checkpoint_dir / "metrics.json", "w") as f: |
| json.dump({ |
| "epoch": epoch, |
| "best_r2": self.best_r2, |
| "best_expression": self.best_expression, |
| "best_epoch": self.best_epoch |
| }, f, indent=2) |
|
|
| def save_results(self): |
| """Save all results.""" |
| |
| with open(self.output_dir / "full_history.json", "w") as f: |
| json.dump(self.epoch_history, f, indent=2) |
|
|
| |
| summary = { |
| "best_expression": self.best_expression, |
| "best_r2": self.best_r2, |
| "best_epoch": self.best_epoch, |
| "total_epochs": len(self.epoch_history), |
| "final_valid_rate": self.epoch_history[-1]["metrics"]["valid_rate"] if self.epoch_history else 0, |
| "timestamp": datetime.datetime.now().isoformat() |
| } |
|
|
| with open(self.output_dir / "summary.json", "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| logger.info(f"Results saved to {self.output_dir}") |
| logger.info(f"Best expression: {self.best_expression} (R²={self.best_r2:.4f})") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", required=True, help="Path to model") |
| parser.add_argument("--dataset", required=True, help="Path to CSV dataset") |
| parser.add_argument("--output_dir", default="./output/grpo_enhanced") |
| parser.add_argument("--epochs", type=int, default=20) |
| parser.add_argument("--samples_per_group", type=int, default=8) |
| parser.add_argument("--groups_per_epoch", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=5e-5) |
| parser.add_argument("--is_prefix", action="store_true", help="Model uses prefix notation") |
| parser.add_argument("--custom_prompt", type=str, help="Custom unified prompt") |
| args = parser.parse_args() |
|
|
| |
| df = pd.read_csv(args.dataset) |
| X = df.iloc[:, :-1].values |
| y = df.iloc[:, -1].values |
|
|
| |
| grpo = EnhancedGRPO( |
| model_path=args.model_path, |
| X=X, |
| y=y, |
| output_dir=args.output_dir, |
| learning_rate=args.learning_rate, |
| is_prefix=args.is_prefix, |
| custom_prompt=args.custom_prompt, |
| save_all_expressions=True |
| ) |
|
|
| grpo.train( |
| epochs=args.epochs, |
| samples_per_group=args.samples_per_group, |
| groups_per_epoch=args.groups_per_epoch |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |