|
|
|
|
|
""" |
|
|
Improved GRPO (Group Relative Policy Optimization) for Symbolic Regression |
|
|
|
|
|
Improvements over basic GRPO: |
|
|
1. Filter invalid expressions before computing group statistics |
|
|
2. Reward shaping with softer penalties |
|
|
3. Hybrid baseline: group stats + exponential moving average |
|
|
4. Entropy bonus for exploration |
|
|
5. Advantage clipping to prevent extreme updates |
|
|
6. Minimum valid ratio check before updates |
|
|
7. Temperature annealing for better exploration/exploitation |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import logging |
|
|
import datetime |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple |
|
|
from collections import deque |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel, LoraConfig, get_peft_model |
|
|
|
|
|
from expression import Expression |
|
|
from dataset import RegressionDataset |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ImprovedGRPO: |
|
|
"""Improved GRPO for symbolic regression.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
output_dir: str = "./output/grpo", |
|
|
learning_rate: float = 5e-5, |
|
|
device: str = None, |
|
|
group_size: int = 16, |
|
|
entropy_coef: float = 0.01, |
|
|
advantage_clip: float = 2.0, |
|
|
min_valid_ratio: float = 0.2, |
|
|
): |
|
|
self.X = X |
|
|
self.y = y |
|
|
self.n_vars = X.shape[1] |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
self.learning_rate = learning_rate |
|
|
self.group_size = group_size |
|
|
self.entropy_coef = entropy_coef |
|
|
self.advantage_clip = advantage_clip |
|
|
self.min_valid_ratio = min_valid_ratio |
|
|
|
|
|
|
|
|
if device: |
|
|
self.device = torch.device(device) |
|
|
else: |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self._load_model(model_path) |
|
|
|
|
|
|
|
|
self.prompt = self._build_prompt() |
|
|
self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device) |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=0.01 |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
|
|
self.optimizer, T_0=10, T_mult=2 |
|
|
) |
|
|
|
|
|
|
|
|
self.best_r2 = -np.inf |
|
|
self.best_expression = None |
|
|
self.history = [] |
|
|
self.discovered_expressions: Dict[str, float] = {} |
|
|
|
|
|
|
|
|
self.ema_baseline = 0.0 |
|
|
self.ema_decay = 0.9 |
|
|
self.reward_buffer = deque(maxlen=100) |
|
|
|
|
|
|
|
|
self.initial_temp = 0.8 |
|
|
self.min_temp = 0.5 |
|
|
self.current_temp = self.initial_temp |
|
|
|
|
|
def _load_model(self, model_path: str): |
|
|
"""Load model and tokenizer.""" |
|
|
logger.info(f"Loading model from {model_path}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
try: |
|
|
logger.info("Attempting to load as LoRA adapter...") |
|
|
base_model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
if len(self.tokenizer) != base_model.config.vocab_size: |
|
|
base_model.resize_token_embeddings(len(self.tokenizer)) |
|
|
logger.info(f"Resized embeddings to {len(self.tokenizer)}") |
|
|
|
|
|
model_with_lora = PeftModel.from_pretrained(base_model, model_path) |
|
|
self.model = model_with_lora.merge_and_unload() |
|
|
logger.info("LoRA adapter loaded and merged successfully") |
|
|
except Exception as e: |
|
|
logger.info(f"LoRA load failed ({e}), loading as standalone model...") |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=8, |
|
|
lora_alpha=16, |
|
|
target_modules=["c_attn"], |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
) |
|
|
self.model = get_peft_model(self.model, lora_config) |
|
|
self.model = self.model.to(self.device) |
|
|
self.model.train() |
|
|
|
|
|
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
|
|
logger.info(f"Model loaded with {trainable} trainable params") |
|
|
|
|
|
def _build_prompt(self, ops: list = None) -> str: |
|
|
"""Build JSON format prompt.""" |
|
|
vars_list = [f"x_{i+1}" for i in range(self.n_vars)] |
|
|
|
|
|
if ops is None: |
|
|
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"] |
|
|
else: |
|
|
ops_list = ops |
|
|
|
|
|
prompt = json.dumps({ |
|
|
"vars": vars_list, |
|
|
"ops": ops_list, |
|
|
"cons": "C", |
|
|
"expr": "" |
|
|
}) |
|
|
prompt = prompt[:-2] |
|
|
return prompt |
|
|
|
|
|
def extract_expression(self, text: str) -> str: |
|
|
"""Extract expression from generated text.""" |
|
|
try: |
|
|
eos_token = "<|endoftext|>" |
|
|
if eos_token in text: |
|
|
text = text[:text.index(eos_token)] |
|
|
|
|
|
if '"expr": "' in text: |
|
|
start = text.index('"expr": "') + len('"expr": "') |
|
|
remaining = text[start:] |
|
|
for terminator in ['"}', '"']: |
|
|
if terminator in remaining: |
|
|
return remaining[:remaining.index(terminator)].strip() |
|
|
return remaining.strip() |
|
|
|
|
|
if '"expr": ' in text: |
|
|
start = text.index('"expr": ') + len('"expr": ') |
|
|
remaining = text[start:] |
|
|
if '"}' in remaining: |
|
|
return remaining[:remaining.index('"}')].strip() |
|
|
return remaining.strip(' "') |
|
|
|
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
if '"expr"' in text: |
|
|
return text.split('"expr"')[-1].strip(' ":{}') |
|
|
return text.strip() |
|
|
|
|
|
def compute_r2(self, expression_str: str) -> Tuple[float, bool]: |
|
|
"""Compute R^2 score.""" |
|
|
if not expression_str or expression_str.isspace(): |
|
|
return -1.0, False |
|
|
|
|
|
if 'C' in expression_str: |
|
|
expression_str = expression_str.replace('C', '1') |
|
|
|
|
|
try: |
|
|
expr = Expression(expression_str, is_prefix=False) |
|
|
if not expr.is_valid_on_dataset(self.X): |
|
|
return -1.0, False |
|
|
|
|
|
y_pred = expr.evaluate(self.X) |
|
|
if not np.all(np.isfinite(y_pred)): |
|
|
return -1.0, False |
|
|
|
|
|
ss_res = np.sum((self.y - y_pred) ** 2) |
|
|
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2) |
|
|
|
|
|
if ss_tot == 0: |
|
|
return 0.0, True |
|
|
|
|
|
r2 = 1 - (ss_res / ss_tot) |
|
|
return float(np.clip(r2, -1.0, 1.0)), True |
|
|
except Exception: |
|
|
return -1.0, False |
|
|
|
|
|
def shape_reward(self, r2: float, is_valid: bool) -> float: |
|
|
"""Shape reward for better learning signal.""" |
|
|
if not is_valid: |
|
|
return -0.1 |
|
|
|
|
|
|
|
|
if r2 >= 0.99: |
|
|
return 2.0 |
|
|
elif r2 >= 0.9: |
|
|
return r2 * 1.5 |
|
|
elif r2 >= 0.5: |
|
|
return r2 * 1.2 |
|
|
elif r2 >= 0: |
|
|
return r2 |
|
|
else: |
|
|
return r2 * 0.5 |
|
|
|
|
|
def generate_group(self, max_new_tokens: int = 50) -> List[Dict]: |
|
|
"""Generate a group of expressions.""" |
|
|
results = [] |
|
|
|
|
|
for _ in range(self.group_size): |
|
|
generated_ids = self.prompt_ids.clone() |
|
|
generated_tokens = [] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_new_tokens): |
|
|
outputs = self.model(generated_ids) |
|
|
logits = outputs.logits[:, -1, :] / self.current_temp |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
generated_tokens.append(next_token.item()) |
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
|
|
|
if next_token.item() == self.tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
if '"}' in text[len(self.prompt):]: |
|
|
break |
|
|
|
|
|
|
|
|
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
expr_str = self.extract_expression(text) |
|
|
r2, is_valid = self.compute_r2(expr_str) |
|
|
reward = self.shape_reward(r2, is_valid) |
|
|
|
|
|
|
|
|
if len(generated_tokens) > 0: |
|
|
full_ids = torch.cat([ |
|
|
self.prompt_ids, |
|
|
torch.tensor([generated_tokens], device=self.device) |
|
|
], dim=1) |
|
|
|
|
|
outputs = self.model(full_ids[:, :-1]) |
|
|
logits = outputs.logits / self.current_temp |
|
|
|
|
|
prompt_len = self.prompt_ids.shape[1] |
|
|
gen_logits = logits[:, prompt_len-1:, :] |
|
|
|
|
|
log_probs_all = F.log_softmax(gen_logits, dim=-1) |
|
|
probs_all = F.softmax(gen_logits, dim=-1) |
|
|
|
|
|
target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0) |
|
|
selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
total_log_prob = selected_log_probs.sum() |
|
|
|
|
|
|
|
|
entropy_per_pos = -(probs_all * log_probs_all).sum(dim=-1) |
|
|
total_entropy = entropy_per_pos.mean() |
|
|
else: |
|
|
total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True) |
|
|
total_entropy = torch.tensor(0.0, device=self.device) |
|
|
|
|
|
results.append({ |
|
|
"text": text, |
|
|
"expression": expr_str, |
|
|
"r2": r2, |
|
|
"is_valid": is_valid, |
|
|
"reward": reward, |
|
|
"log_prob": total_log_prob, |
|
|
"entropy": total_entropy, |
|
|
}) |
|
|
|
|
|
|
|
|
if is_valid: |
|
|
self.discovered_expressions[expr_str] = max( |
|
|
self.discovered_expressions.get(expr_str, -np.inf), r2 |
|
|
) |
|
|
self.reward_buffer.append(reward) |
|
|
|
|
|
if r2 > self.best_r2: |
|
|
self.best_r2 = r2 |
|
|
self.best_expression = expr_str |
|
|
|
|
|
if self.device.type == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return results |
|
|
|
|
|
def compute_advantages(self, results: List[Dict]) -> Tuple[List[float], dict]: |
|
|
""" |
|
|
Compute improved GRPO advantages. |
|
|
|
|
|
Key improvement: Only use VALID expressions for group statistics. |
|
|
Invalid expressions get a fixed small negative advantage. |
|
|
""" |
|
|
valid_results = [r for r in results if r["is_valid"]] |
|
|
valid_rewards = [r["reward"] for r in valid_results] |
|
|
|
|
|
stats = { |
|
|
"valid_count": len(valid_results), |
|
|
"total_count": len(results), |
|
|
"valid_ratio": len(valid_results) / len(results), |
|
|
} |
|
|
|
|
|
|
|
|
if len(valid_rewards) < 2: |
|
|
advantages = [] |
|
|
for r in results: |
|
|
if r["is_valid"]: |
|
|
adv = r["reward"] - self.ema_baseline |
|
|
else: |
|
|
adv = -0.5 |
|
|
advantages.append(adv) |
|
|
stats["method"] = "ema_only" |
|
|
return advantages, stats |
|
|
|
|
|
|
|
|
group_mean = np.mean(valid_rewards) |
|
|
group_std = np.std(valid_rewards) |
|
|
|
|
|
|
|
|
self.ema_baseline = self.ema_decay * self.ema_baseline + (1 - self.ema_decay) * group_mean |
|
|
|
|
|
|
|
|
hybrid_baseline = 0.7 * group_mean + 0.3 * self.ema_baseline |
|
|
|
|
|
|
|
|
if group_std < 1e-8: |
|
|
group_std = 1.0 |
|
|
|
|
|
|
|
|
advantages = [] |
|
|
for r in results: |
|
|
if r["is_valid"]: |
|
|
|
|
|
adv = (r["reward"] - hybrid_baseline) / group_std |
|
|
|
|
|
adv = np.clip(adv, -self.advantage_clip, self.advantage_clip) |
|
|
else: |
|
|
|
|
|
adv = -0.3 |
|
|
advantages.append(adv) |
|
|
|
|
|
stats["method"] = "hybrid" |
|
|
stats["group_mean"] = group_mean |
|
|
stats["group_std"] = group_std |
|
|
stats["ema_baseline"] = self.ema_baseline |
|
|
|
|
|
return advantages, stats |
|
|
|
|
|
def train_step(self, num_groups: int = 2) -> dict: |
|
|
"""Perform one training step.""" |
|
|
self.model.train() |
|
|
|
|
|
all_results = [] |
|
|
all_advantages = [] |
|
|
total_policy_loss = 0.0 |
|
|
total_entropy_loss = 0.0 |
|
|
skipped_groups = 0 |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
for _ in range(num_groups): |
|
|
if self.device.type == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
group_results = self.generate_group() |
|
|
all_results.extend(group_results) |
|
|
|
|
|
|
|
|
advantages, adv_stats = self.compute_advantages(group_results) |
|
|
all_advantages.extend(advantages) |
|
|
|
|
|
|
|
|
if adv_stats["valid_ratio"] < self.min_valid_ratio: |
|
|
skipped_groups += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
policy_loss = torch.tensor(0.0, device=self.device) |
|
|
entropy_loss = torch.tensor(0.0, device=self.device) |
|
|
valid_count = 0 |
|
|
|
|
|
for result, advantage in zip(group_results, advantages): |
|
|
if result["is_valid"] and advantage != 0: |
|
|
policy_loss = policy_loss - result["log_prob"] * advantage |
|
|
entropy_loss = entropy_loss - result["entropy"] |
|
|
valid_count += 1 |
|
|
|
|
|
if valid_count > 0: |
|
|
policy_loss = policy_loss / valid_count |
|
|
entropy_loss = entropy_loss / valid_count |
|
|
|
|
|
|
|
|
loss = policy_loss + self.entropy_coef * entropy_loss |
|
|
loss = loss / num_groups |
|
|
loss.backward() |
|
|
|
|
|
total_policy_loss += policy_loss.item() |
|
|
total_entropy_loss += entropy_loss.item() |
|
|
|
|
|
|
|
|
if skipped_groups < num_groups: |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
self.optimizer.step() |
|
|
self.scheduler.step() |
|
|
|
|
|
|
|
|
r2_values = [r["r2"] for r in all_results] |
|
|
valid_mask = [r["is_valid"] for r in all_results] |
|
|
valid_r2 = [r2 for r2, v in zip(r2_values, valid_mask) if v] |
|
|
|
|
|
return { |
|
|
"valid_count": int(sum(valid_mask)), |
|
|
"total_count": len(all_results), |
|
|
"valid_rate": sum(valid_mask) / len(all_results) if all_results else 0, |
|
|
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else 0.0, |
|
|
"max_r2": float(max(r2_values)) if r2_values else 0.0, |
|
|
"mean_advantage": float(np.mean(all_advantages)) if all_advantages else 0.0, |
|
|
"ema_baseline": self.ema_baseline, |
|
|
"policy_loss": total_policy_loss / max(num_groups - skipped_groups, 1), |
|
|
"entropy_loss": total_entropy_loss / max(num_groups - skipped_groups, 1), |
|
|
"lr": self.scheduler.get_last_lr()[0], |
|
|
"temperature": self.current_temp, |
|
|
"skipped_groups": skipped_groups, |
|
|
} |
|
|
|
|
|
def anneal_temperature(self, epoch: int, total_epochs: int): |
|
|
"""Anneal temperature from initial to minimum.""" |
|
|
progress = epoch / total_epochs |
|
|
self.current_temp = self.initial_temp - progress * (self.initial_temp - self.min_temp) |
|
|
|
|
|
def run( |
|
|
self, |
|
|
epochs: int = 50, |
|
|
num_groups: int = 2, |
|
|
target_r2: float = 0.99, |
|
|
patience: int = 20, |
|
|
) -> dict: |
|
|
"""Run improved GRPO training.""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("IMPROVED GRPO SYMBOLIC REGRESSION") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Epochs: {epochs}") |
|
|
logger.info(f"Group size: {self.group_size}") |
|
|
logger.info(f"Num groups: {num_groups}") |
|
|
logger.info(f"Effective batch: {self.group_size * num_groups}") |
|
|
logger.info(f"Entropy coef: {self.entropy_coef}") |
|
|
logger.info(f"Advantage clip: {self.advantage_clip}") |
|
|
logger.info(f"Min valid ratio: {self.min_valid_ratio}") |
|
|
logger.info(f"Target R^2: {target_r2}") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
no_improvement_count = 0 |
|
|
best_r2_at_start = self.best_r2 |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
|
|
|
|
self.anneal_temperature(epoch, epochs) |
|
|
|
|
|
stats = self.train_step(num_groups) |
|
|
self.history.append({ |
|
|
"epoch": epoch, |
|
|
**stats, |
|
|
"best_r2": self.best_r2, |
|
|
}) |
|
|
|
|
|
logger.info( |
|
|
f"Epoch {epoch:3d} | " |
|
|
f"Valid: {stats['valid_count']}/{stats['total_count']} | " |
|
|
f"Mean R²: {stats['mean_r2']:.4f} | " |
|
|
f"Best: {self.best_r2:.4f} | " |
|
|
f"EMA: {stats['ema_baseline']:.3f} | " |
|
|
f"Temp: {stats['temperature']:.2f} | " |
|
|
f"LR: {stats['lr']:.2e}" |
|
|
) |
|
|
|
|
|
|
|
|
if self.best_r2 >= target_r2: |
|
|
logger.info(f"Target R^2 {target_r2} reached at epoch {epoch}!") |
|
|
break |
|
|
|
|
|
|
|
|
if self.best_r2 > best_r2_at_start: |
|
|
best_r2_at_start = self.best_r2 |
|
|
no_improvement_count = 0 |
|
|
else: |
|
|
no_improvement_count += 1 |
|
|
|
|
|
if no_improvement_count >= patience: |
|
|
logger.info(f"No improvement for {patience} epochs. Early stopping.") |
|
|
break |
|
|
|
|
|
|
|
|
logger.info("") |
|
|
logger.info("=" * 60) |
|
|
logger.info("FINAL RESULTS") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Best R^2: {self.best_r2:.4f}") |
|
|
logger.info(f"Best expression: {self.best_expression}") |
|
|
logger.info(f"Unique expressions discovered: {len(self.discovered_expressions)}") |
|
|
|
|
|
top_exprs = sorted( |
|
|
self.discovered_expressions.items(), |
|
|
key=lambda x: x[1], |
|
|
reverse=True |
|
|
)[:5] |
|
|
logger.info("Top 5 expressions:") |
|
|
for expr, r2 in top_exprs: |
|
|
logger.info(f" R²={r2:.4f}: {expr}") |
|
|
|
|
|
|
|
|
results = { |
|
|
"algorithm": "ImprovedGRPO", |
|
|
"best_r2": self.best_r2, |
|
|
"best_expression": self.best_expression, |
|
|
"history": self.history, |
|
|
"discovered_expressions": dict(list(self.discovered_expressions.items())[:100]), |
|
|
"config": { |
|
|
"group_size": self.group_size, |
|
|
"num_groups": num_groups, |
|
|
"learning_rate": self.learning_rate, |
|
|
"entropy_coef": self.entropy_coef, |
|
|
"advantage_clip": self.advantage_clip, |
|
|
"min_valid_ratio": self.min_valid_ratio, |
|
|
} |
|
|
} |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
output_path = self.output_dir / f"results_grpo_improved_{timestamp}.json" |
|
|
with open(output_path, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
logger.info(f"Results saved to: {output_path}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Improved GRPO for Symbolic Regression") |
|
|
parser.add_argument("--model_path", type=str, required=True) |
|
|
parser.add_argument("--dataset", type=str, required=True) |
|
|
parser.add_argument("--output_dir", type=str, default="./output/grpo") |
|
|
parser.add_argument("--epochs", type=int, default=50) |
|
|
parser.add_argument("--group_size", type=int, default=16) |
|
|
parser.add_argument("--num_groups", type=int, default=2) |
|
|
parser.add_argument("--learning_rate", type=float, default=5e-5) |
|
|
parser.add_argument("--target_r2", type=float, default=0.99) |
|
|
parser.add_argument("--entropy_coef", type=float, default=0.01) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
df = pd.read_csv(args.dataset) |
|
|
|
|
|
x_cols = [c for c in df.columns if c.startswith('x_')] |
|
|
X = df[x_cols].values |
|
|
y = df['y'].values |
|
|
|
|
|
logger.info(f"Loaded dataset: {args.dataset}") |
|
|
logger.info(f" Samples: {len(df)}, Variables: {len(x_cols)}") |
|
|
|
|
|
|
|
|
grpo = ImprovedGRPO( |
|
|
model_path=args.model_path, |
|
|
X=X, |
|
|
y=y, |
|
|
output_dir=args.output_dir, |
|
|
learning_rate=args.learning_rate, |
|
|
group_size=args.group_size, |
|
|
entropy_coef=args.entropy_coef, |
|
|
) |
|
|
|
|
|
|
|
|
results = grpo.run( |
|
|
epochs=args.epochs, |
|
|
num_groups=args.num_groups, |
|
|
target_r2=args.target_r2, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|