|
|
|
|
|
""" |
|
|
Debug version of REINFORCE that saves ALL expressions (valid and invalid). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import List, Dict |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel, LoraConfig, get_peft_model |
|
|
from expression import Expression |
|
|
|
|
|
|
|
|
class DebugREINFORCE: |
|
|
"""REINFORCE that logs all expressions.""" |
|
|
|
|
|
def __init__(self, model_path: str, X: np.ndarray, y: np.ndarray, device: str = None): |
|
|
self.X = X |
|
|
self.y = y |
|
|
self.n_vars = X.shape[1] |
|
|
|
|
|
if device: |
|
|
self.device = torch.device(device) |
|
|
else: |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
try: |
|
|
base_model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
if len(self.tokenizer) != base_model.config.vocab_size: |
|
|
base_model.resize_token_embeddings(len(self.tokenizer)) |
|
|
model_with_lora = PeftModel.from_pretrained(base_model, model_path) |
|
|
self.model = model_with_lora.merge_and_unload() |
|
|
except: |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], lora_dropout=0.05, bias="none") |
|
|
self.model = get_peft_model(self.model, lora_config) |
|
|
self.model = self.model.to(self.device) |
|
|
self.model.train() |
|
|
|
|
|
|
|
|
vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
|
|
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"] |
|
|
self.prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2] |
|
|
self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device) |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5) |
|
|
|
|
|
|
|
|
self.baseline = 0.0 |
|
|
self.baseline_decay = 0.9 |
|
|
|
|
|
|
|
|
self.all_expressions = [] |
|
|
|
|
|
def extract_expression(self, text: str) -> str: |
|
|
"""Extract expression from generated text.""" |
|
|
try: |
|
|
if '"expr": "' in text: |
|
|
start = text.index('"expr": "') + len('"expr": "') |
|
|
remaining = text[start:] |
|
|
for terminator in ['"}', '"']: |
|
|
if terminator in remaining: |
|
|
return remaining[:remaining.index(terminator)].strip() |
|
|
except: |
|
|
pass |
|
|
return text.strip() |
|
|
|
|
|
def compute_r2(self, expression_str: str) -> tuple: |
|
|
"""Compute R^2 and detailed error info.""" |
|
|
result = { |
|
|
"expression": expression_str, |
|
|
"r2": -1.0, |
|
|
"is_valid": False, |
|
|
"error_type": None, |
|
|
"error_message": None, |
|
|
} |
|
|
|
|
|
if not expression_str or expression_str.isspace(): |
|
|
result["error_type"] = "empty" |
|
|
return result |
|
|
|
|
|
test_expr = expression_str.replace('C', '1') |
|
|
|
|
|
try: |
|
|
expr = Expression(test_expr, is_prefix=False) |
|
|
|
|
|
if not expr.is_valid_on_dataset(self.X): |
|
|
result["error_type"] = "invalid_on_dataset" |
|
|
result["error_message"] = "NaN/Inf on dataset" |
|
|
return result |
|
|
|
|
|
y_pred = expr.evaluate(self.X) |
|
|
|
|
|
if not np.all(np.isfinite(y_pred)): |
|
|
result["error_type"] = "non_finite_output" |
|
|
return result |
|
|
|
|
|
ss_res = np.sum((self.y - y_pred) ** 2) |
|
|
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
|
|
|
if ss_tot == 0: |
|
|
r2 = 0.0 |
|
|
else: |
|
|
r2 = 1 - (ss_res / ss_tot) |
|
|
|
|
|
result["r2"] = float(np.clip(r2, -1.0, 1.0)) |
|
|
result["is_valid"] = True |
|
|
|
|
|
except Exception as e: |
|
|
result["error_type"] = "parse_error" |
|
|
result["error_message"] = str(e)[:100] |
|
|
|
|
|
return result |
|
|
|
|
|
def generate_batch(self, batch_size: int = 16, max_new_tokens: int = 50): |
|
|
"""Generate batch and evaluate.""" |
|
|
results = [] |
|
|
|
|
|
for _ in range(batch_size): |
|
|
generated_ids = self.prompt_ids.clone() |
|
|
generated_tokens = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_new_tokens): |
|
|
outputs = self.model(generated_ids) |
|
|
logits = outputs.logits[:, -1, :] / 0.7 |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
generated_tokens.append(next_token.item()) |
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
|
|
|
if next_token.item() == self.tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
if '"}' in text[len(self.prompt):]: |
|
|
break |
|
|
|
|
|
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
expr_str = self.extract_expression(text) |
|
|
|
|
|
|
|
|
eval_result = self.compute_r2(expr_str) |
|
|
|
|
|
|
|
|
if len(generated_tokens) > 0: |
|
|
full_ids = torch.cat([self.prompt_ids, torch.tensor([generated_tokens], device=self.device)], dim=1) |
|
|
outputs = self.model(full_ids[:, :-1]) |
|
|
logits = outputs.logits / 0.7 |
|
|
prompt_len = self.prompt_ids.shape[1] |
|
|
gen_logits = logits[:, prompt_len-1:, :] |
|
|
log_probs_all = F.log_softmax(gen_logits, dim=-1) |
|
|
target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0) |
|
|
selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
total_log_prob = selected_log_probs.sum() |
|
|
else: |
|
|
total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True) |
|
|
|
|
|
eval_result["log_prob"] = total_log_prob |
|
|
results.append(eval_result) |
|
|
|
|
|
|
|
|
self.all_expressions.append(eval_result.copy()) |
|
|
|
|
|
return results |
|
|
|
|
|
def train_step(self, batch_size: int = 16): |
|
|
"""One training step.""" |
|
|
results = self.generate_batch(batch_size) |
|
|
|
|
|
|
|
|
rewards = [r["r2"] if r["is_valid"] else -0.1 for r in results] |
|
|
|
|
|
|
|
|
valid_rewards = [r for r in rewards if r > -0.1] |
|
|
if valid_rewards: |
|
|
mean_reward = np.mean(valid_rewards) |
|
|
self.baseline = self.baseline_decay * self.baseline + (1 - self.baseline_decay) * mean_reward |
|
|
|
|
|
|
|
|
advantages = [r - self.baseline for r in rewards] |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
policy_loss = torch.tensor(0.0, device=self.device) |
|
|
|
|
|
for result, advantage in zip(results, advantages): |
|
|
if result["is_valid"] or result["error_type"] == "parse_error": |
|
|
policy_loss = policy_loss - result["log_prob"] * advantage |
|
|
|
|
|
if len(results) > 0: |
|
|
policy_loss = policy_loss / len(results) |
|
|
policy_loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
valid_count = sum(1 for r in results if r["is_valid"]) |
|
|
valid_r2 = [r["r2"] for r in results if r["is_valid"]] |
|
|
|
|
|
return { |
|
|
"valid_count": valid_count, |
|
|
"total_count": len(results), |
|
|
"mean_r2": np.mean(valid_r2) if valid_r2 else -1.0, |
|
|
"max_r2": max(r["r2"] for r in results), |
|
|
"baseline": self.baseline, |
|
|
} |
|
|
|
|
|
def run(self, epochs: int = 10): |
|
|
"""Run training.""" |
|
|
print(f"Running debug REINFORCE for {epochs} epochs...") |
|
|
print() |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
|
stats = self.train_step() |
|
|
print(f"Epoch {epoch:2d} | Valid: {stats['valid_count']}/{stats['total_count']} | Mean R²: {stats['mean_r2']:.4f} | Max R²: {stats['max_r2']:.4f}") |
|
|
|
|
|
|
|
|
output_file = "debug_expressions.json" |
|
|
with open(output_file, "w") as f: |
|
|
json.dump({"all_expressions": self.all_expressions}, f, indent=2, default=str) |
|
|
|
|
|
print() |
|
|
print(f"Saved {len(self.all_expressions)} expressions to {output_file}") |
|
|
|
|
|
|
|
|
valid = [e for e in self.all_expressions if e["is_valid"]] |
|
|
invalid = [e for e in self.all_expressions if not e["is_valid"]] |
|
|
|
|
|
print() |
|
|
print("SUMMARY:") |
|
|
print(f" Total: {len(self.all_expressions)}") |
|
|
print(f" Valid: {len(valid)} ({100*len(valid)/len(self.all_expressions):.1f}%)") |
|
|
print(f" Invalid: {len(invalid)} ({100*len(invalid)/len(self.all_expressions):.1f}%)") |
|
|
|
|
|
if invalid: |
|
|
error_types = {} |
|
|
for e in invalid: |
|
|
et = e.get("error_type", "unknown") |
|
|
error_types[et] = error_types.get(et, 0) + 1 |
|
|
|
|
|
print() |
|
|
print("Invalid expression types:") |
|
|
for et, count in sorted(error_types.items(), key=lambda x: -x[1]): |
|
|
print(f" {et}: {count} ({100*count/len(invalid):.1f}%)") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_path", type=str, required=True) |
|
|
parser.add_argument("--dataset", type=str, required=True) |
|
|
parser.add_argument("--epochs", type=int, default=10) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
df = pd.read_csv(args.dataset) |
|
|
x_cols = [c for c in df.columns if c.startswith('x_')] |
|
|
X = df[x_cols].values |
|
|
y = df['y'].values |
|
|
|
|
|
print(f"Dataset: {args.dataset}") |
|
|
print(f" Samples: {len(df)}, Variables: {len(x_cols)}") |
|
|
print() |
|
|
|
|
|
|
|
|
reinforce = DebugREINFORCE(args.model_path, X, y) |
|
|
reinforce.run(epochs=args.epochs) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|