| |
| """ |
| Debug version of REINFORCE that saves ALL expressions (valid and invalid). |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| from pathlib import Path |
| from typing import List, Dict |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel, LoraConfig, get_peft_model |
| from expression import Expression |
|
|
|
|
| class DebugREINFORCE: |
| """REINFORCE that logs all expressions.""" |
|
|
| def __init__(self, model_path: str, X: np.ndarray, y: np.ndarray, device: str = None): |
| self.X = X |
| self.y = y |
| self.n_vars = X.shape[1] |
|
|
| if device: |
| self.device = torch.device(device) |
| else: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| try: |
| base_model = AutoModelForCausalLM.from_pretrained("gpt2") |
| if len(self.tokenizer) != base_model.config.vocab_size: |
| base_model.resize_token_embeddings(len(self.tokenizer)) |
| model_with_lora = PeftModel.from_pretrained(base_model, model_path) |
| self.model = model_with_lora.merge_and_unload() |
| except: |
| self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
| |
| lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], lora_dropout=0.05, bias="none") |
| self.model = get_peft_model(self.model, lora_config) |
| self.model = self.model.to(self.device) |
| self.model.train() |
|
|
| |
| vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
| ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"] |
| self.prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2] |
| self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device) |
|
|
| |
| self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5) |
|
|
| |
| self.baseline = 0.0 |
| self.baseline_decay = 0.9 |
|
|
| |
| self.all_expressions = [] |
|
|
| def extract_expression(self, text: str) -> str: |
| """Extract expression from generated text.""" |
| try: |
| if '"expr": "' in text: |
| start = text.index('"expr": "') + len('"expr": "') |
| remaining = text[start:] |
| for terminator in ['"}', '"']: |
| if terminator in remaining: |
| return remaining[:remaining.index(terminator)].strip() |
| except: |
| pass |
| return text.strip() |
|
|
| def compute_r2(self, expression_str: str) -> tuple: |
| """Compute R^2 and detailed error info.""" |
| result = { |
| "expression": expression_str, |
| "r2": -1.0, |
| "is_valid": False, |
| "error_type": None, |
| "error_message": None, |
| } |
|
|
| if not expression_str or expression_str.isspace(): |
| result["error_type"] = "empty" |
| return result |
|
|
| test_expr = expression_str.replace('C', '1') |
|
|
| try: |
| expr = Expression(test_expr, is_prefix=False) |
|
|
| if not expr.is_valid_on_dataset(self.X): |
| result["error_type"] = "invalid_on_dataset" |
| result["error_message"] = "NaN/Inf on dataset" |
| return result |
|
|
| y_pred = expr.evaluate(self.X) |
|
|
| if not np.all(np.isfinite(y_pred)): |
| result["error_type"] = "non_finite_output" |
| return result |
|
|
| ss_res = np.sum((self.y - y_pred) ** 2) |
| ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
| if ss_tot == 0: |
| r2 = 0.0 |
| else: |
| r2 = 1 - (ss_res / ss_tot) |
|
|
| result["r2"] = float(np.clip(r2, -1.0, 1.0)) |
| result["is_valid"] = True |
|
|
| except Exception as e: |
| result["error_type"] = "parse_error" |
| result["error_message"] = str(e)[:100] |
|
|
| return result |
|
|
| def generate_batch(self, batch_size: int = 16, max_new_tokens: int = 50): |
| """Generate batch and evaluate.""" |
| results = [] |
|
|
| for _ in range(batch_size): |
| generated_ids = self.prompt_ids.clone() |
| generated_tokens = [] |
|
|
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| outputs = self.model(generated_ids) |
| logits = outputs.logits[:, -1, :] / 0.7 |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| generated_tokens.append(next_token.item()) |
| generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
| if next_token.item() == self.tokenizer.eos_token_id: |
| break |
|
|
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| if '"}' in text[len(self.prompt):]: |
| break |
|
|
| text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| expr_str = self.extract_expression(text) |
|
|
| |
| eval_result = self.compute_r2(expr_str) |
|
|
| |
| if len(generated_tokens) > 0: |
| full_ids = torch.cat([self.prompt_ids, torch.tensor([generated_tokens], device=self.device)], dim=1) |
| outputs = self.model(full_ids[:, :-1]) |
| logits = outputs.logits / 0.7 |
| prompt_len = self.prompt_ids.shape[1] |
| gen_logits = logits[:, prompt_len-1:, :] |
| log_probs_all = F.log_softmax(gen_logits, dim=-1) |
| target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0) |
| selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1) |
| total_log_prob = selected_log_probs.sum() |
| else: |
| total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True) |
|
|
| eval_result["log_prob"] = total_log_prob |
| results.append(eval_result) |
|
|
| |
| self.all_expressions.append(eval_result.copy()) |
|
|
| return results |
|
|
| def train_step(self, batch_size: int = 16): |
| """One training step.""" |
| results = self.generate_batch(batch_size) |
|
|
| |
| rewards = [r["r2"] if r["is_valid"] else -0.1 for r in results] |
|
|
| |
| valid_rewards = [r for r in rewards if r > -0.1] |
| if valid_rewards: |
| mean_reward = np.mean(valid_rewards) |
| self.baseline = self.baseline_decay * self.baseline + (1 - self.baseline_decay) * mean_reward |
|
|
| |
| advantages = [r - self.baseline for r in rewards] |
|
|
| |
| self.optimizer.zero_grad() |
| policy_loss = torch.tensor(0.0, device=self.device) |
|
|
| for result, advantage in zip(results, advantages): |
| if result["is_valid"] or result["error_type"] == "parse_error": |
| policy_loss = policy_loss - result["log_prob"] * advantage |
|
|
| if len(results) > 0: |
| policy_loss = policy_loss / len(results) |
| policy_loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.step() |
|
|
| |
| valid_count = sum(1 for r in results if r["is_valid"]) |
| valid_r2 = [r["r2"] for r in results if r["is_valid"]] |
|
|
| return { |
| "valid_count": valid_count, |
| "total_count": len(results), |
| "mean_r2": np.mean(valid_r2) if valid_r2 else -1.0, |
| "max_r2": max(r["r2"] for r in results), |
| "baseline": self.baseline, |
| } |
|
|
| def run(self, epochs: int = 10): |
| """Run training.""" |
| print(f"Running debug REINFORCE for {epochs} epochs...") |
| print() |
|
|
| for epoch in range(1, epochs + 1): |
| stats = self.train_step() |
| print(f"Epoch {epoch:2d} | Valid: {stats['valid_count']}/{stats['total_count']} | Mean R²: {stats['mean_r2']:.4f} | Max R²: {stats['max_r2']:.4f}") |
|
|
| |
| output_file = "debug_expressions.json" |
| with open(output_file, "w") as f: |
| json.dump({"all_expressions": self.all_expressions}, f, indent=2, default=str) |
|
|
| print() |
| print(f"Saved {len(self.all_expressions)} expressions to {output_file}") |
|
|
| |
| valid = [e for e in self.all_expressions if e["is_valid"]] |
| invalid = [e for e in self.all_expressions if not e["is_valid"]] |
|
|
| print() |
| print("SUMMARY:") |
| print(f" Total: {len(self.all_expressions)}") |
| print(f" Valid: {len(valid)} ({100*len(valid)/len(self.all_expressions):.1f}%)") |
| print(f" Invalid: {len(invalid)} ({100*len(invalid)/len(self.all_expressions):.1f}%)") |
|
|
| if invalid: |
| error_types = {} |
| for e in invalid: |
| et = e.get("error_type", "unknown") |
| error_types[et] = error_types.get(et, 0) + 1 |
|
|
| print() |
| print("Invalid expression types:") |
| for et, count in sorted(error_types.items(), key=lambda x: -x[1]): |
| print(f" {et}: {count} ({100*count/len(invalid):.1f}%)") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", type=str, required=True) |
| parser.add_argument("--dataset", type=str, required=True) |
| parser.add_argument("--epochs", type=int, default=10) |
| args = parser.parse_args() |
|
|
| |
| import pandas as pd |
| df = pd.read_csv(args.dataset) |
| x_cols = [c for c in df.columns if c.startswith('x_')] |
| X = df[x_cols].values |
| y = df['y'].values |
|
|
| print(f"Dataset: {args.dataset}") |
| print(f" Samples: {len(df)}, Variables: {len(x_cols)}") |
| print() |
|
|
| |
| reinforce = DebugREINFORCE(args.model_path, X, y) |
| reinforce.run(epochs=args.epochs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|