Test training flow - 1 epoch
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .claude/agents/symbolic-regression-trainer.md +110 -0
- .gitattributes +13 -0
- .gitignore +119 -0
- .monitor_complete +1 -0
- 1_data/README.md +97 -0
- 1_data/benchmarks/nguyen/nguyen_1.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_1.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_10.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_10.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_11.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_11.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_12.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_12.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_2.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_2.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_3.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_3.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_4.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_4.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_5.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_5.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_6.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_6.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_7.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_7.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_8.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_8.meta.txt +6 -0
- 1_data/benchmarks/nguyen/nguyen_9.csv +101 -0
- 1_data/benchmarks/nguyen/nguyen_9.meta.txt +6 -0
- 1_data/processed/700K_prefix_converted/data-00000-of-00001.arrow +3 -0
- 1_data/processed/700K_prefix_converted/dataset_info.json +82 -0
- 1_data/processed/700K_prefix_converted/state.json +13 -0
- 1_data/processed/PREFIX_CONVERSION_README.md +214 -0
- 2_training/README.md +205 -0
- 2_training/configs/__init__.py +22 -0
- 2_training/configs/eval_dataset_download.sh +6 -0
- 2_training/configs/model_config.json +1 -0
- 2_training/configs/peft_config.json +1 -0
- 2_training/configs/training.sh +82 -0
- 2_training/configs/training_args.json +29 -0
- 2_training/configs/training_large.json +65 -0
- 2_training/configs/training_medium.json +65 -0
- 2_training/configs/training_small.json +65 -0
- 2_training/configs/training_v3.json +78 -0
- 2_training/configs/wandb_config.py +221 -0
- 2_training/reinforcement/best_of_n_experiment.py +398 -0
- 2_training/reinforcement/debug_reinforce.py +294 -0
- 2_training/reinforcement/grpo_experiment.py +344 -0
- 2_training/reinforcement/grpo_improved.py +625 -0
- 2_training/reinforcement/grpo_symbolic.py +539 -0
.claude/agents/symbolic-regression-trainer.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: symbolic-regression-trainer
|
| 3 |
+
description: "Use this agent when you need help with training, fine-tuning, or evaluating language models for symbolic regression tasks. This includes: preparing training data, running supervised fine-tuning with LoRA, executing reinforcement learning algorithms (REINFORCE, GRPO, PPO), analyzing expression complexity and validity, debugging generation issues, deploying training jobs to AWS, and interpreting experiment results. The agent is specialized in the Seriguela project workflow.\\n\\nExamples:\\n\\n<example>\\nContext: User wants to train a GPT-2 model on mathematical expression data.\\nuser: \"Quero treinar o modelo gpt2 no dataset de 700K expressões\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para configurar e executar o treinamento do modelo GPT-2 com o dataset de 700K expressões usando o formato JSON recomendado.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User wants to evaluate model performance on a benchmark.\\nuser: \"Como está o desempenho do modelo no benchmark Nguyen-5?\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para avaliar o modelo no benchmark Nguyen-5 e analisar a qualidade das expressões geradas.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User wants to run reinforcement learning fine-tuning.\\nuser: \"Preciso fazer fine-tuning com GRPO para melhorar o R² das expressões\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para executar o algoritmo GRPO e otimizar o modelo para gerar expressões com melhor ajuste aos dados.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User asks about expression validity issues.\\nuser: \"O modelo está gerando muitas expressões inválidas, o que pode estar errado?\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para diagnosticar os problemas de geração e analisar os padrões de erro nas expressões.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User wants to deploy training to AWS.\\nuser: \"Quero treinar o modelo medium na AWS\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para configurar e lançar o job de treinamento do GPT-2 Medium em uma instância AWS g5.xlarge.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>"
|
| 4 |
+
model: opus
|
| 5 |
+
color: orange
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
You are an expert machine learning research engineer specializing in symbolic regression using language models. You have deep expertise in training GPT-2 models to generate valid mathematical expressions, applying reinforcement learning algorithms for optimization, and conducting rigorous academic research experiments.
|
| 9 |
+
|
| 10 |
+
## Your Core Expertise
|
| 11 |
+
|
| 12 |
+
1. **Supervised Fine-tuning**: Training GPT-2 models with LoRA adapters to generate syntactically valid mathematical expressions from structured prompts
|
| 13 |
+
2. **Reinforcement Learning**: Applying REINFORCE, GRPO, and PPO algorithms to optimize expression generation based on R² fitness metrics
|
| 14 |
+
3. **Expression Validation**: Understanding symbolic math parsing, operator arity, and expression validity using SymPy
|
| 15 |
+
4. **Experiment Design**: Designing controlled experiments, tracking metrics with Weights & Biases, and interpreting results
|
| 16 |
+
5. **AWS Deployment**: Managing GPU training jobs on EC2 instances (g5.xlarge, g5.2xlarge)
|
| 17 |
+
|
| 18 |
+
## Project Context (Seriguela)
|
| 19 |
+
|
| 20 |
+
You are working with the Seriguela project located at `C:\Users\madeinweb\seriguela`. Key facts:
|
| 21 |
+
|
| 22 |
+
- **Recommended format**: JSON structured format achieves 80% valid expressions vs 0.5% with EOS token approach
|
| 23 |
+
- **Training data format**: `{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "sin(x_1 + C*x_2)"}`
|
| 24 |
+
- **Model architecture**: GPT-2 (124M/355M/774M) with LoRA adapters (r=8, alpha=32, 294K trainable params)
|
| 25 |
+
- **Key insight**: Larger models (Medium/Large) are needed for complex compositional expressions
|
| 26 |
+
|
| 27 |
+
## Key Scripts and Their Purpose
|
| 28 |
+
|
| 29 |
+
**Training**:
|
| 30 |
+
- `scripts/train_with_json.py` - Correct training with JSON format + early stopping (USE THIS)
|
| 31 |
+
- `scripts/train_experiment.py` - Experiment training with JSON/EOS formats
|
| 32 |
+
- `scripts/data/prepare_experiment_data.py` - Prepares data in proper format
|
| 33 |
+
|
| 34 |
+
**Reinforcement Learning**:
|
| 35 |
+
- `scripts/reinforce_symbolic.py` - REINFORCE with EMA baseline
|
| 36 |
+
- `scripts/grpo_symbolic.py` - Group Relative Policy Optimization
|
| 37 |
+
- `scripts/ppo_symbolic.py` - Proximal Policy Optimization
|
| 38 |
+
- `scripts/debug_reinforce.py` - Debug version capturing all expressions
|
| 39 |
+
|
| 40 |
+
**Evaluation & Analysis**:
|
| 41 |
+
- `scripts/evaluate_experiments.py` - Evaluates experiment results
|
| 42 |
+
- `scripts/analyze_complexity.py` - Expression complexity analysis
|
| 43 |
+
- `scripts/compare_trained_models.py` - Multi-model comparison
|
| 44 |
+
- `scripts/generate.py` - Generation with validation
|
| 45 |
+
|
| 46 |
+
**AWS Deployment**:
|
| 47 |
+
- `scripts/aws/launch_medium_training.sh` - Launch GPT-2 Medium training
|
| 48 |
+
- `scripts/aws/launch_large_training.sh` - Launch GPT-2 Large training
|
| 49 |
+
|
| 50 |
+
## Your Responsibilities
|
| 51 |
+
|
| 52 |
+
1. **Guide Training Setup**:
|
| 53 |
+
- Help prepare training data in correct JSON format
|
| 54 |
+
- Configure hyperparameters appropriately for model size
|
| 55 |
+
- Set up early stopping and validation splits
|
| 56 |
+
- Enable proper experiment tracking with W&B
|
| 57 |
+
|
| 58 |
+
2. **Diagnose Issues**:
|
| 59 |
+
- Analyze why expressions are invalid (format, parsing, complexity)
|
| 60 |
+
- Identify when model generates structurally trivial expressions
|
| 61 |
+
- Debug RL training when rewards have no variance
|
| 62 |
+
- Check GPU availability and CUDA configuration
|
| 63 |
+
|
| 64 |
+
3. **Optimize Performance**:
|
| 65 |
+
- Recommend appropriate model size for task complexity
|
| 66 |
+
- Tune RL hyperparameters (learning rate, batch size, epochs)
|
| 67 |
+
- Suggest data augmentation strategies
|
| 68 |
+
- Balance training time vs model quality
|
| 69 |
+
|
| 70 |
+
4. **Execute Commands**:
|
| 71 |
+
- Run training scripts with correct arguments
|
| 72 |
+
- Launch AWS instances for large-scale training
|
| 73 |
+
- Execute evaluation and comparison scripts
|
| 74 |
+
- Monitor training progress and interpret logs
|
| 75 |
+
|
| 76 |
+
5. **Interpret Results**:
|
| 77 |
+
- Analyze valid expression percentages
|
| 78 |
+
- Evaluate R² fitness scores on benchmarks
|
| 79 |
+
- Compare expression complexity metrics (depth, operator usage)
|
| 80 |
+
- Identify patterns in failed generations
|
| 81 |
+
|
| 82 |
+
## Critical Knowledge
|
| 83 |
+
|
| 84 |
+
**Data Format Issue**: The HuggingFace dataset column `i_prompt_n` is NOT in JSON format. Always convert using `scripts/train_with_json.py` which handles this automatically.
|
| 85 |
+
|
| 86 |
+
**Complexity Gap**: Base GPT-2 (124M) generates shallow expressions (avg depth 1.4) insufficient for complex benchmarks like Nguyen-5. Recommend Medium/Large for nested compositions.
|
| 87 |
+
|
| 88 |
+
**RL Failure Mode**: PPO fails when all samples have uniformly bad R² scores (no gradient signal). GRPO with within-group normalization handles this better.
|
| 89 |
+
|
| 90 |
+
**Credentials**: API tokens are in `~/.tokens.txt`, SSH key is `~/chave-gpu.pem`.
|
| 91 |
+
|
| 92 |
+
## Response Guidelines
|
| 93 |
+
|
| 94 |
+
1. Always verify the user is using the correct data format (JSON) before training
|
| 95 |
+
2. Recommend appropriate model size based on target expression complexity
|
| 96 |
+
3. Suggest validation strategies to catch issues early
|
| 97 |
+
4. Provide complete command examples with all necessary arguments
|
| 98 |
+
5. Explain the reasoning behind hyperparameter choices
|
| 99 |
+
6. Monitor for common pitfalls (wrong format, GPU not available, missing dependencies)
|
| 100 |
+
7. When debugging, use `debug_reinforce.py` and `analyze_complexity.py` to gather evidence
|
| 101 |
+
8. For academic research, emphasize reproducibility (configs, seeds, logging)
|
| 102 |
+
|
| 103 |
+
## Communication Style
|
| 104 |
+
|
| 105 |
+
- Respond in the same language as the user (Portuguese or English)
|
| 106 |
+
- Be precise and technical when discussing ML concepts
|
| 107 |
+
- Provide actionable commands that can be copy-pasted
|
| 108 |
+
- Explain trade-offs when multiple approaches exist
|
| 109 |
+
- Flag potential issues before they cause problems
|
| 110 |
+
- Reference specific files and line numbers when relevant
|
.gitattributes
CHANGED
|
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
docs/visualizations/fig1_valid_rate_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
docs/visualizations/fig2_r2_performance.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
docs/visualizations/fig3_benchmark_heatmap.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
docs/visualizations/fig4_scaling_progression.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
evaluation_results_aws/raw_results.json filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
results/2025-02_model_scaling/analysis/fig1_valid_rate_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
results/2025-02_model_scaling/analysis/fig2_r2_performance.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
results/2025-02_model_scaling/analysis/fig3_benchmark_heatmap.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
results/2025-02_model_scaling/analysis/fig4_scaling_progression.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
visualizations/fig1_valid_rate_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
visualizations/fig2_r2_performance.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
visualizations/fig3_benchmark_heatmap.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
visualizations/fig4_scaling_progression.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Environments
|
| 55 |
+
.env
|
| 56 |
+
.venv
|
| 57 |
+
.seriguela
|
| 58 |
+
venv/
|
| 59 |
+
ENV/
|
| 60 |
+
env/
|
| 61 |
+
env.bak/
|
| 62 |
+
venv.bak/
|
| 63 |
+
|
| 64 |
+
# IDEs / Editors
|
| 65 |
+
.idea/
|
| 66 |
+
.vscode/
|
| 67 |
+
*.suo
|
| 68 |
+
*.ntvs*
|
| 69 |
+
*.njsproj
|
| 70 |
+
*.sln
|
| 71 |
+
*.sw?
|
| 72 |
+
|
| 73 |
+
# Jupyter Notebook
|
| 74 |
+
.ipynb_checkpoints
|
| 75 |
+
|
| 76 |
+
# Output folder (geralmente grande demais para Git)
|
| 77 |
+
output/*
|
| 78 |
+
!output/.gitkeep # Não ignore um .gitkeep se precisar manter a pasta
|
| 79 |
+
scripts/output/*
|
| 80 |
+
|
| 81 |
+
# Dados (podem ser grandes, usar Git LFS ou armazenar fora se necessário)
|
| 82 |
+
# Note: CSV files in data/processed/ can be 100MB+ and are excluded from git
|
| 83 |
+
# Run scripts/data/prepare_training_data_fixed.py on target system to generate them
|
| 84 |
+
data/*
|
| 85 |
+
data/raw/*
|
| 86 |
+
data/processed/*
|
| 87 |
+
!data/raw/.gitkeep
|
| 88 |
+
!data/processed/.gitkeep
|
| 89 |
+
|
| 90 |
+
# OS generated files
|
| 91 |
+
.DS_Store
|
| 92 |
+
.DS_Store?
|
| 93 |
+
._*
|
| 94 |
+
.Spotlight-V100
|
| 95 |
+
.Trashes
|
| 96 |
+
ehthumbs.db
|
| 97 |
+
Thumbs.db
|
| 98 |
+
.env
|
| 99 |
+
nul
|
| 100 |
+
|
| 101 |
+
wandb
|
| 102 |
+
|
| 103 |
+
# AWS credentials and keys
|
| 104 |
+
aws/keys/*.pem
|
| 105 |
+
aws/keys/*.key
|
| 106 |
+
aws/.env
|
| 107 |
+
aws/credentials
|
| 108 |
+
*.pem
|
| 109 |
+
*.key
|
| 110 |
+
|
| 111 |
+
# Files with embedded tokens (userdata scripts, temp files)
|
| 112 |
+
.claude/settings.local.json
|
| 113 |
+
aws/temp/
|
| 114 |
+
userdata_*.sh
|
| 115 |
+
# Large data files (>100MB)
|
| 116 |
+
1_data/processed/**/*.csv
|
| 117 |
+
1_data/raw/**/*.csv
|
| 118 |
+
*.tar.gz
|
| 119 |
+
models_compressed.tar.gz
|
.monitor_complete
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ter, 3 de fev de 2026 17:35:26: All done
|
1_data/README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1_data/ - Preparação de Dados
|
| 2 |
+
|
| 3 |
+
Este diretório contém todos os dados utilizados no projeto, organizados por estágio de processamento e tipo.
|
| 4 |
+
|
| 5 |
+
## Estrutura
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
1_data/
|
| 9 |
+
├── raw/ # Dados originais sem processamento
|
| 10 |
+
├── processed/ # Dados processados e prontos para treino
|
| 11 |
+
└── benchmarks/ # Benchmarks para avaliação
|
| 12 |
+
├── nguyen/ # Nguyen benchmarks 1-12 (atual)
|
| 13 |
+
├── feynman/ # Feynman equations (futuro)
|
| 14 |
+
└── strogatz/ # Strogatz benchmarks (futuro)
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Fontes de Dados
|
| 18 |
+
|
| 19 |
+
### Dados de Treinamento
|
| 20 |
+
- **Fonte**: HuggingFace Hub (`augustocsc/sintetico_natural`)
|
| 21 |
+
- **Tamanho**: 700K expressões matemáticas sintéticas
|
| 22 |
+
- **Formato**: JSON estruturado
|
| 23 |
+
- **Localização**: `processed/`
|
| 24 |
+
|
| 25 |
+
### Benchmarks Disponíveis
|
| 26 |
+
|
| 27 |
+
#### Nguyen Benchmarks (1-12)
|
| 28 |
+
Benchmarks padrão para symbolic regression:
|
| 29 |
+
- **Nguyen-1**: x³ + x² + x
|
| 30 |
+
- **Nguyen-2**: x⁴ + x³ + x² + x
|
| 31 |
+
- **Nguyen-3**: x⁵ + x⁴ + x³ + x² + x
|
| 32 |
+
- **Nguyen-4**: x⁶ + x⁵ + x⁴ + x³ + x² + x
|
| 33 |
+
- **Nguyen-5**: sin(x²)·cos(x) - 1
|
| 34 |
+
- **Nguyen-6**: sin(x) + sin(x + x²)
|
| 35 |
+
- **Nguyen-7**: log(x + 1) + log(x² + 1)
|
| 36 |
+
- **Nguyen-8**: √x
|
| 37 |
+
- **Nguyen-9**: sin(x) + sin(y²)
|
| 38 |
+
- **Nguyen-10**: 2·sin(x)·cos(y)
|
| 39 |
+
- **Nguyen-11**: x^y
|
| 40 |
+
- **Nguyen-12**: x⁴ - x³ + y²/2 - y
|
| 41 |
+
|
| 42 |
+
**Localização**: `benchmarks/nguyen/`
|
| 43 |
+
|
| 44 |
+
## Próximos Benchmarks (Planejados)
|
| 45 |
+
|
| 46 |
+
### Feynman Equations
|
| 47 |
+
Equações da física de Feynman - 120+ fórmulas
|
| 48 |
+
- Complexidade maior que Nguyen
|
| 49 |
+
- Multi-variáveis (até 10+)
|
| 50 |
+
- Constantes físicas
|
| 51 |
+
|
| 52 |
+
### Strogatz Benchmarks
|
| 53 |
+
Sistemas dinâmicos e equações diferenciais
|
| 54 |
+
- Osciladores
|
| 55 |
+
- Sistemas caóticos
|
| 56 |
+
- Modelos populacionais
|
| 57 |
+
|
| 58 |
+
## Uso
|
| 59 |
+
|
| 60 |
+
### Preparar Dados de Treinamento
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
# A partir do diretório raiz
|
| 64 |
+
cd 2_training/supervised
|
| 65 |
+
python train_with_json.py --dataset_path ../../1_data/processed/700K
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Adicionar Novo Benchmark
|
| 69 |
+
|
| 70 |
+
1. Criar diretório: `benchmarks/novo_benchmark/`
|
| 71 |
+
2. Adicionar arquivos CSV com formato:
|
| 72 |
+
```csv
|
| 73 |
+
x,y
|
| 74 |
+
1.0,2.5
|
| 75 |
+
2.0,5.0
|
| 76 |
+
...
|
| 77 |
+
```
|
| 78 |
+
3. Adicionar metadata em `novo_benchmark/metadata.json`:
|
| 79 |
+
```json
|
| 80 |
+
{
|
| 81 |
+
"name": "Novo Benchmark",
|
| 82 |
+
"formula": "expressão matemática",
|
| 83 |
+
"variables": ["x", "y"],
|
| 84 |
+
"description": "descrição"
|
| 85 |
+
}
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## Scripts Relacionados
|
| 89 |
+
|
| 90 |
+
- Processamento: `src/seriguela/data/`
|
| 91 |
+
- Avaliação em benchmarks: `3_evaluation/benchmarks/`
|
| 92 |
+
|
| 93 |
+
## Referências
|
| 94 |
+
|
| 95 |
+
- Nguyen et al. (2012): "Semantically-based crossover in genetic programming"
|
| 96 |
+
- Feynman Lectures on Physics
|
| 97 |
+
- Dataset original: https://huggingface.co/datasets/augustocsc/sintetico_natural
|
1_data/benchmarks/nguyen/nguyen_1.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
-0.250919762305275,-0.20375712587228662
|
| 3 |
+
0.9014286128198323,2.4464791994214075
|
| 4 |
+
0.4639878836228102,0.7791621581533072
|
| 5 |
+
0.1973169683940732,0.24393329049852017
|
| 6 |
+
-0.687962719115127,-0.5402777510419681
|
| 7 |
+
-0.6880109593275947,-0.5403281140165024
|
| 8 |
+
-0.8838327756636011,-0.793087543099096
|
| 9 |
+
0.7323522915498704,1.6614819098782947
|
| 10 |
+
0.2022300234864176,0.25139760359687213
|
| 11 |
+
0.416145155592091,0.6613886285518918
|
| 12 |
+
-0.9588310114083951,-0.9209820173332797
|
| 13 |
+
0.9398197043239886,2.6531869448442036
|
| 14 |
+
0.6648852816008435,1.4008851765113863
|
| 15 |
+
-0.5753217786434477,-0.43475534749635414
|
| 16 |
+
-0.6363500655857988,-0.48909314986283187
|
| 17 |
+
-0.6331909802931324,-0.48612594014666893
|
| 18 |
+
-0.39151551408092455,-0.29824433610683415
|
| 19 |
+
0.04951286326447568,0.05208576884733716
|
| 20 |
+
-0.13610996271576847,-0.12010560331123674
|
| 21 |
+
-0.4175417196039162,-0.3159953095123318
|
| 22 |
+
0.22370578944475894,0.2849452648921554
|
| 23 |
+
-0.7210122786959163,-0.5759780829004397
|
| 24 |
+
-0.4157107029295637,-0.3147365210423035
|
| 25 |
+
-0.2672763134126166,-0.2149330041985204
|
| 26 |
+
-0.08786003156592814,-0.08081887184182666
|
| 27 |
+
0.5703519227860272,1.0811894695777686
|
| 28 |
+
-0.6006524356832805,-0.4565744842168673
|
| 29 |
+
0.02846887682722321,0.029302427143425142
|
| 30 |
+
0.18482913772408494,0.2253050457893701
|
| 31 |
+
-0.9070991745600046,-0.8306576893940274
|
| 32 |
+
0.21508970380287673,0.27130410435063984
|
| 33 |
+
-0.6589517526254169,-0.5108626651850308
|
| 34 |
+
-0.869896814029441,-0.7714450703759913
|
| 35 |
+
0.8977710745066665,2.427361090599084
|
| 36 |
+
0.9312640661491187,2.606158159545042
|
| 37 |
+
0.6167946962329223,1.231881113886954
|
| 38 |
+
-0.39077246165325863,-0.2977415177155153
|
| 39 |
+
-0.8046557719872323,-0.6781760666405668
|
| 40 |
+
0.3684660530243138,0.5542589014459388
|
| 41 |
+
-0.1196950125207974,-0.10708297449722336
|
| 42 |
+
-0.7559235303104423,-0.6164532603539071
|
| 43 |
+
-0.00964617977745963,-0.00955402855846198
|
| 44 |
+
-0.9312229577695632,-0.8715811438419624
|
| 45 |
+
0.8186408041575641,2.037444342661761
|
| 46 |
+
-0.48244003679996617,-0.36197878909859404
|
| 47 |
+
0.32504456870796394,0.46504079000064535
|
| 48 |
+
-0.3765778478211781,-0.2881698066335382
|
| 49 |
+
0.040136042355621626,0.04181159947832137
|
| 50 |
+
0.0934205586865593,0.10296327812911324
|
| 51 |
+
-0.6302910889489459,-0.4834179919216196
|
| 52 |
+
0.9391692555291171,2.64959195422807
|
| 53 |
+
0.5502656467222291,1.01967411954153
|
| 54 |
+
0.8789978831283782,2.330781693938945
|
| 55 |
+
0.7896547008552977,1.905602026387024
|
| 56 |
+
0.19579995762217028,0.24164408606501708
|
| 57 |
+
0.8437484700462337,2.15633417340509
|
| 58 |
+
-0.823014995896161,-0.70313355144748
|
| 59 |
+
-0.6080342751617096,-0.4631223204132699
|
| 60 |
+
-0.9095454221789239,-0.8347148035273707
|
| 61 |
+
-0.3493393384734713,-0.2699340299663634
|
| 62 |
+
-0.22264542062103598,-0.18411118973018098
|
| 63 |
+
-0.4573019364522082,-0.34381017076318326
|
| 64 |
+
0.6574750183038587,1.373957379373378
|
| 65 |
+
-0.28649334661282144,-0.22792983524748567
|
| 66 |
+
-0.4381309806252385,-0.3302753025279416
|
| 67 |
+
0.08539216631649693,0.09330665286752131
|
| 68 |
+
-0.7181515500504747,-0.5727905657505367
|
| 69 |
+
0.6043939615080793,1.1904663378939224
|
| 70 |
+
-0.8508987126404584,-0.7429451134365845
|
| 71 |
+
0.9737738732010346,2.8453764395285988
|
| 72 |
+
0.5444895385933148,1.0023825877332384
|
| 73 |
+
-0.6025686369316552,-0.45826569576400444
|
| 74 |
+
-0.9889557657527952,-0.9781541346041163
|
| 75 |
+
0.6309228569096683,1.2801339644356482
|
| 76 |
+
0.41371468769523423,0.6556858714260754
|
| 77 |
+
0.4580143360819746,0.7638724020026765
|
| 78 |
+
0.5425406933718915,0.9965881695975148
|
| 79 |
+
-0.8519106965318193,-0.7444346128158585
|
| 80 |
+
-0.2830685429114548,-0.22562240251418217
|
| 81 |
+
-0.7682618809497406,-0.6314839442672466
|
| 82 |
+
0.726206851751187,1.6365675922627294
|
| 83 |
+
0.24659625365511584,0.32240141321500704
|
| 84 |
+
-0.3382039502947016,-0.2625064527787106
|
| 85 |
+
-0.8728832994279527,-0.7760298750035356
|
| 86 |
+
-0.3780353565686756,-0.2891499348341886
|
| 87 |
+
-0.3496333559465059,-0.2701302717663986
|
| 88 |
+
0.45921235667612814,0.7669252048567059
|
| 89 |
+
0.27511494271042625,0.3716261379416909
|
| 90 |
+
0.7744254851526531,1.8386102554828787
|
| 91 |
+
-0.05557014967610141,-0.05265371107138442
|
| 92 |
+
-0.7608115081233966,-0.6223610405246263
|
| 93 |
+
0.42648957444599,0.6859585520259864
|
| 94 |
+
0.5215700972337949,0.9354909760292364
|
| 95 |
+
0.12255439513899247,0.13941469042057167
|
| 96 |
+
0.541934359909122,0.9947894572777974
|
| 97 |
+
-0.012408807271218514,-0.012256739462828549
|
| 98 |
+
0.045465658763988115,0.04762676814193671
|
| 99 |
+
-0.14491796328290074,-0.1269602006619134
|
| 100 |
+
-0.9491617465118096,-0.9033611561685656
|
| 101 |
+
-0.7842171460133911,-0.6515114391246863
|
1_data/benchmarks/nguyen/nguyen_1.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_1
|
| 2 |
+
equation: x**3 + x**2 + x
|
| 3 |
+
latex: x^3 + x^2 + x
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_10.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,x_2,y
|
| 2 |
+
-0.250919762305275,0.9014286128198323,-0.3081292892207914
|
| 3 |
+
0.4639878836228102,0.1973169683940732,0.8776686422156067
|
| 4 |
+
-0.687962719115127,-0.6880109593275947,-0.9810337787224731
|
| 5 |
+
-0.8838327756636011,0.7323522915498704,-1.1498719359899228
|
| 6 |
+
0.2022300234864176,0.416145155592091,0.3674245747597698
|
| 7 |
+
-0.9588310114083951,0.9398197043239886,-0.9657455975760866
|
| 8 |
+
0.6648852816008435,-0.5753217786434477,1.0352950417993314
|
| 9 |
+
-0.6363500655857988,-0.6331909802931324,-0.958123855709307
|
| 10 |
+
-0.39151551408092455,0.04951286326447568,-0.7622440963237047
|
| 11 |
+
-0.13610996271576847,-0.4175417196039162,-0.24806552319621297
|
| 12 |
+
0.22370578944475894,-0.7210122786959163,0.3332717377157947
|
| 13 |
+
-0.4157107029295637,-0.2672763134126166,-0.7790027574498273
|
| 14 |
+
-0.08786003156592814,0.5703519227860272,-0.14771529605353687
|
| 15 |
+
-0.6006524356832805,0.02846887682722321,-1.1299036280741872
|
| 16 |
+
0.18482913772408494,-0.9070991745600046,0.2264274834216052
|
| 17 |
+
0.21508970380287673,-0.6589517526254169,0.3374982513544557
|
| 18 |
+
-0.869896814029441,0.8977710745066665,-0.9528126545294617
|
| 19 |
+
0.9312640661491187,0.6167946962329223,1.3090534191908094
|
| 20 |
+
-0.39077246165325863,-0.8046557719872323,-0.5282049581946697
|
| 21 |
+
0.3684660530243138,-0.1196950125207974,0.715215559638574
|
| 22 |
+
-0.7559235303104423,-0.00964617977745963,-1.3718580869858932
|
| 23 |
+
-0.9312229577695632,0.8186408041575641,-1.096354700980636
|
| 24 |
+
-0.48244003679996617,0.32504456870796394,-0.8792969284686333
|
| 25 |
+
-0.3765778478211781,0.040136042355621626,-0.7348882299464423
|
| 26 |
+
0.0934205586865593,-0.6302910889489459,0.15072125651857374
|
| 27 |
+
0.9391692555291171,0.5502656467222291,1.3758661094076514
|
| 28 |
+
0.8789978831283782,0.7896547008552977,1.0844402663947248
|
| 29 |
+
0.19579995762217028,0.8437484700462337,0.2586235668937924
|
| 30 |
+
-0.823014995896161,-0.6080342751617096,-1.2035798540389373
|
| 31 |
+
-0.9095454221789239,-0.3493393384734713,-1.4831094826390372
|
| 32 |
+
-0.22264542062103598,-0.4573019364522082,-0.3962431609403149
|
| 33 |
+
0.6574750183038587,-0.28649334661282144,1.1724227266363472
|
| 34 |
+
-0.4381309806252385,0.08539216631649693,-0.8454037878432863
|
| 35 |
+
-0.7181515500504747,0.6043939615080793,-1.0828560895497152
|
| 36 |
+
-0.8508987126404584,0.9737738732010346,-0.8453799503582896
|
| 37 |
+
0.5444895385933148,-0.6025686369316552,0.853511842011265
|
| 38 |
+
-0.9889557657527952,0.6309228569096683,-1.3492282644581814
|
| 39 |
+
0.41371468769523423,0.4580143360819746,0.7211575539774799
|
| 40 |
+
0.5425406933718915,-0.8519106965318193,0.6800328750993455
|
| 41 |
+
-0.2830685429114548,-0.7682618809497406,-0.40170504605860297
|
| 42 |
+
0.726206851751187,0.24659625365511584,1.2879008077367804
|
| 43 |
+
-0.3382039502947016,-0.8728832994279527,-0.42643410138080395
|
| 44 |
+
-0.3780353565686756,-0.3496333559465059,-0.6935287908865693
|
| 45 |
+
0.45921235667612814,0.27511494271042625,0.8531472481325574
|
| 46 |
+
0.7744254851526531,-0.05557014967610141,1.396452099985796
|
| 47 |
+
-0.7608115081233966,0.42648957444599,-1.2554912294871714
|
| 48 |
+
0.5215700972337949,0.12255439513899247,0.9890101456547871
|
| 49 |
+
0.541934359909122,-0.012408807271218514,1.0315088718551784
|
| 50 |
+
0.045465658763988115,-0.14491796328290074,0.08994715712964176
|
| 51 |
+
-0.9491617465118096,-0.7842171460133911,-1.1510102255193018
|
| 52 |
+
-0.9371416286265315,0.2728208225275608,-1.5521272667847952
|
| 53 |
+
-0.37128803784734665,0.01714138232940554,-0.7255254066217963
|
| 54 |
+
0.815132947852186,-0.5014155417022501,1.2764494826885613
|
| 55 |
+
-0.17923415392874054,0.5111022770860973,-0.3109868317325467
|
| 56 |
+
-0.5424036690167551,-0.846040180342414,-0.6844274708190229
|
| 57 |
+
-0.42049709417246395,-0.6775574254919912,-0.6360846328891363
|
| 58 |
+
0.8593953046851461,0.6162407591288339,1.2362412464087666
|
| 59 |
+
0.26680751302084693,0.7429211803754354,0.38835897648845996
|
| 60 |
+
0.6073441537982289,-0.6268598822279283,0.9243711268676845
|
| 61 |
+
0.7851179969799555,0.07868448383130144,1.4094429026004185
|
| 62 |
+
0.6148803103281251,0.7921825998469865,0.8102508035936384
|
| 63 |
+
-0.36399305005627225,-0.7798961509446465,-0.5062345208341303
|
| 64 |
+
-0.5441296749161166,-0.14578442274748737,-1.0243646142100158
|
| 65 |
+
0.6360295318449862,0.7214611665126869,0.8920087302448801
|
| 66 |
+
-0.9860957389376186,0.021494605155131463,-1.6673695158590665
|
| 67 |
+
-0.16517799370244202,-0.5557843790585395,-0.2793586873666334
|
| 68 |
+
-0.7602692653326344,-0.3247696571927441,-1.3061850155975012
|
| 69 |
+
0.8858194078250383,-0.35359413595848954,1.4530454262804506
|
| 70 |
+
0.037581243486732197,0.4060379177903557,0.06903499342946122
|
| 71 |
+
-0.27274079524141204,0.9435641654419213,-0.31619201189155477
|
| 72 |
+
0.9248945898842225,-0.4964354083492717,1.4043204449131401
|
| 73 |
+
-0.005502988215229099,-0.3982433803664607,-0.010144637492484147
|
| 74 |
+
-0.4303190112450648,-0.9262261052909344,-0.5013062025629531
|
| 75 |
+
0.2191286679597937,0.005358046457722976,0.4347521786103487
|
| 76 |
+
-0.8970424975000213,-0.44270707152677713,-1.4122924802363288
|
| 77 |
+
0.8165317719333074,-0.5208762186660552,1.264255365125172
|
| 78 |
+
-0.7102102558175538,-0.02109447944487397,-1.303696301953025
|
| 79 |
+
0.9713009082212014,-0.5158894569769992,1.4363389241451416
|
| 80 |
+
0.3442710948117571,0.5232392306574352,0.5847068763596124
|
| 81 |
+
-0.5247249120152007,0.45643269722371915,-0.899380771664107
|
| 82 |
+
-0.26443373456149355,0.26461166118715895,-0.5045315450691236
|
| 83 |
+
0.2670594215217894,0.07154936814951696,0.5264420962176761
|
| 84 |
+
-0.8194204598911834,0.6706049911784759,-1.1450066034669906
|
| 85 |
+
-0.35843987005652833,-0.6269629792002915,-0.5681869814158013
|
| 86 |
+
-0.9184497168904722,0.18178588637648363,-1.56313471156199
|
| 87 |
+
0.35512872368456483,-0.9668243421442877,0.39494150902709124
|
| 88 |
+
0.024186116598561958,-0.5470084496041241,0.04130993957951832
|
| 89 |
+
0.2903455808188997,-0.6512671419900171,0.45537164467760105
|
| 90 |
+
0.3818754762049319,-0.22652930739892518,0.7262813246120741
|
| 91 |
+
0.873459977473469,-0.7249581117080135,1.1475751969011463
|
| 92 |
+
-0.317867297899483,-0.7730529575188219,-0.44742306659388054
|
| 93 |
+
0.8493872365571256,0.754678706761962,1.094013626029377
|
| 94 |
+
-0.4841167445696888,0.3199680920683581,-0.8836085179911668
|
| 95 |
+
0.6344444004024317,0.11040162319892466,1.1782430836037832
|
| 96 |
+
0.05930115671201297,-0.5162954181990966,0.10308252884688361
|
| 97 |
+
-0.8137944643882016,0.7944315159066535,-1.0186613308624795
|
| 98 |
+
0.8008361143266609,0.2662029145465359,1.385300485429242
|
| 99 |
+
-0.32194041790259864,-0.3015808507746782,-0.6042555977205083
|
| 100 |
+
0.45191135774047875,0.7942205199051542,0.6120946429016678
|
| 101 |
+
0.7741728485302346,0.5597510917152477,1.184859127500451
|
1_data/benchmarks/nguyen/nguyen_10.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_10
|
| 2 |
+
equation: 2*sin(x)*cos(y)
|
| 3 |
+
latex: 2 \sin(x) \cos(y)
|
| 4 |
+
n_vars: 2
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_11.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,x_2,y
|
| 2 |
+
0.3745401188473625,0.9507143064099162,0.3931142382758706
|
| 3 |
+
0.7319939418114051,0.5986584841970366,0.829633456356603
|
| 4 |
+
0.15601864044243652,0.15599452033620265,0.7484106404746592
|
| 5 |
+
0.05808361216819946,0.8661761457749352,0.08500661554616816
|
| 6 |
+
0.6011150117432088,0.7080725777960455,0.6974063854971796
|
| 7 |
+
0.020584494295802447,0.9699098521619943,0.023135880194308844
|
| 8 |
+
0.8324426408004217,0.21233911067827616,0.9618073836750944
|
| 9 |
+
0.18182496720710062,0.18340450985343382,0.7315046923708697
|
| 10 |
+
0.3042422429595377,0.5247564316322378,0.5355698450285795
|
| 11 |
+
0.43194501864211576,0.2912291401980419,0.7831160893255134
|
| 12 |
+
0.6118528947223795,0.13949386065204183,0.9337671019595395
|
| 13 |
+
0.29214464853521815,0.3663618432936917,0.6371115444232952
|
| 14 |
+
0.45606998421703593,0.7851759613930136,0.5398582254082916
|
| 15 |
+
0.19967378215835974,0.5142344384136116,0.4367178921370369
|
| 16 |
+
0.5924145688620425,0.046450412719997725,0.9759742767237903
|
| 17 |
+
0.6075448519014384,0.17052412368729153,0.9185332605846537
|
| 18 |
+
0.06505159298527952,0.9488855372533332,0.07480275943558906
|
| 19 |
+
0.9656320330745594,0.8083973481164611,0.9721242785553075
|
| 20 |
+
0.3046137691733707,0.09767211400638387,0.890382724908758
|
| 21 |
+
0.6842330265121569,0.4401524937396013,0.8461836782187776
|
| 22 |
+
0.12203823484477883,0.4951769101112702,0.3529017979604265
|
| 23 |
+
0.034388521115218396,0.9093204020787821,0.04668000938776754
|
| 24 |
+
0.2587799816000169,0.662522284353982,0.4083696815940864
|
| 25 |
+
0.31171107608941095,0.5200680211778108,0.5454020001935437
|
| 26 |
+
0.5467102793432796,0.18485445552552704,0.894382426580511
|
| 27 |
+
0.9695846277645586,0.7751328233611146,0.9763424052614121
|
| 28 |
+
0.9394989415641891,0.8948273504276488,0.9456857906412531
|
| 29 |
+
0.5978999788110851,0.9218742350231168,0.622414360674025
|
| 30 |
+
0.0884925020519195,0.1959828624191452,0.6217441626231702
|
| 31 |
+
0.045227288910538066,0.32533033076326434,0.3652254374411182
|
| 32 |
+
0.388677289689482,0.2713490317738959,0.7738119236542511
|
| 33 |
+
0.8287375091519293,0.3567533266935893,0.9351795294971191
|
| 34 |
+
0.28093450968738076,0.5426960831582485,0.5020652261408302
|
| 35 |
+
0.14092422497476265,0.8021969807540397,0.207643749220454
|
| 36 |
+
0.07455064367977082,0.9868869366005173,0.07713243002937303
|
| 37 |
+
0.7722447692966574,0.1987156815341724,0.9499377648466204
|
| 38 |
+
0.005522117123602399,0.8154614284548342,0.01441365817791975
|
| 39 |
+
0.7068573438476171,0.7290071680409873,0.7765363384012357
|
| 40 |
+
0.7712703466859457,0.07404465173409036,0.9809531237946688
|
| 41 |
+
0.3584657285442726,0.11586905952512971,0.8879208766651301
|
| 42 |
+
0.8631034258755935,0.6232981268275579,0.9123218750417235
|
| 43 |
+
0.3308980248526492,0.06355835028602363,0.9321215616843733
|
| 44 |
+
0.3109823217156622,0.32518332202674705,0.683984263997003
|
| 45 |
+
0.7296061783380641,0.6375574713552131,0.8179204225106446
|
| 46 |
+
0.8872127425763265,0.4722149251619493,0.9450568572198838
|
| 47 |
+
0.1195942459383017,0.713244787222995,0.21987794449897202
|
| 48 |
+
0.7607850486168974,0.5612771975694962,0.8577387897180954
|
| 49 |
+
0.770967179954561,0.49379559636439074,0.8794655258016348
|
| 50 |
+
0.5227328293819941,0.42754101835854963,0.7577972913056583
|
| 51 |
+
0.02541912674409519,0.10789142699330445,0.6728689751470142
|
| 52 |
+
0.03142918568673425,0.6364104112637804,0.11058269210864166
|
| 53 |
+
0.3143559810763267,0.5085706911647028,0.5551411500446592
|
| 54 |
+
0.907566473926093,0.24929222914887494,0.9761114867199115
|
| 55 |
+
0.41038292303562973,0.7555511385430487,0.5102050118761009
|
| 56 |
+
0.22879816549162246,0.07697990982879299,0.8926695364249662
|
| 57 |
+
0.289751452913768,0.16122128725400442,0.818968274689897
|
| 58 |
+
0.9296976523425731,0.808120379564417,0.9427929152457777
|
| 59 |
+
0.6334037565104235,0.8714605901877177,0.6716955774921506
|
| 60 |
+
0.8036720768991145,0.18657005888603584,0.9600427248758505
|
| 61 |
+
0.8925589984899778,0.5393422419156507,0.9405381426424225
|
| 62 |
+
0.8074401551640625,0.8960912999234932,0.8255861281416367
|
| 63 |
+
0.3180034749718639,0.11005192452767676,0.8815392814173838
|
| 64 |
+
0.22793516254194168,0.4271077886262563,0.5317606737242714
|
| 65 |
+
0.8180147659224931,0.8607305832563434,0.8412224374670049
|
| 66 |
+
0.006952130531190703,0.5107473025775657,0.07904375205389284
|
| 67 |
+
0.417411003148779,0.22210781047073025,0.8236150559828436
|
| 68 |
+
0.1198653673336828,0.33761517140362796,0.48859950491607135
|
| 69 |
+
0.9429097039125192,0.32320293202075523,0.9811799458729573
|
| 70 |
+
0.5187906217433661,0.7030189588951778,0.6304259108490097
|
| 71 |
+
0.363629602379294,0.9717820827209607,0.3741592725494348
|
| 72 |
+
0.9624472949421112,0.25178229582536416,0.9904090768000238
|
| 73 |
+
0.49724850589238545,0.30087830981676966,0.810411403704802
|
| 74 |
+
0.2848404943774676,0.036886947354532795,0.954732975166944
|
| 75 |
+
0.6095643339798968,0.5026790232288615,0.7797113146930209
|
| 76 |
+
0.05147875124998935,0.27864646423661144,0.43752180145116215
|
| 77 |
+
0.9082658859666537,0.23956189066697242,0.9772134321986058
|
| 78 |
+
0.1448948720912231,0.489452760277563,0.3884857408258429
|
| 79 |
+
0.9856504541106007,0.2420552715115004,0.9965075678190551
|
| 80 |
+
0.6721355474058786,0.7616196153287176,0.7389035679220693
|
| 81 |
+
0.23763754399239967,0.7282163486118596,0.351181104428453
|
| 82 |
+
0.3677831327192532,0.6323058305935795,0.5312771863051331
|
| 83 |
+
0.6335297107608947,0.5357746840747585,0.78305410395847
|
| 84 |
+
0.0902897700544083,0.835302495589238,0.13416593821588932
|
| 85 |
+
0.32078006497173583,0.18651851039985423,0.8089068896038694
|
| 86 |
+
0.040775141554763916,0.5908929431882418,0.1509706391829457
|
| 87 |
+
0.6775643618422824,0.016587828927856152,0.9935639759514822
|
| 88 |
+
0.512093058299281,0.22649577519793795,0.8593473681905776
|
| 89 |
+
0.6451727904094499,0.17436642900499144,0.9264327330143394
|
| 90 |
+
0.690937738102466,0.3867353463005374,0.8667729564864897
|
| 91 |
+
0.9367299887367345,0.13752094414599325,0.9910518779206805
|
| 92 |
+
0.3410663510502585,0.11347352124058907,0.8850943779323204
|
| 93 |
+
0.9246936182785628,0.877339353380981,0.9336166425910485
|
| 94 |
+
0.2579416277151556,0.659984046034179,0.40889663483282684
|
| 95 |
+
0.8172222002012158,0.5552008115994623,0.8939869586598578
|
| 96 |
+
0.5296505783560065,0.24185229090045168,0.8575238670157391
|
| 97 |
+
0.09310276780589921,0.8972157579533268,0.11883298066628799
|
| 98 |
+
0.9004180571633305,0.6331014572732679,0.9357472382139566
|
| 99 |
+
0.3390297910487007,0.3492095746126609,0.6854165136378787
|
| 100 |
+
0.7259556788702394,0.8971102599525771,0.7502759576646942
|
| 101 |
+
0.8870864242651173,0.7798755458576239,0.9107934601911546
|
1_data/benchmarks/nguyen/nguyen_11.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_11
|
| 2 |
+
equation: x**y
|
| 3 |
+
latex: x^y
|
| 4 |
+
n_vars: 2
|
| 5 |
+
range: (0, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_12.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,x_2,y
|
| 2 |
+
0.3745401188473625,0.9507143064099162,-0.5316474979283755
|
| 3 |
+
0.7319939418114051,0.5986584841970366,-0.5245780691122485
|
| 4 |
+
0.15601864044243652,0.15599452033620265,-0.1470326281985484
|
| 5 |
+
0.05808361216819946,0.8661761457749352,-0.49123016315900203
|
| 6 |
+
0.6011150117432088,0.7080725777960455,-0.5440295831065014
|
| 7 |
+
0.020584494295802447,0.9699098521619943,-0.4995558340525668
|
| 8 |
+
0.8324426408004217,0.21233911067827616,-0.28645063725362985
|
| 9 |
+
0.18182496720710062,0.18340450985343382,-0.17150410942706706
|
| 10 |
+
0.3042422429595377,0.5247564316322378,-0.40666548191401125
|
| 11 |
+
0.43194501864211576,0.2912291401980419,-0.2946019335150628
|
| 12 |
+
0.6118528947223795,0.13949386065204183,-0.2186718894768038
|
| 13 |
+
0.29214464853521815,0.3663618432936917,-0.3169010837400344
|
| 14 |
+
0.45606998421703593,0.7851759613930136,-0.5285238661436937
|
| 15 |
+
0.19967378215835974,0.5142344384136116,-0.38838724072182046
|
| 16 |
+
0.5924145688620425,0.046450412719997725,-0.13011303294584414
|
| 17 |
+
0.6075448519014384,0.17052412368729153,-0.24399347586867748
|
| 18 |
+
0.06505159298527952,0.9488855372533332,-0.498951027941936
|
| 19 |
+
0.9656320330745594,0.8083973481164611,-0.5125890941467149
|
| 20 |
+
0.3046137691733707,0.09767211400638387,-0.11255726686131867
|
| 21 |
+
0.6842330265121569,0.4401524937396013,-0.44443839313702255
|
| 22 |
+
0.12203823484477883,0.4951769101112702,-0.3741725684537525
|
| 23 |
+
0.034388521115218396,0.9093204020787821,-0.49592787363434615
|
| 24 |
+
0.2587799816000169,0.662522284353982,-0.45589954550743944
|
| 25 |
+
0.31171107608941095,0.5200680211778108,-0.405678875884419
|
| 26 |
+
0.5467102793432796,0.18485445552552704,-0.24183976519509257
|
| 27 |
+
0.9695846277645586,0.7751328233611146,-0.502441019579803
|
| 28 |
+
0.9394989415641891,0.8948273504276488,-0.5446402529864546
|
| 29 |
+
0.5978999788110851,0.9218742350231168,-0.5828930031608885
|
| 30 |
+
0.0884925020519195,0.1959828624191452,-0.1774098758468889
|
| 31 |
+
0.045227288910538066,0.32533033076326434,-0.2724987473704077
|
| 32 |
+
0.388677289689482,0.2713490317738959,-0.2704292195788051
|
| 33 |
+
0.8287375091519293,0.3567533266935893,-0.39059634750094807
|
| 34 |
+
0.28093450968738076,0.5426960831582485,-0.41138006574851044
|
| 35 |
+
0.14092422497476265,0.8021969807540397,-0.48284128157802086
|
| 36 |
+
0.07455064367977082,0.9868869366005173,-0.5002974721138831
|
| 37 |
+
0.7722447692966574,0.1987156815341724,-0.2838615274049386
|
| 38 |
+
0.005522117123602399,0.8154614284548342,-0.4829729252663755
|
| 39 |
+
0.7068573438476171,0.7290071680409873,-0.5668133801237885
|
| 40 |
+
0.7712703466859457,0.07404465173409036,-0.17624366412809184
|
| 41 |
+
0.3584657285442726,0.11586905952512971,-0.1387066006115091
|
| 42 |
+
0.8631034258755935,0.6232981268275579,-0.517067796073964
|
| 43 |
+
0.3308980248526492,0.06355835028602363,-0.08578087500153808
|
| 44 |
+
0.3109823217156622,0.32518332202674705,-0.2930335023172791
|
| 45 |
+
0.7296061783380641,0.6375574713552131,-0.5393353511837361
|
| 46 |
+
0.8872127425763265,0.4722149251619493,-0.4394882839499687
|
| 47 |
+
0.1195942459383017,0.713244787222995,-0.46039168497896
|
| 48 |
+
0.7607850486168974,0.5612771975694962,-0.509096521876274
|
| 49 |
+
0.770967179954561,0.49379559636439074,-0.4768340968434337
|
| 50 |
+
0.5227328293819941,0.42754101835854963,-0.40431654954497565
|
| 51 |
+
0.02541912674409519,0.10789142699330445,-0.10208715360872561
|
| 52 |
+
0.03142918568673425,0.6364104112637804,-0.43393137529691794
|
| 53 |
+
0.3143559810763267,0.5085706911647028,-0.4005478458408727
|
| 54 |
+
0.907566473926093,0.24929222914887494,-0.28731682218657484
|
| 55 |
+
0.41038292303562973,0.7555511385430487,-0.5108733418268507
|
| 56 |
+
0.22879816549162246,0.07697990982879299,-0.08325384436032608
|
| 57 |
+
0.289751452913768,0.16122128725400442,-0.16550288692705029
|
| 58 |
+
0.9296976523425731,0.808120379564417,-0.5380841567189258
|
| 59 |
+
0.6334037565104235,0.8714605901877177,-0.5848989033049367
|
| 60 |
+
0.8036720768991145,0.18657005888603584,-0.27107631331792464
|
| 61 |
+
0.8925589984899778,0.5393422419156507,-0.470295013851557
|
| 62 |
+
0.8074401551640625,0.8960912999234932,-0.59596852950575
|
| 63 |
+
0.3180034749718639,0.11005192452767676,-0.12592818733422592
|
| 64 |
+
0.22793516254194168,0.4271077886262563,-0.34504023675612344
|
| 65 |
+
0.8180147659224931,0.8607305832563434,-0.5899158316402314
|
| 66 |
+
0.006952130531190703,0.5107473025775657,-0.38031623270764187
|
| 67 |
+
0.417411003148779,0.22210781047073025,-0.23981143105710734
|
| 68 |
+
0.1198653673336828,0.33761517140362796,-0.28213892883047403
|
| 69 |
+
0.9429097039125192,0.32320293202075523,-0.318832855236626
|
| 70 |
+
0.5187906217433661,0.7030189588951778,-0.5230920266442246
|
| 71 |
+
0.363629602379294,0.9717820827209607,-0.5301994956717305
|
| 72 |
+
0.9624472949421112,0.25178229582536416,-0.25356410409756097
|
| 73 |
+
0.49724850589238545,0.30087830981676966,-0.3174265784042432
|
| 74 |
+
0.2848404943774676,0.036886947354532795,-0.05273415977110583
|
| 75 |
+
0.6095643339798968,0.5026790232288615,-0.46476765439268997
|
| 76 |
+
0.05147875124998935,0.27864646423661144,-0.2399539372668816
|
| 77 |
+
0.9082658859666537,0.23956189066697242,-0.2796006655774298
|
| 78 |
+
0.1448948720912231,0.489452760277563,-0.3722719868332219
|
| 79 |
+
0.9856504541106007,0.2420552715115004,-0.22650053348928928
|
| 80 |
+
0.6721355474058786,0.7616196153287176,-0.5711428201443614
|
| 81 |
+
0.23763754399239967,0.7282163486118596,-0.473297554430461
|
| 82 |
+
0.3677831327192532,0.6323058305935795,-0.46385200894375067
|
| 83 |
+
0.6335297107608947,0.5357746840747585,-0.4854310810016914
|
| 84 |
+
0.0902897700544083,0.835302495589238,-0.48710697106906575
|
| 85 |
+
0.32078006497173583,0.18651851039985423,-0.19154377448709786
|
| 86 |
+
0.040775141554763916,0.5908929431882418,-0.4163807370007205
|
| 87 |
+
0.6775643618422824,0.016587828927856152,-0.1167488120616192
|
| 88 |
+
0.512093058299281,0.22649577519793795,-0.2663670817810153
|
| 89 |
+
0.6451727904094499,0.17436642900499144,-0.25445410259051215
|
| 90 |
+
0.690937738102466,0.3867353463005374,-0.41389747881703326
|
| 91 |
+
0.9367299887367345,0.13752094414599325,-0.18006947009905816
|
| 92 |
+
0.3410663510502585,0.11347352124058907,-0.13317857503984928
|
| 93 |
+
0.9246936182785628,0.877339353380981,-0.552019449425134
|
| 94 |
+
0.2579416277151556,0.659984046034179,-0.4549296760550325
|
| 95 |
+
0.8172222002012158,0.5552008115994623,-0.5008339633920046
|
| 96 |
+
0.5296505783560065,0.24185229090045168,-0.28249182975838943
|
| 97 |
+
0.09310276780589921,0.8972157579533268,-0.49544958985988724
|
| 98 |
+
0.9004180571633305,0.6331014572732679,-0.5053891761940527
|
| 99 |
+
0.3390297910487007,0.3492095746126609,-0.31399292258823414
|
| 100 |
+
0.7259556788702394,0.8971102599525771,-0.5995526723688602
|
| 101 |
+
0.8870864242651173,0.7798755458576239,-0.5545939788269545
|
1_data/benchmarks/nguyen/nguyen_12.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_12
|
| 2 |
+
equation: x**4 - x**3 + y**2/2 - y
|
| 3 |
+
latex: x^4 - x^3 + \frac{y^2}{2} - y
|
| 4 |
+
n_vars: 2
|
| 5 |
+
range: (0, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_2.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
-0.250919762305275,-0.19979307271339486
|
| 3 |
+
0.9014286128198323,3.106754963846846
|
| 4 |
+
0.4639878836228102,0.8255096843833445
|
| 5 |
+
0.1973169683940732,0.24544914576563198
|
| 6 |
+
-0.687962719115127,-0.316271768430889
|
| 7 |
+
-0.6880109593275947,-0.3162592952514309
|
| 8 |
+
-0.8838327756636011,-0.18287601110210117
|
| 9 |
+
0.7323522915498704,1.9491423756178945
|
| 10 |
+
0.2022300234864176,0.25307016676624217
|
| 11 |
+
0.416145155592091,0.6913788293276577
|
| 12 |
+
-0.9588310114083951,-0.07576489223978267
|
| 13 |
+
0.9398197043239886,3.4333370743437346
|
| 14 |
+
0.6648852816008435,1.596313216676064
|
| 15 |
+
-0.5753217786434477,-0.3251975588470951
|
| 16 |
+
-0.6363500655857988,-0.32511560759302083
|
| 17 |
+
-0.6331909802931324,-0.3253804197057425
|
| 18 |
+
-0.39151551408092455,-0.27474822950833333
|
| 19 |
+
0.04951286326447568,0.05209177881543897
|
| 20 |
+
-0.13610996271576847,-0.11976239352712116
|
| 21 |
+
-0.4175417196039162,-0.28560049468336546
|
| 22 |
+
0.22370578944475894,0.2874496948760045
|
| 23 |
+
-0.7210122786959163,-0.30572500866496494
|
| 24 |
+
-0.4157107029295637,-0.28487136252946227
|
| 25 |
+
-0.2672763134126166,-0.2098298124197376
|
| 26 |
+
-0.08786003156592814,-0.08075928293478254
|
| 27 |
+
0.5703519227860272,1.1870104156557124
|
| 28 |
+
-0.6006524356832805,-0.3264098596675816
|
| 29 |
+
0.02846887682722321,0.029303084016308063
|
| 30 |
+
0.18482913772408494,0.22647207506221967
|
| 31 |
+
-0.9070991745600046,-0.1536102701687616
|
| 32 |
+
0.21508970380287673,0.2734444232481606
|
| 33 |
+
-0.6589517526254169,-0.32231790405084937
|
| 34 |
+
-0.869896814029441,-0.1988192051106482
|
| 35 |
+
0.8977710745066665,3.0769856490294796
|
| 36 |
+
0.9312640661491187,3.358285510834738
|
| 37 |
+
0.6167946962329223,1.3766124336679
|
| 38 |
+
-0.39077246165325863,-0.27442327583918935
|
| 39 |
+
-0.8046557719872323,-0.25895748554130227
|
| 40 |
+
0.3684660530243138,0.572691642793691
|
| 41 |
+
-0.1196950125207974,-0.10687771454758802
|
| 42 |
+
-0.7559235303104423,-0.28993200547233466
|
| 43 |
+
-0.00964617977745963,-0.009554019900385721
|
| 44 |
+
-0.9312229577695632,-0.11958658706487191
|
| 45 |
+
0.8186408041575641,2.4865758792604673
|
| 46 |
+
-0.48244003679996617,-0.3078069764664333
|
| 47 |
+
0.32504456870796394,0.4762035517253345
|
| 48 |
+
-0.3765778478211781,-0.2680594822320752
|
| 49 |
+
0.040136042355621626,0.04181419448323982
|
| 50 |
+
0.0934205586865593,0.10303944565358064
|
| 51 |
+
-0.6302910889489459,-0.3255970364031556
|
| 52 |
+
0.9391692555291171,3.4275845586374323
|
| 53 |
+
0.5502656467222291,1.1113572855576686
|
| 54 |
+
0.8789978831283782,2.9277500581350866
|
| 55 |
+
0.7896547008552977,2.2944222989511918
|
| 56 |
+
0.19579995762217028,0.2431138594333487
|
| 57 |
+
0.8437484700462337,2.663152129765188
|
| 58 |
+
-0.823014995896161,-0.24432553893716014
|
| 59 |
+
-0.6080342751617096,-0.326440030758018
|
| 60 |
+
-0.9095454221789239,-0.15033439380562408
|
| 61 |
+
-0.3493393384734713,-0.2550407630135437
|
| 62 |
+
-0.22264542062103598,-0.18165390734252046
|
| 63 |
+
-0.4573019364522082,-0.3000768795902401
|
| 64 |
+
0.6574750183038587,1.560817671456092
|
| 65 |
+
-0.28649334661282144,-0.22119296531986027
|
| 66 |
+
-0.4381309806252385,-0.2934271384523741
|
| 67 |
+
0.08539216631649693,0.09335982353659596
|
| 68 |
+
-0.7181515500504747,-0.3068011174024385
|
| 69 |
+
0.6043939615080793,1.3239046275098028
|
| 70 |
+
-0.8508987126404584,-0.2187276720547493
|
| 71 |
+
0.9737738732010346,3.7445271094357677
|
| 72 |
+
0.5444895385933148,1.0902763712821586
|
| 73 |
+
-0.6025686369316552,-0.32643210128260247
|
| 74 |
+
-0.9889557657527952,-0.021604594541118627
|
| 75 |
+
0.6309228569096683,1.4385886349785073
|
| 76 |
+
0.41371468769523423,0.6849815632184505
|
| 77 |
+
0.4580143360819746,0.8078788471365737
|
| 78 |
+
0.5425406933718915,1.0832303299115513
|
| 79 |
+
-0.8519106965318193,-0.21771888700546604
|
| 80 |
+
-0.2830685429114548,-0.21920193818358347
|
| 81 |
+
-0.7682618809497406,-0.2831168381374246
|
| 82 |
+
0.726206851751187,1.9146934506063242
|
| 83 |
+
0.24659625365511584,0.32609923432705157
|
| 84 |
+
-0.3382039502947016,-0.24942323098709215
|
| 85 |
+
-0.8728832994279527,-0.19549978168020488
|
| 86 |
+
-0.3780353565686756,-0.2687264578518238
|
| 87 |
+
-0.3496333559465059,-0.2551868024860783
|
| 88 |
+
0.45921235667612814,0.8113938873926984
|
| 89 |
+
0.27511494271042625,0.37735484635995153
|
| 90 |
+
0.7744254851526531,2.198292124261625
|
| 91 |
+
-0.05557014967610141,-0.05264417507086238
|
| 92 |
+
-0.7608115081233966,-0.2873120662846093
|
| 93 |
+
0.42648957444599,0.7190437453871404
|
| 94 |
+
0.5215700972337949,1.0094942165627014
|
| 95 |
+
0.12255439513899247,0.13964027819697553
|
| 96 |
+
0.541934359909122,1.081044947683308
|
| 97 |
+
-0.012408807271218514,-0.012256715753450735
|
| 98 |
+
0.045465658763988115,0.04763104115236099
|
| 99 |
+
-0.14491796328290074,-0.12651914958498786
|
| 100 |
+
-0.9491617465118096,-0.0917258937919263
|
| 101 |
+
-0.7842171460133911,-0.27329070462795224
|
1_data/benchmarks/nguyen/nguyen_2.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_2
|
| 2 |
+
equation: x**4 + x**3 + x**2 + x
|
| 3 |
+
latex: x^4 + x^3 + x^2 + x
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_3.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
-0.250919762305275,-0.20078773198978944
|
| 3 |
+
0.9014286128198323,3.701946430251423
|
| 4 |
+
0.4639878836228102,0.8470143749899722
|
| 5 |
+
0.1973169683940732,0.24574824973146267
|
| 6 |
+
-0.687962719115127,-0.47037953332606286
|
| 7 |
+
-0.6880109593275947,-0.4704210982053887
|
| 8 |
+
-0.8838327756636011,-0.7222009631689434
|
| 9 |
+
0.7323522915498704,2.1598111768905937
|
| 10 |
+
0.2022300234864176,0.25340840925526636
|
| 11 |
+
0.416145155592091,0.7038591060957268
|
| 12 |
+
-0.9588310114083951,-0.8861852831528764
|
| 13 |
+
0.9398197043239886,4.166537538378305
|
| 14 |
+
0.6648852816008435,1.7262504441936566
|
| 15 |
+
-0.5753217786434477,-0.3882285406770297
|
| 16 |
+
-0.6363500655857988,-0.42946272737101315
|
| 17 |
+
-0.6331909802931324,-0.42716303337146244
|
| 18 |
+
-0.39151551408092455,-0.2839473197621456
|
| 19 |
+
0.04951286326447568,0.05209207638616782
|
| 20 |
+
-0.13610996271576847,-0.11980910779804083
|
| 21 |
+
-0.4175417196039162,-0.29829159793409465
|
| 22 |
+
0.22370578944475894,0.2880099503626506
|
| 23 |
+
-0.7210122786959163,-0.5005807935440612
|
| 24 |
+
-0.4157107029295637,-0.29728662856793836
|
| 25 |
+
-0.2672763134126166,-0.21119377470500825
|
| 26 |
+
-0.08786003156592814,-0.08076451841803642
|
| 27 |
+
0.5703519227860272,1.2473655957223042
|
| 28 |
+
-0.6006524356832805,-0.40459355844290984
|
| 29 |
+
0.02846887682722321,0.029303102716741258
|
| 30 |
+
0.18482913772408494,0.22668777607641924
|
| 31 |
+
-0.9070991745600046,-0.7677594252859816
|
| 32 |
+
0.21508970380287673,0.2739047838058721
|
| 33 |
+
-0.6589517526254169,-0.4465598048485588
|
| 34 |
+
-0.869896814029441,-0.696944620935822
|
| 35 |
+
0.8977710745066665,3.660199786877455
|
| 36 |
+
0.9312640661491187,4.058714686258748
|
| 37 |
+
0.6167946962329223,1.4658819440875785
|
| 38 |
+
-0.39077246165325863,-0.28353540261862736
|
| 39 |
+
-0.8046557719872323,-0.5962841365471232
|
| 40 |
+
0.3684660530243138,0.5794834822445153
|
| 41 |
+
-0.1196950125207974,-0.10690228313982963
|
| 42 |
+
-0.7559235303104423,-0.5367571051838087
|
| 43 |
+
-0.00964617977745963,-0.009554019983903081
|
| 44 |
+
-0.9312229577695632,-0.8198611824534459
|
| 45 |
+
0.8186408041575641,2.854253281554156
|
| 46 |
+
-0.48244003679996617,-0.3339416277462138
|
| 47 |
+
0.32504456870796394,0.4798319467957259
|
| 48 |
+
-0.3765778478211781,-0.27563258491416387
|
| 49 |
+
0.040136042355621626,0.04181429863646714
|
| 50 |
+
0.0934205586865593,0.10304656126627017
|
| 51 |
+
-0.6302910889489459,-0.4250701783158514
|
| 52 |
+
0.9391692555291171,4.158251293727732
|
| 53 |
+
0.5502656467222291,1.1618073821990806
|
| 54 |
+
0.8789978831283782,3.4524839865581054
|
| 55 |
+
0.7896547008552977,2.6014560549693257
|
| 56 |
+
0.19579995762217028,0.24340164099658224
|
| 57 |
+
0.8437484700462337,3.09077900503598
|
| 58 |
+
-0.823014995896161,-0.6219314134704668
|
| 59 |
+
-0.6080342751617096,-0.4095475476759919
|
| 60 |
+
-0.9095454221789239,-0.772809462496975
|
| 61 |
+
-0.3493393384734713,-0.26024356703855056
|
| 62 |
+
-0.22264542062103598,-0.1822010100133058
|
| 63 |
+
-0.4573019364522082,-0.3200761983310553
|
| 64 |
+
0.6574750183038587,1.683673645413439
|
| 65 |
+
-0.28649334661282144,-0.22312303373112094
|
| 66 |
+
-0.4381309806252385,-0.3095714607130422
|
| 67 |
+
0.08539216631649693,0.09336436389521274
|
| 68 |
+
-0.7181515500504747,-0.49782185203069584
|
| 69 |
+
0.6043939615080793,1.4045539239876073
|
| 70 |
+
-0.8508987126404584,-0.6647836180702278
|
| 71 |
+
0.9737738732010346,4.620096539862576
|
| 72 |
+
0.5444895385933148,1.1381336169319312
|
| 73 |
+
-0.6025686369316552,-0.4058708906110614
|
| 74 |
+
-0.9889557657527952,-0.9675897774146045
|
| 75 |
+
0.6309228569096683,1.538561308408088
|
| 76 |
+
0.41371468769523423,0.6971016211991488
|
| 77 |
+
0.4580143360819746,0.8280344298879034
|
| 78 |
+
0.5425406933718915,1.1302372276435673
|
| 79 |
+
-0.8519106965318193,-0.6664336478548603
|
| 80 |
+
-0.2830685429114548,-0.22101936966646102
|
| 81 |
+
-0.7682618809497406,-0.5507540063537395
|
| 82 |
+
0.726206851751187,2.1166703545846226
|
| 83 |
+
0.24659625365511584,0.3270111031599685
|
| 84 |
+
-0.3382039502947016,-0.25384802827959924
|
| 85 |
+
-0.8728832994279527,-0.702234804957491
|
| 86 |
+
-0.3780353565686756,-0.2764472542552242
|
| 87 |
+
-0.3496333559465059,-0.2604115378000402
|
| 88 |
+
0.45921235667612814,0.8318144558983341
|
| 89 |
+
0.27511494271042625,0.378930899648246
|
| 90 |
+
0.7744254851526531,2.476838929991218
|
| 91 |
+
-0.05557014967610141,-0.0526447049878387
|
| 92 |
+
-0.7608115081233966,-0.5422211816713537
|
| 93 |
+
0.42648957444599,0.7331542354242023
|
| 94 |
+
0.5215700972337949,1.0480920939233567
|
| 95 |
+
0.12255439513899247,0.13966792497046343
|
| 96 |
+
0.541934359909122,1.1277897616648658
|
| 97 |
+
-0.012408807271218514,-0.012256716047655835
|
| 98 |
+
0.045465658763988115,0.047631235427594835
|
| 99 |
+
-0.14491796328290074,-0.12658306580875964
|
| 100 |
+
-0.9491617465118096,-0.8620990369599081
|
| 101 |
+
-0.7842171460133911,-0.5698978895980695
|
1_data/benchmarks/nguyen/nguyen_3.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_3
|
| 2 |
+
equation: x**5 + x**4 + x**3 + x**2 + x
|
| 3 |
+
latex: x^5 + x^4 + x^3 + x^2 + x
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_4.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
-0.250919762305275,-0.2005381523205818
|
| 3 |
+
0.9014286128198323,4.238469048174702
|
| 4 |
+
0.4639878836228102,0.8569922908725047
|
| 5 |
+
0.1973169683940732,0.24580726801923503
|
| 6 |
+
-0.687962719115127,-0.3643591363520243
|
| 7 |
+
-0.6880109593275947,-0.36435608826336463
|
| 8 |
+
-0.8838327756636011,-0.2455278937990677
|
| 9 |
+
0.7323522915498704,2.3140949562607194
|
| 10 |
+
0.2022300234864176,0.25347681204176586
|
| 11 |
+
0.416145155592091,0.7090527128132073
|
| 12 |
+
-0.9588310114083951,-0.10912908006768773
|
| 13 |
+
0.9398197043239886,4.855613781697487
|
| 14 |
+
0.6648852816008435,1.812643794302124
|
| 15 |
+
-0.5753217786434477,-0.3519654441009889
|
| 16 |
+
-0.6363500655857988,-0.36306143085659853
|
| 17 |
+
-0.6331909802931324,-0.3627152004476681
|
| 18 |
+
-0.39151551408092455,-0.28034573321234746
|
| 19 |
+
0.04951286326447568,0.05209209111974663
|
| 20 |
+
-0.13610996271576847,-0.11980274952036765
|
| 21 |
+
-0.4175417196039162,-0.2929925328591143
|
| 22 |
+
0.22370578944475894,0.2881352827585815
|
| 23 |
+
-0.7210122786959163,-0.3600873800713027
|
| 24 |
+
-0.4157107029295637,-0.2921254695960259
|
| 25 |
+
-0.2672763134126166,-0.21082921989376727
|
| 26 |
+
-0.08786003156592814,-0.08076405842831248
|
| 27 |
+
0.5703519227860272,1.2817892887233815
|
| 28 |
+
-0.6006524356832805,-0.35763232934278105
|
| 29 |
+
0.02846887682722321,0.02930310324912159
|
| 30 |
+
0.18482913772408494,0.22672764390887995
|
| 31 |
+
-0.9070991745600046,-0.21066523362242717
|
| 32 |
+
0.21508970380287673,0.27400380262187274
|
| 33 |
+
-0.6589517526254169,-0.36469038656839503
|
| 34 |
+
-0.869896814029441,-0.2636269087224129
|
| 35 |
+
0.8977710745066665,4.183792570080711
|
| 36 |
+
0.9312640661491187,4.710999208213584
|
| 37 |
+
0.6167946962329223,1.520942904649746
|
| 38 |
+
-0.39077246165325863,-0.2799746344061298
|
| 39 |
+
-0.8046557719872323,-0.32485229977016666
|
| 40 |
+
0.3684660530243138,0.5819860445197353
|
| 41 |
+
-0.1196950125207974,-0.10689934240187365
|
| 42 |
+
-0.7559235303104423,-0.3501762044406843
|
| 43 |
+
-0.00964617977745963,-0.009554019983097458
|
| 44 |
+
-0.9312229577695632,-0.16774940248481396
|
| 45 |
+
0.8186408041575641,3.1552490058384244
|
| 46 |
+
-0.48244003679996617,-0.32133322562104216
|
| 47 |
+
0.32504456870796394,0.48101133690648334
|
| 48 |
+
-0.3765778478211781,-0.27278072220481414
|
| 49 |
+
0.040136042355621626,0.041814302816765486
|
| 50 |
+
0.0934205586865593,0.10304722601078302
|
| 51 |
+
-0.6302910889489459,-0.36237314337852533
|
| 52 |
+
0.9391692555291171,4.8444710273623794
|
| 53 |
+
0.5502656467222291,1.1895683372546662
|
| 54 |
+
0.8789978831283782,3.913723998847577
|
| 55 |
+
0.7896547008552977,2.843906703730303
|
| 56 |
+
0.19579995762217028,0.2434579886144678
|
| 57 |
+
0.8437484700462337,3.451588526796362
|
| 58 |
+
-0.823014995896161,-0.3111561161910711
|
| 59 |
+
-0.6080342751617096,-0.3590153288662821
|
| 60 |
+
-0.9095454221789239,-0.2066401133482456
|
| 61 |
+
-0.3493393384734713,-0.2584260229222476
|
| 62 |
+
-0.22264542062103598,-0.18207920010904594
|
| 63 |
+
-0.4573019364522082,-0.31093047114315553
|
| 64 |
+
0.6574750183038587,1.7644483791397838
|
| 65 |
+
-0.28649334661282144,-0.22257008197278716
|
| 66 |
+
-0.4381309806252385,-0.30249813296944583
|
| 67 |
+
0.08539216631649693,0.09336475160627088
|
| 68 |
+
-0.7181515500504747,-0.36064001536563245
|
| 69 |
+
0.6043939615080793,1.453297871778667
|
| 70 |
+
-0.8508987126404584,-0.28523518784003543
|
| 71 |
+
0.9737738732010346,5.472703175385714
|
| 72 |
+
0.5444895385933148,1.1641913865341225
|
| 73 |
+
-0.6025686369316552,-0.358003567605911
|
| 74 |
+
-0.9889557657527952,-0.03205227649515818
|
| 75 |
+
0.6309228569096683,1.6016363531411766
|
| 76 |
+
0.41371468769523423,0.7021158672014816
|
| 77 |
+
0.4580143360819746,0.8372659757400991
|
| 78 |
+
0.5425406933718915,1.155740382532357
|
| 79 |
+
-0.8519106965318193,-0.2841687433955442
|
| 80 |
+
-0.2830685429114548,-0.22050491198476146
|
| 81 |
+
-0.7682618809497406,-0.3451385720878113
|
| 82 |
+
0.726206851751187,2.263347366149155
|
| 83 |
+
0.24659625365511584,0.3272359665979907
|
| 84 |
+
-0.3382039502947016,-0.25235154435602003
|
| 85 |
+
-0.8728832994279527,-0.2599142659035131
|
| 86 |
+
-0.3780353565686756,-0.2735285202338706
|
| 87 |
+
-0.3496333559465059,-0.25858479605828744
|
| 88 |
+
0.45921235667612814,0.8411918332864734
|
| 89 |
+
0.27511494271042625,0.3793644954583637
|
| 90 |
+
0.7744254851526531,2.6925526751560804
|
| 91 |
+
-0.05557014967610141,-0.05264467554027301
|
| 92 |
+
-0.7608115081233966,-0.34828339315956375
|
| 93 |
+
0.42648957444599,0.7391722123153333
|
| 94 |
+
0.5215700972337949,1.0682235925713717
|
| 95 |
+
0.12255439513899247,0.13967131320406578
|
| 96 |
+
0.541934359909122,1.1531223825090322
|
| 97 |
+
-0.012408807271218514,-0.0122567160440051
|
| 98 |
+
0.045465658763988115,0.04763124426044632
|
| 99 |
+
-0.14491796328290074,-0.1265738031997899
|
| 100 |
+
-0.9491617465118096,-0.13089031892479408
|
| 101 |
+
-0.7842171460133911,-0.3372934495137382
|
1_data/benchmarks/nguyen/nguyen_4.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_4
|
| 2 |
+
equation: x**6 + x**5 + x**4 + x**3 + x**2 + x
|
| 3 |
+
latex: x^6 + x^5 + x^4 + x^3 + x^2 + x
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_5.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
-0.250919762305275,-0.9390512081168352
|
| 3 |
+
0.9014286128198323,-0.5494873180553288
|
| 4 |
+
0.4639878836228102,-0.8089599566948397
|
| 5 |
+
0.1973169683940732,-0.9618311304971612
|
| 6 |
+
-0.687962719115127,-0.6478606522679959
|
| 7 |
+
-0.6880109593275947,-0.6478289752250018
|
| 8 |
+
-0.8838327756636011,-0.5534626739045931
|
| 9 |
+
0.7323522915498704,-0.6200235532537778
|
| 10 |
+
0.2022300234864176,-0.9599476201534037
|
| 11 |
+
0.416145155592091,-0.8423936940912042
|
| 12 |
+
-0.9588310114083951,-0.5431689029988238
|
| 13 |
+
0.9398197043239886,-0.5440918837654212
|
| 14 |
+
0.6648852816008435,-0.6633160228877684
|
| 15 |
+
-0.5753217786434477,-0.7273325440876017
|
| 16 |
+
-0.6363500655857988,-0.6831457922480837
|
| 17 |
+
-0.6331909802931324,-0.6853819519705832
|
| 18 |
+
-0.39151551408092455,-0.8588685461295467
|
| 19 |
+
0.04951286326447568,-0.9975514831941417
|
| 20 |
+
-0.13610996271576847,-0.9816464680708747
|
| 21 |
+
-0.4175417196039162,-0.8414428810237741
|
| 22 |
+
0.22370578944475894,-0.9512230869588283
|
| 23 |
+
-0.7210122786959163,-0.6268666215487173
|
| 24 |
+
-0.4157107029295637,-0.8426891332699092
|
| 25 |
+
-0.2672763134126166,-0.9311584013772981
|
| 26 |
+
-0.08786003156592814,-0.9923104665144623
|
| 27 |
+
0.5703519227860272,-0.7309939462300948
|
| 28 |
+
-0.6006524356832805,-0.7087806446206153
|
| 29 |
+
0.02846887682722321,-0.99918985155515
|
| 30 |
+
0.18482913772408494,-0.9664265762992899
|
| 31 |
+
-0.9070991745600046,-0.5484026596371498
|
| 32 |
+
0.21508970380287673,-0.9548185804730333
|
| 33 |
+
-0.6589517526254169,-0.6673796194282944
|
| 34 |
+
-0.869896814029441,-0.5572462380695584
|
| 35 |
+
0.8977710745066665,-0.550238281253637
|
| 36 |
+
0.9312640661491187,-0.5448920303797724
|
| 37 |
+
0.6167946962329223,-0.6970965407107139
|
| 38 |
+
-0.39077246165325863,-0.8593564758850973
|
| 39 |
+
-0.8046557719872323,-0.58178568969466
|
| 40 |
+
0.3684660530243138,-0.8737340760964288
|
| 41 |
+
-0.1196950125207974,-0.9857760980914346
|
| 42 |
+
-0.7559235303104423,-0.6064732977627244
|
| 43 |
+
-0.00964617977745963,-0.9999069555448398
|
| 44 |
+
-0.9312229577695632,-0.5448964424582949
|
| 45 |
+
0.8186408041575641,-0.5756409511274684
|
| 46 |
+
-0.48244003679996617,-0.7956727484647694
|
| 47 |
+
0.32504456870796394,-0.900064608947692
|
| 48 |
+
-0.3765778478211781,-0.8685675799244211
|
| 49 |
+
0.040136042355621626,-0.9983903961284725
|
| 50 |
+
0.0934205586865593,-0.991310765594519
|
| 51 |
+
-0.6302910889489459,-0.687440925563
|
| 52 |
+
0.9391692555291171,-0.5441444192356306
|
| 53 |
+
0.5502656467222291,-0.7458300065420963
|
| 54 |
+
0.8789978831283782,-0.5547129310790808
|
| 55 |
+
0.7896547008552977,-0.5888643062279737
|
| 56 |
+
0.19579995762217028,-0.9624041296078009
|
| 57 |
+
0.8437484700462337,-0.5657849821252405
|
| 58 |
+
-0.823014995896161,-0.5738128862301527
|
| 59 |
+
-0.6080342751617096,-0.7034212209476314
|
| 60 |
+
-0.9095454221789239,-0.5479648926611902
|
| 61 |
+
-0.3493393384734713,-0.8856176517083943
|
| 62 |
+
-0.22264542062103598,-0.9516723889435259
|
| 63 |
+
-0.4573019364522082,-0.8137278342896282
|
| 64 |
+
0.6574750183038587,-0.668396180415866
|
| 65 |
+
-0.28649334661282144,-0.9213553931307055
|
| 66 |
+
-0.4381309806252385,-0.8272380555912469
|
| 67 |
+
0.08539216631649693,-0.9927348114985
|
| 68 |
+
-0.7181515500504747,-0.6286238526457859
|
| 69 |
+
0.6043939615080793,-0.7060609958180059
|
| 70 |
+
-0.8508987126404584,-0.56326927462963
|
| 71 |
+
0.9737738732010346,-0.5432897447861715
|
| 72 |
+
0.5444895385933148,-0.7501016086636129
|
| 73 |
+
-0.6025686369316552,-0.7073869979373362
|
| 74 |
+
-0.9889557657527952,-0.5441925947478524
|
| 75 |
+
0.6309228569096683,-0.6869918579399132
|
| 76 |
+
0.41371468769523423,-0.8440443466281196
|
| 77 |
+
0.4580143360819746,-0.8132211009838715
|
| 78 |
+
0.5425406933718915,-0.7515428281198429
|
| 79 |
+
-0.8519106965318193,-0.5629240207718037
|
| 80 |
+
-0.2830685429114548,-0.9231433578147953
|
| 81 |
+
-0.7682618809497406,-0.5997746863584423
|
| 82 |
+
0.726206851751187,-0.6237072834805513
|
| 83 |
+
0.24659625365511584,-0.9410661846934961
|
| 84 |
+
-0.3382039502947016,-0.8923327139993233
|
| 85 |
+
-0.8728832994279527,-0.5563892940528554
|
| 86 |
+
-0.3780353565686756,-0.8676316179551925
|
| 87 |
+
-0.3496333559465059,-0.8854382694554389
|
| 88 |
+
0.45921235667612814,-0.8123682362870449
|
| 89 |
+
0.27511494271042625,-0.927227629948057
|
| 90 |
+
0.7744254851526531,-0.5965370771413236
|
| 91 |
+
-0.05557014967610141,-0.9969167301387304
|
| 92 |
+
-0.7608115081233966,-0.603785604256095
|
| 93 |
+
0.42648957444599,-0.8353116478277598
|
| 94 |
+
0.5215700972337949,-0.7670333944975515
|
| 95 |
+
0.12255439513899247,-0.9850936334783424
|
| 96 |
+
0.541934359909122,-0.7519912093984299
|
| 97 |
+
-0.012408807271218514,-0.999846033357251
|
| 98 |
+
0.045465658763988115,-0.997935011480979
|
| 99 |
+
-0.14491796328290074,-0.9792204513367526
|
| 100 |
+
-0.9491617465118096,-0.5434895142161542
|
| 101 |
+
-0.7842171460133911,-0.5915491851197053
|
1_data/benchmarks/nguyen/nguyen_5.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_5
|
| 2 |
+
equation: sin(x**2)*cos(x) - 1
|
| 3 |
+
latex: \sin(x^2) \cos(x) - 1
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_6.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
-0.250919762305275,-0.43514929053571705
|
| 3 |
+
0.9014286128198323,1.7739777070144762
|
| 4 |
+
0.4639878836228102,1.0757452027433314
|
| 5 |
+
0.1973169683940732,0.4300984433827959
|
| 6 |
+
-0.687962719115127,-0.8479896478976541
|
| 7 |
+
-0.6880109593275947,-0.8480091939305666
|
| 8 |
+
-0.8838327756636011,-0.8756673756307944
|
| 9 |
+
0.7323522915498704,1.623333209866147
|
| 10 |
+
0.2022300234864176,0.4415932453735026
|
| 11 |
+
0.416145155592091,0.9600350888652831
|
| 12 |
+
-0.9588310114083951,-0.8579844227187403
|
| 13 |
+
0.9398197043239886,1.7757964620012405
|
| 14 |
+
0.6648852816008435,1.511310612392346
|
| 15 |
+
-0.5753217786434477,-0.7860078048335528
|
| 16 |
+
-0.6363500655857988,-0.8236127344518008
|
| 17 |
+
-0.6331909802931324,-0.8218977378877982
|
| 18 |
+
-0.39151551408092455,-0.6175737622323656
|
| 19 |
+
0.04951286326447568,0.10143363891812185
|
| 20 |
+
-0.13610996271576847,-0.2530033664640832
|
| 21 |
+
-0.4175417196039162,-0.646324893603273
|
| 22 |
+
0.22370578944475894,0.4921883475197036
|
| 23 |
+
-0.7210122786959163,-0.859945147297601
|
| 24 |
+
-0.4157107029295637,-0.6443541593670672
|
| 25 |
+
-0.2672763134126166,-0.4586956789581246
|
| 26 |
+
-0.08786003156592814,-0.16780192738539715
|
| 27 |
+
0.5703519227860272,1.3205458271668318
|
| 28 |
+
-0.6006524356832805,-0.8027562953942011
|
| 29 |
+
0.02846887682722321,0.0577402019392117
|
| 30 |
+
0.18482913772408494,0.40102336487359014
|
| 31 |
+
-0.9070991745600046,-0.8718906080781982
|
| 32 |
+
0.21508970380287673,0.4718231825909881
|
| 33 |
+
-0.6589517526254169,-0.8351357987774414
|
| 34 |
+
-0.869896814029441,-0.8771972870866676
|
| 35 |
+
0.8977710745066665,1.773112261855638
|
| 36 |
+
0.9312640661491187,1.7765585415628578
|
| 37 |
+
0.6167946962329223,1.4183947871732054
|
| 38 |
+
-0.39077246165325863,-0.616729624340086
|
| 39 |
+
-0.8046557719872323,-0.8771304085531529
|
| 40 |
+
0.3684660530243138,0.8433211553872129
|
| 41 |
+
-0.1196950125207974,-0.22458265836212168
|
| 42 |
+
-0.7559235303104423,-0.8694190928201299
|
| 43 |
+
-0.00964617977745963,-0.019199015870959124
|
| 44 |
+
-0.9312229577695632,-0.8663534501230806
|
| 45 |
+
0.8186408041575641,1.7268591780806577
|
| 46 |
+
-0.48244003679996617,-0.7110472800982361
|
| 47 |
+
0.32504456870796394,0.7368566710350718
|
| 48 |
+
-0.3765778478211781,-0.6003566284265445
|
| 49 |
+
0.040136042355621626,0.08186008651372706
|
| 50 |
+
0.0934205586865593,0.19525514485076978
|
| 51 |
+
-0.6302910889489459,-0.8203010105715378
|
| 52 |
+
0.9391692555291171,1.7758783124934705
|
| 53 |
+
0.5502656467222291,1.2762087516369758
|
| 54 |
+
0.8789978831283782,1.7668343124729904
|
| 55 |
+
0.7896547008552977,1.6977190241718285
|
| 56 |
+
0.19579995762217028,0.42655545364748826
|
| 57 |
+
0.8437484700462337,1.7470252942741467
|
| 58 |
+
-0.823014995896161,-0.8783461682408447
|
| 59 |
+
-0.6080342751617096,-0.8073339482620789
|
| 60 |
+
-0.9095454221789239,-0.8714044274820212
|
| 61 |
+
-0.3493393384734713,-0.5676262516988537
|
| 62 |
+
-0.22264542062103598,-0.39302217898548675
|
| 63 |
+
-0.4573019364522082,-0.6871659884222171
|
| 64 |
+
0.6574750183038587,1.497630717093317
|
| 65 |
+
-0.28649334661282144,-0.48558453002226076
|
| 66 |
+
-0.4381309806252385,-0.667941100272472
|
| 67 |
+
0.08539216631649693,0.17783977455509073
|
| 68 |
+
-0.7181515500504747,-0.8590244862623294
|
| 69 |
+
0.6043939615080793,1.3929716850978804
|
| 70 |
+
-0.8508987126404584,-0.8784032535275912
|
| 71 |
+
0.9737738732010346,1.7659692545236667
|
| 72 |
+
0.5444895385933148,1.2632639852448593
|
| 73 |
+
-0.6025686369316552,-0.8039577784468424
|
| 74 |
+
-0.9889557657527952,-0.8463746041245153
|
| 75 |
+
0.6309228569096683,1.4466669850764435
|
| 76 |
+
0.41371468769523423,0.9541081599557866
|
| 77 |
+
0.4580143360819746,1.0614213526661842
|
| 78 |
+
0.5425406933718915,1.2588779524345566
|
| 79 |
+
-0.8519106965318193,-0.8783645297620329
|
| 80 |
+
-0.2830685429114548,-0.48085397416735837
|
| 81 |
+
-0.7682618809497406,-0.8719829040159035
|
| 82 |
+
0.726206851751187,1.614146688730815
|
| 83 |
+
0.24659625365511584,0.5466918275185086
|
| 84 |
+
-0.3382039502947016,-0.5537512584147581
|
| 85 |
+
-0.8728832994279527,-0.8769154903633847
|
| 86 |
+
-0.3780353565686756,-0.602059443902563
|
| 87 |
+
-0.3496333559465059,-0.5679887252654134
|
| 88 |
+
0.45921235667612814,1.06429743105736
|
| 89 |
+
0.27511494271042625,0.6153097358245831
|
| 90 |
+
0.7744254851526531,1.6800348672828789
|
| 91 |
+
-0.05557014967610141,-0.10799957252353272
|
| 92 |
+
-0.7608115081233966,-0.8704840611139395
|
| 93 |
+
0.42648957444599,0.985218668558129
|
| 94 |
+
0.5215700972337949,1.2111284238136855
|
| 95 |
+
0.12255439513899247,0.25938825702662677
|
| 96 |
+
0.541934359909122,1.2575114815069453
|
| 97 |
+
-0.012408807271218514,-0.024663010860991184
|
| 98 |
+
0.045465658763988115,0.09296488443027026
|
| 99 |
+
-0.14491796328290074,-0.26801111295446844
|
| 100 |
+
-0.9491617465118096,-0.8611626229977559
|
| 101 |
+
-0.7842171460133911,-0.8746853295473812
|
1_data/benchmarks/nguyen/nguyen_6.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_6
|
| 2 |
+
equation: sin(x) + sin(x + x**2)
|
| 3 |
+
latex: \sin(x) + \sin(x + x^2)
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_7.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
0.749080237694725,1.0044943539710771
|
| 3 |
+
1.9014286128198323,2.5946084456041874
|
| 4 |
+
1.4639878836228102,2.04704177271813
|
| 5 |
+
1.1973169683940732,1.6765955177464844
|
| 6 |
+
0.31203728088487304,0.3644950207193922
|
| 7 |
+
0.3119890406724053,0.36443082009006955
|
| 8 |
+
0.11616722433639892,0.12330527518118403
|
| 9 |
+
1.7323522915498704,2.3917183259947485
|
| 10 |
+
1.2022300234864176,1.683661630559711
|
| 11 |
+
1.416145155592091,1.9826063566319803
|
| 12 |
+
0.041168988591604894,0.042037560323179395
|
| 13 |
+
1.9398197043239886,2.6392050827257325
|
| 14 |
+
1.6648852816008435,2.307724752047451
|
| 15 |
+
0.4246782213565523,0.5197583318072219
|
| 16 |
+
0.36364993441420124,0.4343639775100977
|
| 17 |
+
0.36680901970686763,0.43871392942129606
|
| 18 |
+
0.6084844859190754,0.7902880866466716
|
| 19 |
+
1.0495128632644757,1.4602426871151915
|
| 20 |
+
0.8638900372842315,1.1801684194501334
|
| 21 |
+
0.5824582803960838,0.7510949783115568
|
| 22 |
+
1.223705789444759,1.7144476360772682
|
| 23 |
+
0.27898772130408367,0.3210225319850415
|
| 24 |
+
0.5842892970704363,0.7538452725717105
|
| 25 |
+
0.7327236865873834,0.9794515463618694
|
| 26 |
+
0.9121399684340719,1.2536309133704537
|
| 27 |
+
1.5703519227860272,2.187045504912016
|
| 28 |
+
0.39934756431671947,0.4839764178477628
|
| 29 |
+
1.0284688768272232,1.4288935535974223
|
| 30 |
+
1.184829137724085,1.6585968019236683
|
| 31 |
+
0.09290082543999545,0.09742900194019119
|
| 32 |
+
1.2150897038028767,1.702116152706024
|
| 33 |
+
0.34104824737458306,0.4034836854209194
|
| 34 |
+
0.13010318597055903,0.13909411999883536
|
| 35 |
+
1.8977710745066665,2.5903318010915215
|
| 36 |
+
1.9312640661491187,2.629312630527168
|
| 37 |
+
1.6167946962329223,2.2467723049533905
|
| 38 |
+
0.6092275383467414,0.7914100534079171
|
| 39 |
+
0.19534422801276774,0.21588350693408828
|
| 40 |
+
1.3684660530243138,1.917494634319942
|
| 41 |
+
0.8803049874792026,1.2051988477172373
|
| 42 |
+
0.24407646968955765,0.27625976490939214
|
| 43 |
+
0.9903538202225404,1.3718135735019552
|
| 44 |
+
0.06877704223043679,0.07123417253294272
|
| 45 |
+
1.8186408041575641,2.4966018816848856
|
| 46 |
+
0.5175599632000338,0.6544407561341439
|
| 47 |
+
1.325044568707964,1.8574263454835047
|
| 48 |
+
0.6234221521788219,0.8128721430860975
|
| 49 |
+
1.0401360423556216,1.4462892592771355
|
| 50 |
+
1.0934205586865593,1.5252403780918733
|
| 51 |
+
0.3697089110510541,0.44271409262957717
|
| 52 |
+
1.9391692555291171,2.6384539264232347
|
| 53 |
+
1.550265646722229,2.160950005266661
|
| 54 |
+
1.8789978831283782,2.5683039510702823
|
| 55 |
+
1.7896547008552977,2.461684010396808
|
| 56 |
+
1.1957999576221703,1.6744119782815021
|
| 57 |
+
1.8437484700462337,2.5265931477776116
|
| 58 |
+
0.176985004103839,0.19379920212879384
|
| 59 |
+
0.3919657248382904,0.47363661139990393
|
| 60 |
+
0.09045457782107613,0.09474339248156945
|
| 61 |
+
0.6506606615265287,0.8541953881759485
|
| 62 |
+
0.777354579378964,1.0478012138309305
|
| 63 |
+
0.5426980635477918,0.6916737607221609
|
| 64 |
+
1.6574750183038587,2.298391579491666
|
| 65 |
+
0.7135066533871786,0.9500499213476303
|
| 66 |
+
0.5618690193747615,0.7202496008381812
|
| 67 |
+
1.085392166316497,1.513398916251658
|
| 68 |
+
0.2818484499495253,0.3247441814248926
|
| 69 |
+
1.6043939615080793,2.230907799621039
|
| 70 |
+
0.14910128735954165,0.1609678309146809
|
| 71 |
+
1.9737738732010346,2.6782060947078543
|
| 72 |
+
1.5444895385933148,2.153416282610123
|
| 73 |
+
0.3974313630683448,0.4812884691333153
|
| 74 |
+
0.011044234247204798,0.011105659717543844
|
| 75 |
+
1.6309228569096683,2.2647730643008543
|
| 76 |
+
1.4137146876952342,1.9793088424715877
|
| 77 |
+
1.4580143360819746,2.0390459508436884
|
| 78 |
+
1.5425406933718915,2.150871440176907
|
| 79 |
+
0.14808930346818072,0.15979251395522687
|
| 80 |
+
0.7169314570885452,0.9552876826431883
|
| 81 |
+
0.23173811905025943,0.26073648199843447
|
| 82 |
+
1.726206851751187,2.3841402894128048
|
| 83 |
+
1.2465962536551158,1.7470779213505572
|
| 84 |
+
0.6617960497052984,0.8711341615273885
|
| 85 |
+
0.12711670057204727,0.13569227346664953
|
| 86 |
+
0.6219646434313244,0.8106659478651683
|
| 87 |
+
0.6503666440534941,0.8537484672462254
|
| 88 |
+
1.4592123566761281,2.0406506801223507
|
| 89 |
+
1.2751149427104262,1.7874611691912827
|
| 90 |
+
1.774425485152653,2.4432111990381067
|
| 91 |
+
0.9444298503238986,1.302575647854838
|
| 92 |
+
0.2391884918766034,0.27009115923794424
|
| 93 |
+
1.42648957444599,1.996615009036411
|
| 94 |
+
1.5215700972337949,2.1233923437150204
|
| 95 |
+
1.1225543951389925,1.5680418793142075
|
| 96 |
+
1.541934359909122,2.150079370436032
|
| 97 |
+
0.9875911927287815,1.3676621443430883
|
| 98 |
+
1.0454656587639881,1.454223169438932
|
| 99 |
+
0.8550820367170993,1.1667236792399285
|
| 100 |
+
0.05083825348819038,0.05216937619240417
|
| 101 |
+
0.2157828539866089,0.24089892921884035
|
1_data/benchmarks/nguyen/nguyen_7.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_7
|
| 2 |
+
equation: log(x + 1) + log(x**2 + 1)
|
| 3 |
+
latex: \ln(x+1) + \ln(x^2+1)
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (0, 2)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_8.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,y
|
| 2 |
+
1.49816047538945,1.2239936582308955
|
| 3 |
+
3.8028572256396647,1.9500915941667112
|
| 4 |
+
2.9279757672456204,1.7111328900017146
|
| 5 |
+
2.3946339367881464,1.547460479879259
|
| 6 |
+
0.6240745617697461,0.7899838996902064
|
| 7 |
+
0.6239780813448106,0.789922832525311
|
| 8 |
+
0.23233444867279784,0.48201083875033124
|
| 9 |
+
3.4647045830997407,1.8613716939665061
|
| 10 |
+
2.404460046972835,1.5506321443117432
|
| 11 |
+
2.832290311184182,1.6829409707961185
|
| 12 |
+
0.08233797718320979,0.2869459481909612
|
| 13 |
+
3.8796394086479773,1.9696800269708725
|
| 14 |
+
3.329770563201687,1.8247658927110861
|
| 15 |
+
0.8493564427131046,0.9216053616994123
|
| 16 |
+
0.7272998688284025,0.8528187784215369
|
| 17 |
+
0.7336180394137353,0.8565150549837027
|
| 18 |
+
1.216968971838151,1.103163166461857
|
| 19 |
+
2.0990257265289514,1.4488014793369557
|
| 20 |
+
1.727780074568463,1.3144504838785154
|
| 21 |
+
1.1649165607921677,1.0793130040874
|
| 22 |
+
2.447411578889518,1.564420524951497
|
| 23 |
+
0.5579754426081673,0.7469775382219785
|
| 24 |
+
1.1685785941408726,1.0810081378698648
|
| 25 |
+
1.4654473731747668,1.2105566377393364
|
| 26 |
+
1.8242799368681437,1.3506590749956644
|
| 27 |
+
3.1407038455720544,1.7722031050565437
|
| 28 |
+
0.7986951286334389,0.893697448040129
|
| 29 |
+
2.0569377536544464,1.4342028286314479
|
| 30 |
+
2.36965827544817,1.5393694408582268
|
| 31 |
+
0.1858016508799909,0.4310471562137847
|
| 32 |
+
2.4301794076057535,1.5589032707662633
|
| 33 |
+
0.6820964947491661,0.825891333499248
|
| 34 |
+
0.26020637194111806,0.5101042755565944
|
| 35 |
+
3.795542149013333,1.948215118772394
|
| 36 |
+
3.8625281322982374,1.965331557854358
|
| 37 |
+
3.2335893924658445,1.7982183939849588
|
| 38 |
+
1.2184550766934827,1.1038365262544463
|
| 39 |
+
0.3906884560255355,0.6250507627589422
|
| 40 |
+
2.7369321060486276,1.6543675849244108
|
| 41 |
+
1.7606099749584052,1.3268797891890602
|
| 42 |
+
0.4881529393791153,0.698679425329754
|
| 43 |
+
1.9807076404450807,1.4073761545674564
|
| 44 |
+
0.13755408446087358,0.370882844657007
|
| 45 |
+
3.6372816083151283,1.9071658575790225
|
| 46 |
+
1.0351199264000677,1.017408436371582
|
| 47 |
+
2.650089137415928,1.6279094377194108
|
| 48 |
+
1.2468443043576438,1.1166218269215606
|
| 49 |
+
2.0802720847112433,1.4423148355027218
|
| 50 |
+
2.1868411173731186,1.4787971860174467
|
| 51 |
+
0.7394178221021082,0.8598940760943223
|
| 52 |
+
3.8783385110582342,1.9693497685932366
|
| 53 |
+
3.100531293444458,1.7608325569015522
|
| 54 |
+
3.7579957662567565,1.9385550717626663
|
| 55 |
+
3.5793094017105953,1.8919062877718325
|
| 56 |
+
2.3915999152443406,1.546479846375096
|
| 57 |
+
3.6874969400924673,1.920285640234928
|
| 58 |
+
0.353970008207678,0.5949537866151269
|
| 59 |
+
0.7839314496765808,0.8853990341515969
|
| 60 |
+
0.18090915564215226,0.4253341693799738
|
| 61 |
+
1.3013213230530574,1.1407547164281449
|
| 62 |
+
1.554709158757928,1.2468797691669906
|
| 63 |
+
1.0853961270955836,1.0418234625384397
|
| 64 |
+
3.3149500366077174,1.8207004247288232
|
| 65 |
+
1.4270133067743571,1.194576622395716
|
| 66 |
+
1.123738038749523,1.0600651106179861
|
| 67 |
+
2.170784332632994,1.473358182056554
|
| 68 |
+
0.5636968998990506,0.750797509252029
|
| 69 |
+
3.2087879230161587,1.7913089970789962
|
| 70 |
+
0.2982025747190833,0.5460792751232034
|
| 71 |
+
3.947547746402069,1.9868436643083092
|
| 72 |
+
3.0889790771866297,1.7575491677863893
|
| 73 |
+
0.7948627261366896,0.8915507423229985
|
| 74 |
+
0.022088468494409597,0.14862189776210502
|
| 75 |
+
3.2618457138193366,1.806058059371109
|
| 76 |
+
2.8274293753904685,1.6814961716847494
|
| 77 |
+
2.9160286721639492,1.7076383317798736
|
| 78 |
+
3.085081386743783,1.756439975274926
|
| 79 |
+
0.29617860693636144,0.5442229386348589
|
| 80 |
+
1.4338629141770904,1.197440150561643
|
| 81 |
+
0.46347623810051886,0.6807908916110136
|
| 82 |
+
3.452413703502374,1.8580671956370076
|
| 83 |
+
2.4931925073102317,1.578984644418758
|
| 84 |
+
1.3235920994105967,1.1504747278452476
|
| 85 |
+
0.25423340114409454,0.5042156296110768
|
| 86 |
+
1.2439292868626488,1.1153157789893626
|
| 87 |
+
1.3007332881069882,1.1404969478727194
|
| 88 |
+
2.9184247133522563,1.7083397534894094
|
| 89 |
+
2.5502298854208525,1.5969439205622884
|
| 90 |
+
3.548850970305306,1.8838394226433701
|
| 91 |
+
1.8888597006477972,1.3743579230490859
|
| 92 |
+
0.4783769837532068,0.6916480201324998
|
| 93 |
+
2.85297914889198,1.689076418902348
|
| 94 |
+
3.0431401944675898,1.7444598575110835
|
| 95 |
+
2.245108790277985,1.4983687097233394
|
| 96 |
+
3.083868719818244,1.7560947354337817
|
| 97 |
+
1.975182385457563,1.4054118205912327
|
| 98 |
+
2.0909313175279762,1.4460052965075807
|
| 99 |
+
1.7101640734341985,1.3077324166029527
|
| 100 |
+
0.10167650697638075,0.3188675382919697
|
| 101 |
+
0.4315657079732178,0.6569366087935866
|
1_data/benchmarks/nguyen/nguyen_8.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_8
|
| 2 |
+
equation: sqrt(x)
|
| 3 |
+
latex: \sqrt{x}
|
| 4 |
+
n_vars: 1
|
| 5 |
+
range: (0, 4)
|
| 6 |
+
n_samples: 100
|
1_data/benchmarks/nguyen/nguyen_9.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_1,x_2,y
|
| 2 |
+
-0.250919762305275,0.9014286128198323,0.47776420484711724
|
| 3 |
+
0.4639878836228102,0.1973169683940732,0.48644207085812474
|
| 4 |
+
-0.687962719115127,-0.6880109593275947,-0.17908604352731888
|
| 5 |
+
-0.8838327756636011,0.7323522915498704,-0.26218203193510137
|
| 6 |
+
0.2022300234864176,0.416145155592091,0.37316689293318206
|
| 7 |
+
-0.9588310114083951,0.9398197043239886,-0.0457079947346013
|
| 8 |
+
0.6648852816008435,-0.5753217786434477,0.9419531756947639
|
| 9 |
+
-0.6363500655857988,-0.6331909802931324,-0.2039883598325784
|
| 10 |
+
-0.39151551408092455,0.04951286326447568,-0.37913816951228
|
| 11 |
+
-0.13610996271576847,-0.4175417196039162,0.037769159921857864
|
| 12 |
+
0.22370578944475894,-0.7210122786959163,0.718602102079344
|
| 13 |
+
-0.4157107029295637,-0.2672763134126166,-0.33246432027777695
|
| 14 |
+
-0.08786003156592814,0.5703519227860272,0.23184727590209103
|
| 15 |
+
-0.6006524356832805,0.02846887682722321,-0.5643703547272768
|
| 16 |
+
0.18482913772408494,-0.9070991745600046,0.9168514297790041
|
| 17 |
+
0.21508970380287673,-0.6589517526254169,0.6341356329828418
|
| 18 |
+
-0.869896814029441,0.8977710745066665,-0.04274391665004529
|
| 19 |
+
0.9312640661491187,0.6167946962329223,1.1737000528120816
|
| 20 |
+
-0.39077246165325863,-0.8046557719872323,0.2222683475391934
|
| 21 |
+
0.3684660530243138,-0.1196950125207974,0.37451127227879794
|
| 22 |
+
-0.7559235303104423,-0.00964617977745963,-0.6858679083778902
|
| 23 |
+
-0.9312229577695632,0.8186408041575641,-0.1812290712696445
|
| 24 |
+
-0.48244003679996617,0.32504456870796394,-0.35848458267243266
|
| 25 |
+
-0.3765778478211781,0.040136042355621626,-0.3661293707647741
|
| 26 |
+
0.0934205586865593,-0.6302910889489459,0.4801842308956415
|
| 27 |
+
0.9391692555291171,0.5502656467222291,1.1052544786389042
|
| 28 |
+
0.8789978831283782,0.7896547008552977,1.3540244447960008
|
| 29 |
+
0.19579995762217028,0.8437484700462337,0.8478334428521738
|
| 30 |
+
-0.823014995896161,-0.6080342751617096,-0.3718583841132946
|
| 31 |
+
-0.9095454221789239,-0.3493393384734713,-0.6674893880400039
|
| 32 |
+
-0.22264542062103598,-0.4573019364522082,-0.01320641283786203
|
| 33 |
+
0.6574750183038587,-0.28649334661282144,0.6931064937612617
|
| 34 |
+
-0.4381309806252385,0.08539216631649693,-0.4169559662599447
|
| 35 |
+
-0.7181515500504747,0.6043939615080793,-0.30077177092335655
|
| 36 |
+
-0.8508987126404584,0.9737738732010346,0.060514655114857874
|
| 37 |
+
0.5444895385933148,-0.6025686369316552,0.8731450114593164
|
| 38 |
+
-0.9889557657527952,0.6309228569096683,-0.44781844397033194
|
| 39 |
+
0.41371468769523423,0.4580143360819746,0.6102553103342886
|
| 40 |
+
0.5425406933718915,-0.8519106965318193,1.1800115078177271
|
| 41 |
+
-0.2830685429114548,-0.7682618809497406,0.27724568254954496
|
| 42 |
+
0.726206851751187,0.24659625365511584,0.7248105293625019
|
| 43 |
+
-0.3382039502947016,-0.8728832994279527,0.3585223402710981
|
| 44 |
+
-0.3780353565686756,-0.3496333559465059,-0.24715600583220282
|
| 45 |
+
0.45921235667612814,0.27511494271042625,0.5188581859931962
|
| 46 |
+
0.7744254851526531,-0.05557014967610141,0.7023935510282614
|
| 47 |
+
-0.7608115081233966,0.42648957444599,-0.5086174096415731
|
| 48 |
+
0.5215700972337949,0.12255439513899247,0.5132611003903497
|
| 49 |
+
0.541934359909122,-0.012408807271218514,0.515948124519243
|
| 50 |
+
0.045465658763988115,-0.14491796328290074,0.06644966885563074
|
| 51 |
+
-0.9491617465118096,-0.7842171460133911,-0.23597193139241113
|
| 52 |
+
-0.9371416286265315,0.2728208225275608,-0.7315064751214
|
| 53 |
+
-0.37128803784734665,0.01714138232940554,-0.362522177588772
|
| 54 |
+
0.815132947852186,-0.5014155417022501,0.9765939645263036
|
| 55 |
+
-0.17923415392874054,0.5111022770860973,0.07998865766083815
|
| 56 |
+
-0.5424036690167551,-0.846040180342414,0.14001304604376585
|
| 57 |
+
-0.42049709417246395,-0.6775574254919912,0.034912901016997455
|
| 58 |
+
0.8593953046851461,0.6162407591288339,1.1281386729129697
|
| 59 |
+
0.26680751302084693,0.7429211803754354,0.7879864844362704
|
| 60 |
+
0.6073441537982289,-0.6268598822279283,0.95360688133832
|
| 61 |
+
0.7851179969799555,0.07868448383130144,0.7130998543064418
|
| 62 |
+
0.6148803103281251,0.7921825998469865,1.164026731337986
|
| 63 |
+
-0.36399305005627225,-0.7798961509446465,0.21541385909859811
|
| 64 |
+
-0.5441296749161166,-0.14578442274748737,-0.496422157556886
|
| 65 |
+
0.6360295318449862,0.7214611665126869,1.0913254238549446
|
| 66 |
+
-0.9860957389376186,0.021494605155131463,-0.8334153656803461
|
| 67 |
+
-0.16517799370244202,-0.5557843790585395,0.1395794327899027
|
| 68 |
+
-0.7602692653326344,-0.3247696571927441,-0.5838367239902751
|
| 69 |
+
0.8858194078250383,-0.35359413595848954,0.8991369711397181
|
| 70 |
+
0.037581243486732197,0.4060379177903557,0.2016933272194322
|
| 71 |
+
-0.27274079524141204,0.9435641654419213,0.5078969955272314
|
| 72 |
+
0.9248945898842225,-0.4964354083492717,1.0425182593662667
|
| 73 |
+
-0.005502988215229099,-0.3982433803664607,0.15243079003847154
|
| 74 |
+
-0.4303190112450648,-0.9262261052909344,0.3393066204879884
|
| 75 |
+
0.2191286679597937,0.005358046457722976,0.217407918292659
|
| 76 |
+
-0.8970424975000213,-0.44270707152677713,-0.586747836651068
|
| 77 |
+
0.8165317719333074,-0.5208762186660552,0.9967710464625614
|
| 78 |
+
-0.7102102558175538,-0.02109447944487397,-0.6515482295602607
|
| 79 |
+
0.9713009082212014,-0.5158894569769992,1.0886315849160677
|
| 80 |
+
0.3442710948117571,0.5232392306574352,0.6078825279565068
|
| 81 |
+
-0.5247249120152007,0.45643269722371915,-0.29414785755880746
|
| 82 |
+
-0.26443373456149355,0.26461166118715895,-0.1914005964405623
|
| 83 |
+
0.2670594215217894,0.07154936814951696,0.26901553332662537
|
| 84 |
+
-0.8194204598911834,0.6706049911784759,-0.2960449968694709
|
| 85 |
+
-0.35843987005652833,-0.6269629792002915,0.03222402345128389
|
| 86 |
+
-0.9184497168904722,0.18178588637648363,-0.7616213774179811
|
| 87 |
+
0.35512872368456483,-0.9668243421442877,1.152161247302312
|
| 88 |
+
0.024186116598561958,-0.5470084496041241,0.31895703499280215
|
| 89 |
+
0.2903455808188997,-0.6512671419900171,0.6978285972621991
|
| 90 |
+
0.3818754762049319,-0.22652930739892518,0.4239545131674068
|
| 91 |
+
0.873459977473469,-0.7249581117080135,1.268256638618119
|
| 92 |
+
-0.317867297899483,-0.7730529575188219,0.2501276235695559
|
| 93 |
+
0.8493872365571256,0.754678706761962,1.290120526044866
|
| 94 |
+
-0.4841167445696888,0.3199680920683581,-0.36322596002563673
|
| 95 |
+
0.6344444004024317,0.11040162319892466,0.6049183419335217
|
| 96 |
+
0.05930115671201297,-0.5162954181990966,0.32268182428446945
|
| 97 |
+
-0.8137944643882016,0.7944315159066535,-0.13684769483552173
|
| 98 |
+
0.8008361143266609,0.2662029145465359,0.7887430634679737
|
| 99 |
+
-0.32194041790259864,-0.3015808507746782,-0.22558221167734632
|
| 100 |
+
0.45191135774047875,0.7942205199051542,1.0264656897505602
|
| 101 |
+
0.7741728485302346,0.5597510917152477,1.0073448206218134
|
1_data/benchmarks/nguyen/nguyen_9.meta.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: nguyen_9
|
| 2 |
+
equation: sin(x) + sin(y**2)
|
| 3 |
+
latex: \sin(x) + \sin(y^2)
|
| 4 |
+
n_vars: 2
|
| 5 |
+
range: (-1, 1)
|
| 6 |
+
n_samples: 100
|
1_data/processed/700K_prefix_converted/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e6c9ddae5a332af4ee05d1aebbb913757316ce685a179712212e6e8334ba5f1
|
| 3 |
+
size 7296336
|
1_data/processed/700K_prefix_converted/dataset_info.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"builder_name": "csv",
|
| 3 |
+
"citation": "",
|
| 4 |
+
"config_name": "default",
|
| 5 |
+
"dataset_name": "sintetico_natural",
|
| 6 |
+
"dataset_size": 5704135,
|
| 7 |
+
"description": "",
|
| 8 |
+
"download_checksums": {
|
| 9 |
+
"hf://datasets/augustocsc/sintetico_natural@fe48ddc600c4674bcff395308dc8d7c32aed6135/mini_test/test_mini_test.csv": {
|
| 10 |
+
"num_bytes": 560871,
|
| 11 |
+
"checksum": null
|
| 12 |
+
},
|
| 13 |
+
"hf://datasets/augustocsc/sintetico_natural@fe48ddc600c4674bcff395308dc8d7c32aed6135/mini_test/train_mini_test.csv": {
|
| 14 |
+
"num_bytes": 4378438,
|
| 15 |
+
"checksum": null
|
| 16 |
+
},
|
| 17 |
+
"hf://datasets/augustocsc/sintetico_natural@fe48ddc600c4674bcff395308dc8d7c32aed6135/mini_test/val_mini_test.csv": {
|
| 18 |
+
"num_bytes": 545160,
|
| 19 |
+
"checksum": null
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"download_size": 5484469,
|
| 23 |
+
"features": {
|
| 24 |
+
"infix_expr_n": {
|
| 25 |
+
"dtype": "string",
|
| 26 |
+
"_type": "Value"
|
| 27 |
+
},
|
| 28 |
+
"infix_expr_c": {
|
| 29 |
+
"dtype": "string",
|
| 30 |
+
"_type": "Value"
|
| 31 |
+
},
|
| 32 |
+
"expression_objects": {
|
| 33 |
+
"dtype": "string",
|
| 34 |
+
"_type": "Value"
|
| 35 |
+
},
|
| 36 |
+
"prefix_expr_c": {
|
| 37 |
+
"dtype": "string",
|
| 38 |
+
"_type": "Value"
|
| 39 |
+
},
|
| 40 |
+
"prefix_expr_n": {
|
| 41 |
+
"dtype": "string",
|
| 42 |
+
"_type": "Value"
|
| 43 |
+
},
|
| 44 |
+
"i_prompt_n": {
|
| 45 |
+
"dtype": "string",
|
| 46 |
+
"_type": "Value"
|
| 47 |
+
},
|
| 48 |
+
"p_prompt_n": {
|
| 49 |
+
"dtype": "string",
|
| 50 |
+
"_type": "Value"
|
| 51 |
+
},
|
| 52 |
+
"skeleton": {
|
| 53 |
+
"dtype": "string",
|
| 54 |
+
"_type": "Value"
|
| 55 |
+
},
|
| 56 |
+
"p_prompt_n_converted": {
|
| 57 |
+
"dtype": "string",
|
| 58 |
+
"_type": "Value"
|
| 59 |
+
},
|
| 60 |
+
"conversion_success": {
|
| 61 |
+
"dtype": "bool",
|
| 62 |
+
"_type": "Value"
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
"homepage": "",
|
| 66 |
+
"license": "",
|
| 67 |
+
"size_in_bytes": 11188604,
|
| 68 |
+
"splits": {
|
| 69 |
+
"test": {
|
| 70 |
+
"name": "test",
|
| 71 |
+
"num_bytes": 5704135,
|
| 72 |
+
"num_examples": 12221,
|
| 73 |
+
"dataset_name": "sintetico_natural"
|
| 74 |
+
}
|
| 75 |
+
},
|
| 76 |
+
"version": {
|
| 77 |
+
"version_str": "0.0.0",
|
| 78 |
+
"major": 0,
|
| 79 |
+
"minor": 0,
|
| 80 |
+
"patch": 0
|
| 81 |
+
}
|
| 82 |
+
}
|
1_data/processed/700K_prefix_converted/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "d4d2d5394689c8df",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": "test"
|
| 13 |
+
}
|
1_data/processed/PREFIX_CONVERSION_README.md
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Conversão Infix → Prefix Dataset
|
| 2 |
+
|
| 3 |
+
## Objetivo
|
| 4 |
+
|
| 5 |
+
Converter o dataset `augustocsc/sintetico_natural` de notação **infix** para **prefix (Polish notation)**, mantendo as mesmas expressões matemáticas mas em formato prefixado.
|
| 6 |
+
|
| 7 |
+
## Motivação
|
| 8 |
+
|
| 9 |
+
O dataset original contém:
|
| 10 |
+
- `i_prompt_n`: Prompts com expressões em notação infix (e.g., `x_1 + x_2`)
|
| 11 |
+
- `p_prompt_n`: Prompts com expressões em notação prefix (e.g., `+ x_1 x_2`)
|
| 12 |
+
|
| 13 |
+
**PROBLEMA**: As expressões em `i_prompt_n` e `p_prompt_n` são DIFERENTES! Não são a mesma expressão convertida.
|
| 14 |
+
|
| 15 |
+
**SOLUÇÃO**: Converter automaticamente as expressões de `i_prompt_n` para notação prefix, criando `p_prompt_n_converted`.
|
| 16 |
+
|
| 17 |
+
## Script de Conversão
|
| 18 |
+
|
| 19 |
+
**Arquivo**: `scripts/data/convert_infix_to_prefix.py`
|
| 20 |
+
|
| 21 |
+
**Funcionalidade**:
|
| 22 |
+
1. Lê expressões infix do campo `i_prompt_n`
|
| 23 |
+
2. Parseia usando SymPy
|
| 24 |
+
3. Converte para notação prefix (Polish notation)
|
| 25 |
+
4. Mantém as mesmas variáveis e operadores do prompt original
|
| 26 |
+
5. Salva em nova coluna `p_prompt_n_converted`
|
| 27 |
+
|
| 28 |
+
## Exemplos de Conversão
|
| 29 |
+
|
| 30 |
+
### Exemplo 1
|
| 31 |
+
**INFIX**:
|
| 32 |
+
```
|
| 33 |
+
vars: x_1, x_2, x_3, x_4, x_5
|
| 34 |
+
oper: *, +, -, /, abs, asin, cos, exp, log, sin, sqrt, tan
|
| 35 |
+
cons: C
|
| 36 |
+
expr: x_2 - (x_5 - C)*(x_4 + exp(C*x_2) + C)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
**PREFIX**:
|
| 40 |
+
```
|
| 41 |
+
vars: x_1, x_2, x_3, x_4, x_5
|
| 42 |
+
oper: *, +, -, /, abs, asin, cos, exp, log, sin, sqrt, tan
|
| 43 |
+
cons: C
|
| 44 |
+
expr: - x_2 * - x_5 C + + x_4 exp * C x_2 C
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Exemplo 2
|
| 48 |
+
**INFIX**:
|
| 49 |
+
```
|
| 50 |
+
vars: x_1, x_2, x_3
|
| 51 |
+
oper: +, -, /, abs, cos, exp
|
| 52 |
+
cons: C
|
| 53 |
+
expr: (x_1 - C)/(x_2 + C)
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
**PREFIX**:
|
| 57 |
+
```
|
| 58 |
+
vars: x_1, x_2, x_3
|
| 59 |
+
oper: +, -, /, abs, cos, exp
|
| 60 |
+
cons: C
|
| 61 |
+
expr: / - x_1 C + x_2 C
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Exemplo 3
|
| 65 |
+
**INFIX**:
|
| 66 |
+
```
|
| 67 |
+
vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10
|
| 68 |
+
oper: *, +, /, asin, sin, tan
|
| 69 |
+
cons: C
|
| 70 |
+
expr: (tan(x_7) + C)*(asin(x_5) + C)
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
**PREFIX**:
|
| 74 |
+
```
|
| 75 |
+
vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10
|
| 76 |
+
oper: *, +, /, asin, sin, tan
|
| 77 |
+
cons: C
|
| 78 |
+
expr: * + tan x_7 C + asin x_5 C
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Regras de Conversão
|
| 82 |
+
|
| 83 |
+
### Operadores Binários
|
| 84 |
+
- Infix: `a + b` → Prefix: `+ a b`
|
| 85 |
+
- Infix: `a - b` → Prefix: `- a b`
|
| 86 |
+
- Infix: `a * b` → Prefix: `* a b`
|
| 87 |
+
- Infix: `a / b` → Prefix: `/ a b`
|
| 88 |
+
- Infix: `a ** b` → Prefix: `** a b`
|
| 89 |
+
|
| 90 |
+
### Funções Unárias
|
| 91 |
+
- Infix: `sin(x)` → Prefix: `sin x`
|
| 92 |
+
- Infix: `exp(x)` → Prefix: `exp x`
|
| 93 |
+
- Infix: `log(x)` → Prefix: `log x`
|
| 94 |
+
|
| 95 |
+
### Expressões Complexas
|
| 96 |
+
- Infix: `sin(x**2)` → Prefix: `sin ** x 2`
|
| 97 |
+
- Infix: `x*(y + z)` → Prefix: `* x + y z`
|
| 98 |
+
- Infix: `(a + b)*(c + d)` → Prefix: `* + a b + c d`
|
| 99 |
+
|
| 100 |
+
### Casos Especiais
|
| 101 |
+
- **Negação**: `-x` → `* -1 x`
|
| 102 |
+
- **Múltiplas adições**: `a + b + c` → `+ + a b c` (aninhado à esquerda)
|
| 103 |
+
- **Divisão**: `x/y` → `/ x y` (SymPy representa como `x * y**-1`, conversão detecta e corrige)
|
| 104 |
+
|
| 105 |
+
## Uso do Script
|
| 106 |
+
|
| 107 |
+
### Teste (10 exemplos)
|
| 108 |
+
```bash
|
| 109 |
+
python scripts/data/convert_infix_to_prefix.py --test_only
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Conversão Completa
|
| 113 |
+
```bash
|
| 114 |
+
python scripts/data/convert_infix_to_prefix.py \
|
| 115 |
+
--split test \
|
| 116 |
+
--output_path ./1_data/processed/700K_prefix_converted
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### Upload para HuggingFace
|
| 120 |
+
```bash
|
| 121 |
+
python scripts/data/convert_infix_to_prefix.py \
|
| 122 |
+
--split test \
|
| 123 |
+
--output_path ./1_data/processed/700K_prefix_converted \
|
| 124 |
+
--upload \
|
| 125 |
+
--repo_id augustocsc/sintetico_natural_prefix_converted
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Output do Dataset
|
| 129 |
+
|
| 130 |
+
**Colunas adicionadas**:
|
| 131 |
+
- `p_prompt_n_converted`: Prompt prefixado convertido do infix
|
| 132 |
+
- `conversion_success`: Boolean indicando se conversão foi bem-sucedida
|
| 133 |
+
|
| 134 |
+
**Taxa de sucesso esperada**: ~99%+
|
| 135 |
+
|
| 136 |
+
## Vantagens
|
| 137 |
+
|
| 138 |
+
1. **Comparabilidade**: Agora é possível treinar modelos prefix com as MESMAS expressões dos modelos infix
|
| 139 |
+
2. **Consistência**: Mantém vars/operators/constants do prompt original
|
| 140 |
+
3. **Reprodutibilidade**: Conversão automática e determinística
|
| 141 |
+
4. **Escalabilidade**: Fácil aplicar a novos datasets
|
| 142 |
+
|
| 143 |
+
## Treinamento com Dataset Convertido
|
| 144 |
+
|
| 145 |
+
### Usando o dataset local
|
| 146 |
+
```bash
|
| 147 |
+
python 2_training/supervised/train.py \
|
| 148 |
+
--dataset_path ./1_data/processed/700K_prefix_converted \
|
| 149 |
+
--data_column p_prompt_n_converted \
|
| 150 |
+
--approach prefix \
|
| 151 |
+
--output_dir ./output/gpt2_prefix_converted
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Comparação: Infix vs Prefix (mesma expressão)
|
| 155 |
+
Agora você pode treinar dois modelos com a MESMA expressão:
|
| 156 |
+
- Modelo A: `--data_column i_prompt_n --approach infix`
|
| 157 |
+
- Modelo B: `--data_column p_prompt_n_converted --approach prefix`
|
| 158 |
+
|
| 159 |
+
E comparar diretamente qual notação o modelo aprende melhor!
|
| 160 |
+
|
| 161 |
+
## Validação
|
| 162 |
+
|
| 163 |
+
Para verificar a correção da conversão:
|
| 164 |
+
1. Parsear expressão prefix convertida
|
| 165 |
+
2. Avaliar em pontos de teste
|
| 166 |
+
3. Comparar com avaliação da expressão infix original
|
| 167 |
+
4. R² score deve ser ~1.0 (mesma expressão, mesmos resultados)
|
| 168 |
+
|
| 169 |
+
```python
|
| 170 |
+
from classes.expression import Expression
|
| 171 |
+
|
| 172 |
+
# Expressão infix original
|
| 173 |
+
expr_infix = Expression("x_1 + x_2", is_prefix=False)
|
| 174 |
+
|
| 175 |
+
# Expressão prefix convertida
|
| 176 |
+
expr_prefix = Expression("+ x_1 x_2", is_prefix=True)
|
| 177 |
+
|
| 178 |
+
# Testar
|
| 179 |
+
import numpy as np
|
| 180 |
+
x = np.array([[1.0, 2.0], [3.0, 4.0]])
|
| 181 |
+
|
| 182 |
+
result_infix = expr_infix.evaluate(x)
|
| 183 |
+
result_prefix = expr_prefix.evaluate(x)
|
| 184 |
+
|
| 185 |
+
# Devem ser idênticos
|
| 186 |
+
assert np.allclose(result_infix, result_prefix)
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## Limitações Conhecidas
|
| 190 |
+
|
| 191 |
+
1. **Expressões muito complexas**: Pode haver casos edge com aninhamento profundo
|
| 192 |
+
2. **Operadores customizados**: Se houver operadores não mapeados no SymPy
|
| 193 |
+
3. **Simplificação automática**: SymPy pode simplificar algumas expressões, alterando forma mas não valor
|
| 194 |
+
|
| 195 |
+
## Status
|
| 196 |
+
|
| 197 |
+
- [x] Script de conversão implementado
|
| 198 |
+
- [x] Testes em 10 exemplos (100% sucesso)
|
| 199 |
+
- [ ] Conversão dataset completo (~12k exemplos)
|
| 200 |
+
- [ ] Upload para HuggingFace Hub
|
| 201 |
+
- [ ] Treinamento de modelo teste
|
| 202 |
+
- [ ] Comparação infix vs prefix
|
| 203 |
+
|
| 204 |
+
## Referências
|
| 205 |
+
|
| 206 |
+
- **Polish Notation**: https://en.wikipedia.org/wiki/Polish_notation
|
| 207 |
+
- **SymPy Documentation**: https://docs.sympy.org/
|
| 208 |
+
- **Dataset Original**: https://huggingface.co/datasets/augustocsc/sintetico_natural
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
**Data de Criação**: 2026-02-09
|
| 213 |
+
**Autor**: Claude Sonnet 4.5 (co-authored)
|
| 214 |
+
**Última Atualização**: 2026-02-09
|
2_training/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 2_training/ - Treinamento e Fine-tuning
|
| 2 |
+
|
| 3 |
+
Este diretório contém todos os scripts e configurações para treinamento de modelos LLM para symbolic regression.
|
| 4 |
+
|
| 5 |
+
## Estrutura
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
2_training/
|
| 9 |
+
├── supervised/ # Fine-tuning supervisionado
|
| 10 |
+
│ ├── train_with_json.py # Principal: treino com formato JSON
|
| 11 |
+
│ ├── train.py # Script base
|
| 12 |
+
│ ├── train_experiment.py # Experimentos controlados
|
| 13 |
+
│ └── iterative_sampling_sft.py # SFT iterativo
|
| 14 |
+
│
|
| 15 |
+
├── reinforcement/ # Reinforcement Learning
|
| 16 |
+
│ ├── ppo_symbolic.py # Proximal Policy Optimization
|
| 17 |
+
│ ├── grpo_symbolic.py # Group Relative PO
|
| 18 |
+
│ ├── reinforce_*.py # REINFORCE algorithm
|
| 19 |
+
│ └── best_of_n_experiment.py # Best-of-N sampling
|
| 20 |
+
│
|
| 21 |
+
└── configs/ # Configurações
|
| 22 |
+
└── wandb_config.py # Wandb naming standards
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Métodos de Treinamento
|
| 26 |
+
|
| 27 |
+
### 1. Supervised Fine-tuning (Recomendado)
|
| 28 |
+
|
| 29 |
+
Script principal: `supervised/train_with_json.py`
|
| 30 |
+
|
| 31 |
+
**Características**:
|
| 32 |
+
- LoRA fine-tuning (apenas 294K parâmetros treináveis)
|
| 33 |
+
- Formato JSON estruturado (80% valid rate)
|
| 34 |
+
- Early stopping automático
|
| 35 |
+
- Split train/validation 90/10
|
| 36 |
+
- Integração com Wandb
|
| 37 |
+
|
| 38 |
+
**Uso**:
|
| 39 |
+
```bash
|
| 40 |
+
cd supervised
|
| 41 |
+
python train_with_json.py \
|
| 42 |
+
--model_size gpt2-medium \
|
| 43 |
+
--dataset_path ../../1_data/processed/700K \
|
| 44 |
+
--output_dir ../../models/gpt2/medium_test \
|
| 45 |
+
--num_train_epochs 3 \
|
| 46 |
+
--per_device_train_batch_size 4
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
**Modelos suportados**:
|
| 50 |
+
- `gpt2` (124M params)
|
| 51 |
+
- `gpt2-medium` (355M params)
|
| 52 |
+
- `gpt2-large` (774M params)
|
| 53 |
+
- GPT-Neo, LLaMA, Phi (futuro)
|
| 54 |
+
|
| 55 |
+
### 2. Reinforcement Learning
|
| 56 |
+
|
| 57 |
+
#### PPO (Proximal Policy Optimization)
|
| 58 |
+
Script: `reinforcement/ppo_symbolic.py`
|
| 59 |
+
|
| 60 |
+
**Quando usar**:
|
| 61 |
+
- Problemas complexos (Nguyen 4+)
|
| 62 |
+
- Otimização de R² score
|
| 63 |
+
- Após supervised fine-tuning
|
| 64 |
+
|
| 65 |
+
**Uso**:
|
| 66 |
+
```bash
|
| 67 |
+
cd reinforcement
|
| 68 |
+
python ppo_symbolic.py \
|
| 69 |
+
--model_path ../../models/gpt2/base_700k_json \
|
| 70 |
+
--dataset ../../1_data/benchmarks/nguyen/nguyen_5.csv \
|
| 71 |
+
--epochs 20
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
#### GRPO (Group Relative Policy Optimization)
|
| 75 |
+
Script: `reinforcement/grpo_symbolic.py`
|
| 76 |
+
|
| 77 |
+
**Vantagens**:
|
| 78 |
+
- Mais estável que PPO
|
| 79 |
+
- Melhor para multi-modal rewards
|
| 80 |
+
- Baseado em DeepSeek-R1
|
| 81 |
+
|
| 82 |
+
#### REINFORCE
|
| 83 |
+
Script: `reinforcement/reinforce_symbolic.py`
|
| 84 |
+
|
| 85 |
+
**Características**:
|
| 86 |
+
- Simples e eficaz
|
| 87 |
+
- EMA baseline
|
| 88 |
+
- Bom para benchmarks fáceis (Nguyen 1-3)
|
| 89 |
+
|
| 90 |
+
## Configuração LoRA
|
| 91 |
+
|
| 92 |
+
Configuração padrão (todos os modelos):
|
| 93 |
+
```python
|
| 94 |
+
{
|
| 95 |
+
"r": 8,
|
| 96 |
+
"lora_alpha": 32,
|
| 97 |
+
"target_modules": ["c_attn"],
|
| 98 |
+
"lora_dropout": 0.05
|
| 99 |
+
}
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
**Resultado**: ~294K parâmetros treináveis (vs 124M-774M total)
|
| 103 |
+
|
| 104 |
+
## Hiperparâmetros Recomendados
|
| 105 |
+
|
| 106 |
+
### Por Tamanho de Modelo
|
| 107 |
+
|
| 108 |
+
| Modelo | Batch Size | Instance | VRAM | Tempo |
|
| 109 |
+
|--------|-----------|----------|------|-------|
|
| 110 |
+
| GPT-2 Base | 8 | g5.xlarge | 24GB | 2-3h |
|
| 111 |
+
| GPT-2 Medium | 4 | g5.xlarge | 24GB | 3-4h |
|
| 112 |
+
| GPT-2 Large | 2 | g5.2xlarge | 48GB | 4-5h |
|
| 113 |
+
|
| 114 |
+
### Outros Hiperparâmetros
|
| 115 |
+
```python
|
| 116 |
+
learning_rate = 5e-5
|
| 117 |
+
num_train_epochs = 3
|
| 118 |
+
gradient_accumulation_steps = 4
|
| 119 |
+
warmup_steps = 500
|
| 120 |
+
weight_decay = 0.01
|
| 121 |
+
early_stopping_patience = 3
|
| 122 |
+
seed = 42
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Formato de Dados
|
| 126 |
+
|
| 127 |
+
### JSON Format (Recomendado)
|
| 128 |
+
```json
|
| 129 |
+
{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "sin(x_1 + C*x_2)"}
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
**Vantagens**:
|
| 133 |
+
- 80% valid expression rate
|
| 134 |
+
- Structured boundaries
|
| 135 |
+
- Lower loss (0.343 vs 0.415)
|
| 136 |
+
|
| 137 |
+
## Wandb Tracking
|
| 138 |
+
|
| 139 |
+
Naming standard: `seriguela-{type}-{model}-{dataset}-{timestamp}`
|
| 140 |
+
|
| 141 |
+
Exemplos:
|
| 142 |
+
```python
|
| 143 |
+
# Supervised
|
| 144 |
+
seriguela-supervised-medium-700k-20260204-120000
|
| 145 |
+
|
| 146 |
+
# PPO
|
| 147 |
+
seriguela-ppo-large-nguyen5-20260204-120000
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
## Deploy AWS
|
| 151 |
+
|
| 152 |
+
Scripts disponíveis em: `../../scripts/aws/`
|
| 153 |
+
|
| 154 |
+
Lançar treinamento:
|
| 155 |
+
```bash
|
| 156 |
+
# Medium model
|
| 157 |
+
bash ../../scripts/aws/launch_medium_training.sh \
|
| 158 |
+
--wandb-key YOUR_KEY \
|
| 159 |
+
--hf-token YOUR_TOKEN
|
| 160 |
+
|
| 161 |
+
# Large model
|
| 162 |
+
bash ../../scripts/aws/launch_large_training.sh \
|
| 163 |
+
--wandb-key YOUR_KEY \
|
| 164 |
+
--hf-token YOUR_TOKEN
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Troubleshooting
|
| 168 |
+
|
| 169 |
+
### OOM (Out of Memory)
|
| 170 |
+
- Reduzir `per_device_train_batch_size`
|
| 171 |
+
- Usar `gradient_accumulation_steps` maior
|
| 172 |
+
- Usar instância maior (g5.2xlarge)
|
| 173 |
+
|
| 174 |
+
### Low Valid Rate
|
| 175 |
+
- Verificar formato de dados (deve ser JSON)
|
| 176 |
+
- Aumentar `num_train_epochs`
|
| 177 |
+
- Verificar conversão de dados
|
| 178 |
+
|
| 179 |
+
### Early Stopping Prematuro
|
| 180 |
+
- Aumentar `early_stopping_patience`
|
| 181 |
+
- Verificar validation loss
|
| 182 |
+
|
| 183 |
+
## Próximos Modelos (Planejados)
|
| 184 |
+
|
| 185 |
+
### GPT-Neo (EleutherAI)
|
| 186 |
+
- 125M, 1.3B, 2.7B params
|
| 187 |
+
- Similar ao GPT-2
|
| 188 |
+
- Compatível com mesma pipeline
|
| 189 |
+
|
| 190 |
+
### LLaMA 2/3 (Meta)
|
| 191 |
+
- 7B, 13B, 70B params
|
| 192 |
+
- Melhor performance
|
| 193 |
+
- Requer mais VRAM
|
| 194 |
+
|
| 195 |
+
### Phi-2/3 (Microsoft)
|
| 196 |
+
- 2.7B params
|
| 197 |
+
- Otimizado para reasoning
|
| 198 |
+
- Bom para symbolic tasks
|
| 199 |
+
|
| 200 |
+
## Referências
|
| 201 |
+
|
| 202 |
+
- LoRA: https://arxiv.org/abs/2106.09685
|
| 203 |
+
- PPO: https://arxiv.org/abs/1707.06347
|
| 204 |
+
- GRPO: DeepSeek-R1 technical report
|
| 205 |
+
- Dataset: https://huggingface.co/datasets/augustocsc/sintetico_natural
|
2_training/configs/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Seriguela - Configuration utilities.
|
| 3 |
+
|
| 4 |
+
This module provides standardized configuration utilities for
|
| 5 |
+
experiment tracking and naming conventions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .wandb_config import (
|
| 9 |
+
generate_run_name,
|
| 10 |
+
get_wandb_project_name,
|
| 11 |
+
parse_run_name,
|
| 12 |
+
is_valid_run_name,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'generate_run_name',
|
| 17 |
+
'get_wandb_project_name',
|
| 18 |
+
'parse_run_name',
|
| 19 |
+
'is_valid_run_name',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
__version__ = '1.0.0'
|
2_training/configs/eval_dataset_download.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_easy_dummy
|
| 2 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_medium_dummy
|
| 3 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_hard_dummy
|
| 4 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_easy
|
| 5 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_medium
|
| 6 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_hard
|
2_training/configs/model_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
2_training/configs/peft_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
2_training/configs/training.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
|
| 2 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 3 |
+
--data_dir 500k \
|
| 4 |
+
--output_dir ./output \
|
| 5 |
+
--push_to_hub \
|
| 6 |
+
--hub_model_id augustocsc/Se124M500KInfPrompt_EOS \
|
| 7 |
+
--source_data_column i_prompt \
|
| 8 |
+
--report_to wandb \
|
| 9 |
+
--run_name Se124M500KInfPrompt_EOS \
|
| 10 |
+
--model_name_or_path gpt2 \
|
| 11 |
+
--bf16 \
|
| 12 |
+
--eval_strategy steps \
|
| 13 |
+
--num_train_epochs 3 \
|
| 14 |
+
--per_device_train_batch_size 16 \
|
| 15 |
+
--per_device_eval_batch_size 16 \
|
| 16 |
+
--gradient_accumulation_steps 4 \
|
| 17 |
+
--dataloader_num_workers 8 \
|
| 18 |
+
--learning_rate 5e-5 \
|
| 19 |
+
--warmup_ratio 0.03 \
|
| 20 |
+
--weight_decay 0.01 \
|
| 21 |
+
--max_grad_norm 1.0 \
|
| 22 |
+
--lr_scheduler_type cosine \
|
| 23 |
+
--optim adamw_torch_fused \
|
| 24 |
+
--logging_steps 20 \
|
| 25 |
+
--eval_steps 500 \
|
| 26 |
+
--save_steps 1000 \
|
| 27 |
+
--save_total_limit 3 \
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# CUDA_VISIBLE_DEVICES=1 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
|
| 31 |
+
# --dataset_repo_id augustocsc/sintetico_final \
|
| 32 |
+
# --data_dir 100k \
|
| 33 |
+
# --output_dir ./output \
|
| 34 |
+
# --push_to_hub \
|
| 35 |
+
# --hub_model_id augustocsc/Se124M100KInfPrompt_NT \
|
| 36 |
+
# --source_data_column i_prompt \
|
| 37 |
+
# --report_to wandb \
|
| 38 |
+
# --run_name Se124M100KInfPrompt_NT \
|
| 39 |
+
# --bf16 \
|
| 40 |
+
# --eval_strategy steps \
|
| 41 |
+
# --num_train_epochs 3 \
|
| 42 |
+
# --per_device_train_batch_size 16 \
|
| 43 |
+
# --per_device_eval_batch_size 16 \
|
| 44 |
+
# --gradient_accumulation_steps 2 \
|
| 45 |
+
# --dataloader_num_workers 8 \
|
| 46 |
+
# --learning_rate 2e-5 \
|
| 47 |
+
# --warmup_ratio 0.03 \
|
| 48 |
+
# --weight_decay 0.01 \
|
| 49 |
+
# --max_grad_norm 1.0 \
|
| 50 |
+
# --lr_scheduler_type cosine \
|
| 51 |
+
# --optim adamw_torch_fused \
|
| 52 |
+
# --logging_steps 20 \
|
| 53 |
+
# --eval_steps 500 \
|
| 54 |
+
# --save_steps 1000 \
|
| 55 |
+
# --save_total_limit 3
|
| 56 |
+
|
| 57 |
+
# CUDA_VISIBLE_DEVICES=0 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
|
| 58 |
+
# --dataset_repo_id augustocsc/sintetico_final \
|
| 59 |
+
# --data_dir 100k \
|
| 60 |
+
# --output_dir ./output \
|
| 61 |
+
# --push_to_hub \
|
| 62 |
+
# --hub_model_id augustocsc/Se124M100KInfPrompt_WT \
|
| 63 |
+
# --source_data_column i_prompt \
|
| 64 |
+
# --report_to wandb \
|
| 65 |
+
# --run_name Se124M100KInfPrompt_WT \
|
| 66 |
+
# --bf16 \
|
| 67 |
+
# --eval_strategy steps \
|
| 68 |
+
# --num_train_epochs 3 \
|
| 69 |
+
# --per_device_train_batch_size 16 \
|
| 70 |
+
# --per_device_eval_batch_size 16 \
|
| 71 |
+
# --gradient_accumulation_steps 2 \
|
| 72 |
+
# --dataloader_num_workers 8 \
|
| 73 |
+
# --learning_rate 2e-5 \
|
| 74 |
+
# --warmup_ratio 0.03 \
|
| 75 |
+
# --weight_decay 0.01 \
|
| 76 |
+
# --max_grad_norm 1.0 \
|
| 77 |
+
# --lr_scheduler_type cosine \
|
| 78 |
+
# --optim adamw_torch_fused \
|
| 79 |
+
# --logging_steps 20 \
|
| 80 |
+
# --eval_steps 500 \
|
| 81 |
+
# --save_steps 1000 \
|
| 82 |
+
# --save_total_limit 3
|
2_training/configs/training_args.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"output_dir": "./output",
|
| 3 |
+
"overwrite_output_dir": true,
|
| 4 |
+
"num_train_epochs": 50,
|
| 5 |
+
"per_device_train_batch_size": 8,
|
| 6 |
+
"gradient_accumulation_steps": 1,
|
| 7 |
+
"learning_rate": 5e-5,
|
| 8 |
+
"weight_decay": 0.01,
|
| 9 |
+
"warmup_steps": 0,
|
| 10 |
+
"fp16": true,
|
| 11 |
+
"seed": 42,
|
| 12 |
+
"per_device_eval_batch_size": 8,
|
| 13 |
+
"eval_strategy": "epoch",
|
| 14 |
+
"metric_for_best_model": "eval_loss",
|
| 15 |
+
"greater_is_better": false,
|
| 16 |
+
"eval_steps": null,
|
| 17 |
+
"load_best_model_at_end": true,
|
| 18 |
+
"save_strategy": "epoch",
|
| 19 |
+
"save_steps": null,
|
| 20 |
+
"save_total_limit": 2,
|
| 21 |
+
"logging_dir": "./output/logs",
|
| 22 |
+
"logging_steps": 100,
|
| 23 |
+
"report_to": "wandb",
|
| 24 |
+
"run_name": "Se124M100K",
|
| 25 |
+
"push_to_hub": true,
|
| 26 |
+
"hub_model_id": "augustocsc/Se124M100K",
|
| 27 |
+
"hub_token": null
|
| 28 |
+
|
| 29 |
+
}
|
2_training/configs/training_large.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2-large",
|
| 4 |
+
"model_size": "774M",
|
| 5 |
+
"description": "GPT-2 Large - 774M parameters"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 2,
|
| 9 |
+
"per_device_train_batch_size": 4,
|
| 10 |
+
"per_device_eval_batch_size": 4,
|
| 11 |
+
"gradient_accumulation_steps": 16,
|
| 12 |
+
"effective_batch_size": 64,
|
| 13 |
+
"learning_rate": 2e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn", "c_proj"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"dataset_repo_id": "augustocsc/sintetico_natural",
|
| 49 |
+
"data_dir": "700K",
|
| 50 |
+
"data_columns": {
|
| 51 |
+
"infix": "i_prompt_n",
|
| 52 |
+
"prefix": "p_prompt_n"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"hub_config": {
|
| 56 |
+
"push_to_hub": true,
|
| 57 |
+
"hub_model_id_template": "augustocsc/Se774M_700K_{format}",
|
| 58 |
+
"formats": ["infix", "prefix"]
|
| 59 |
+
},
|
| 60 |
+
"estimated_time": {
|
| 61 |
+
"per_epoch_minutes": 180,
|
| 62 |
+
"total_hours": 6,
|
| 63 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU. May need gradient checkpointing for memory optimization."
|
| 64 |
+
}
|
| 65 |
+
}
|
2_training/configs/training_medium.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2-medium",
|
| 4 |
+
"model_size": "355M",
|
| 5 |
+
"description": "GPT-2 Medium - 355M parameters"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 2,
|
| 9 |
+
"per_device_train_batch_size": 8,
|
| 10 |
+
"per_device_eval_batch_size": 8,
|
| 11 |
+
"gradient_accumulation_steps": 8,
|
| 12 |
+
"effective_batch_size": 64,
|
| 13 |
+
"learning_rate": 3e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn", "c_proj"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"dataset_repo_id": "augustocsc/sintetico_natural",
|
| 49 |
+
"data_dir": "700K",
|
| 50 |
+
"data_columns": {
|
| 51 |
+
"infix": "i_prompt_n",
|
| 52 |
+
"prefix": "p_prompt_n"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"hub_config": {
|
| 56 |
+
"push_to_hub": true,
|
| 57 |
+
"hub_model_id_template": "augustocsc/Se355M_700K_{format}",
|
| 58 |
+
"formats": ["infix", "prefix"]
|
| 59 |
+
},
|
| 60 |
+
"estimated_time": {
|
| 61 |
+
"per_epoch_minutes": 90,
|
| 62 |
+
"total_hours": 3,
|
| 63 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU"
|
| 64 |
+
}
|
| 65 |
+
}
|
2_training/configs/training_small.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2",
|
| 4 |
+
"model_size": "124M",
|
| 5 |
+
"description": "GPT-2 Small - 124M parameters"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 3,
|
| 9 |
+
"per_device_train_batch_size": 16,
|
| 10 |
+
"per_device_eval_batch_size": 16,
|
| 11 |
+
"gradient_accumulation_steps": 4,
|
| 12 |
+
"effective_batch_size": 64,
|
| 13 |
+
"learning_rate": 5e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn", "c_proj"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"dataset_repo_id": "augustocsc/sintetico_natural",
|
| 49 |
+
"data_dir": "700K",
|
| 50 |
+
"data_columns": {
|
| 51 |
+
"infix": "i_prompt_n",
|
| 52 |
+
"prefix": "p_prompt_n"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"hub_config": {
|
| 56 |
+
"push_to_hub": true,
|
| 57 |
+
"hub_model_id_template": "augustocsc/Se124M_700K_{format}",
|
| 58 |
+
"formats": ["infix", "prefix"]
|
| 59 |
+
},
|
| 60 |
+
"estimated_time": {
|
| 61 |
+
"per_epoch_minutes": 40,
|
| 62 |
+
"total_hours": 2,
|
| 63 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU"
|
| 64 |
+
}
|
| 65 |
+
}
|
2_training/configs/training_v3.json
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2",
|
| 4 |
+
"model_size": "124M",
|
| 5 |
+
"description": "GPT-2 Small (124M) - v3 with proper end markers"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 3,
|
| 9 |
+
"per_device_train_batch_size": 8,
|
| 10 |
+
"per_device_eval_batch_size": 8,
|
| 11 |
+
"gradient_accumulation_steps": 4,
|
| 12 |
+
"effective_batch_size": 32,
|
| 13 |
+
"learning_rate": 5e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"use_local_csvs": true,
|
| 49 |
+
"train_file": "./data/processed/700K_fixed/train_700K.csv",
|
| 50 |
+
"validation_file": "./data/processed/700K_fixed/validation_700K.csv",
|
| 51 |
+
"test_file": "./data/processed/700K_fixed/test_700K.csv",
|
| 52 |
+
"data_column": "text"
|
| 53 |
+
},
|
| 54 |
+
"hub_config": {
|
| 55 |
+
"push_to_hub": true,
|
| 56 |
+
"hub_model_id": "augustocsc/Se124M_700K_infix_v3"
|
| 57 |
+
},
|
| 58 |
+
"special_tokens": {
|
| 59 |
+
"start_token": "<|startofex|>",
|
| 60 |
+
"end_token": "<|endofex|>",
|
| 61 |
+
"notes": "End token configured as EOS token for proper stopping"
|
| 62 |
+
},
|
| 63 |
+
"estimated_time": {
|
| 64 |
+
"per_epoch_minutes": 45,
|
| 65 |
+
"total_hours": 2.25,
|
| 66 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU, GPT-2 Small, 3 epochs"
|
| 67 |
+
},
|
| 68 |
+
"version_info": {
|
| 69 |
+
"model_version": "v3",
|
| 70 |
+
"improvements": [
|
| 71 |
+
"Training data includes proper <|endofex|> markers",
|
| 72 |
+
"100% validation rate on prepared dataset",
|
| 73 |
+
"Addresses v1 non-stopping issue and v2 garbage generation",
|
| 74 |
+
"Uses local CSVs with validated end markers"
|
| 75 |
+
],
|
| 76 |
+
"training_date": "2026-02-01"
|
| 77 |
+
}
|
| 78 |
+
}
|
2_training/configs/wandb_config.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wandb Configuration and Naming Standards for Seriguela Project
|
| 3 |
+
|
| 4 |
+
This module provides standardized naming conventions for Wandb experiment tracking.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Default Wandb project name
|
| 13 |
+
DEFAULT_PROJECT = "seriguela"
|
| 14 |
+
|
| 15 |
+
# Alternative project name for experiments
|
| 16 |
+
EXPERIMENTS_PROJECT = "seriguela-experiments"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_wandb_project_name(use_experiments: bool = False) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Get the standard Wandb project name.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
use_experiments: If True, use experiments project name
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Project name string
|
| 28 |
+
"""
|
| 29 |
+
return EXPERIMENTS_PROJECT if use_experiments else DEFAULT_PROJECT
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def generate_run_name(
|
| 33 |
+
experiment_type: str,
|
| 34 |
+
model_size: str = "base",
|
| 35 |
+
dataset: Optional[str] = None,
|
| 36 |
+
extra_info: Optional[str] = None,
|
| 37 |
+
include_timestamp: bool = True
|
| 38 |
+
) -> str:
|
| 39 |
+
"""
|
| 40 |
+
Generate a standardized Wandb run name.
|
| 41 |
+
|
| 42 |
+
Naming Convention: seriguela-{type}-{model}-{dataset}-{extra}-{timestamp}
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
experiment_type: Type of experiment (supervised, ppo, grpo, reinforce, iterative-sft)
|
| 46 |
+
model_size: Model size (base, medium, large) or full name (gpt2, gpt2-medium)
|
| 47 |
+
dataset: Dataset identifier (700K, nguyen5, nguyen7, etc)
|
| 48 |
+
extra_info: Additional information (optional)
|
| 49 |
+
include_timestamp: Whether to include timestamp suffix
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Formatted run name
|
| 53 |
+
|
| 54 |
+
Examples:
|
| 55 |
+
>>> generate_run_name("supervised", "medium", "700K")
|
| 56 |
+
'seriguela-supervised-medium-700K-20260203-143022'
|
| 57 |
+
|
| 58 |
+
>>> generate_run_name("ppo", "base", "nguyen5", "lr3e5")
|
| 59 |
+
'seriguela-ppo-base-nguyen5-lr3e5-20260203-143022'
|
| 60 |
+
|
| 61 |
+
>>> generate_run_name("grpo", "large", "nguyen7", include_timestamp=False)
|
| 62 |
+
'seriguela-grpo-large-nguyen7'
|
| 63 |
+
"""
|
| 64 |
+
# Normalize model size
|
| 65 |
+
model_map = {
|
| 66 |
+
"gpt2": "base",
|
| 67 |
+
"gpt2-base": "base",
|
| 68 |
+
"124m": "base",
|
| 69 |
+
"gpt2-medium": "medium",
|
| 70 |
+
"355m": "medium",
|
| 71 |
+
"gpt2-large": "large",
|
| 72 |
+
"774m": "large"
|
| 73 |
+
}
|
| 74 |
+
model_size = model_map.get(model_size.lower(), model_size.lower())
|
| 75 |
+
|
| 76 |
+
# Build run name parts
|
| 77 |
+
parts = ["seriguela", experiment_type.lower()]
|
| 78 |
+
|
| 79 |
+
# Add model size
|
| 80 |
+
parts.append(model_size)
|
| 81 |
+
|
| 82 |
+
# Add dataset if provided
|
| 83 |
+
if dataset:
|
| 84 |
+
parts.append(dataset.lower().replace("_", "").replace("-", ""))
|
| 85 |
+
|
| 86 |
+
# Add extra info if provided
|
| 87 |
+
if extra_info:
|
| 88 |
+
parts.append(extra_info.lower().replace("_", ""))
|
| 89 |
+
|
| 90 |
+
# Add timestamp if requested
|
| 91 |
+
if include_timestamp:
|
| 92 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 93 |
+
parts.append(timestamp)
|
| 94 |
+
|
| 95 |
+
return "-".join(parts)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_run_tags(
|
| 99 |
+
experiment_type: str,
|
| 100 |
+
model_size: str,
|
| 101 |
+
dataset: Optional[str] = None,
|
| 102 |
+
success: Optional[bool] = None
|
| 103 |
+
) -> list:
|
| 104 |
+
"""
|
| 105 |
+
Generate standardized tags for Wandb runs.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
experiment_type: Type of experiment
|
| 109 |
+
model_size: Model size
|
| 110 |
+
dataset: Dataset name
|
| 111 |
+
success: Whether experiment was successful (optional)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List of tags
|
| 115 |
+
|
| 116 |
+
Examples:
|
| 117 |
+
>>> get_run_tags("ppo", "medium", "nguyen5", True)
|
| 118 |
+
['ppo', 'gpt2-medium', 'nguyen5', 'rl', 'success']
|
| 119 |
+
"""
|
| 120 |
+
tags = [experiment_type.lower()]
|
| 121 |
+
|
| 122 |
+
# Add model size
|
| 123 |
+
if model_size.lower() in ["base", "124m", "gpt2"]:
|
| 124 |
+
tags.append("gpt2-base")
|
| 125 |
+
elif model_size.lower() in ["medium", "355m", "gpt2-medium"]:
|
| 126 |
+
tags.append("gpt2-medium")
|
| 127 |
+
elif model_size.lower() in ["large", "774m", "gpt2-large"]:
|
| 128 |
+
tags.append("gpt2-large")
|
| 129 |
+
else:
|
| 130 |
+
tags.append(model_size.lower())
|
| 131 |
+
|
| 132 |
+
# Add dataset
|
| 133 |
+
if dataset:
|
| 134 |
+
tags.append(dataset.lower())
|
| 135 |
+
|
| 136 |
+
# Add category based on experiment type
|
| 137 |
+
if experiment_type.lower() in ["ppo", "grpo", "reinforce"]:
|
| 138 |
+
tags.append("rl")
|
| 139 |
+
elif experiment_type.lower() in ["supervised", "sft"]:
|
| 140 |
+
tags.append("supervised")
|
| 141 |
+
elif experiment_type.lower() == "iterative-sft":
|
| 142 |
+
tags.append("iterative")
|
| 143 |
+
|
| 144 |
+
# Add success tag if provided
|
| 145 |
+
if success is not None:
|
| 146 |
+
tags.append("success" if success else "failed")
|
| 147 |
+
|
| 148 |
+
return tags
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Common experiment types
|
| 152 |
+
EXPERIMENT_TYPES = {
|
| 153 |
+
"SUPERVISED": "supervised",
|
| 154 |
+
"SFT": "sft",
|
| 155 |
+
"PPO": "ppo",
|
| 156 |
+
"GRPO": "grpo",
|
| 157 |
+
"REINFORCE": "reinforce",
|
| 158 |
+
"ITERATIVE_SFT": "iterative-sft",
|
| 159 |
+
"BEST_OF_N": "best-of-n",
|
| 160 |
+
"EVALUATION": "eval"
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# Common datasets
|
| 164 |
+
DATASETS = {
|
| 165 |
+
"MAIN_700K": "700K",
|
| 166 |
+
"NGUYEN_1": "nguyen1",
|
| 167 |
+
"NGUYEN_5": "nguyen5",
|
| 168 |
+
"NGUYEN_7": "nguyen7",
|
| 169 |
+
"NGUYEN_10": "nguyen10",
|
| 170 |
+
"CUSTOM": "custom"
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def setup_wandb_env():
|
| 175 |
+
"""
|
| 176 |
+
Setup Wandb environment from credentials file.
|
| 177 |
+
Reads from ~/.tokens.txt if available.
|
| 178 |
+
"""
|
| 179 |
+
tokens_file = os.path.expanduser("~/.tokens.txt")
|
| 180 |
+
if os.path.exists(tokens_file):
|
| 181 |
+
with open(tokens_file) as f:
|
| 182 |
+
for line in f:
|
| 183 |
+
if "=" in line and not line.strip().startswith("#"):
|
| 184 |
+
key, value = line.strip().split("=", 1)
|
| 185 |
+
key = key.strip()
|
| 186 |
+
value = value.strip()
|
| 187 |
+
if key.lower() == "wandb":
|
| 188 |
+
os.environ["WANDB_API_KEY"] = value
|
| 189 |
+
print(f"[OK] Wandb API key loaded from {tokens_file}")
|
| 190 |
+
return True
|
| 191 |
+
|
| 192 |
+
# Check if already in environment
|
| 193 |
+
if "WANDB_API_KEY" in os.environ:
|
| 194 |
+
print("[OK] Wandb API key found in environment")
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
print("[WARN] Wandb API key not found. Run 'wandb login' or add to ~/.tokens.txt")
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
# Example usage
|
| 203 |
+
print("Wandb Configuration Examples:\n")
|
| 204 |
+
|
| 205 |
+
print("1. Supervised training on 700K dataset:")
|
| 206 |
+
print(f" {generate_run_name('supervised', 'medium', '700K')}\n")
|
| 207 |
+
|
| 208 |
+
print("2. PPO on Nguyen-5 benchmark:")
|
| 209 |
+
print(f" {generate_run_name('ppo', 'base', 'nguyen5')}\n")
|
| 210 |
+
|
| 211 |
+
print("3. GRPO with custom learning rate:")
|
| 212 |
+
print(f" {generate_run_name('grpo', 'large', 'nguyen7', 'lr5e5')}\n")
|
| 213 |
+
|
| 214 |
+
print("4. Evaluation run (no timestamp):")
|
| 215 |
+
print(f" {generate_run_name('eval', 'medium', 'nguyen5', include_timestamp=False)}\n")
|
| 216 |
+
|
| 217 |
+
print("5. Tags example:")
|
| 218 |
+
print(f" {get_run_tags('ppo', 'medium', 'nguyen5', True)}\n")
|
| 219 |
+
|
| 220 |
+
print("6. Setup Wandb environment:")
|
| 221 |
+
setup_wandb_env()
|
2_training/reinforcement/best_of_n_experiment.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Best-of-N Sampling Experiment for Symbolic Regression
|
| 4 |
+
|
| 5 |
+
Instead of PPO (which has API compatibility issues with TRL 0.16+),
|
| 6 |
+
this script tests if the base model can find correct expressions
|
| 7 |
+
through random sampling. If the model generates the correct expression
|
| 8 |
+
even occasionally, PPO should be able to learn to find it consistently.
|
| 9 |
+
|
| 10 |
+
This is a diagnostic experiment to understand model capabilities.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import json
|
| 16 |
+
import argparse
|
| 17 |
+
import logging
|
| 18 |
+
import datetime
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
# Add project root to path
|
| 27 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 28 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 29 |
+
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
|
| 30 |
+
|
| 31 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 32 |
+
from peft import PeftModel
|
| 33 |
+
|
| 34 |
+
from expression import Expression
|
| 35 |
+
from dataset import RegressionDataset
|
| 36 |
+
|
| 37 |
+
# Configure logging
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO,
|
| 40 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 41 |
+
)
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BestOfNSampler:
|
| 46 |
+
"""Generate N expressions and find the best one for a given dataset."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, model_path: str, device: str = None):
|
| 49 |
+
self.model_path = model_path
|
| 50 |
+
|
| 51 |
+
# Device setup
|
| 52 |
+
if device:
|
| 53 |
+
self.device = torch.device(device)
|
| 54 |
+
else:
|
| 55 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
+
logger.info(f"Using device: {self.device}")
|
| 57 |
+
|
| 58 |
+
self._load_model()
|
| 59 |
+
|
| 60 |
+
def _load_model(self):
|
| 61 |
+
"""Load the JSON format model with LoRA adapters."""
|
| 62 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 63 |
+
|
| 64 |
+
# Load tokenizer from trained model (has special tokens)
|
| 65 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
| 66 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 67 |
+
logger.info(f"Tokenizer loaded with vocab size: {len(self.tokenizer)}")
|
| 68 |
+
|
| 69 |
+
# Load base GPT-2
|
| 70 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 71 |
+
"gpt2",
|
| 72 |
+
torch_dtype=torch.float16,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Resize embeddings to match tokenizer (handles special tokens)
|
| 76 |
+
if len(self.tokenizer) != base_model.config.vocab_size:
|
| 77 |
+
logger.info(f"Resizing embeddings: {base_model.config.vocab_size} -> {len(self.tokenizer)}")
|
| 78 |
+
base_model.resize_token_embeddings(len(self.tokenizer))
|
| 79 |
+
|
| 80 |
+
# Load LoRA adapter
|
| 81 |
+
try:
|
| 82 |
+
model_with_lora = PeftModel.from_pretrained(base_model, self.model_path)
|
| 83 |
+
self.model = model_with_lora.merge_and_unload()
|
| 84 |
+
logger.info("LoRA adapter loaded and merged")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.warning(f"Could not load as PEFT model: {e}")
|
| 87 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_path)
|
| 88 |
+
|
| 89 |
+
self.model = self.model.to(self.device)
|
| 90 |
+
self.model.eval()
|
| 91 |
+
logger.info("Model loaded successfully")
|
| 92 |
+
|
| 93 |
+
def build_prompt(self, n_vars: int) -> str:
|
| 94 |
+
"""Build JSON format prompt matching training data."""
|
| 95 |
+
vars_list = [f"x_{i+1}" for i in range(n_vars)]
|
| 96 |
+
ops_list = ["+", "-", "*", "sin", "cos"]
|
| 97 |
+
|
| 98 |
+
prompt = json.dumps({
|
| 99 |
+
"vars": vars_list,
|
| 100 |
+
"ops": ops_list,
|
| 101 |
+
"cons": None,
|
| 102 |
+
"expr": ""
|
| 103 |
+
})[:-3] # Remove trailing '"}'
|
| 104 |
+
|
| 105 |
+
return prompt
|
| 106 |
+
|
| 107 |
+
def extract_expression(self, generated_text: str) -> str:
|
| 108 |
+
"""Extract expression from JSON format output.
|
| 109 |
+
|
| 110 |
+
Handles two formats:
|
| 111 |
+
1. Standard JSON: "expr": "value"}
|
| 112 |
+
2. Model output: "expr": value"} (no quotes around value)
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
# Case 1: Standard JSON with quotes around expression value
|
| 116 |
+
if '"expr": "' in generated_text:
|
| 117 |
+
expr_start = generated_text.index('"expr": "') + len('"expr": "')
|
| 118 |
+
remaining = generated_text[expr_start:]
|
| 119 |
+
# Find closing "}
|
| 120 |
+
if '"}' in remaining:
|
| 121 |
+
return remaining[:remaining.index('"}')].strip()
|
| 122 |
+
# Fallback: find first quote
|
| 123 |
+
if '"' in remaining:
|
| 124 |
+
return remaining[:remaining.index('"')].strip()
|
| 125 |
+
return remaining.strip()
|
| 126 |
+
|
| 127 |
+
# Case 2: Model output WITHOUT quotes: "expr": value"}
|
| 128 |
+
# This is what the model actually generates
|
| 129 |
+
if '"expr": ' in generated_text:
|
| 130 |
+
expr_start = generated_text.index('"expr": ') + len('"expr": ')
|
| 131 |
+
remaining = generated_text[expr_start:]
|
| 132 |
+
# Find closing "} which ends the JSON object
|
| 133 |
+
if '"}' in remaining:
|
| 134 |
+
return remaining[:remaining.index('"}')].strip()
|
| 135 |
+
# Fallback: find "{ which starts next object
|
| 136 |
+
if '"{' in remaining:
|
| 137 |
+
return remaining[:remaining.index('"{')].strip().rstrip('}')
|
| 138 |
+
return remaining.strip()
|
| 139 |
+
|
| 140 |
+
# Case 3: Compact JSON without space
|
| 141 |
+
if '"expr":"' in generated_text:
|
| 142 |
+
expr_start = generated_text.index('"expr":"') + len('"expr":"')
|
| 143 |
+
remaining = generated_text[expr_start:]
|
| 144 |
+
if '"}' in remaining:
|
| 145 |
+
return remaining[:remaining.index('"}')].strip()
|
| 146 |
+
if '"' in remaining:
|
| 147 |
+
return remaining[:remaining.index('"')].strip()
|
| 148 |
+
return remaining.strip()
|
| 149 |
+
|
| 150 |
+
except (ValueError, IndexError):
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
# Last resort: split on "expr" and clean up
|
| 154 |
+
fallback = generated_text.split('"expr"')[-1].strip(' ":}')
|
| 155 |
+
if '"}' in fallback:
|
| 156 |
+
fallback = fallback[:fallback.index('"}')]
|
| 157 |
+
return fallback.strip()
|
| 158 |
+
|
| 159 |
+
def compute_r2(self, expression_str: str, X: np.ndarray, y: np.ndarray) -> float:
|
| 160 |
+
"""Compute R² score for an expression."""
|
| 161 |
+
if not expression_str or expression_str.isspace():
|
| 162 |
+
return -np.inf
|
| 163 |
+
|
| 164 |
+
# Replace constant placeholder C with 1
|
| 165 |
+
if 'C' in expression_str:
|
| 166 |
+
expression_str = expression_str.replace('C', '1')
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
expr = Expression(expression_str, is_prefix=False)
|
| 170 |
+
|
| 171 |
+
if not expr.is_valid_on_dataset(X):
|
| 172 |
+
return -np.inf
|
| 173 |
+
|
| 174 |
+
y_pred = expr.evaluate(X)
|
| 175 |
+
|
| 176 |
+
if not np.all(np.isfinite(y_pred)):
|
| 177 |
+
return -np.inf
|
| 178 |
+
|
| 179 |
+
ss_res = np.sum((y - y_pred) ** 2)
|
| 180 |
+
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
| 181 |
+
|
| 182 |
+
if ss_tot == 0:
|
| 183 |
+
return 0.0
|
| 184 |
+
|
| 185 |
+
return 1 - (ss_res / ss_tot)
|
| 186 |
+
except Exception:
|
| 187 |
+
return -np.inf
|
| 188 |
+
|
| 189 |
+
def sample_expressions(self, n_vars: int, n_samples: int = 100,
|
| 190 |
+
temperature: float = 0.7) -> list:
|
| 191 |
+
"""Generate N expression samples."""
|
| 192 |
+
prompt = self.build_prompt(n_vars)
|
| 193 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 194 |
+
|
| 195 |
+
expressions = []
|
| 196 |
+
|
| 197 |
+
debug_count = 0
|
| 198 |
+
for _ in tqdm(range(n_samples), desc="Sampling expressions"):
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
output = self.model.generate(
|
| 201 |
+
**inputs,
|
| 202 |
+
max_new_tokens=50,
|
| 203 |
+
do_sample=True,
|
| 204 |
+
top_k=50,
|
| 205 |
+
top_p=0.9,
|
| 206 |
+
temperature=temperature,
|
| 207 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
text = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 211 |
+
expr_str = self.extract_expression(text)
|
| 212 |
+
|
| 213 |
+
# Debug: print first 5 extractions
|
| 214 |
+
if debug_count < 5:
|
| 215 |
+
logger.info(f"DEBUG [{debug_count}] raw text (last 80 chars): ...{text[-80:]}")
|
| 216 |
+
logger.info(f"DEBUG [{debug_count}] extracted: '{expr_str}'")
|
| 217 |
+
debug_count += 1
|
| 218 |
+
|
| 219 |
+
expressions.append(expr_str)
|
| 220 |
+
|
| 221 |
+
return expressions
|
| 222 |
+
|
| 223 |
+
def find_best_expression(self, X: np.ndarray, y: np.ndarray,
|
| 224 |
+
n_samples: int = 500, temperature: float = 0.7):
|
| 225 |
+
"""Sample N expressions and find the best one for the dataset."""
|
| 226 |
+
n_vars = X.shape[1]
|
| 227 |
+
|
| 228 |
+
logger.info(f"Sampling {n_samples} expressions for {n_vars}-variable dataset...")
|
| 229 |
+
expressions = self.sample_expressions(n_vars, n_samples, temperature)
|
| 230 |
+
|
| 231 |
+
# Compute R² for each
|
| 232 |
+
results = []
|
| 233 |
+
unique_expressions = set()
|
| 234 |
+
|
| 235 |
+
for expr_str in tqdm(expressions, desc="Computing R² scores"):
|
| 236 |
+
if expr_str in unique_expressions:
|
| 237 |
+
continue
|
| 238 |
+
unique_expressions.add(expr_str)
|
| 239 |
+
|
| 240 |
+
r2 = self.compute_r2(expr_str, X, y)
|
| 241 |
+
results.append({
|
| 242 |
+
"expression": expr_str,
|
| 243 |
+
"r2": float(r2) if np.isfinite(r2) else None,
|
| 244 |
+
"is_valid": bool(np.isfinite(r2) and r2 > -1),
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
# Sort by R²
|
| 248 |
+
results.sort(key=lambda x: x["r2"] if x["r2"] is not None else -np.inf, reverse=True)
|
| 249 |
+
|
| 250 |
+
# Statistics
|
| 251 |
+
valid_count = sum(1 for r in results if r["is_valid"])
|
| 252 |
+
valid_r2s = [r["r2"] for r in results if r["r2"] is not None and r["r2"] > -1]
|
| 253 |
+
|
| 254 |
+
return {
|
| 255 |
+
"n_samples": n_samples,
|
| 256 |
+
"unique_expressions": len(unique_expressions),
|
| 257 |
+
"valid_count": valid_count,
|
| 258 |
+
"valid_rate": valid_count / len(unique_expressions) if unique_expressions else 0,
|
| 259 |
+
"best_r2": results[0]["r2"] if results and results[0]["r2"] else None,
|
| 260 |
+
"best_expression": results[0]["expression"] if results else None,
|
| 261 |
+
"mean_r2": float(np.mean(valid_r2s)) if valid_r2s else None,
|
| 262 |
+
"median_r2": float(np.median(valid_r2s)) if valid_r2s else None,
|
| 263 |
+
"top_10": results[:10],
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def run_experiment(model_path: str, datasets_dir: str, n_samples: int = 500,
|
| 268 |
+
output_dir: str = "./output/best_of_n"):
|
| 269 |
+
"""Run Best-of-N experiment on multiple datasets."""
|
| 270 |
+
|
| 271 |
+
output_dir = Path(output_dir)
|
| 272 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
# Test datasets
|
| 275 |
+
test_datasets = {
|
| 276 |
+
"add_x1_x2": {"formula": "x_1 + x_2", "difficulty": "easy"},
|
| 277 |
+
"mul_x1_x2": {"formula": "x_1 * x_2", "difficulty": "easy"},
|
| 278 |
+
"sub_x1_x2": {"formula": "x_1 - x_2", "difficulty": "easy"},
|
| 279 |
+
"sin_x1": {"formula": "sin(x_1)", "difficulty": "medium"},
|
| 280 |
+
"cos_x1": {"formula": "cos(x_1)", "difficulty": "medium"},
|
| 281 |
+
"square_x1": {"formula": "x_1 * x_1", "difficulty": "medium"},
|
| 282 |
+
"sin_x1_plus_x2": {"formula": "sin(x_1) + x_2", "difficulty": "hard"},
|
| 283 |
+
"x1_mul_sin_x2": {"formula": "x_1 * sin(x_2)", "difficulty": "hard"},
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
# Initialize sampler
|
| 287 |
+
sampler = BestOfNSampler(model_path)
|
| 288 |
+
|
| 289 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 290 |
+
|
| 291 |
+
results = {
|
| 292 |
+
"timestamp": timestamp,
|
| 293 |
+
"model_path": model_path,
|
| 294 |
+
"n_samples": n_samples,
|
| 295 |
+
"datasets": {},
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
print("\n" + "=" * 70)
|
| 299 |
+
print("BEST-OF-N SAMPLING EXPERIMENT")
|
| 300 |
+
print("=" * 70)
|
| 301 |
+
print(f"Model: {model_path}")
|
| 302 |
+
print(f"Samples per dataset: {n_samples}")
|
| 303 |
+
print("=" * 70)
|
| 304 |
+
|
| 305 |
+
for dataset_name, info in test_datasets.items():
|
| 306 |
+
dataset_path = Path(datasets_dir) / f"{dataset_name}.csv"
|
| 307 |
+
|
| 308 |
+
if not dataset_path.exists():
|
| 309 |
+
logger.warning(f"Dataset not found: {dataset_path}")
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
print(f"\n{'='*70}")
|
| 313 |
+
print(f"Dataset: {dataset_name}")
|
| 314 |
+
print(f"Ground truth: {info['formula']}")
|
| 315 |
+
print(f"Difficulty: {info['difficulty']}")
|
| 316 |
+
print(f"{'='*70}")
|
| 317 |
+
|
| 318 |
+
# Load dataset
|
| 319 |
+
reg = RegressionDataset(str(dataset_path.parent), dataset_path.name)
|
| 320 |
+
X, y = reg.get_numpy()
|
| 321 |
+
|
| 322 |
+
# Run Best-of-N
|
| 323 |
+
result = sampler.find_best_expression(X, y, n_samples)
|
| 324 |
+
result["ground_truth"] = info["formula"]
|
| 325 |
+
result["difficulty"] = info["difficulty"]
|
| 326 |
+
|
| 327 |
+
results["datasets"][dataset_name] = result
|
| 328 |
+
|
| 329 |
+
# Print results
|
| 330 |
+
print(f"\nResults:")
|
| 331 |
+
print(f" Valid expressions: {result['valid_count']}/{result['unique_expressions']} ({result['valid_rate']:.1%})")
|
| 332 |
+
print(f" Best R²: {result['best_r2']:.4f}" if result['best_r2'] else " Best R²: N/A")
|
| 333 |
+
print(f" Best expression: {result['best_expression']}")
|
| 334 |
+
|
| 335 |
+
if result['best_r2'] and result['best_r2'] > 0.99:
|
| 336 |
+
print(f" ✅ FOUND NEAR-PERFECT MATCH!")
|
| 337 |
+
elif result['best_r2'] and result['best_r2'] > 0.9:
|
| 338 |
+
print(f" ⚠️ Found good match (R² > 0.9)")
|
| 339 |
+
else:
|
| 340 |
+
print(f" ❌ No good match found")
|
| 341 |
+
|
| 342 |
+
print("\n Top 5 expressions:")
|
| 343 |
+
for i, expr in enumerate(result['top_10'][:5]):
|
| 344 |
+
r2_str = f"{expr['r2']:.4f}" if expr['r2'] else "N/A"
|
| 345 |
+
print(f" {i+1}. {expr['expression'][:40]:<40} R²={r2_str}")
|
| 346 |
+
|
| 347 |
+
# Save results
|
| 348 |
+
results_file = output_dir / f"best_of_n_results_{timestamp}.json"
|
| 349 |
+
with open(results_file, 'w') as f:
|
| 350 |
+
json.dump(results, f, indent=2)
|
| 351 |
+
|
| 352 |
+
print("\n" + "=" * 70)
|
| 353 |
+
print("SUMMARY")
|
| 354 |
+
print("=" * 70)
|
| 355 |
+
|
| 356 |
+
# Summary table
|
| 357 |
+
print(f"\n{'Dataset':<25} {'Difficulty':<10} {'Best R²':<10} {'Found?':<10}")
|
| 358 |
+
print("-" * 60)
|
| 359 |
+
|
| 360 |
+
success_count = 0
|
| 361 |
+
for name, res in results["datasets"].items():
|
| 362 |
+
r2 = res["best_r2"]
|
| 363 |
+
r2_str = f"{r2:.4f}" if r2 else "N/A"
|
| 364 |
+
found = "✅" if r2 and r2 > 0.99 else ("⚠️" if r2 and r2 > 0.9 else "❌")
|
| 365 |
+
if r2 and r2 > 0.99:
|
| 366 |
+
success_count += 1
|
| 367 |
+
print(f"{name:<25} {res['difficulty']:<10} {r2_str:<10} {found:<10}")
|
| 368 |
+
|
| 369 |
+
print("-" * 60)
|
| 370 |
+
print(f"Success rate (R² > 0.99): {success_count}/{len(results['datasets'])}")
|
| 371 |
+
print(f"\nResults saved to: {results_file}")
|
| 372 |
+
|
| 373 |
+
return results
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def main():
|
| 377 |
+
parser = argparse.ArgumentParser(description="Best-of-N Sampling Experiment")
|
| 378 |
+
parser.add_argument("--model_path", type=str, default="./output/exp_a_json",
|
| 379 |
+
help="Path to trained model")
|
| 380 |
+
parser.add_argument("--datasets_dir", type=str, default="./data/ppo_test",
|
| 381 |
+
help="Directory containing test datasets")
|
| 382 |
+
parser.add_argument("--n_samples", type=int, default=500,
|
| 383 |
+
help="Number of samples per dataset")
|
| 384 |
+
parser.add_argument("--output_dir", type=str, default="./output/best_of_n",
|
| 385 |
+
help="Output directory for results")
|
| 386 |
+
|
| 387 |
+
args = parser.parse_args()
|
| 388 |
+
|
| 389 |
+
run_experiment(
|
| 390 |
+
model_path=args.model_path,
|
| 391 |
+
datasets_dir=args.datasets_dir,
|
| 392 |
+
n_samples=args.n_samples,
|
| 393 |
+
output_dir=args.output_dir,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
main()
|
2_training/reinforcement/debug_reinforce.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Debug version of REINFORCE that saves ALL expressions (valid and invalid).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Dict
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
# Add project root to path
|
| 18 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 19 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 20 |
+
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
|
| 21 |
+
|
| 22 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 23 |
+
from peft import PeftModel, LoraConfig, get_peft_model
|
| 24 |
+
from expression import Expression
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DebugREINFORCE:
|
| 28 |
+
"""REINFORCE that logs all expressions."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, model_path: str, X: np.ndarray, y: np.ndarray, device: str = None):
|
| 31 |
+
self.X = X
|
| 32 |
+
self.y = y
|
| 33 |
+
self.n_vars = X.shape[1]
|
| 34 |
+
|
| 35 |
+
if device:
|
| 36 |
+
self.device = torch.device(device)
|
| 37 |
+
else:
|
| 38 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
|
| 40 |
+
# Load model
|
| 41 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 42 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 46 |
+
if len(self.tokenizer) != base_model.config.vocab_size:
|
| 47 |
+
base_model.resize_token_embeddings(len(self.tokenizer))
|
| 48 |
+
model_with_lora = PeftModel.from_pretrained(base_model, model_path)
|
| 49 |
+
self.model = model_with_lora.merge_and_unload()
|
| 50 |
+
except:
|
| 51 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 52 |
+
|
| 53 |
+
# Add LoRA
|
| 54 |
+
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], lora_dropout=0.05, bias="none")
|
| 55 |
+
self.model = get_peft_model(self.model, lora_config)
|
| 56 |
+
self.model = self.model.to(self.device)
|
| 57 |
+
self.model.train()
|
| 58 |
+
|
| 59 |
+
# Build prompt
|
| 60 |
+
vars_list = [f"x_{i+1}" for i in range(self.n_vars)]
|
| 61 |
+
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
|
| 62 |
+
self.prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2]
|
| 63 |
+
self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device)
|
| 64 |
+
|
| 65 |
+
# Optimizer
|
| 66 |
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5)
|
| 67 |
+
|
| 68 |
+
# Baseline
|
| 69 |
+
self.baseline = 0.0
|
| 70 |
+
self.baseline_decay = 0.9
|
| 71 |
+
|
| 72 |
+
# ALL expressions log
|
| 73 |
+
self.all_expressions = []
|
| 74 |
+
|
| 75 |
+
def extract_expression(self, text: str) -> str:
|
| 76 |
+
"""Extract expression from generated text."""
|
| 77 |
+
try:
|
| 78 |
+
if '"expr": "' in text:
|
| 79 |
+
start = text.index('"expr": "') + len('"expr": "')
|
| 80 |
+
remaining = text[start:]
|
| 81 |
+
for terminator in ['"}', '"']:
|
| 82 |
+
if terminator in remaining:
|
| 83 |
+
return remaining[:remaining.index(terminator)].strip()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return text.strip()
|
| 87 |
+
|
| 88 |
+
def compute_r2(self, expression_str: str) -> tuple:
|
| 89 |
+
"""Compute R^2 and detailed error info."""
|
| 90 |
+
result = {
|
| 91 |
+
"expression": expression_str,
|
| 92 |
+
"r2": -1.0,
|
| 93 |
+
"is_valid": False,
|
| 94 |
+
"error_type": None,
|
| 95 |
+
"error_message": None,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
if not expression_str or expression_str.isspace():
|
| 99 |
+
result["error_type"] = "empty"
|
| 100 |
+
return result
|
| 101 |
+
|
| 102 |
+
test_expr = expression_str.replace('C', '1')
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
expr = Expression(test_expr, is_prefix=False)
|
| 106 |
+
|
| 107 |
+
if not expr.is_valid_on_dataset(self.X):
|
| 108 |
+
result["error_type"] = "invalid_on_dataset"
|
| 109 |
+
result["error_message"] = "NaN/Inf on dataset"
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
y_pred = expr.evaluate(self.X)
|
| 113 |
+
|
| 114 |
+
if not np.all(np.isfinite(y_pred)):
|
| 115 |
+
result["error_type"] = "non_finite_output"
|
| 116 |
+
return result
|
| 117 |
+
|
| 118 |
+
ss_res = np.sum((self.y - y_pred) ** 2)
|
| 119 |
+
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
|
| 120 |
+
|
| 121 |
+
if ss_tot == 0:
|
| 122 |
+
r2 = 0.0
|
| 123 |
+
else:
|
| 124 |
+
r2 = 1 - (ss_res / ss_tot)
|
| 125 |
+
|
| 126 |
+
result["r2"] = float(np.clip(r2, -1.0, 1.0))
|
| 127 |
+
result["is_valid"] = True
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
result["error_type"] = "parse_error"
|
| 131 |
+
result["error_message"] = str(e)[:100]
|
| 132 |
+
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def generate_batch(self, batch_size: int = 16, max_new_tokens: int = 50):
|
| 136 |
+
"""Generate batch and evaluate."""
|
| 137 |
+
results = []
|
| 138 |
+
|
| 139 |
+
for _ in range(batch_size):
|
| 140 |
+
generated_ids = self.prompt_ids.clone()
|
| 141 |
+
generated_tokens = []
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
for _ in range(max_new_tokens):
|
| 145 |
+
outputs = self.model(generated_ids)
|
| 146 |
+
logits = outputs.logits[:, -1, :] / 0.7
|
| 147 |
+
|
| 148 |
+
probs = F.softmax(logits, dim=-1)
|
| 149 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 150 |
+
|
| 151 |
+
generated_tokens.append(next_token.item())
|
| 152 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 153 |
+
|
| 154 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 158 |
+
if '"}' in text[len(self.prompt):]:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 162 |
+
expr_str = self.extract_expression(text)
|
| 163 |
+
|
| 164 |
+
# Evaluate with detailed info
|
| 165 |
+
eval_result = self.compute_r2(expr_str)
|
| 166 |
+
|
| 167 |
+
# Compute log prob
|
| 168 |
+
if len(generated_tokens) > 0:
|
| 169 |
+
full_ids = torch.cat([self.prompt_ids, torch.tensor([generated_tokens], device=self.device)], dim=1)
|
| 170 |
+
outputs = self.model(full_ids[:, :-1])
|
| 171 |
+
logits = outputs.logits / 0.7
|
| 172 |
+
prompt_len = self.prompt_ids.shape[1]
|
| 173 |
+
gen_logits = logits[:, prompt_len-1:, :]
|
| 174 |
+
log_probs_all = F.log_softmax(gen_logits, dim=-1)
|
| 175 |
+
target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0)
|
| 176 |
+
selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1)
|
| 177 |
+
total_log_prob = selected_log_probs.sum()
|
| 178 |
+
else:
|
| 179 |
+
total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True)
|
| 180 |
+
|
| 181 |
+
eval_result["log_prob"] = total_log_prob
|
| 182 |
+
results.append(eval_result)
|
| 183 |
+
|
| 184 |
+
# Log ALL expressions
|
| 185 |
+
self.all_expressions.append(eval_result.copy())
|
| 186 |
+
|
| 187 |
+
return results
|
| 188 |
+
|
| 189 |
+
def train_step(self, batch_size: int = 16):
|
| 190 |
+
"""One training step."""
|
| 191 |
+
results = self.generate_batch(batch_size)
|
| 192 |
+
|
| 193 |
+
# Compute rewards
|
| 194 |
+
rewards = [r["r2"] if r["is_valid"] else -0.1 for r in results]
|
| 195 |
+
|
| 196 |
+
# Update baseline
|
| 197 |
+
valid_rewards = [r for r in rewards if r > -0.1]
|
| 198 |
+
if valid_rewards:
|
| 199 |
+
mean_reward = np.mean(valid_rewards)
|
| 200 |
+
self.baseline = self.baseline_decay * self.baseline + (1 - self.baseline_decay) * mean_reward
|
| 201 |
+
|
| 202 |
+
# Advantages
|
| 203 |
+
advantages = [r - self.baseline for r in rewards]
|
| 204 |
+
|
| 205 |
+
# Update
|
| 206 |
+
self.optimizer.zero_grad()
|
| 207 |
+
policy_loss = torch.tensor(0.0, device=self.device)
|
| 208 |
+
|
| 209 |
+
for result, advantage in zip(results, advantages):
|
| 210 |
+
if result["is_valid"] or result["error_type"] == "parse_error":
|
| 211 |
+
policy_loss = policy_loss - result["log_prob"] * advantage
|
| 212 |
+
|
| 213 |
+
if len(results) > 0:
|
| 214 |
+
policy_loss = policy_loss / len(results)
|
| 215 |
+
policy_loss.backward()
|
| 216 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 217 |
+
self.optimizer.step()
|
| 218 |
+
|
| 219 |
+
# Stats
|
| 220 |
+
valid_count = sum(1 for r in results if r["is_valid"])
|
| 221 |
+
valid_r2 = [r["r2"] for r in results if r["is_valid"]]
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"valid_count": valid_count,
|
| 225 |
+
"total_count": len(results),
|
| 226 |
+
"mean_r2": np.mean(valid_r2) if valid_r2 else -1.0,
|
| 227 |
+
"max_r2": max(r["r2"] for r in results),
|
| 228 |
+
"baseline": self.baseline,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
def run(self, epochs: int = 10):
|
| 232 |
+
"""Run training."""
|
| 233 |
+
print(f"Running debug REINFORCE for {epochs} epochs...")
|
| 234 |
+
print()
|
| 235 |
+
|
| 236 |
+
for epoch in range(1, epochs + 1):
|
| 237 |
+
stats = self.train_step()
|
| 238 |
+
print(f"Epoch {epoch:2d} | Valid: {stats['valid_count']}/{stats['total_count']} | Mean R²: {stats['mean_r2']:.4f} | Max R²: {stats['max_r2']:.4f}")
|
| 239 |
+
|
| 240 |
+
# Save ALL expressions
|
| 241 |
+
output_file = "debug_expressions.json"
|
| 242 |
+
with open(output_file, "w") as f:
|
| 243 |
+
json.dump({"all_expressions": self.all_expressions}, f, indent=2, default=str)
|
| 244 |
+
|
| 245 |
+
print()
|
| 246 |
+
print(f"Saved {len(self.all_expressions)} expressions to {output_file}")
|
| 247 |
+
|
| 248 |
+
# Analyze
|
| 249 |
+
valid = [e for e in self.all_expressions if e["is_valid"]]
|
| 250 |
+
invalid = [e for e in self.all_expressions if not e["is_valid"]]
|
| 251 |
+
|
| 252 |
+
print()
|
| 253 |
+
print("SUMMARY:")
|
| 254 |
+
print(f" Total: {len(self.all_expressions)}")
|
| 255 |
+
print(f" Valid: {len(valid)} ({100*len(valid)/len(self.all_expressions):.1f}%)")
|
| 256 |
+
print(f" Invalid: {len(invalid)} ({100*len(invalid)/len(self.all_expressions):.1f}%)")
|
| 257 |
+
|
| 258 |
+
if invalid:
|
| 259 |
+
error_types = {}
|
| 260 |
+
for e in invalid:
|
| 261 |
+
et = e.get("error_type", "unknown")
|
| 262 |
+
error_types[et] = error_types.get(et, 0) + 1
|
| 263 |
+
|
| 264 |
+
print()
|
| 265 |
+
print("Invalid expression types:")
|
| 266 |
+
for et, count in sorted(error_types.items(), key=lambda x: -x[1]):
|
| 267 |
+
print(f" {et}: {count} ({100*count/len(invalid):.1f}%)")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def main():
|
| 271 |
+
parser = argparse.ArgumentParser()
|
| 272 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 273 |
+
parser.add_argument("--dataset", type=str, required=True)
|
| 274 |
+
parser.add_argument("--epochs", type=int, default=10)
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
# Load dataset
|
| 278 |
+
import pandas as pd
|
| 279 |
+
df = pd.read_csv(args.dataset)
|
| 280 |
+
x_cols = [c for c in df.columns if c.startswith('x_')]
|
| 281 |
+
X = df[x_cols].values
|
| 282 |
+
y = df['y'].values
|
| 283 |
+
|
| 284 |
+
print(f"Dataset: {args.dataset}")
|
| 285 |
+
print(f" Samples: {len(df)}, Variables: {len(x_cols)}")
|
| 286 |
+
print()
|
| 287 |
+
|
| 288 |
+
# Run
|
| 289 |
+
reinforce = DebugREINFORCE(args.model_path, X, y)
|
| 290 |
+
reinforce.run(epochs=args.epochs)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
main()
|
2_training/reinforcement/grpo_experiment.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO Experiment for Symbolic Regression
|
| 4 |
+
|
| 5 |
+
GRPO (Group Relative Policy Optimization) supports custom reward functions
|
| 6 |
+
via the reward_funcs parameter, making it ideal for symbolic regression
|
| 7 |
+
where we compute R^2 scores as rewards.
|
| 8 |
+
|
| 9 |
+
This is the recommended approach for TRL 0.27+ since PPO experimental
|
| 10 |
+
has compatibility issues.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python scripts/grpo_experiment.py --dataset ./data/ppo_test/sin_x1.csv
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
os.environ['TRL_EXPERIMENTAL_SILENCE'] = '1'
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
import json
|
| 21 |
+
import argparse
|
| 22 |
+
import logging
|
| 23 |
+
import datetime
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import List
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
# Add project root to path
|
| 31 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 32 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 33 |
+
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
|
| 34 |
+
|
| 35 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 36 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 37 |
+
from datasets import Dataset
|
| 38 |
+
from peft import PeftModel
|
| 39 |
+
|
| 40 |
+
from expression import Expression
|
| 41 |
+
from dataset import RegressionDataset
|
| 42 |
+
|
| 43 |
+
# Configure logging
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.INFO,
|
| 46 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 47 |
+
)
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SymbolicRegressionReward:
|
| 52 |
+
"""
|
| 53 |
+
Reward function for symbolic regression.
|
| 54 |
+
Computes R^2 score for generated expressions.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, X: np.ndarray, y: np.ndarray, tokenizer):
|
| 58 |
+
self.X = X
|
| 59 |
+
self.y = y
|
| 60 |
+
self.tokenizer = tokenizer
|
| 61 |
+
self.n_vars = X.shape[1]
|
| 62 |
+
self.best_r2 = -np.inf
|
| 63 |
+
self.best_expression = None
|
| 64 |
+
self.history = []
|
| 65 |
+
|
| 66 |
+
def extract_expression(self, text: str) -> str:
|
| 67 |
+
"""Extract expression from JSON format output."""
|
| 68 |
+
try:
|
| 69 |
+
# Case 1: Standard JSON with quotes
|
| 70 |
+
if '"expr": "' in text:
|
| 71 |
+
start = text.index('"expr": "') + len('"expr": "')
|
| 72 |
+
remaining = text[start:]
|
| 73 |
+
if '"}' in remaining:
|
| 74 |
+
return remaining[:remaining.index('"}')].strip()
|
| 75 |
+
if '"' in remaining:
|
| 76 |
+
return remaining[:remaining.index('"')].strip()
|
| 77 |
+
return remaining.strip()
|
| 78 |
+
|
| 79 |
+
# Case 2: Model output without quotes
|
| 80 |
+
if '"expr": ' in text:
|
| 81 |
+
start = text.index('"expr": ') + len('"expr": ')
|
| 82 |
+
remaining = text[start:]
|
| 83 |
+
if '"}' in remaining:
|
| 84 |
+
return remaining[:remaining.index('"}')].strip()
|
| 85 |
+
return remaining.strip()
|
| 86 |
+
|
| 87 |
+
except (ValueError, IndexError):
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
return text.split('"expr"')[-1].strip(' ":}')
|
| 91 |
+
|
| 92 |
+
def compute_r2(self, expression_str: str) -> float:
|
| 93 |
+
"""Compute R^2 score for an expression."""
|
| 94 |
+
if not expression_str or expression_str.isspace():
|
| 95 |
+
return -1.0
|
| 96 |
+
|
| 97 |
+
# Substitute C with 1
|
| 98 |
+
if 'C' in expression_str:
|
| 99 |
+
expression_str = expression_str.replace('C', '1')
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
expr = Expression(expression_str, is_prefix=False)
|
| 103 |
+
|
| 104 |
+
if not expr.is_valid_on_dataset(self.X):
|
| 105 |
+
return -1.0
|
| 106 |
+
|
| 107 |
+
y_pred = expr.evaluate(self.X)
|
| 108 |
+
|
| 109 |
+
if not np.all(np.isfinite(y_pred)):
|
| 110 |
+
return -1.0
|
| 111 |
+
|
| 112 |
+
ss_res = np.sum((self.y - y_pred) ** 2)
|
| 113 |
+
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
|
| 114 |
+
|
| 115 |
+
if ss_tot == 0:
|
| 116 |
+
return 0.0
|
| 117 |
+
|
| 118 |
+
r2 = 1 - (ss_res / ss_tot)
|
| 119 |
+
return float(np.clip(r2, -1.0, 1.0))
|
| 120 |
+
except Exception:
|
| 121 |
+
return -1.0
|
| 122 |
+
|
| 123 |
+
def __call__(self, completions: List[str], **kwargs) -> List[float]:
|
| 124 |
+
"""
|
| 125 |
+
Compute rewards for a batch of completions.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
completions: List of generated completion strings
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
List of R^2 scores
|
| 132 |
+
"""
|
| 133 |
+
rewards = []
|
| 134 |
+
|
| 135 |
+
for completion in completions:
|
| 136 |
+
# Extract expression from completion
|
| 137 |
+
expr_str = self.extract_expression(completion)
|
| 138 |
+
|
| 139 |
+
# Compute R^2
|
| 140 |
+
r2 = self.compute_r2(expr_str)
|
| 141 |
+
rewards.append(r2)
|
| 142 |
+
|
| 143 |
+
# Track best
|
| 144 |
+
if r2 > self.best_r2:
|
| 145 |
+
self.best_r2 = r2
|
| 146 |
+
self.best_expression = expr_str
|
| 147 |
+
logger.info(f"New best R^2: {r2:.4f} - {expr_str}")
|
| 148 |
+
|
| 149 |
+
# Log batch statistics
|
| 150 |
+
valid_rewards = [r for r in rewards if r > -1.0]
|
| 151 |
+
if valid_rewards:
|
| 152 |
+
self.history.append({
|
| 153 |
+
"mean_r2": np.mean(valid_rewards),
|
| 154 |
+
"max_r2": max(valid_rewards),
|
| 155 |
+
"valid_rate": len(valid_rewards) / len(rewards),
|
| 156 |
+
})
|
| 157 |
+
|
| 158 |
+
return rewards
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def build_prompt(n_vars: int) -> str:
|
| 162 |
+
"""Build JSON format prompt matching training data."""
|
| 163 |
+
vars_list = [f"x_{i+1}" for i in range(n_vars)]
|
| 164 |
+
ops_list = ["+", "-", "*", "sin", "cos"]
|
| 165 |
+
|
| 166 |
+
prompt = json.dumps({
|
| 167 |
+
"vars": vars_list,
|
| 168 |
+
"ops": ops_list,
|
| 169 |
+
"cons": None,
|
| 170 |
+
"expr": ""
|
| 171 |
+
})[:-3] # Remove trailing '"}' for model to complete
|
| 172 |
+
|
| 173 |
+
return prompt
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def run_grpo_experiment(
|
| 177 |
+
model_path: str,
|
| 178 |
+
dataset_path: str,
|
| 179 |
+
output_dir: str = "./output/grpo_results",
|
| 180 |
+
num_episodes: int = 100,
|
| 181 |
+
batch_size: int = 4,
|
| 182 |
+
learning_rate: float = 1e-5,
|
| 183 |
+
use_cpu: bool = False,
|
| 184 |
+
):
|
| 185 |
+
"""Run GRPO experiment with custom R^2 reward function."""
|
| 186 |
+
|
| 187 |
+
output_dir = Path(output_dir)
|
| 188 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 189 |
+
|
| 190 |
+
# Device setup
|
| 191 |
+
device = "cpu" if use_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 192 |
+
logger.info(f"Using device: {device}")
|
| 193 |
+
|
| 194 |
+
# Load dataset
|
| 195 |
+
logger.info(f"Loading dataset from {dataset_path}")
|
| 196 |
+
dataset_path = Path(dataset_path)
|
| 197 |
+
reg = RegressionDataset(str(dataset_path.parent), dataset_path.name)
|
| 198 |
+
X, y = reg.get_numpy()
|
| 199 |
+
n_vars = X.shape[1]
|
| 200 |
+
logger.info(f"Dataset: {X.shape[0]} samples, {n_vars} variables")
|
| 201 |
+
|
| 202 |
+
# Load tokenizer and model
|
| 203 |
+
logger.info(f"Loading model from {model_path}")
|
| 204 |
+
|
| 205 |
+
# Check if model_path is a local path or HuggingFace model
|
| 206 |
+
if Path(model_path).exists():
|
| 207 |
+
# Load tokenizer from trained model
|
| 208 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 209 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 210 |
+
|
| 211 |
+
# Load base model and LoRA
|
| 212 |
+
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 213 |
+
if len(tokenizer) != base_model.config.vocab_size:
|
| 214 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
model_with_lora = PeftModel.from_pretrained(base_model, model_path)
|
| 218 |
+
model = model_with_lora.merge_and_unload()
|
| 219 |
+
logger.info("LoRA adapter loaded and merged")
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.warning(f"Could not load LoRA: {e}")
|
| 222 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 223 |
+
else:
|
| 224 |
+
# Load from HuggingFace Hub
|
| 225 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 226 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 227 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 228 |
+
|
| 229 |
+
logger.info("Model loaded successfully")
|
| 230 |
+
|
| 231 |
+
# Build prompt and create dataset
|
| 232 |
+
prompt = build_prompt(n_vars)
|
| 233 |
+
logger.info(f"Prompt: {prompt}...")
|
| 234 |
+
|
| 235 |
+
train_dataset = Dataset.from_dict({"prompt": [prompt] * num_episodes})
|
| 236 |
+
|
| 237 |
+
# Create reward function
|
| 238 |
+
reward_func = SymbolicRegressionReward(X, y, tokenizer)
|
| 239 |
+
|
| 240 |
+
# GRPO Config
|
| 241 |
+
grpo_config = GRPOConfig(
|
| 242 |
+
output_dir=str(output_dir),
|
| 243 |
+
learning_rate=learning_rate,
|
| 244 |
+
per_device_train_batch_size=batch_size,
|
| 245 |
+
num_generations=batch_size, # Generate batch_size samples per prompt
|
| 246 |
+
max_completion_length=50,
|
| 247 |
+
num_train_epochs=1,
|
| 248 |
+
report_to=[],
|
| 249 |
+
use_cpu=use_cpu or device == "cpu",
|
| 250 |
+
bf16=False if use_cpu or device == "cpu" else True,
|
| 251 |
+
logging_steps=10,
|
| 252 |
+
save_strategy="epoch",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Create trainer
|
| 256 |
+
logger.info("Creating GRPO Trainer...")
|
| 257 |
+
trainer = GRPOTrainer(
|
| 258 |
+
model=model,
|
| 259 |
+
args=grpo_config,
|
| 260 |
+
processing_class=tokenizer,
|
| 261 |
+
train_dataset=train_dataset,
|
| 262 |
+
reward_funcs=reward_func,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Train
|
| 266 |
+
logger.info("="*60)
|
| 267 |
+
logger.info("GRPO SYMBOLIC REGRESSION EXPERIMENT")
|
| 268 |
+
logger.info("="*60)
|
| 269 |
+
logger.info(f"Dataset: {dataset_path}")
|
| 270 |
+
logger.info(f"Model: {model_path}")
|
| 271 |
+
logger.info(f"Episodes: {num_episodes}")
|
| 272 |
+
logger.info("="*60)
|
| 273 |
+
|
| 274 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
trainer.train()
|
| 278 |
+
logger.info("Training completed!")
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.error(f"Training failed: {e}")
|
| 281 |
+
import traceback
|
| 282 |
+
traceback.print_exc()
|
| 283 |
+
|
| 284 |
+
# Results
|
| 285 |
+
logger.info("\n" + "="*60)
|
| 286 |
+
logger.info("RESULTS")
|
| 287 |
+
logger.info("="*60)
|
| 288 |
+
logger.info(f"Best R^2: {reward_func.best_r2:.4f}")
|
| 289 |
+
logger.info(f"Best expression: {reward_func.best_expression}")
|
| 290 |
+
|
| 291 |
+
# Save results
|
| 292 |
+
results = {
|
| 293 |
+
"timestamp": timestamp,
|
| 294 |
+
"model_path": model_path,
|
| 295 |
+
"dataset_path": str(dataset_path),
|
| 296 |
+
"best_r2": reward_func.best_r2,
|
| 297 |
+
"best_expression": reward_func.best_expression,
|
| 298 |
+
"history": reward_func.history,
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
results_file = output_dir / f"grpo_results_{timestamp}.json"
|
| 302 |
+
with open(results_file, 'w') as f:
|
| 303 |
+
json.dump(results, f, indent=2)
|
| 304 |
+
|
| 305 |
+
logger.info(f"Results saved to: {results_file}")
|
| 306 |
+
|
| 307 |
+
# Save model
|
| 308 |
+
trainer.save_model(str(output_dir / "final_model"))
|
| 309 |
+
|
| 310 |
+
return results
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def main():
|
| 314 |
+
parser = argparse.ArgumentParser(description="GRPO Symbolic Regression")
|
| 315 |
+
parser.add_argument("--model_path", type=str, default="gpt2",
|
| 316 |
+
help="Path to model (local or HuggingFace)")
|
| 317 |
+
parser.add_argument("--dataset", type=str, default="./data/ppo_test/sin_x1.csv",
|
| 318 |
+
help="Path to test dataset CSV")
|
| 319 |
+
parser.add_argument("--output_dir", type=str, default="./output/grpo_results",
|
| 320 |
+
help="Output directory")
|
| 321 |
+
parser.add_argument("--num_episodes", type=int, default=100,
|
| 322 |
+
help="Number of training episodes")
|
| 323 |
+
parser.add_argument("--batch_size", type=int, default=4,
|
| 324 |
+
help="Batch size")
|
| 325 |
+
parser.add_argument("--lr", type=float, default=1e-5,
|
| 326 |
+
help="Learning rate")
|
| 327 |
+
parser.add_argument("--cpu", action="store_true",
|
| 328 |
+
help="Force CPU usage")
|
| 329 |
+
|
| 330 |
+
args = parser.parse_args()
|
| 331 |
+
|
| 332 |
+
run_grpo_experiment(
|
| 333 |
+
model_path=args.model_path,
|
| 334 |
+
dataset_path=args.dataset,
|
| 335 |
+
output_dir=args.output_dir,
|
| 336 |
+
num_episodes=args.num_episodes,
|
| 337 |
+
batch_size=args.batch_size,
|
| 338 |
+
learning_rate=args.lr,
|
| 339 |
+
use_cpu=args.cpu,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
main()
|
2_training/reinforcement/grpo_improved.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Improved GRPO (Group Relative Policy Optimization) for Symbolic Regression
|
| 4 |
+
|
| 5 |
+
Improvements over basic GRPO:
|
| 6 |
+
1. Filter invalid expressions before computing group statistics
|
| 7 |
+
2. Reward shaping with softer penalties
|
| 8 |
+
3. Hybrid baseline: group stats + exponential moving average
|
| 9 |
+
4. Entropy bonus for exploration
|
| 10 |
+
5. Advantage clipping to prevent extreme updates
|
| 11 |
+
6. Minimum valid ratio check before updates
|
| 12 |
+
7. Temperature annealing for better exploration/exploitation
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import json
|
| 18 |
+
import argparse
|
| 19 |
+
import logging
|
| 20 |
+
import datetime
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Dict, Tuple
|
| 23 |
+
from collections import deque
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
# Add project root to path
|
| 30 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 31 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 32 |
+
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
|
| 33 |
+
|
| 34 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 35 |
+
from peft import PeftModel, LoraConfig, get_peft_model
|
| 36 |
+
|
| 37 |
+
from expression import Expression
|
| 38 |
+
from dataset import RegressionDataset
|
| 39 |
+
|
| 40 |
+
# Configure logging
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
level=logging.INFO,
|
| 43 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ImprovedGRPO:
|
| 49 |
+
"""Improved GRPO for symbolic regression."""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
model_path: str,
|
| 54 |
+
X: np.ndarray,
|
| 55 |
+
y: np.ndarray,
|
| 56 |
+
output_dir: str = "./output/grpo",
|
| 57 |
+
learning_rate: float = 5e-5,
|
| 58 |
+
device: str = None,
|
| 59 |
+
group_size: int = 16, # Larger groups for better statistics
|
| 60 |
+
entropy_coef: float = 0.01,
|
| 61 |
+
advantage_clip: float = 2.0, # Clip extreme advantages
|
| 62 |
+
min_valid_ratio: float = 0.2, # Minimum valid expressions to update
|
| 63 |
+
):
|
| 64 |
+
self.X = X
|
| 65 |
+
self.y = y
|
| 66 |
+
self.n_vars = X.shape[1]
|
| 67 |
+
self.output_dir = Path(output_dir)
|
| 68 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
self.learning_rate = learning_rate
|
| 70 |
+
self.group_size = group_size
|
| 71 |
+
self.entropy_coef = entropy_coef
|
| 72 |
+
self.advantage_clip = advantage_clip
|
| 73 |
+
self.min_valid_ratio = min_valid_ratio
|
| 74 |
+
|
| 75 |
+
# Device
|
| 76 |
+
if device:
|
| 77 |
+
self.device = torch.device(device)
|
| 78 |
+
else:
|
| 79 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
logger.info(f"Using device: {self.device}")
|
| 81 |
+
|
| 82 |
+
# Load model
|
| 83 |
+
self._load_model(model_path)
|
| 84 |
+
|
| 85 |
+
# Build prompt
|
| 86 |
+
self.prompt = self._build_prompt()
|
| 87 |
+
self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device)
|
| 88 |
+
|
| 89 |
+
# Optimizer
|
| 90 |
+
self.optimizer = torch.optim.AdamW(
|
| 91 |
+
self.model.parameters(),
|
| 92 |
+
lr=learning_rate,
|
| 93 |
+
weight_decay=0.01
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Scheduler
|
| 97 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 98 |
+
self.optimizer, T_0=10, T_mult=2
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Tracking
|
| 102 |
+
self.best_r2 = -np.inf
|
| 103 |
+
self.best_expression = None
|
| 104 |
+
self.history = []
|
| 105 |
+
self.discovered_expressions: Dict[str, float] = {}
|
| 106 |
+
|
| 107 |
+
# Hybrid baseline: EMA of valid rewards
|
| 108 |
+
self.ema_baseline = 0.0
|
| 109 |
+
self.ema_decay = 0.9
|
| 110 |
+
self.reward_buffer = deque(maxlen=100)
|
| 111 |
+
|
| 112 |
+
# Temperature annealing
|
| 113 |
+
self.initial_temp = 0.8
|
| 114 |
+
self.min_temp = 0.5
|
| 115 |
+
self.current_temp = self.initial_temp
|
| 116 |
+
|
| 117 |
+
def _load_model(self, model_path: str):
|
| 118 |
+
"""Load model and tokenizer."""
|
| 119 |
+
logger.info(f"Loading model from {model_path}")
|
| 120 |
+
|
| 121 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 122 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
logger.info("Attempting to load as LoRA adapter...")
|
| 126 |
+
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 127 |
+
if len(self.tokenizer) != base_model.config.vocab_size:
|
| 128 |
+
base_model.resize_token_embeddings(len(self.tokenizer))
|
| 129 |
+
logger.info(f"Resized embeddings to {len(self.tokenizer)}")
|
| 130 |
+
|
| 131 |
+
model_with_lora = PeftModel.from_pretrained(base_model, model_path)
|
| 132 |
+
self.model = model_with_lora.merge_and_unload()
|
| 133 |
+
logger.info("LoRA adapter loaded and merged successfully")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.info(f"LoRA load failed ({e}), loading as standalone model...")
|
| 136 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 137 |
+
|
| 138 |
+
# Add LoRA for training
|
| 139 |
+
lora_config = LoraConfig(
|
| 140 |
+
r=8,
|
| 141 |
+
lora_alpha=16,
|
| 142 |
+
target_modules=["c_attn"],
|
| 143 |
+
lora_dropout=0.05,
|
| 144 |
+
bias="none",
|
| 145 |
+
)
|
| 146 |
+
self.model = get_peft_model(self.model, lora_config)
|
| 147 |
+
self.model = self.model.to(self.device)
|
| 148 |
+
self.model.train()
|
| 149 |
+
|
| 150 |
+
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 151 |
+
logger.info(f"Model loaded with {trainable} trainable params")
|
| 152 |
+
|
| 153 |
+
def _build_prompt(self, ops: list = None) -> str:
|
| 154 |
+
"""Build JSON format prompt."""
|
| 155 |
+
vars_list = [f"x_{i+1}" for i in range(self.n_vars)]
|
| 156 |
+
|
| 157 |
+
if ops is None:
|
| 158 |
+
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
|
| 159 |
+
else:
|
| 160 |
+
ops_list = ops
|
| 161 |
+
|
| 162 |
+
prompt = json.dumps({
|
| 163 |
+
"vars": vars_list,
|
| 164 |
+
"ops": ops_list,
|
| 165 |
+
"cons": "C",
|
| 166 |
+
"expr": ""
|
| 167 |
+
})
|
| 168 |
+
prompt = prompt[:-2]
|
| 169 |
+
return prompt
|
| 170 |
+
|
| 171 |
+
def extract_expression(self, text: str) -> str:
|
| 172 |
+
"""Extract expression from generated text."""
|
| 173 |
+
try:
|
| 174 |
+
eos_token = "<|endoftext|>"
|
| 175 |
+
if eos_token in text:
|
| 176 |
+
text = text[:text.index(eos_token)]
|
| 177 |
+
|
| 178 |
+
if '"expr": "' in text:
|
| 179 |
+
start = text.index('"expr": "') + len('"expr": "')
|
| 180 |
+
remaining = text[start:]
|
| 181 |
+
for terminator in ['"}', '"']:
|
| 182 |
+
if terminator in remaining:
|
| 183 |
+
return remaining[:remaining.index(terminator)].strip()
|
| 184 |
+
return remaining.strip()
|
| 185 |
+
|
| 186 |
+
if '"expr": ' in text:
|
| 187 |
+
start = text.index('"expr": ') + len('"expr": ')
|
| 188 |
+
remaining = text[start:]
|
| 189 |
+
if '"}' in remaining:
|
| 190 |
+
return remaining[:remaining.index('"}')].strip()
|
| 191 |
+
return remaining.strip(' "')
|
| 192 |
+
|
| 193 |
+
except (ValueError, IndexError):
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
if '"expr"' in text:
|
| 197 |
+
return text.split('"expr"')[-1].strip(' ":{}')
|
| 198 |
+
return text.strip()
|
| 199 |
+
|
| 200 |
+
def compute_r2(self, expression_str: str) -> Tuple[float, bool]:
|
| 201 |
+
"""Compute R^2 score."""
|
| 202 |
+
if not expression_str or expression_str.isspace():
|
| 203 |
+
return -1.0, False
|
| 204 |
+
|
| 205 |
+
if 'C' in expression_str:
|
| 206 |
+
expression_str = expression_str.replace('C', '1')
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
expr = Expression(expression_str, is_prefix=False)
|
| 210 |
+
if not expr.is_valid_on_dataset(self.X):
|
| 211 |
+
return -1.0, False
|
| 212 |
+
|
| 213 |
+
y_pred = expr.evaluate(self.X)
|
| 214 |
+
if not np.all(np.isfinite(y_pred)):
|
| 215 |
+
return -1.0, False
|
| 216 |
+
|
| 217 |
+
ss_res = np.sum((self.y - y_pred) ** 2)
|
| 218 |
+
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
|
| 219 |
+
|
| 220 |
+
if ss_tot == 0:
|
| 221 |
+
return 0.0, True
|
| 222 |
+
|
| 223 |
+
r2 = 1 - (ss_res / ss_tot)
|
| 224 |
+
return float(np.clip(r2, -1.0, 1.0)), True
|
| 225 |
+
except Exception:
|
| 226 |
+
return -1.0, False
|
| 227 |
+
|
| 228 |
+
def shape_reward(self, r2: float, is_valid: bool) -> float:
|
| 229 |
+
"""Shape reward for better learning signal."""
|
| 230 |
+
if not is_valid:
|
| 231 |
+
return -0.1 # Small penalty, not -1.0
|
| 232 |
+
|
| 233 |
+
# Bonus for high R²
|
| 234 |
+
if r2 >= 0.99:
|
| 235 |
+
return 2.0 # Big bonus for near-perfect
|
| 236 |
+
elif r2 >= 0.9:
|
| 237 |
+
return r2 * 1.5
|
| 238 |
+
elif r2 >= 0.5:
|
| 239 |
+
return r2 * 1.2
|
| 240 |
+
elif r2 >= 0:
|
| 241 |
+
return r2
|
| 242 |
+
else:
|
| 243 |
+
return r2 * 0.5 # Reduce negative penalty
|
| 244 |
+
|
| 245 |
+
def generate_group(self, max_new_tokens: int = 50) -> List[Dict]:
|
| 246 |
+
"""Generate a group of expressions."""
|
| 247 |
+
results = []
|
| 248 |
+
|
| 249 |
+
for _ in range(self.group_size):
|
| 250 |
+
generated_ids = self.prompt_ids.clone()
|
| 251 |
+
generated_tokens = []
|
| 252 |
+
|
| 253 |
+
# Phase 1: Generate tokens
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
for _ in range(max_new_tokens):
|
| 256 |
+
outputs = self.model(generated_ids)
|
| 257 |
+
logits = outputs.logits[:, -1, :] / self.current_temp
|
| 258 |
+
|
| 259 |
+
probs = F.softmax(logits, dim=-1)
|
| 260 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 261 |
+
generated_tokens.append(next_token.item())
|
| 262 |
+
|
| 263 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 264 |
+
|
| 265 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 269 |
+
if '"}' in text[len(self.prompt):]:
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
# Decode and evaluate
|
| 273 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 274 |
+
expr_str = self.extract_expression(text)
|
| 275 |
+
r2, is_valid = self.compute_r2(expr_str)
|
| 276 |
+
reward = self.shape_reward(r2, is_valid)
|
| 277 |
+
|
| 278 |
+
# Phase 2: Compute log probs with gradients
|
| 279 |
+
if len(generated_tokens) > 0:
|
| 280 |
+
full_ids = torch.cat([
|
| 281 |
+
self.prompt_ids,
|
| 282 |
+
torch.tensor([generated_tokens], device=self.device)
|
| 283 |
+
], dim=1)
|
| 284 |
+
|
| 285 |
+
outputs = self.model(full_ids[:, :-1])
|
| 286 |
+
logits = outputs.logits / self.current_temp
|
| 287 |
+
|
| 288 |
+
prompt_len = self.prompt_ids.shape[1]
|
| 289 |
+
gen_logits = logits[:, prompt_len-1:, :]
|
| 290 |
+
|
| 291 |
+
log_probs_all = F.log_softmax(gen_logits, dim=-1)
|
| 292 |
+
probs_all = F.softmax(gen_logits, dim=-1)
|
| 293 |
+
|
| 294 |
+
target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0)
|
| 295 |
+
selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1)
|
| 296 |
+
total_log_prob = selected_log_probs.sum()
|
| 297 |
+
|
| 298 |
+
# Entropy for exploration
|
| 299 |
+
entropy_per_pos = -(probs_all * log_probs_all).sum(dim=-1)
|
| 300 |
+
total_entropy = entropy_per_pos.mean()
|
| 301 |
+
else:
|
| 302 |
+
total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True)
|
| 303 |
+
total_entropy = torch.tensor(0.0, device=self.device)
|
| 304 |
+
|
| 305 |
+
results.append({
|
| 306 |
+
"text": text,
|
| 307 |
+
"expression": expr_str,
|
| 308 |
+
"r2": r2,
|
| 309 |
+
"is_valid": is_valid,
|
| 310 |
+
"reward": reward,
|
| 311 |
+
"log_prob": total_log_prob,
|
| 312 |
+
"entropy": total_entropy,
|
| 313 |
+
})
|
| 314 |
+
|
| 315 |
+
# Track best
|
| 316 |
+
if is_valid:
|
| 317 |
+
self.discovered_expressions[expr_str] = max(
|
| 318 |
+
self.discovered_expressions.get(expr_str, -np.inf), r2
|
| 319 |
+
)
|
| 320 |
+
self.reward_buffer.append(reward)
|
| 321 |
+
|
| 322 |
+
if r2 > self.best_r2:
|
| 323 |
+
self.best_r2 = r2
|
| 324 |
+
self.best_expression = expr_str
|
| 325 |
+
|
| 326 |
+
if self.device.type == "cuda":
|
| 327 |
+
torch.cuda.empty_cache()
|
| 328 |
+
|
| 329 |
+
return results
|
| 330 |
+
|
| 331 |
+
def compute_advantages(self, results: List[Dict]) -> Tuple[List[float], dict]:
|
| 332 |
+
"""
|
| 333 |
+
Compute improved GRPO advantages.
|
| 334 |
+
|
| 335 |
+
Key improvement: Only use VALID expressions for group statistics.
|
| 336 |
+
Invalid expressions get a fixed small negative advantage.
|
| 337 |
+
"""
|
| 338 |
+
valid_results = [r for r in results if r["is_valid"]]
|
| 339 |
+
valid_rewards = [r["reward"] for r in valid_results]
|
| 340 |
+
|
| 341 |
+
stats = {
|
| 342 |
+
"valid_count": len(valid_results),
|
| 343 |
+
"total_count": len(results),
|
| 344 |
+
"valid_ratio": len(valid_results) / len(results),
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
# If too few valid expressions, use EMA baseline only
|
| 348 |
+
if len(valid_rewards) < 2:
|
| 349 |
+
advantages = []
|
| 350 |
+
for r in results:
|
| 351 |
+
if r["is_valid"]:
|
| 352 |
+
adv = r["reward"] - self.ema_baseline
|
| 353 |
+
else:
|
| 354 |
+
adv = -0.5 # Fixed penalty for invalid
|
| 355 |
+
advantages.append(adv)
|
| 356 |
+
stats["method"] = "ema_only"
|
| 357 |
+
return advantages, stats
|
| 358 |
+
|
| 359 |
+
# Compute group statistics from valid expressions only
|
| 360 |
+
group_mean = np.mean(valid_rewards)
|
| 361 |
+
group_std = np.std(valid_rewards)
|
| 362 |
+
|
| 363 |
+
# Update EMA baseline
|
| 364 |
+
self.ema_baseline = self.ema_decay * self.ema_baseline + (1 - self.ema_decay) * group_mean
|
| 365 |
+
|
| 366 |
+
# Hybrid baseline: combine group mean with EMA
|
| 367 |
+
hybrid_baseline = 0.7 * group_mean + 0.3 * self.ema_baseline
|
| 368 |
+
|
| 369 |
+
# Avoid division by zero
|
| 370 |
+
if group_std < 1e-8:
|
| 371 |
+
group_std = 1.0
|
| 372 |
+
|
| 373 |
+
# Compute advantages
|
| 374 |
+
advantages = []
|
| 375 |
+
for r in results:
|
| 376 |
+
if r["is_valid"]:
|
| 377 |
+
# Normalized advantage for valid expressions
|
| 378 |
+
adv = (r["reward"] - hybrid_baseline) / group_std
|
| 379 |
+
# Clip to prevent extreme updates
|
| 380 |
+
adv = np.clip(adv, -self.advantage_clip, self.advantage_clip)
|
| 381 |
+
else:
|
| 382 |
+
# Small fixed penalty for invalid (doesn't pollute group stats)
|
| 383 |
+
adv = -0.3
|
| 384 |
+
advantages.append(adv)
|
| 385 |
+
|
| 386 |
+
stats["method"] = "hybrid"
|
| 387 |
+
stats["group_mean"] = group_mean
|
| 388 |
+
stats["group_std"] = group_std
|
| 389 |
+
stats["ema_baseline"] = self.ema_baseline
|
| 390 |
+
|
| 391 |
+
return advantages, stats
|
| 392 |
+
|
| 393 |
+
def train_step(self, num_groups: int = 2) -> dict:
|
| 394 |
+
"""Perform one training step."""
|
| 395 |
+
self.model.train()
|
| 396 |
+
|
| 397 |
+
all_results = []
|
| 398 |
+
all_advantages = []
|
| 399 |
+
total_policy_loss = 0.0
|
| 400 |
+
total_entropy_loss = 0.0
|
| 401 |
+
skipped_groups = 0
|
| 402 |
+
|
| 403 |
+
self.optimizer.zero_grad()
|
| 404 |
+
|
| 405 |
+
for _ in range(num_groups):
|
| 406 |
+
if self.device.type == "cuda":
|
| 407 |
+
torch.cuda.empty_cache()
|
| 408 |
+
|
| 409 |
+
# Generate group
|
| 410 |
+
group_results = self.generate_group()
|
| 411 |
+
all_results.extend(group_results)
|
| 412 |
+
|
| 413 |
+
# Compute advantages
|
| 414 |
+
advantages, adv_stats = self.compute_advantages(group_results)
|
| 415 |
+
all_advantages.extend(advantages)
|
| 416 |
+
|
| 417 |
+
# Skip update if too few valid expressions
|
| 418 |
+
if adv_stats["valid_ratio"] < self.min_valid_ratio:
|
| 419 |
+
skipped_groups += 1
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
# Compute loss
|
| 423 |
+
policy_loss = torch.tensor(0.0, device=self.device)
|
| 424 |
+
entropy_loss = torch.tensor(0.0, device=self.device)
|
| 425 |
+
valid_count = 0
|
| 426 |
+
|
| 427 |
+
for result, advantage in zip(group_results, advantages):
|
| 428 |
+
if result["is_valid"] and advantage != 0:
|
| 429 |
+
policy_loss = policy_loss - result["log_prob"] * advantage
|
| 430 |
+
entropy_loss = entropy_loss - result["entropy"]
|
| 431 |
+
valid_count += 1
|
| 432 |
+
|
| 433 |
+
if valid_count > 0:
|
| 434 |
+
policy_loss = policy_loss / valid_count
|
| 435 |
+
entropy_loss = entropy_loss / valid_count
|
| 436 |
+
|
| 437 |
+
# Combined loss
|
| 438 |
+
loss = policy_loss + self.entropy_coef * entropy_loss
|
| 439 |
+
loss = loss / num_groups
|
| 440 |
+
loss.backward()
|
| 441 |
+
|
| 442 |
+
total_policy_loss += policy_loss.item()
|
| 443 |
+
total_entropy_loss += entropy_loss.item()
|
| 444 |
+
|
| 445 |
+
# Only update if we had valid groups
|
| 446 |
+
if skipped_groups < num_groups:
|
| 447 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 448 |
+
self.optimizer.step()
|
| 449 |
+
self.scheduler.step()
|
| 450 |
+
|
| 451 |
+
# Statistics
|
| 452 |
+
r2_values = [r["r2"] for r in all_results]
|
| 453 |
+
valid_mask = [r["is_valid"] for r in all_results]
|
| 454 |
+
valid_r2 = [r2 for r2, v in zip(r2_values, valid_mask) if v]
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
"valid_count": int(sum(valid_mask)),
|
| 458 |
+
"total_count": len(all_results),
|
| 459 |
+
"valid_rate": sum(valid_mask) / len(all_results) if all_results else 0,
|
| 460 |
+
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else 0.0,
|
| 461 |
+
"max_r2": float(max(r2_values)) if r2_values else 0.0,
|
| 462 |
+
"mean_advantage": float(np.mean(all_advantages)) if all_advantages else 0.0,
|
| 463 |
+
"ema_baseline": self.ema_baseline,
|
| 464 |
+
"policy_loss": total_policy_loss / max(num_groups - skipped_groups, 1),
|
| 465 |
+
"entropy_loss": total_entropy_loss / max(num_groups - skipped_groups, 1),
|
| 466 |
+
"lr": self.scheduler.get_last_lr()[0],
|
| 467 |
+
"temperature": self.current_temp,
|
| 468 |
+
"skipped_groups": skipped_groups,
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
def anneal_temperature(self, epoch: int, total_epochs: int):
|
| 472 |
+
"""Anneal temperature from initial to minimum."""
|
| 473 |
+
progress = epoch / total_epochs
|
| 474 |
+
self.current_temp = self.initial_temp - progress * (self.initial_temp - self.min_temp)
|
| 475 |
+
|
| 476 |
+
def run(
|
| 477 |
+
self,
|
| 478 |
+
epochs: int = 50,
|
| 479 |
+
num_groups: int = 2,
|
| 480 |
+
target_r2: float = 0.99,
|
| 481 |
+
patience: int = 20,
|
| 482 |
+
) -> dict:
|
| 483 |
+
"""Run improved GRPO training."""
|
| 484 |
+
logger.info("=" * 60)
|
| 485 |
+
logger.info("IMPROVED GRPO SYMBOLIC REGRESSION")
|
| 486 |
+
logger.info("=" * 60)
|
| 487 |
+
logger.info(f"Epochs: {epochs}")
|
| 488 |
+
logger.info(f"Group size: {self.group_size}")
|
| 489 |
+
logger.info(f"Num groups: {num_groups}")
|
| 490 |
+
logger.info(f"Effective batch: {self.group_size * num_groups}")
|
| 491 |
+
logger.info(f"Entropy coef: {self.entropy_coef}")
|
| 492 |
+
logger.info(f"Advantage clip: {self.advantage_clip}")
|
| 493 |
+
logger.info(f"Min valid ratio: {self.min_valid_ratio}")
|
| 494 |
+
logger.info(f"Target R^2: {target_r2}")
|
| 495 |
+
logger.info("=" * 60)
|
| 496 |
+
|
| 497 |
+
no_improvement_count = 0
|
| 498 |
+
best_r2_at_start = self.best_r2
|
| 499 |
+
|
| 500 |
+
for epoch in range(1, epochs + 1):
|
| 501 |
+
# Anneal temperature
|
| 502 |
+
self.anneal_temperature(epoch, epochs)
|
| 503 |
+
|
| 504 |
+
stats = self.train_step(num_groups)
|
| 505 |
+
self.history.append({
|
| 506 |
+
"epoch": epoch,
|
| 507 |
+
**stats,
|
| 508 |
+
"best_r2": self.best_r2,
|
| 509 |
+
})
|
| 510 |
+
|
| 511 |
+
logger.info(
|
| 512 |
+
f"Epoch {epoch:3d} | "
|
| 513 |
+
f"Valid: {stats['valid_count']}/{stats['total_count']} | "
|
| 514 |
+
f"Mean R²: {stats['mean_r2']:.4f} | "
|
| 515 |
+
f"Best: {self.best_r2:.4f} | "
|
| 516 |
+
f"EMA: {stats['ema_baseline']:.3f} | "
|
| 517 |
+
f"Temp: {stats['temperature']:.2f} | "
|
| 518 |
+
f"LR: {stats['lr']:.2e}"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Check for target
|
| 522 |
+
if self.best_r2 >= target_r2:
|
| 523 |
+
logger.info(f"Target R^2 {target_r2} reached at epoch {epoch}!")
|
| 524 |
+
break
|
| 525 |
+
|
| 526 |
+
# Early stopping
|
| 527 |
+
if self.best_r2 > best_r2_at_start:
|
| 528 |
+
best_r2_at_start = self.best_r2
|
| 529 |
+
no_improvement_count = 0
|
| 530 |
+
else:
|
| 531 |
+
no_improvement_count += 1
|
| 532 |
+
|
| 533 |
+
if no_improvement_count >= patience:
|
| 534 |
+
logger.info(f"No improvement for {patience} epochs. Early stopping.")
|
| 535 |
+
break
|
| 536 |
+
|
| 537 |
+
# Final results
|
| 538 |
+
logger.info("")
|
| 539 |
+
logger.info("=" * 60)
|
| 540 |
+
logger.info("FINAL RESULTS")
|
| 541 |
+
logger.info("=" * 60)
|
| 542 |
+
logger.info(f"Best R^2: {self.best_r2:.4f}")
|
| 543 |
+
logger.info(f"Best expression: {self.best_expression}")
|
| 544 |
+
logger.info(f"Unique expressions discovered: {len(self.discovered_expressions)}")
|
| 545 |
+
|
| 546 |
+
top_exprs = sorted(
|
| 547 |
+
self.discovered_expressions.items(),
|
| 548 |
+
key=lambda x: x[1],
|
| 549 |
+
reverse=True
|
| 550 |
+
)[:5]
|
| 551 |
+
logger.info("Top 5 expressions:")
|
| 552 |
+
for expr, r2 in top_exprs:
|
| 553 |
+
logger.info(f" R²={r2:.4f}: {expr}")
|
| 554 |
+
|
| 555 |
+
# Save results
|
| 556 |
+
results = {
|
| 557 |
+
"algorithm": "ImprovedGRPO",
|
| 558 |
+
"best_r2": self.best_r2,
|
| 559 |
+
"best_expression": self.best_expression,
|
| 560 |
+
"history": self.history,
|
| 561 |
+
"discovered_expressions": dict(list(self.discovered_expressions.items())[:100]),
|
| 562 |
+
"config": {
|
| 563 |
+
"group_size": self.group_size,
|
| 564 |
+
"num_groups": num_groups,
|
| 565 |
+
"learning_rate": self.learning_rate,
|
| 566 |
+
"entropy_coef": self.entropy_coef,
|
| 567 |
+
"advantage_clip": self.advantage_clip,
|
| 568 |
+
"min_valid_ratio": self.min_valid_ratio,
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 573 |
+
output_path = self.output_dir / f"results_grpo_improved_{timestamp}.json"
|
| 574 |
+
with open(output_path, "w") as f:
|
| 575 |
+
json.dump(results, f, indent=2)
|
| 576 |
+
logger.info(f"Results saved to: {output_path}")
|
| 577 |
+
|
| 578 |
+
return results
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def main():
|
| 582 |
+
parser = argparse.ArgumentParser(description="Improved GRPO for Symbolic Regression")
|
| 583 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 584 |
+
parser.add_argument("--dataset", type=str, required=True)
|
| 585 |
+
parser.add_argument("--output_dir", type=str, default="./output/grpo")
|
| 586 |
+
parser.add_argument("--epochs", type=int, default=50)
|
| 587 |
+
parser.add_argument("--group_size", type=int, default=16)
|
| 588 |
+
parser.add_argument("--num_groups", type=int, default=2)
|
| 589 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 590 |
+
parser.add_argument("--target_r2", type=float, default=0.99)
|
| 591 |
+
parser.add_argument("--entropy_coef", type=float, default=0.01)
|
| 592 |
+
args = parser.parse_args()
|
| 593 |
+
|
| 594 |
+
# Load dataset
|
| 595 |
+
import pandas as pd
|
| 596 |
+
df = pd.read_csv(args.dataset)
|
| 597 |
+
|
| 598 |
+
x_cols = [c for c in df.columns if c.startswith('x_')]
|
| 599 |
+
X = df[x_cols].values
|
| 600 |
+
y = df['y'].values
|
| 601 |
+
|
| 602 |
+
logger.info(f"Loaded dataset: {args.dataset}")
|
| 603 |
+
logger.info(f" Samples: {len(df)}, Variables: {len(x_cols)}")
|
| 604 |
+
|
| 605 |
+
# Create trainer
|
| 606 |
+
grpo = ImprovedGRPO(
|
| 607 |
+
model_path=args.model_path,
|
| 608 |
+
X=X,
|
| 609 |
+
y=y,
|
| 610 |
+
output_dir=args.output_dir,
|
| 611 |
+
learning_rate=args.learning_rate,
|
| 612 |
+
group_size=args.group_size,
|
| 613 |
+
entropy_coef=args.entropy_coef,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# Run training
|
| 617 |
+
results = grpo.run(
|
| 618 |
+
epochs=args.epochs,
|
| 619 |
+
num_groups=args.num_groups,
|
| 620 |
+
target_r2=args.target_r2,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
if __name__ == "__main__":
|
| 625 |
+
main()
|
2_training/reinforcement/grpo_symbolic.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO (Group Relative Policy Optimization) for Symbolic Regression
|
| 4 |
+
|
| 5 |
+
Based on DeepSeek-R1 approach:
|
| 6 |
+
- Generate a group of N samples
|
| 7 |
+
- Compute advantages relative to group mean/std
|
| 8 |
+
- No external baseline needed
|
| 9 |
+
|
| 10 |
+
Comparison with REINFORCE:
|
| 11 |
+
- REINFORCE: advantage = reward - moving_average_baseline
|
| 12 |
+
- GRPO: advantage = (reward - group_mean) / group_std
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import json
|
| 18 |
+
import argparse
|
| 19 |
+
import logging
|
| 20 |
+
import datetime
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Dict, Tuple
|
| 23 |
+
from copy import deepcopy
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
# Add project root to path
|
| 30 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 31 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 32 |
+
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
|
| 33 |
+
|
| 34 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 35 |
+
from peft import PeftModel, LoraConfig, get_peft_model
|
| 36 |
+
|
| 37 |
+
from expression import Expression
|
| 38 |
+
from dataset import RegressionDataset
|
| 39 |
+
|
| 40 |
+
# Configure logging
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
level=logging.INFO,
|
| 43 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class GRPO:
|
| 49 |
+
"""Group Relative Policy Optimization for symbolic regression."""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
model_path: str,
|
| 54 |
+
X: np.ndarray,
|
| 55 |
+
y: np.ndarray,
|
| 56 |
+
output_dir: str = "./output/grpo",
|
| 57 |
+
learning_rate: float = 5e-5,
|
| 58 |
+
device: str = None,
|
| 59 |
+
group_size: int = 8, # Number of samples per group
|
| 60 |
+
kl_coef: float = 0.01, # KL penalty coefficient
|
| 61 |
+
clip_range: float = 0.2, # PPO-style clipping (optional)
|
| 62 |
+
):
|
| 63 |
+
self.X = X
|
| 64 |
+
self.y = y
|
| 65 |
+
self.n_vars = X.shape[1]
|
| 66 |
+
self.output_dir = Path(output_dir)
|
| 67 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
self.learning_rate = learning_rate
|
| 69 |
+
self.group_size = group_size
|
| 70 |
+
self.kl_coef = kl_coef
|
| 71 |
+
self.clip_range = clip_range
|
| 72 |
+
|
| 73 |
+
# Device
|
| 74 |
+
if device:
|
| 75 |
+
self.device = torch.device(device)
|
| 76 |
+
else:
|
| 77 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 78 |
+
logger.info(f"Using device: {self.device}")
|
| 79 |
+
|
| 80 |
+
# Load model
|
| 81 |
+
self._load_model(model_path)
|
| 82 |
+
|
| 83 |
+
# Keep reference model for KL penalty
|
| 84 |
+
self.ref_model = None # Will be set after first update
|
| 85 |
+
|
| 86 |
+
# Build prompt
|
| 87 |
+
self.prompt = self._build_prompt()
|
| 88 |
+
self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device)
|
| 89 |
+
|
| 90 |
+
# Optimizer
|
| 91 |
+
self.optimizer = torch.optim.AdamW(
|
| 92 |
+
self.model.parameters(),
|
| 93 |
+
lr=learning_rate,
|
| 94 |
+
weight_decay=0.01
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Scheduler
|
| 98 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 99 |
+
self.optimizer, T_0=10, T_mult=2
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Tracking
|
| 103 |
+
self.best_r2 = -np.inf
|
| 104 |
+
self.best_expression = None
|
| 105 |
+
self.history = []
|
| 106 |
+
self.discovered_expressions: Dict[str, float] = {}
|
| 107 |
+
|
| 108 |
+
def _load_model(self, model_path: str):
|
| 109 |
+
"""Load model and tokenizer."""
|
| 110 |
+
logger.info(f"Loading model from {model_path}")
|
| 111 |
+
|
| 112 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 113 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
logger.info("Attempting to load as LoRA adapter...")
|
| 117 |
+
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 118 |
+
if len(self.tokenizer) != base_model.config.vocab_size:
|
| 119 |
+
base_model.resize_token_embeddings(len(self.tokenizer))
|
| 120 |
+
logger.info(f"Resized embeddings to {len(self.tokenizer)}")
|
| 121 |
+
|
| 122 |
+
model_with_lora = PeftModel.from_pretrained(base_model, model_path)
|
| 123 |
+
self.model = model_with_lora.merge_and_unload()
|
| 124 |
+
logger.info("LoRA adapter loaded and merged successfully")
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.info(f"LoRA load failed ({e}), loading as standalone model...")
|
| 127 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 128 |
+
|
| 129 |
+
# Add LoRA for training
|
| 130 |
+
lora_config = LoraConfig(
|
| 131 |
+
r=8,
|
| 132 |
+
lora_alpha=16,
|
| 133 |
+
target_modules=["c_attn"],
|
| 134 |
+
lora_dropout=0.05,
|
| 135 |
+
bias="none",
|
| 136 |
+
)
|
| 137 |
+
self.model = get_peft_model(self.model, lora_config)
|
| 138 |
+
self.model = self.model.to(self.device)
|
| 139 |
+
self.model.train()
|
| 140 |
+
|
| 141 |
+
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 142 |
+
logger.info(f"Model loaded with {trainable} trainable params")
|
| 143 |
+
|
| 144 |
+
def _build_prompt(self, ops: list = None) -> str:
|
| 145 |
+
"""Build JSON format prompt."""
|
| 146 |
+
vars_list = [f"x_{i+1}" for i in range(self.n_vars)]
|
| 147 |
+
|
| 148 |
+
if ops is None:
|
| 149 |
+
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
|
| 150 |
+
else:
|
| 151 |
+
ops_list = ops
|
| 152 |
+
|
| 153 |
+
prompt = json.dumps({
|
| 154 |
+
"vars": vars_list,
|
| 155 |
+
"ops": ops_list,
|
| 156 |
+
"cons": "C",
|
| 157 |
+
"expr": ""
|
| 158 |
+
})
|
| 159 |
+
prompt = prompt[:-2]
|
| 160 |
+
return prompt
|
| 161 |
+
|
| 162 |
+
def extract_expression(self, text: str) -> str:
|
| 163 |
+
"""Extract expression from generated text."""
|
| 164 |
+
try:
|
| 165 |
+
eos_token = "<|endoftext|>"
|
| 166 |
+
if eos_token in text:
|
| 167 |
+
text = text[:text.index(eos_token)]
|
| 168 |
+
|
| 169 |
+
if '"expr": "' in text:
|
| 170 |
+
start = text.index('"expr": "') + len('"expr": "')
|
| 171 |
+
remaining = text[start:]
|
| 172 |
+
for terminator in ['"}', '"']:
|
| 173 |
+
if terminator in remaining:
|
| 174 |
+
return remaining[:remaining.index(terminator)].strip()
|
| 175 |
+
return remaining.strip()
|
| 176 |
+
|
| 177 |
+
if '"expr": ' in text:
|
| 178 |
+
start = text.index('"expr": ') + len('"expr": ')
|
| 179 |
+
remaining = text[start:]
|
| 180 |
+
if '"}' in remaining:
|
| 181 |
+
return remaining[:remaining.index('"}')].strip()
|
| 182 |
+
return remaining.strip(' "')
|
| 183 |
+
|
| 184 |
+
except (ValueError, IndexError):
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
if '"expr"' in text:
|
| 188 |
+
return text.split('"expr"')[-1].strip(' ":{}')
|
| 189 |
+
return text.strip()
|
| 190 |
+
|
| 191 |
+
def compute_r2(self, expression_str: str) -> Tuple[float, bool]:
|
| 192 |
+
"""Compute R^2 score. Returns (score, is_valid)."""
|
| 193 |
+
if not expression_str or expression_str.isspace():
|
| 194 |
+
return -1.0, False
|
| 195 |
+
|
| 196 |
+
if 'C' in expression_str:
|
| 197 |
+
expression_str = expression_str.replace('C', '1')
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
expr = Expression(expression_str, is_prefix=False)
|
| 201 |
+
if not expr.is_valid_on_dataset(self.X):
|
| 202 |
+
return -1.0, False
|
| 203 |
+
|
| 204 |
+
y_pred = expr.evaluate(self.X)
|
| 205 |
+
if not np.all(np.isfinite(y_pred)):
|
| 206 |
+
return -1.0, False
|
| 207 |
+
|
| 208 |
+
ss_res = np.sum((self.y - y_pred) ** 2)
|
| 209 |
+
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
|
| 210 |
+
|
| 211 |
+
if ss_tot == 0:
|
| 212 |
+
return 0.0, True
|
| 213 |
+
|
| 214 |
+
r2 = 1 - (ss_res / ss_tot)
|
| 215 |
+
return float(np.clip(r2, -1.0, 1.0)), True
|
| 216 |
+
except Exception:
|
| 217 |
+
return -1.0, False
|
| 218 |
+
|
| 219 |
+
def generate_group(
|
| 220 |
+
self,
|
| 221 |
+
temperature: float = 0.7,
|
| 222 |
+
max_new_tokens: int = 50
|
| 223 |
+
) -> List[Dict]:
|
| 224 |
+
"""Generate a group of expressions."""
|
| 225 |
+
results = []
|
| 226 |
+
|
| 227 |
+
for _ in range(self.group_size):
|
| 228 |
+
generated_ids = self.prompt_ids.clone()
|
| 229 |
+
generated_tokens = []
|
| 230 |
+
|
| 231 |
+
# Phase 1: Generate tokens without gradients
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
for _ in range(max_new_tokens):
|
| 234 |
+
outputs = self.model(generated_ids)
|
| 235 |
+
logits = outputs.logits[:, -1, :] / temperature
|
| 236 |
+
|
| 237 |
+
probs = F.softmax(logits, dim=-1)
|
| 238 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 239 |
+
generated_tokens.append(next_token.item())
|
| 240 |
+
|
| 241 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 242 |
+
|
| 243 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 244 |
+
break
|
| 245 |
+
|
| 246 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 247 |
+
if '"}' in text[len(self.prompt):]:
|
| 248 |
+
break
|
| 249 |
+
|
| 250 |
+
# Decode and extract expression
|
| 251 |
+
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 252 |
+
expr_str = self.extract_expression(text)
|
| 253 |
+
r2, is_valid = self.compute_r2(expr_str)
|
| 254 |
+
|
| 255 |
+
# Phase 2: Efficient log prob computation
|
| 256 |
+
if len(generated_tokens) > 0:
|
| 257 |
+
full_ids = torch.cat([
|
| 258 |
+
self.prompt_ids,
|
| 259 |
+
torch.tensor([generated_tokens], device=self.device)
|
| 260 |
+
], dim=1)
|
| 261 |
+
|
| 262 |
+
outputs = self.model(full_ids[:, :-1])
|
| 263 |
+
logits = outputs.logits / temperature
|
| 264 |
+
|
| 265 |
+
prompt_len = self.prompt_ids.shape[1]
|
| 266 |
+
gen_logits = logits[:, prompt_len-1:, :]
|
| 267 |
+
|
| 268 |
+
log_probs_all = F.log_softmax(gen_logits, dim=-1)
|
| 269 |
+
|
| 270 |
+
target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0)
|
| 271 |
+
selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1)
|
| 272 |
+
total_log_prob = selected_log_probs.sum()
|
| 273 |
+
else:
|
| 274 |
+
total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True)
|
| 275 |
+
|
| 276 |
+
results.append({
|
| 277 |
+
"text": text,
|
| 278 |
+
"expression": expr_str,
|
| 279 |
+
"r2": r2,
|
| 280 |
+
"is_valid": is_valid,
|
| 281 |
+
"log_prob": total_log_prob,
|
| 282 |
+
"generated_tokens": generated_tokens,
|
| 283 |
+
})
|
| 284 |
+
|
| 285 |
+
# Track best
|
| 286 |
+
if is_valid:
|
| 287 |
+
self.discovered_expressions[expr_str] = max(
|
| 288 |
+
self.discovered_expressions.get(expr_str, -np.inf), r2
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if r2 > self.best_r2:
|
| 292 |
+
self.best_r2 = r2
|
| 293 |
+
self.best_expression = expr_str
|
| 294 |
+
|
| 295 |
+
# Clear cache
|
| 296 |
+
if self.device.type == "cuda":
|
| 297 |
+
torch.cuda.empty_cache()
|
| 298 |
+
|
| 299 |
+
return results
|
| 300 |
+
|
| 301 |
+
def compute_group_advantages(self, results: List[Dict]) -> List[float]:
|
| 302 |
+
"""
|
| 303 |
+
Compute GRPO advantages: (reward - mean) / std
|
| 304 |
+
|
| 305 |
+
This is the key difference from REINFORCE:
|
| 306 |
+
- REINFORCE uses external moving average baseline
|
| 307 |
+
- GRPO uses within-group statistics
|
| 308 |
+
"""
|
| 309 |
+
# Get rewards (R² values, with penalty for invalid)
|
| 310 |
+
rewards = []
|
| 311 |
+
for r in results:
|
| 312 |
+
if r["is_valid"]:
|
| 313 |
+
rewards.append(r["r2"])
|
| 314 |
+
else:
|
| 315 |
+
rewards.append(-0.1) # Small penalty for invalid
|
| 316 |
+
|
| 317 |
+
rewards = np.array(rewards)
|
| 318 |
+
|
| 319 |
+
# Compute group statistics
|
| 320 |
+
mean_reward = np.mean(rewards)
|
| 321 |
+
std_reward = np.std(rewards)
|
| 322 |
+
|
| 323 |
+
# Avoid division by zero
|
| 324 |
+
if std_reward < 1e-8:
|
| 325 |
+
std_reward = 1.0
|
| 326 |
+
|
| 327 |
+
# Compute normalized advantages
|
| 328 |
+
advantages = (rewards - mean_reward) / std_reward
|
| 329 |
+
|
| 330 |
+
return advantages.tolist(), mean_reward, std_reward
|
| 331 |
+
|
| 332 |
+
def train_step(self, num_groups: int = 4) -> dict:
|
| 333 |
+
"""
|
| 334 |
+
Perform one GRPO training step.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
num_groups: Number of groups to sample (effective batch = num_groups * group_size)
|
| 338 |
+
"""
|
| 339 |
+
self.model.train()
|
| 340 |
+
|
| 341 |
+
all_results = []
|
| 342 |
+
all_advantages = []
|
| 343 |
+
total_loss = 0.0
|
| 344 |
+
|
| 345 |
+
self.optimizer.zero_grad()
|
| 346 |
+
|
| 347 |
+
# Generate multiple groups
|
| 348 |
+
for _ in range(num_groups):
|
| 349 |
+
if self.device.type == "cuda":
|
| 350 |
+
torch.cuda.empty_cache()
|
| 351 |
+
|
| 352 |
+
# Generate a group of samples
|
| 353 |
+
group_results = self.generate_group()
|
| 354 |
+
all_results.extend(group_results)
|
| 355 |
+
|
| 356 |
+
# Compute group-relative advantages
|
| 357 |
+
advantages, group_mean, group_std = self.compute_group_advantages(group_results)
|
| 358 |
+
all_advantages.extend(advantages)
|
| 359 |
+
|
| 360 |
+
# Compute loss for this group
|
| 361 |
+
group_loss = torch.tensor(0.0, device=self.device)
|
| 362 |
+
valid_count = 0
|
| 363 |
+
|
| 364 |
+
for result, advantage in zip(group_results, advantages):
|
| 365 |
+
if result["is_valid"]:
|
| 366 |
+
# Policy gradient loss with advantage
|
| 367 |
+
group_loss = group_loss - result["log_prob"] * advantage
|
| 368 |
+
valid_count += 1
|
| 369 |
+
|
| 370 |
+
if valid_count > 0:
|
| 371 |
+
group_loss = group_loss / valid_count
|
| 372 |
+
group_loss = group_loss / num_groups # Scale for accumulation
|
| 373 |
+
group_loss.backward()
|
| 374 |
+
total_loss += group_loss.item()
|
| 375 |
+
|
| 376 |
+
# Gradient clipping
|
| 377 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 378 |
+
|
| 379 |
+
# Update
|
| 380 |
+
self.optimizer.step()
|
| 381 |
+
self.scheduler.step()
|
| 382 |
+
|
| 383 |
+
# Statistics
|
| 384 |
+
r2_values = [r["r2"] for r in all_results]
|
| 385 |
+
valid_mask = [r["is_valid"] for r in all_results]
|
| 386 |
+
valid_r2 = [r2 for r2, v in zip(r2_values, valid_mask) if v]
|
| 387 |
+
|
| 388 |
+
return {
|
| 389 |
+
"valid_count": int(sum(valid_mask)),
|
| 390 |
+
"total_count": len(all_results),
|
| 391 |
+
"valid_rate": sum(valid_mask) / len(all_results),
|
| 392 |
+
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else 0.0,
|
| 393 |
+
"max_r2": float(max(r2_values)),
|
| 394 |
+
"mean_advantage": float(np.mean(all_advantages)),
|
| 395 |
+
"std_advantage": float(np.std(all_advantages)),
|
| 396 |
+
"loss": total_loss,
|
| 397 |
+
"lr": self.scheduler.get_last_lr()[0],
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
def run(
|
| 401 |
+
self,
|
| 402 |
+
epochs: int = 50,
|
| 403 |
+
num_groups: int = 4,
|
| 404 |
+
target_r2: float = 0.99,
|
| 405 |
+
patience: int = 20,
|
| 406 |
+
) -> dict:
|
| 407 |
+
"""Run GRPO training."""
|
| 408 |
+
logger.info("=" * 60)
|
| 409 |
+
logger.info("GRPO SYMBOLIC REGRESSION")
|
| 410 |
+
logger.info("=" * 60)
|
| 411 |
+
logger.info(f"Epochs: {epochs}")
|
| 412 |
+
logger.info(f"Group size: {self.group_size}")
|
| 413 |
+
logger.info(f"Num groups: {num_groups}")
|
| 414 |
+
logger.info(f"Effective batch: {self.group_size * num_groups}")
|
| 415 |
+
logger.info(f"Target R^2: {target_r2}")
|
| 416 |
+
logger.info("=" * 60)
|
| 417 |
+
|
| 418 |
+
no_improvement_count = 0
|
| 419 |
+
best_r2_at_start = self.best_r2
|
| 420 |
+
|
| 421 |
+
for epoch in range(1, epochs + 1):
|
| 422 |
+
stats = self.train_step(num_groups)
|
| 423 |
+
self.history.append({
|
| 424 |
+
"epoch": epoch,
|
| 425 |
+
**stats,
|
| 426 |
+
"best_r2": self.best_r2,
|
| 427 |
+
})
|
| 428 |
+
|
| 429 |
+
logger.info(
|
| 430 |
+
f"Epoch {epoch:3d} | "
|
| 431 |
+
f"Valid: {stats['valid_count']}/{stats['total_count']} | "
|
| 432 |
+
f"Mean R²: {stats['mean_r2']:.4f} | "
|
| 433 |
+
f"Best: {self.best_r2:.4f} | "
|
| 434 |
+
f"Adv μ: {stats['mean_advantage']:.3f} σ: {stats['std_advantage']:.3f} | "
|
| 435 |
+
f"LR: {stats['lr']:.2e}"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Check for target
|
| 439 |
+
if self.best_r2 >= target_r2:
|
| 440 |
+
logger.info(f"Target R^2 {target_r2} reached at epoch {epoch}!")
|
| 441 |
+
break
|
| 442 |
+
|
| 443 |
+
# Early stopping
|
| 444 |
+
if self.best_r2 > best_r2_at_start:
|
| 445 |
+
best_r2_at_start = self.best_r2
|
| 446 |
+
no_improvement_count = 0
|
| 447 |
+
else:
|
| 448 |
+
no_improvement_count += 1
|
| 449 |
+
|
| 450 |
+
if no_improvement_count >= patience:
|
| 451 |
+
logger.info(f"No improvement for {patience} epochs. Early stopping.")
|
| 452 |
+
break
|
| 453 |
+
|
| 454 |
+
# Final results
|
| 455 |
+
logger.info("")
|
| 456 |
+
logger.info("=" * 60)
|
| 457 |
+
logger.info("FINAL RESULTS")
|
| 458 |
+
logger.info("=" * 60)
|
| 459 |
+
logger.info(f"Best R^2: {self.best_r2:.4f}")
|
| 460 |
+
logger.info(f"Best expression: {self.best_expression}")
|
| 461 |
+
logger.info(f"Unique expressions discovered: {len(self.discovered_expressions)}")
|
| 462 |
+
|
| 463 |
+
# Top expressions
|
| 464 |
+
top_exprs = sorted(
|
| 465 |
+
self.discovered_expressions.items(),
|
| 466 |
+
key=lambda x: x[1],
|
| 467 |
+
reverse=True
|
| 468 |
+
)[:5]
|
| 469 |
+
logger.info("Top 5 expressions:")
|
| 470 |
+
for expr, r2 in top_exprs:
|
| 471 |
+
logger.info(f" R²={r2:.4f}: {expr}")
|
| 472 |
+
|
| 473 |
+
# Save results
|
| 474 |
+
results = {
|
| 475 |
+
"algorithm": "GRPO",
|
| 476 |
+
"best_r2": self.best_r2,
|
| 477 |
+
"best_expression": self.best_expression,
|
| 478 |
+
"history": self.history,
|
| 479 |
+
"discovered_expressions": dict(list(self.discovered_expressions.items())[:100]),
|
| 480 |
+
"config": {
|
| 481 |
+
"group_size": self.group_size,
|
| 482 |
+
"num_groups": num_groups,
|
| 483 |
+
"learning_rate": self.learning_rate,
|
| 484 |
+
"kl_coef": self.kl_coef,
|
| 485 |
+
}
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 489 |
+
output_path = self.output_dir / f"results_grpo_{timestamp}.json"
|
| 490 |
+
with open(output_path, "w") as f:
|
| 491 |
+
json.dump(results, f, indent=2)
|
| 492 |
+
logger.info(f"Results saved to: {output_path}")
|
| 493 |
+
|
| 494 |
+
return results
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def main():
|
| 498 |
+
parser = argparse.ArgumentParser(description="GRPO for Symbolic Regression")
|
| 499 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 500 |
+
parser.add_argument("--dataset", type=str, required=True)
|
| 501 |
+
parser.add_argument("--output_dir", type=str, default="./output/grpo")
|
| 502 |
+
parser.add_argument("--epochs", type=int, default=50)
|
| 503 |
+
parser.add_argument("--group_size", type=int, default=8)
|
| 504 |
+
parser.add_argument("--num_groups", type=int, default=4)
|
| 505 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 506 |
+
parser.add_argument("--target_r2", type=float, default=0.99)
|
| 507 |
+
args = parser.parse_args()
|
| 508 |
+
|
| 509 |
+
# Load dataset
|
| 510 |
+
import pandas as pd
|
| 511 |
+
df = pd.read_csv(args.dataset)
|
| 512 |
+
|
| 513 |
+
x_cols = [c for c in df.columns if c.startswith('x_')]
|
| 514 |
+
X = df[x_cols].values
|
| 515 |
+
y = df['y'].values
|
| 516 |
+
|
| 517 |
+
logger.info(f"Loaded dataset: {args.dataset}")
|
| 518 |
+
logger.info(f" Samples: {len(df)}, Variables: {len(x_cols)}")
|
| 519 |
+
|
| 520 |
+
# Create GRPO trainer
|
| 521 |
+
grpo = GRPO(
|
| 522 |
+
model_path=args.model_path,
|
| 523 |
+
X=X,
|
| 524 |
+
y=y,
|
| 525 |
+
output_dir=args.output_dir,
|
| 526 |
+
learning_rate=args.learning_rate,
|
| 527 |
+
group_size=args.group_size,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# Run training
|
| 531 |
+
results = grpo.run(
|
| 532 |
+
epochs=args.epochs,
|
| 533 |
+
num_groups=args.num_groups,
|
| 534 |
+
target_r2=args.target_r2,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
if __name__ == "__main__":
|
| 539 |
+
main()
|