|
|
|
|
|
""" |
|
|
PPO Experiment using Legacy TRL API (v0.11.0 or earlier) |
|
|
|
|
|
This script uses the old PPOTrainer.step() API which accepts custom rewards |
|
|
directly. This is the fallback approach if the modern TRL API doesn't work. |
|
|
|
|
|
REQUIRES: pip install trl==0.11.0 |
|
|
|
|
|
Usage: |
|
|
pip install trl==0.11.0 # Downgrade TRL first |
|
|
python scripts/ppo_experiment_legacy.py --dataset ./data/ppo_test/sin_x1.csv |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import logging |
|
|
import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
|
|
|
from expression import Expression |
|
|
from dataset import RegressionDataset |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def check_trl_version(): |
|
|
"""Check if TRL version supports legacy API.""" |
|
|
import trl |
|
|
version = trl.__version__ |
|
|
major, minor = map(int, version.split('.')[:2]) |
|
|
|
|
|
if major > 0 or minor >= 12: |
|
|
logger.warning(f"TRL version {version} may not support legacy step() API") |
|
|
logger.warning("Consider: pip install trl==0.11.0") |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
class LegacyPPOSymbolicRegression: |
|
|
"""PPO-based symbolic regression using legacy TRL API.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
dataset_path: str, |
|
|
output_dir: str = "./output/ppo_legacy", |
|
|
batch_size: int = 16, |
|
|
learning_rate: float = 1e-5, |
|
|
): |
|
|
self.model_path = model_path |
|
|
self.dataset_path = Path(dataset_path) |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
self.batch_size = batch_size |
|
|
self.learning_rate = learning_rate |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self._load_dataset() |
|
|
|
|
|
|
|
|
self._load_model() |
|
|
|
|
|
|
|
|
self._build_prompt() |
|
|
|
|
|
|
|
|
self._setup_ppo() |
|
|
|
|
|
|
|
|
self.best_r2 = -np.inf |
|
|
self.best_expression = None |
|
|
self.history = [] |
|
|
|
|
|
def _load_dataset(self): |
|
|
"""Load regression dataset.""" |
|
|
logger.info(f"Loading dataset from {self.dataset_path}") |
|
|
reg = RegressionDataset(str(self.dataset_path.parent), self.dataset_path.name) |
|
|
self.X, self.y = reg.get_numpy() |
|
|
self.n_vars = self.X.shape[1] |
|
|
logger.info(f"Dataset: {self.X.shape[0]} samples, {self.n_vars} variables") |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the JSON format model with LoRA adapters.""" |
|
|
logger.info(f"Loading model from {self.model_path}") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float32) |
|
|
|
|
|
|
|
|
if len(self.tokenizer) != base_model.config.vocab_size: |
|
|
base_model.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
|
|
|
|
|
try: |
|
|
model_with_lora = PeftModel.from_pretrained(base_model, self.model_path) |
|
|
merged_model = model_with_lora.merge_and_unload() |
|
|
logger.info("LoRA adapter loaded and merged") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load as PEFT model: {e}") |
|
|
merged_model = AutoModelForCausalLM.from_pretrained(self.model_path) |
|
|
|
|
|
|
|
|
try: |
|
|
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead |
|
|
self.ppo_modules = { |
|
|
'PPOConfig': PPOConfig, |
|
|
'PPOTrainer': PPOTrainer, |
|
|
'AutoModelForCausalLMWithValueHead': AutoModelForCausalLMWithValueHead, |
|
|
} |
|
|
except ImportError: |
|
|
logger.error("Could not import legacy TRL modules") |
|
|
logger.error("Try: pip install trl==0.11.0") |
|
|
raise |
|
|
|
|
|
|
|
|
ValueHeadModel = self.ppo_modules['AutoModelForCausalLMWithValueHead'] |
|
|
self.model = ValueHeadModel.from_pretrained(merged_model) |
|
|
self.ref_model = ValueHeadModel.from_pretrained(merged_model) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
self.ref_model = self.ref_model.to(self.device) |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
def _build_prompt(self): |
|
|
"""Build JSON format prompt.""" |
|
|
vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
|
|
ops_list = ["+", "-", "*", "sin", "cos"] |
|
|
|
|
|
self.prompt = json.dumps({ |
|
|
"vars": vars_list, |
|
|
"ops": ops_list, |
|
|
"cons": None, |
|
|
"expr": "" |
|
|
})[:-3] |
|
|
|
|
|
logger.info(f"Prompt: {self.prompt}...") |
|
|
|
|
|
def _setup_ppo(self): |
|
|
"""Setup legacy PPO trainer.""" |
|
|
PPOConfig = self.ppo_modules['PPOConfig'] |
|
|
PPOTrainer = self.ppo_modules['PPOTrainer'] |
|
|
|
|
|
self.ppo_config = PPOConfig( |
|
|
learning_rate=self.learning_rate, |
|
|
batch_size=self.batch_size, |
|
|
mini_batch_size=min(4, self.batch_size), |
|
|
ppo_epochs=4, |
|
|
log_with=None, |
|
|
) |
|
|
|
|
|
self.ppo_trainer = PPOTrainer( |
|
|
config=self.ppo_config, |
|
|
model=self.model, |
|
|
ref_model=self.ref_model, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
logger.info("Legacy PPO trainer ready") |
|
|
|
|
|
def extract_expression(self, text: str) -> str: |
|
|
"""Extract expression from JSON output.""" |
|
|
try: |
|
|
if '"expr": "' in text: |
|
|
start = text.index('"expr": "') + len('"expr": "') |
|
|
remaining = text[start:] |
|
|
if '"}' in remaining: |
|
|
return remaining[:remaining.index('"}')].strip() |
|
|
if '"' in remaining: |
|
|
return remaining[:remaining.index('"')].strip() |
|
|
return remaining.strip() |
|
|
|
|
|
if '"expr": ' in text: |
|
|
start = text.index('"expr": ') + len('"expr": ') |
|
|
remaining = text[start:] |
|
|
if '"}' in remaining: |
|
|
return remaining[:remaining.index('"}')].strip() |
|
|
return remaining.strip() |
|
|
|
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
return text.split('"expr"')[-1].strip(' ":}') |
|
|
|
|
|
def compute_reward(self, expression_str: str) -> float: |
|
|
"""Compute R² reward for an expression.""" |
|
|
if not expression_str or expression_str.isspace(): |
|
|
return -1.0 |
|
|
|
|
|
|
|
|
if 'C' in expression_str: |
|
|
expression_str = expression_str.replace('C', '1') |
|
|
|
|
|
try: |
|
|
expr = Expression(expression_str, is_prefix=False) |
|
|
|
|
|
if not expr.is_valid_on_dataset(self.X): |
|
|
return -1.0 |
|
|
|
|
|
y_pred = expr.evaluate(self.X) |
|
|
|
|
|
if not np.all(np.isfinite(y_pred)): |
|
|
return -1.0 |
|
|
|
|
|
ss_res = np.sum((self.y - y_pred) ** 2) |
|
|
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
|
|
|
if ss_tot == 0: |
|
|
return 0.0 |
|
|
|
|
|
r2 = 1 - (ss_res / ss_tot) |
|
|
return float(np.clip(r2, -1.0, 1.0)) |
|
|
except Exception: |
|
|
return -1.0 |
|
|
|
|
|
def train_epoch(self, epoch: int): |
|
|
"""Run one epoch of PPO training using legacy step() API.""" |
|
|
logger.info(f"\n{'='*60}\nEPOCH {epoch + 1}\n{'='*60}") |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
[self.prompt] * self.batch_size, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
).to(self.device) |
|
|
|
|
|
queries = [inputs["input_ids"][i] for i in range(self.batch_size)] |
|
|
|
|
|
|
|
|
responses = [] |
|
|
expressions = [] |
|
|
rewards = [] |
|
|
|
|
|
for i in tqdm(range(self.batch_size), desc="Generating"): |
|
|
output = self.model.generate( |
|
|
input_ids=inputs["input_ids"][i:i+1], |
|
|
attention_mask=inputs["attention_mask"][i:i+1], |
|
|
max_new_tokens=50, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.9, |
|
|
temperature=0.7, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
response_ids = output[0][inputs["input_ids"].shape[1]:] |
|
|
full_text = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
expr_str = self.extract_expression(full_text) |
|
|
reward = self.compute_reward(expr_str) |
|
|
|
|
|
responses.append(response_ids) |
|
|
expressions.append(expr_str) |
|
|
rewards.append(reward) |
|
|
|
|
|
|
|
|
reward_tensors = [torch.tensor(r, dtype=torch.float32, device=self.device) for r in rewards] |
|
|
|
|
|
|
|
|
try: |
|
|
stats = self.ppo_trainer.step(queries, responses, reward_tensors) |
|
|
logger.info(f"PPO step completed") |
|
|
except Exception as e: |
|
|
logger.error(f"PPO step failed: {e}") |
|
|
stats = {} |
|
|
|
|
|
|
|
|
valid_count = sum(1 for r in rewards if r > 0) |
|
|
rewards_array = np.array(rewards) |
|
|
|
|
|
epoch_result = { |
|
|
"epoch": epoch + 1, |
|
|
"valid_count": valid_count, |
|
|
"valid_rate": valid_count / len(rewards), |
|
|
"mean_reward": float(np.mean(rewards_array)), |
|
|
"max_reward": float(np.max(rewards_array)), |
|
|
"top_expressions": [], |
|
|
} |
|
|
|
|
|
|
|
|
sorted_idx = np.argsort(rewards)[::-1] |
|
|
for i in sorted_idx[:5]: |
|
|
if rewards[i] > -1.0: |
|
|
epoch_result["top_expressions"].append({ |
|
|
"expression": expressions[i], |
|
|
"r2": rewards[i], |
|
|
}) |
|
|
|
|
|
if rewards[i] > self.best_r2: |
|
|
self.best_r2 = rewards[i] |
|
|
self.best_expression = expressions[i] |
|
|
|
|
|
self.history.append(epoch_result) |
|
|
|
|
|
|
|
|
logger.info(f"Valid: {valid_count}/{len(rewards)} ({epoch_result['valid_rate']:.1%})") |
|
|
logger.info(f"Mean R²: {epoch_result['mean_reward']:.4f}") |
|
|
logger.info(f"Max R²: {epoch_result['max_reward']:.4f}") |
|
|
|
|
|
if epoch_result["top_expressions"]: |
|
|
logger.info("Top expressions:") |
|
|
for i, expr in enumerate(epoch_result["top_expressions"][:3]): |
|
|
logger.info(f" {i+1}. {expr['expression']} (R²={expr['r2']:.4f})") |
|
|
|
|
|
return epoch_result |
|
|
|
|
|
def run(self, n_epochs: int = 10): |
|
|
"""Run PPO training.""" |
|
|
logger.info("="*60) |
|
|
logger.info("LEGACY PPO SYMBOLIC REGRESSION") |
|
|
logger.info("="*60) |
|
|
logger.info(f"Dataset: {self.dataset_path}") |
|
|
logger.info(f"Model: {self.model_path}") |
|
|
logger.info(f"Epochs: {n_epochs}") |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
self.train_epoch(epoch) |
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
"epoch": epoch + 1, |
|
|
"best_r2": self.best_r2, |
|
|
"best_expression": self.best_expression, |
|
|
"history": self.history, |
|
|
} |
|
|
|
|
|
with open(self.output_dir / f"checkpoint_{epoch+1}.json", 'w') as f: |
|
|
json.dump(checkpoint, f, indent=2) |
|
|
|
|
|
|
|
|
if self.best_r2 > 0.99: |
|
|
logger.info(f"Early stopping: R² > 0.99") |
|
|
break |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("TRAINING COMPLETE") |
|
|
logger.info("="*60) |
|
|
logger.info(f"Best R²: {self.best_r2:.4f}") |
|
|
logger.info(f"Best expression: {self.best_expression}") |
|
|
|
|
|
|
|
|
final_file = self.output_dir / f"final_results_{timestamp}.json" |
|
|
with open(final_file, 'w') as f: |
|
|
json.dump({ |
|
|
"best_r2": self.best_r2, |
|
|
"best_expression": self.best_expression, |
|
|
"history": self.history, |
|
|
}, f, indent=2) |
|
|
|
|
|
logger.info(f"Results saved to: {final_file}") |
|
|
|
|
|
return self.history |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Legacy PPO Symbolic Regression") |
|
|
parser.add_argument("--model_path", type=str, default="./output/exp_a_json") |
|
|
parser.add_argument("--dataset", type=str, default="./data/ppo_test/sin_x1.csv") |
|
|
parser.add_argument("--output_dir", type=str, default="./output/ppo_legacy") |
|
|
parser.add_argument("--batch_size", type=int, default=16) |
|
|
parser.add_argument("--epochs", type=int, default=10) |
|
|
parser.add_argument("--lr", type=float, default=1e-5) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
check_trl_version() |
|
|
|
|
|
experiment = LegacyPPOSymbolicRegression( |
|
|
model_path=args.model_path, |
|
|
dataset_path=args.dataset, |
|
|
output_dir=args.output_dir, |
|
|
batch_size=args.batch_size, |
|
|
learning_rate=args.lr, |
|
|
) |
|
|
|
|
|
experiment.run(n_epochs=args.epochs) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|