|
|
|
|
|
""" |
|
|
Iterative Sampling + SFT for Symbolic Regression |
|
|
|
|
|
This approach: |
|
|
1. Generate N expressions using the current model |
|
|
2. Evaluate R^2 for each expression |
|
|
3. Filter expressions with R^2 > threshold |
|
|
4. Fine-tune the model on the best expressions |
|
|
5. Repeat |
|
|
|
|
|
This is a form of "Expert Iteration" or "Self-Play" adapted for symbolic regression. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import logging |
|
|
import datetime |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorForLanguageModeling, |
|
|
) |
|
|
from datasets import Dataset |
|
|
from peft import PeftModel, LoraConfig, get_peft_model |
|
|
|
|
|
from expression import Expression |
|
|
from dataset import RegressionDataset |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class IterativeSamplingSFT: |
|
|
"""Iterative Sampling with Supervised Fine-Tuning.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
output_dir: str = "./output/iterative_sft", |
|
|
device: str = None, |
|
|
): |
|
|
self.X = X |
|
|
self.y = y |
|
|
self.n_vars = X.shape[1] |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if device: |
|
|
self.device = torch.device(device) |
|
|
else: |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self._load_model(model_path) |
|
|
|
|
|
|
|
|
self.prompt = self._build_prompt() |
|
|
|
|
|
|
|
|
self.best_r2 = -np.inf |
|
|
self.best_expression = None |
|
|
self.history = [] |
|
|
|
|
|
def _load_model(self, model_path: str): |
|
|
"""Load model and tokenizer.""" |
|
|
logger.info(f"Loading model from {model_path}") |
|
|
|
|
|
if Path(model_path).exists(): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
if len(self.tokenizer) != base_model.config.vocab_size: |
|
|
base_model.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
|
|
try: |
|
|
model_with_lora = PeftModel.from_pretrained(base_model, model_path) |
|
|
self.model = model_with_lora.merge_and_unload() |
|
|
logger.info("LoRA adapter loaded and merged") |
|
|
except Exception: |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
else: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
logger.info("Model loaded") |
|
|
|
|
|
def _build_prompt(self) -> str: |
|
|
"""Build JSON format prompt.""" |
|
|
vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
|
|
ops_list = ["+", "-", "*", "sin", "cos"] |
|
|
|
|
|
prompt = json.dumps({ |
|
|
"vars": vars_list, |
|
|
"ops": ops_list, |
|
|
"cons": None, |
|
|
"expr": "" |
|
|
})[:-3] |
|
|
|
|
|
return prompt |
|
|
|
|
|
def extract_expression(self, text: str) -> str: |
|
|
"""Extract expression from generated text.""" |
|
|
try: |
|
|
if '"expr": "' in text: |
|
|
start = text.index('"expr": "') + len('"expr": "') |
|
|
remaining = text[start:] |
|
|
if '"}' in remaining: |
|
|
return remaining[:remaining.index('"}')].strip() |
|
|
if '"' in remaining: |
|
|
return remaining[:remaining.index('"')].strip() |
|
|
|
|
|
if '"expr": ' in text: |
|
|
start = text.index('"expr": ') + len('"expr": ') |
|
|
remaining = text[start:] |
|
|
if '"}' in remaining: |
|
|
return remaining[:remaining.index('"}')].strip() |
|
|
|
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
return text.split('"expr"')[-1].strip(' ":}') |
|
|
|
|
|
def compute_r2(self, expression_str: str) -> float: |
|
|
"""Compute R^2 score.""" |
|
|
if not expression_str or expression_str.isspace(): |
|
|
return -np.inf |
|
|
|
|
|
if 'C' in expression_str: |
|
|
expression_str = expression_str.replace('C', '1') |
|
|
|
|
|
try: |
|
|
expr = Expression(expression_str, is_prefix=False) |
|
|
if not expr.is_valid_on_dataset(self.X): |
|
|
return -np.inf |
|
|
|
|
|
y_pred = expr.evaluate(self.X) |
|
|
if not np.all(np.isfinite(y_pred)): |
|
|
return -np.inf |
|
|
|
|
|
ss_res = np.sum((self.y - y_pred) ** 2) |
|
|
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
|
|
|
if ss_tot == 0: |
|
|
return 0.0 |
|
|
|
|
|
return 1 - (ss_res / ss_tot) |
|
|
except Exception: |
|
|
return -np.inf |
|
|
|
|
|
def sample_expressions(self, n_samples: int, temperature: float = 0.7) -> List[Tuple[str, str, float]]: |
|
|
"""Generate N expressions and evaluate them.""" |
|
|
self.model.eval() |
|
|
|
|
|
inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device) |
|
|
results = [] |
|
|
|
|
|
for _ in tqdm(range(n_samples), desc="Sampling"): |
|
|
with torch.no_grad(): |
|
|
output = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=50, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.9, |
|
|
temperature=temperature, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
text = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
expr_str = self.extract_expression(text) |
|
|
r2 = self.compute_r2(expr_str) |
|
|
|
|
|
if np.isfinite(r2): |
|
|
results.append((text, expr_str, r2)) |
|
|
|
|
|
if r2 > self.best_r2: |
|
|
self.best_r2 = r2 |
|
|
self.best_expression = expr_str |
|
|
|
|
|
return results |
|
|
|
|
|
def filter_best(self, results: List[Tuple[str, str, float]], threshold: float = 0.5) -> List[str]: |
|
|
"""Filter expressions with R^2 above threshold.""" |
|
|
best = [(text, expr, r2) for text, expr, r2 in results if r2 > threshold] |
|
|
best.sort(key=lambda x: x[2], reverse=True) |
|
|
|
|
|
|
|
|
return [text for text, expr, r2 in best] |
|
|
|
|
|
def fine_tune(self, good_texts: List[str], epochs: int = 1): |
|
|
"""Fine-tune on good expressions.""" |
|
|
if not good_texts: |
|
|
logger.warning("No good expressions to fine-tune on") |
|
|
return |
|
|
|
|
|
logger.info(f"Fine-tuning on {len(good_texts)} good expressions") |
|
|
|
|
|
|
|
|
dataset = Dataset.from_dict({"text": good_texts}) |
|
|
|
|
|
def tokenize(examples): |
|
|
return self.tokenizer( |
|
|
examples["text"], |
|
|
truncation=True, |
|
|
max_length=128, |
|
|
padding="max_length", |
|
|
) |
|
|
|
|
|
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"]) |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=8, |
|
|
lora_alpha=32, |
|
|
target_modules=["c_attn"], |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
) |
|
|
|
|
|
self.model = get_peft_model(self.model, lora_config) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=str(self.output_dir / "checkpoints"), |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=min(4, len(good_texts)), |
|
|
learning_rate=5e-5, |
|
|
logging_steps=10, |
|
|
save_strategy="no", |
|
|
report_to=[], |
|
|
use_cpu=self.device.type == "cpu", |
|
|
) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=self.tokenizer, |
|
|
mlm=False, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=self.model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized, |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
self.model = self.model.merge_and_unload() |
|
|
logger.info("Fine-tuning complete") |
|
|
|
|
|
def run( |
|
|
self, |
|
|
n_iterations: int = 5, |
|
|
samples_per_iteration: int = 100, |
|
|
r2_threshold: float = 0.5, |
|
|
target_r2: float = 0.99, |
|
|
): |
|
|
"""Run iterative sampling + SFT.""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("ITERATIVE SAMPLING + SFT") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Iterations: {n_iterations}") |
|
|
logger.info(f"Samples per iteration: {samples_per_iteration}") |
|
|
logger.info(f"R^2 threshold: {r2_threshold}") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
for iteration in range(n_iterations): |
|
|
logger.info(f"\n{'='*60}") |
|
|
logger.info(f"ITERATION {iteration + 1}/{n_iterations}") |
|
|
logger.info(f"{'='*60}") |
|
|
|
|
|
|
|
|
results = self.sample_expressions(samples_per_iteration) |
|
|
|
|
|
|
|
|
if results: |
|
|
r2_scores = [r2 for _, _, r2 in results] |
|
|
logger.info(f"Valid expressions: {len(results)}/{samples_per_iteration}") |
|
|
logger.info(f"Mean R^2: {np.mean(r2_scores):.4f}") |
|
|
logger.info(f"Max R^2: {np.max(r2_scores):.4f}") |
|
|
logger.info(f"Best overall: {self.best_r2:.4f} - {self.best_expression}") |
|
|
|
|
|
self.history.append({ |
|
|
"iteration": iteration + 1, |
|
|
"valid_count": len(results), |
|
|
"mean_r2": float(np.mean(r2_scores)), |
|
|
"max_r2": float(np.max(r2_scores)), |
|
|
"best_overall_r2": self.best_r2, |
|
|
}) |
|
|
|
|
|
|
|
|
if self.best_r2 >= target_r2: |
|
|
logger.info(f"Target R^2 {target_r2} reached!") |
|
|
break |
|
|
|
|
|
|
|
|
good_texts = self.filter_best(results, threshold=r2_threshold) |
|
|
if good_texts: |
|
|
logger.info(f"Fine-tuning on {len(good_texts)} expressions with R^2 > {r2_threshold}") |
|
|
self.fine_tune(good_texts, epochs=1) |
|
|
|
|
|
|
|
|
r2_threshold = min(r2_threshold + 0.1, 0.9) |
|
|
else: |
|
|
logger.warning("No valid expressions generated") |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("FINAL RESULTS") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Best R^2: {self.best_r2:.4f}") |
|
|
logger.info(f"Best expression: {self.best_expression}") |
|
|
|
|
|
return { |
|
|
"best_r2": self.best_r2, |
|
|
"best_expression": self.best_expression, |
|
|
"history": self.history, |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Iterative Sampling + SFT") |
|
|
parser.add_argument("--model_path", type=str, default="gpt2") |
|
|
parser.add_argument("--dataset", type=str, default="./data/ppo_test/sin_x1.csv") |
|
|
parser.add_argument("--output_dir", type=str, default="./output/iterative_sft") |
|
|
parser.add_argument("--iterations", type=int, default=5) |
|
|
parser.add_argument("--samples", type=int, default=100) |
|
|
parser.add_argument("--threshold", type=float, default=0.5) |
|
|
parser.add_argument("--cpu", action="store_true") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
dataset_path = Path(args.dataset) |
|
|
if not dataset_path.exists(): |
|
|
logger.error(f"Dataset not found: {dataset_path}") |
|
|
return |
|
|
|
|
|
reg = RegressionDataset(str(dataset_path.parent), dataset_path.name) |
|
|
X, y = reg.get_numpy() |
|
|
|
|
|
|
|
|
experiment = IterativeSamplingSFT( |
|
|
model_path=args.model_path, |
|
|
X=X, |
|
|
y=y, |
|
|
output_dir=args.output_dir, |
|
|
device="cpu" if args.cpu else None, |
|
|
) |
|
|
|
|
|
results = experiment.run( |
|
|
n_iterations=args.iterations, |
|
|
samples_per_iteration=args.samples, |
|
|
r2_threshold=args.threshold, |
|
|
) |
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
results_file = Path(args.output_dir) / f"results_{timestamp}.json" |
|
|
with open(results_file, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
logger.info(f"Results saved to: {results_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|