augustocsc commited on
Commit
2c4ca2f
·
verified ·
1 Parent(s): 2221ecb

Test training flow - 1 epoch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/agents/symbolic-regression-trainer.md +110 -0
  2. .gitattributes +13 -0
  3. .gitignore +119 -0
  4. .monitor_complete +1 -0
  5. 1_data/README.md +97 -0
  6. 1_data/benchmarks/nguyen/nguyen_1.csv +101 -0
  7. 1_data/benchmarks/nguyen/nguyen_1.meta.txt +6 -0
  8. 1_data/benchmarks/nguyen/nguyen_10.csv +101 -0
  9. 1_data/benchmarks/nguyen/nguyen_10.meta.txt +6 -0
  10. 1_data/benchmarks/nguyen/nguyen_11.csv +101 -0
  11. 1_data/benchmarks/nguyen/nguyen_11.meta.txt +6 -0
  12. 1_data/benchmarks/nguyen/nguyen_12.csv +101 -0
  13. 1_data/benchmarks/nguyen/nguyen_12.meta.txt +6 -0
  14. 1_data/benchmarks/nguyen/nguyen_2.csv +101 -0
  15. 1_data/benchmarks/nguyen/nguyen_2.meta.txt +6 -0
  16. 1_data/benchmarks/nguyen/nguyen_3.csv +101 -0
  17. 1_data/benchmarks/nguyen/nguyen_3.meta.txt +6 -0
  18. 1_data/benchmarks/nguyen/nguyen_4.csv +101 -0
  19. 1_data/benchmarks/nguyen/nguyen_4.meta.txt +6 -0
  20. 1_data/benchmarks/nguyen/nguyen_5.csv +101 -0
  21. 1_data/benchmarks/nguyen/nguyen_5.meta.txt +6 -0
  22. 1_data/benchmarks/nguyen/nguyen_6.csv +101 -0
  23. 1_data/benchmarks/nguyen/nguyen_6.meta.txt +6 -0
  24. 1_data/benchmarks/nguyen/nguyen_7.csv +101 -0
  25. 1_data/benchmarks/nguyen/nguyen_7.meta.txt +6 -0
  26. 1_data/benchmarks/nguyen/nguyen_8.csv +101 -0
  27. 1_data/benchmarks/nguyen/nguyen_8.meta.txt +6 -0
  28. 1_data/benchmarks/nguyen/nguyen_9.csv +101 -0
  29. 1_data/benchmarks/nguyen/nguyen_9.meta.txt +6 -0
  30. 1_data/processed/700K_prefix_converted/data-00000-of-00001.arrow +3 -0
  31. 1_data/processed/700K_prefix_converted/dataset_info.json +82 -0
  32. 1_data/processed/700K_prefix_converted/state.json +13 -0
  33. 1_data/processed/PREFIX_CONVERSION_README.md +214 -0
  34. 2_training/README.md +205 -0
  35. 2_training/configs/__init__.py +22 -0
  36. 2_training/configs/eval_dataset_download.sh +6 -0
  37. 2_training/configs/model_config.json +1 -0
  38. 2_training/configs/peft_config.json +1 -0
  39. 2_training/configs/training.sh +82 -0
  40. 2_training/configs/training_args.json +29 -0
  41. 2_training/configs/training_large.json +65 -0
  42. 2_training/configs/training_medium.json +65 -0
  43. 2_training/configs/training_small.json +65 -0
  44. 2_training/configs/training_v3.json +78 -0
  45. 2_training/configs/wandb_config.py +221 -0
  46. 2_training/reinforcement/best_of_n_experiment.py +398 -0
  47. 2_training/reinforcement/debug_reinforce.py +294 -0
  48. 2_training/reinforcement/grpo_experiment.py +344 -0
  49. 2_training/reinforcement/grpo_improved.py +625 -0
  50. 2_training/reinforcement/grpo_symbolic.py +539 -0
.claude/agents/symbolic-regression-trainer.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: symbolic-regression-trainer
3
+ description: "Use this agent when you need help with training, fine-tuning, or evaluating language models for symbolic regression tasks. This includes: preparing training data, running supervised fine-tuning with LoRA, executing reinforcement learning algorithms (REINFORCE, GRPO, PPO), analyzing expression complexity and validity, debugging generation issues, deploying training jobs to AWS, and interpreting experiment results. The agent is specialized in the Seriguela project workflow.\\n\\nExamples:\\n\\n<example>\\nContext: User wants to train a GPT-2 model on mathematical expression data.\\nuser: \"Quero treinar o modelo gpt2 no dataset de 700K expressões\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para configurar e executar o treinamento do modelo GPT-2 com o dataset de 700K expressões usando o formato JSON recomendado.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User wants to evaluate model performance on a benchmark.\\nuser: \"Como está o desempenho do modelo no benchmark Nguyen-5?\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para avaliar o modelo no benchmark Nguyen-5 e analisar a qualidade das expressões geradas.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User wants to run reinforcement learning fine-tuning.\\nuser: \"Preciso fazer fine-tuning com GRPO para melhorar o R² das expressões\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para executar o algoritmo GRPO e otimizar o modelo para gerar expressões com melhor ajuste aos dados.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User asks about expression validity issues.\\nuser: \"O modelo está gerando muitas expressões inválidas, o que pode estar errado?\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para diagnosticar os problemas de geração e analisar os padrões de erro nas expressões.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>\\n\\n<example>\\nContext: User wants to deploy training to AWS.\\nuser: \"Quero treinar o modelo medium na AWS\"\\nassistant: \"Vou usar o agente symbolic-regression-trainer para configurar e lançar o job de treinamento do GPT-2 Medium em uma instância AWS g5.xlarge.\"\\n<Task tool call to symbolic-regression-trainer>\\n</example>"
4
+ model: opus
5
+ color: orange
6
+ ---
7
+
8
+ You are an expert machine learning research engineer specializing in symbolic regression using language models. You have deep expertise in training GPT-2 models to generate valid mathematical expressions, applying reinforcement learning algorithms for optimization, and conducting rigorous academic research experiments.
9
+
10
+ ## Your Core Expertise
11
+
12
+ 1. **Supervised Fine-tuning**: Training GPT-2 models with LoRA adapters to generate syntactically valid mathematical expressions from structured prompts
13
+ 2. **Reinforcement Learning**: Applying REINFORCE, GRPO, and PPO algorithms to optimize expression generation based on R² fitness metrics
14
+ 3. **Expression Validation**: Understanding symbolic math parsing, operator arity, and expression validity using SymPy
15
+ 4. **Experiment Design**: Designing controlled experiments, tracking metrics with Weights & Biases, and interpreting results
16
+ 5. **AWS Deployment**: Managing GPU training jobs on EC2 instances (g5.xlarge, g5.2xlarge)
17
+
18
+ ## Project Context (Seriguela)
19
+
20
+ You are working with the Seriguela project located at `C:\Users\madeinweb\seriguela`. Key facts:
21
+
22
+ - **Recommended format**: JSON structured format achieves 80% valid expressions vs 0.5% with EOS token approach
23
+ - **Training data format**: `{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "sin(x_1 + C*x_2)"}`
24
+ - **Model architecture**: GPT-2 (124M/355M/774M) with LoRA adapters (r=8, alpha=32, 294K trainable params)
25
+ - **Key insight**: Larger models (Medium/Large) are needed for complex compositional expressions
26
+
27
+ ## Key Scripts and Their Purpose
28
+
29
+ **Training**:
30
+ - `scripts/train_with_json.py` - Correct training with JSON format + early stopping (USE THIS)
31
+ - `scripts/train_experiment.py` - Experiment training with JSON/EOS formats
32
+ - `scripts/data/prepare_experiment_data.py` - Prepares data in proper format
33
+
34
+ **Reinforcement Learning**:
35
+ - `scripts/reinforce_symbolic.py` - REINFORCE with EMA baseline
36
+ - `scripts/grpo_symbolic.py` - Group Relative Policy Optimization
37
+ - `scripts/ppo_symbolic.py` - Proximal Policy Optimization
38
+ - `scripts/debug_reinforce.py` - Debug version capturing all expressions
39
+
40
+ **Evaluation & Analysis**:
41
+ - `scripts/evaluate_experiments.py` - Evaluates experiment results
42
+ - `scripts/analyze_complexity.py` - Expression complexity analysis
43
+ - `scripts/compare_trained_models.py` - Multi-model comparison
44
+ - `scripts/generate.py` - Generation with validation
45
+
46
+ **AWS Deployment**:
47
+ - `scripts/aws/launch_medium_training.sh` - Launch GPT-2 Medium training
48
+ - `scripts/aws/launch_large_training.sh` - Launch GPT-2 Large training
49
+
50
+ ## Your Responsibilities
51
+
52
+ 1. **Guide Training Setup**:
53
+ - Help prepare training data in correct JSON format
54
+ - Configure hyperparameters appropriately for model size
55
+ - Set up early stopping and validation splits
56
+ - Enable proper experiment tracking with W&B
57
+
58
+ 2. **Diagnose Issues**:
59
+ - Analyze why expressions are invalid (format, parsing, complexity)
60
+ - Identify when model generates structurally trivial expressions
61
+ - Debug RL training when rewards have no variance
62
+ - Check GPU availability and CUDA configuration
63
+
64
+ 3. **Optimize Performance**:
65
+ - Recommend appropriate model size for task complexity
66
+ - Tune RL hyperparameters (learning rate, batch size, epochs)
67
+ - Suggest data augmentation strategies
68
+ - Balance training time vs model quality
69
+
70
+ 4. **Execute Commands**:
71
+ - Run training scripts with correct arguments
72
+ - Launch AWS instances for large-scale training
73
+ - Execute evaluation and comparison scripts
74
+ - Monitor training progress and interpret logs
75
+
76
+ 5. **Interpret Results**:
77
+ - Analyze valid expression percentages
78
+ - Evaluate R² fitness scores on benchmarks
79
+ - Compare expression complexity metrics (depth, operator usage)
80
+ - Identify patterns in failed generations
81
+
82
+ ## Critical Knowledge
83
+
84
+ **Data Format Issue**: The HuggingFace dataset column `i_prompt_n` is NOT in JSON format. Always convert using `scripts/train_with_json.py` which handles this automatically.
85
+
86
+ **Complexity Gap**: Base GPT-2 (124M) generates shallow expressions (avg depth 1.4) insufficient for complex benchmarks like Nguyen-5. Recommend Medium/Large for nested compositions.
87
+
88
+ **RL Failure Mode**: PPO fails when all samples have uniformly bad R² scores (no gradient signal). GRPO with within-group normalization handles this better.
89
+
90
+ **Credentials**: API tokens are in `~/.tokens.txt`, SSH key is `~/chave-gpu.pem`.
91
+
92
+ ## Response Guidelines
93
+
94
+ 1. Always verify the user is using the correct data format (JSON) before training
95
+ 2. Recommend appropriate model size based on target expression complexity
96
+ 3. Suggest validation strategies to catch issues early
97
+ 4. Provide complete command examples with all necessary arguments
98
+ 5. Explain the reasoning behind hyperparameter choices
99
+ 6. Monitor for common pitfalls (wrong format, GPU not available, missing dependencies)
100
+ 7. When debugging, use `debug_reinforce.py` and `analyze_complexity.py` to gather evidence
101
+ 8. For academic research, emphasize reproducibility (configs, seeds, logging)
102
+
103
+ ## Communication Style
104
+
105
+ - Respond in the same language as the user (Portuguese or English)
106
+ - Be precise and technical when discussing ML concepts
107
+ - Provide actionable commands that can be copy-pasted
108
+ - Explain trade-offs when multiple approaches exist
109
+ - Flag potential issues before they cause problems
110
+ - Reference specific files and line numbers when relevant
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/visualizations/fig1_valid_rate_comparison.png filter=lfs diff=lfs merge=lfs -text
37
+ docs/visualizations/fig2_r2_performance.png filter=lfs diff=lfs merge=lfs -text
38
+ docs/visualizations/fig3_benchmark_heatmap.png filter=lfs diff=lfs merge=lfs -text
39
+ docs/visualizations/fig4_scaling_progression.png filter=lfs diff=lfs merge=lfs -text
40
+ evaluation_results_aws/raw_results.json filter=lfs diff=lfs merge=lfs -text
41
+ results/2025-02_model_scaling/analysis/fig1_valid_rate_comparison.png filter=lfs diff=lfs merge=lfs -text
42
+ results/2025-02_model_scaling/analysis/fig2_r2_performance.png filter=lfs diff=lfs merge=lfs -text
43
+ results/2025-02_model_scaling/analysis/fig3_benchmark_heatmap.png filter=lfs diff=lfs merge=lfs -text
44
+ results/2025-02_model_scaling/analysis/fig4_scaling_progression.png filter=lfs diff=lfs merge=lfs -text
45
+ visualizations/fig1_valid_rate_comparison.png filter=lfs diff=lfs merge=lfs -text
46
+ visualizations/fig2_r2_performance.png filter=lfs diff=lfs merge=lfs -text
47
+ visualizations/fig3_benchmark_heatmap.png filter=lfs diff=lfs merge=lfs -text
48
+ visualizations/fig4_scaling_progression.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Environments
55
+ .env
56
+ .venv
57
+ .seriguela
58
+ venv/
59
+ ENV/
60
+ env/
61
+ env.bak/
62
+ venv.bak/
63
+
64
+ # IDEs / Editors
65
+ .idea/
66
+ .vscode/
67
+ *.suo
68
+ *.ntvs*
69
+ *.njsproj
70
+ *.sln
71
+ *.sw?
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # Output folder (geralmente grande demais para Git)
77
+ output/*
78
+ !output/.gitkeep # Não ignore um .gitkeep se precisar manter a pasta
79
+ scripts/output/*
80
+
81
+ # Dados (podem ser grandes, usar Git LFS ou armazenar fora se necessário)
82
+ # Note: CSV files in data/processed/ can be 100MB+ and are excluded from git
83
+ # Run scripts/data/prepare_training_data_fixed.py on target system to generate them
84
+ data/*
85
+ data/raw/*
86
+ data/processed/*
87
+ !data/raw/.gitkeep
88
+ !data/processed/.gitkeep
89
+
90
+ # OS generated files
91
+ .DS_Store
92
+ .DS_Store?
93
+ ._*
94
+ .Spotlight-V100
95
+ .Trashes
96
+ ehthumbs.db
97
+ Thumbs.db
98
+ .env
99
+ nul
100
+
101
+ wandb
102
+
103
+ # AWS credentials and keys
104
+ aws/keys/*.pem
105
+ aws/keys/*.key
106
+ aws/.env
107
+ aws/credentials
108
+ *.pem
109
+ *.key
110
+
111
+ # Files with embedded tokens (userdata scripts, temp files)
112
+ .claude/settings.local.json
113
+ aws/temp/
114
+ userdata_*.sh
115
+ # Large data files (>100MB)
116
+ 1_data/processed/**/*.csv
117
+ 1_data/raw/**/*.csv
118
+ *.tar.gz
119
+ models_compressed.tar.gz
.monitor_complete ADDED
@@ -0,0 +1 @@
 
 
1
+ ter, 3 de fev de 2026 17:35:26: All done
1_data/README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1_data/ - Preparação de Dados
2
+
3
+ Este diretório contém todos os dados utilizados no projeto, organizados por estágio de processamento e tipo.
4
+
5
+ ## Estrutura
6
+
7
+ ```
8
+ 1_data/
9
+ ├── raw/ # Dados originais sem processamento
10
+ ├── processed/ # Dados processados e prontos para treino
11
+ └── benchmarks/ # Benchmarks para avaliação
12
+ ├── nguyen/ # Nguyen benchmarks 1-12 (atual)
13
+ ├── feynman/ # Feynman equations (futuro)
14
+ └── strogatz/ # Strogatz benchmarks (futuro)
15
+ ```
16
+
17
+ ## Fontes de Dados
18
+
19
+ ### Dados de Treinamento
20
+ - **Fonte**: HuggingFace Hub (`augustocsc/sintetico_natural`)
21
+ - **Tamanho**: 700K expressões matemáticas sintéticas
22
+ - **Formato**: JSON estruturado
23
+ - **Localização**: `processed/`
24
+
25
+ ### Benchmarks Disponíveis
26
+
27
+ #### Nguyen Benchmarks (1-12)
28
+ Benchmarks padrão para symbolic regression:
29
+ - **Nguyen-1**: x³ + x² + x
30
+ - **Nguyen-2**: x⁴ + x³ + x² + x
31
+ - **Nguyen-3**: x⁵ + x⁴ + x³ + x² + x
32
+ - **Nguyen-4**: x⁶ + x⁵ + x⁴ + x³ + x² + x
33
+ - **Nguyen-5**: sin(x²)·cos(x) - 1
34
+ - **Nguyen-6**: sin(x) + sin(x + x²)
35
+ - **Nguyen-7**: log(x + 1) + log(x² + 1)
36
+ - **Nguyen-8**: √x
37
+ - **Nguyen-9**: sin(x) + sin(y²)
38
+ - **Nguyen-10**: 2·sin(x)·cos(y)
39
+ - **Nguyen-11**: x^y
40
+ - **Nguyen-12**: x⁴ - x³ + y²/2 - y
41
+
42
+ **Localização**: `benchmarks/nguyen/`
43
+
44
+ ## Próximos Benchmarks (Planejados)
45
+
46
+ ### Feynman Equations
47
+ Equações da física de Feynman - 120+ fórmulas
48
+ - Complexidade maior que Nguyen
49
+ - Multi-variáveis (até 10+)
50
+ - Constantes físicas
51
+
52
+ ### Strogatz Benchmarks
53
+ Sistemas dinâmicos e equações diferenciais
54
+ - Osciladores
55
+ - Sistemas caóticos
56
+ - Modelos populacionais
57
+
58
+ ## Uso
59
+
60
+ ### Preparar Dados de Treinamento
61
+
62
+ ```bash
63
+ # A partir do diretório raiz
64
+ cd 2_training/supervised
65
+ python train_with_json.py --dataset_path ../../1_data/processed/700K
66
+ ```
67
+
68
+ ### Adicionar Novo Benchmark
69
+
70
+ 1. Criar diretório: `benchmarks/novo_benchmark/`
71
+ 2. Adicionar arquivos CSV com formato:
72
+ ```csv
73
+ x,y
74
+ 1.0,2.5
75
+ 2.0,5.0
76
+ ...
77
+ ```
78
+ 3. Adicionar metadata em `novo_benchmark/metadata.json`:
79
+ ```json
80
+ {
81
+ "name": "Novo Benchmark",
82
+ "formula": "expressão matemática",
83
+ "variables": ["x", "y"],
84
+ "description": "descrição"
85
+ }
86
+ ```
87
+
88
+ ## Scripts Relacionados
89
+
90
+ - Processamento: `src/seriguela/data/`
91
+ - Avaliação em benchmarks: `3_evaluation/benchmarks/`
92
+
93
+ ## Referências
94
+
95
+ - Nguyen et al. (2012): "Semantically-based crossover in genetic programming"
96
+ - Feynman Lectures on Physics
97
+ - Dataset original: https://huggingface.co/datasets/augustocsc/sintetico_natural
1_data/benchmarks/nguyen/nguyen_1.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ -0.250919762305275,-0.20375712587228662
3
+ 0.9014286128198323,2.4464791994214075
4
+ 0.4639878836228102,0.7791621581533072
5
+ 0.1973169683940732,0.24393329049852017
6
+ -0.687962719115127,-0.5402777510419681
7
+ -0.6880109593275947,-0.5403281140165024
8
+ -0.8838327756636011,-0.793087543099096
9
+ 0.7323522915498704,1.6614819098782947
10
+ 0.2022300234864176,0.25139760359687213
11
+ 0.416145155592091,0.6613886285518918
12
+ -0.9588310114083951,-0.9209820173332797
13
+ 0.9398197043239886,2.6531869448442036
14
+ 0.6648852816008435,1.4008851765113863
15
+ -0.5753217786434477,-0.43475534749635414
16
+ -0.6363500655857988,-0.48909314986283187
17
+ -0.6331909802931324,-0.48612594014666893
18
+ -0.39151551408092455,-0.29824433610683415
19
+ 0.04951286326447568,0.05208576884733716
20
+ -0.13610996271576847,-0.12010560331123674
21
+ -0.4175417196039162,-0.3159953095123318
22
+ 0.22370578944475894,0.2849452648921554
23
+ -0.7210122786959163,-0.5759780829004397
24
+ -0.4157107029295637,-0.3147365210423035
25
+ -0.2672763134126166,-0.2149330041985204
26
+ -0.08786003156592814,-0.08081887184182666
27
+ 0.5703519227860272,1.0811894695777686
28
+ -0.6006524356832805,-0.4565744842168673
29
+ 0.02846887682722321,0.029302427143425142
30
+ 0.18482913772408494,0.2253050457893701
31
+ -0.9070991745600046,-0.8306576893940274
32
+ 0.21508970380287673,0.27130410435063984
33
+ -0.6589517526254169,-0.5108626651850308
34
+ -0.869896814029441,-0.7714450703759913
35
+ 0.8977710745066665,2.427361090599084
36
+ 0.9312640661491187,2.606158159545042
37
+ 0.6167946962329223,1.231881113886954
38
+ -0.39077246165325863,-0.2977415177155153
39
+ -0.8046557719872323,-0.6781760666405668
40
+ 0.3684660530243138,0.5542589014459388
41
+ -0.1196950125207974,-0.10708297449722336
42
+ -0.7559235303104423,-0.6164532603539071
43
+ -0.00964617977745963,-0.00955402855846198
44
+ -0.9312229577695632,-0.8715811438419624
45
+ 0.8186408041575641,2.037444342661761
46
+ -0.48244003679996617,-0.36197878909859404
47
+ 0.32504456870796394,0.46504079000064535
48
+ -0.3765778478211781,-0.2881698066335382
49
+ 0.040136042355621626,0.04181159947832137
50
+ 0.0934205586865593,0.10296327812911324
51
+ -0.6302910889489459,-0.4834179919216196
52
+ 0.9391692555291171,2.64959195422807
53
+ 0.5502656467222291,1.01967411954153
54
+ 0.8789978831283782,2.330781693938945
55
+ 0.7896547008552977,1.905602026387024
56
+ 0.19579995762217028,0.24164408606501708
57
+ 0.8437484700462337,2.15633417340509
58
+ -0.823014995896161,-0.70313355144748
59
+ -0.6080342751617096,-0.4631223204132699
60
+ -0.9095454221789239,-0.8347148035273707
61
+ -0.3493393384734713,-0.2699340299663634
62
+ -0.22264542062103598,-0.18411118973018098
63
+ -0.4573019364522082,-0.34381017076318326
64
+ 0.6574750183038587,1.373957379373378
65
+ -0.28649334661282144,-0.22792983524748567
66
+ -0.4381309806252385,-0.3302753025279416
67
+ 0.08539216631649693,0.09330665286752131
68
+ -0.7181515500504747,-0.5727905657505367
69
+ 0.6043939615080793,1.1904663378939224
70
+ -0.8508987126404584,-0.7429451134365845
71
+ 0.9737738732010346,2.8453764395285988
72
+ 0.5444895385933148,1.0023825877332384
73
+ -0.6025686369316552,-0.45826569576400444
74
+ -0.9889557657527952,-0.9781541346041163
75
+ 0.6309228569096683,1.2801339644356482
76
+ 0.41371468769523423,0.6556858714260754
77
+ 0.4580143360819746,0.7638724020026765
78
+ 0.5425406933718915,0.9965881695975148
79
+ -0.8519106965318193,-0.7444346128158585
80
+ -0.2830685429114548,-0.22562240251418217
81
+ -0.7682618809497406,-0.6314839442672466
82
+ 0.726206851751187,1.6365675922627294
83
+ 0.24659625365511584,0.32240141321500704
84
+ -0.3382039502947016,-0.2625064527787106
85
+ -0.8728832994279527,-0.7760298750035356
86
+ -0.3780353565686756,-0.2891499348341886
87
+ -0.3496333559465059,-0.2701302717663986
88
+ 0.45921235667612814,0.7669252048567059
89
+ 0.27511494271042625,0.3716261379416909
90
+ 0.7744254851526531,1.8386102554828787
91
+ -0.05557014967610141,-0.05265371107138442
92
+ -0.7608115081233966,-0.6223610405246263
93
+ 0.42648957444599,0.6859585520259864
94
+ 0.5215700972337949,0.9354909760292364
95
+ 0.12255439513899247,0.13941469042057167
96
+ 0.541934359909122,0.9947894572777974
97
+ -0.012408807271218514,-0.012256739462828549
98
+ 0.045465658763988115,0.04762676814193671
99
+ -0.14491796328290074,-0.1269602006619134
100
+ -0.9491617465118096,-0.9033611561685656
101
+ -0.7842171460133911,-0.6515114391246863
1_data/benchmarks/nguyen/nguyen_1.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_1
2
+ equation: x**3 + x**2 + x
3
+ latex: x^3 + x^2 + x
4
+ n_vars: 1
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_10.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,x_2,y
2
+ -0.250919762305275,0.9014286128198323,-0.3081292892207914
3
+ 0.4639878836228102,0.1973169683940732,0.8776686422156067
4
+ -0.687962719115127,-0.6880109593275947,-0.9810337787224731
5
+ -0.8838327756636011,0.7323522915498704,-1.1498719359899228
6
+ 0.2022300234864176,0.416145155592091,0.3674245747597698
7
+ -0.9588310114083951,0.9398197043239886,-0.9657455975760866
8
+ 0.6648852816008435,-0.5753217786434477,1.0352950417993314
9
+ -0.6363500655857988,-0.6331909802931324,-0.958123855709307
10
+ -0.39151551408092455,0.04951286326447568,-0.7622440963237047
11
+ -0.13610996271576847,-0.4175417196039162,-0.24806552319621297
12
+ 0.22370578944475894,-0.7210122786959163,0.3332717377157947
13
+ -0.4157107029295637,-0.2672763134126166,-0.7790027574498273
14
+ -0.08786003156592814,0.5703519227860272,-0.14771529605353687
15
+ -0.6006524356832805,0.02846887682722321,-1.1299036280741872
16
+ 0.18482913772408494,-0.9070991745600046,0.2264274834216052
17
+ 0.21508970380287673,-0.6589517526254169,0.3374982513544557
18
+ -0.869896814029441,0.8977710745066665,-0.9528126545294617
19
+ 0.9312640661491187,0.6167946962329223,1.3090534191908094
20
+ -0.39077246165325863,-0.8046557719872323,-0.5282049581946697
21
+ 0.3684660530243138,-0.1196950125207974,0.715215559638574
22
+ -0.7559235303104423,-0.00964617977745963,-1.3718580869858932
23
+ -0.9312229577695632,0.8186408041575641,-1.096354700980636
24
+ -0.48244003679996617,0.32504456870796394,-0.8792969284686333
25
+ -0.3765778478211781,0.040136042355621626,-0.7348882299464423
26
+ 0.0934205586865593,-0.6302910889489459,0.15072125651857374
27
+ 0.9391692555291171,0.5502656467222291,1.3758661094076514
28
+ 0.8789978831283782,0.7896547008552977,1.0844402663947248
29
+ 0.19579995762217028,0.8437484700462337,0.2586235668937924
30
+ -0.823014995896161,-0.6080342751617096,-1.2035798540389373
31
+ -0.9095454221789239,-0.3493393384734713,-1.4831094826390372
32
+ -0.22264542062103598,-0.4573019364522082,-0.3962431609403149
33
+ 0.6574750183038587,-0.28649334661282144,1.1724227266363472
34
+ -0.4381309806252385,0.08539216631649693,-0.8454037878432863
35
+ -0.7181515500504747,0.6043939615080793,-1.0828560895497152
36
+ -0.8508987126404584,0.9737738732010346,-0.8453799503582896
37
+ 0.5444895385933148,-0.6025686369316552,0.853511842011265
38
+ -0.9889557657527952,0.6309228569096683,-1.3492282644581814
39
+ 0.41371468769523423,0.4580143360819746,0.7211575539774799
40
+ 0.5425406933718915,-0.8519106965318193,0.6800328750993455
41
+ -0.2830685429114548,-0.7682618809497406,-0.40170504605860297
42
+ 0.726206851751187,0.24659625365511584,1.2879008077367804
43
+ -0.3382039502947016,-0.8728832994279527,-0.42643410138080395
44
+ -0.3780353565686756,-0.3496333559465059,-0.6935287908865693
45
+ 0.45921235667612814,0.27511494271042625,0.8531472481325574
46
+ 0.7744254851526531,-0.05557014967610141,1.396452099985796
47
+ -0.7608115081233966,0.42648957444599,-1.2554912294871714
48
+ 0.5215700972337949,0.12255439513899247,0.9890101456547871
49
+ 0.541934359909122,-0.012408807271218514,1.0315088718551784
50
+ 0.045465658763988115,-0.14491796328290074,0.08994715712964176
51
+ -0.9491617465118096,-0.7842171460133911,-1.1510102255193018
52
+ -0.9371416286265315,0.2728208225275608,-1.5521272667847952
53
+ -0.37128803784734665,0.01714138232940554,-0.7255254066217963
54
+ 0.815132947852186,-0.5014155417022501,1.2764494826885613
55
+ -0.17923415392874054,0.5111022770860973,-0.3109868317325467
56
+ -0.5424036690167551,-0.846040180342414,-0.6844274708190229
57
+ -0.42049709417246395,-0.6775574254919912,-0.6360846328891363
58
+ 0.8593953046851461,0.6162407591288339,1.2362412464087666
59
+ 0.26680751302084693,0.7429211803754354,0.38835897648845996
60
+ 0.6073441537982289,-0.6268598822279283,0.9243711268676845
61
+ 0.7851179969799555,0.07868448383130144,1.4094429026004185
62
+ 0.6148803103281251,0.7921825998469865,0.8102508035936384
63
+ -0.36399305005627225,-0.7798961509446465,-0.5062345208341303
64
+ -0.5441296749161166,-0.14578442274748737,-1.0243646142100158
65
+ 0.6360295318449862,0.7214611665126869,0.8920087302448801
66
+ -0.9860957389376186,0.021494605155131463,-1.6673695158590665
67
+ -0.16517799370244202,-0.5557843790585395,-0.2793586873666334
68
+ -0.7602692653326344,-0.3247696571927441,-1.3061850155975012
69
+ 0.8858194078250383,-0.35359413595848954,1.4530454262804506
70
+ 0.037581243486732197,0.4060379177903557,0.06903499342946122
71
+ -0.27274079524141204,0.9435641654419213,-0.31619201189155477
72
+ 0.9248945898842225,-0.4964354083492717,1.4043204449131401
73
+ -0.005502988215229099,-0.3982433803664607,-0.010144637492484147
74
+ -0.4303190112450648,-0.9262261052909344,-0.5013062025629531
75
+ 0.2191286679597937,0.005358046457722976,0.4347521786103487
76
+ -0.8970424975000213,-0.44270707152677713,-1.4122924802363288
77
+ 0.8165317719333074,-0.5208762186660552,1.264255365125172
78
+ -0.7102102558175538,-0.02109447944487397,-1.303696301953025
79
+ 0.9713009082212014,-0.5158894569769992,1.4363389241451416
80
+ 0.3442710948117571,0.5232392306574352,0.5847068763596124
81
+ -0.5247249120152007,0.45643269722371915,-0.899380771664107
82
+ -0.26443373456149355,0.26461166118715895,-0.5045315450691236
83
+ 0.2670594215217894,0.07154936814951696,0.5264420962176761
84
+ -0.8194204598911834,0.6706049911784759,-1.1450066034669906
85
+ -0.35843987005652833,-0.6269629792002915,-0.5681869814158013
86
+ -0.9184497168904722,0.18178588637648363,-1.56313471156199
87
+ 0.35512872368456483,-0.9668243421442877,0.39494150902709124
88
+ 0.024186116598561958,-0.5470084496041241,0.04130993957951832
89
+ 0.2903455808188997,-0.6512671419900171,0.45537164467760105
90
+ 0.3818754762049319,-0.22652930739892518,0.7262813246120741
91
+ 0.873459977473469,-0.7249581117080135,1.1475751969011463
92
+ -0.317867297899483,-0.7730529575188219,-0.44742306659388054
93
+ 0.8493872365571256,0.754678706761962,1.094013626029377
94
+ -0.4841167445696888,0.3199680920683581,-0.8836085179911668
95
+ 0.6344444004024317,0.11040162319892466,1.1782430836037832
96
+ 0.05930115671201297,-0.5162954181990966,0.10308252884688361
97
+ -0.8137944643882016,0.7944315159066535,-1.0186613308624795
98
+ 0.8008361143266609,0.2662029145465359,1.385300485429242
99
+ -0.32194041790259864,-0.3015808507746782,-0.6042555977205083
100
+ 0.45191135774047875,0.7942205199051542,0.6120946429016678
101
+ 0.7741728485302346,0.5597510917152477,1.184859127500451
1_data/benchmarks/nguyen/nguyen_10.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_10
2
+ equation: 2*sin(x)*cos(y)
3
+ latex: 2 \sin(x) \cos(y)
4
+ n_vars: 2
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_11.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,x_2,y
2
+ 0.3745401188473625,0.9507143064099162,0.3931142382758706
3
+ 0.7319939418114051,0.5986584841970366,0.829633456356603
4
+ 0.15601864044243652,0.15599452033620265,0.7484106404746592
5
+ 0.05808361216819946,0.8661761457749352,0.08500661554616816
6
+ 0.6011150117432088,0.7080725777960455,0.6974063854971796
7
+ 0.020584494295802447,0.9699098521619943,0.023135880194308844
8
+ 0.8324426408004217,0.21233911067827616,0.9618073836750944
9
+ 0.18182496720710062,0.18340450985343382,0.7315046923708697
10
+ 0.3042422429595377,0.5247564316322378,0.5355698450285795
11
+ 0.43194501864211576,0.2912291401980419,0.7831160893255134
12
+ 0.6118528947223795,0.13949386065204183,0.9337671019595395
13
+ 0.29214464853521815,0.3663618432936917,0.6371115444232952
14
+ 0.45606998421703593,0.7851759613930136,0.5398582254082916
15
+ 0.19967378215835974,0.5142344384136116,0.4367178921370369
16
+ 0.5924145688620425,0.046450412719997725,0.9759742767237903
17
+ 0.6075448519014384,0.17052412368729153,0.9185332605846537
18
+ 0.06505159298527952,0.9488855372533332,0.07480275943558906
19
+ 0.9656320330745594,0.8083973481164611,0.9721242785553075
20
+ 0.3046137691733707,0.09767211400638387,0.890382724908758
21
+ 0.6842330265121569,0.4401524937396013,0.8461836782187776
22
+ 0.12203823484477883,0.4951769101112702,0.3529017979604265
23
+ 0.034388521115218396,0.9093204020787821,0.04668000938776754
24
+ 0.2587799816000169,0.662522284353982,0.4083696815940864
25
+ 0.31171107608941095,0.5200680211778108,0.5454020001935437
26
+ 0.5467102793432796,0.18485445552552704,0.894382426580511
27
+ 0.9695846277645586,0.7751328233611146,0.9763424052614121
28
+ 0.9394989415641891,0.8948273504276488,0.9456857906412531
29
+ 0.5978999788110851,0.9218742350231168,0.622414360674025
30
+ 0.0884925020519195,0.1959828624191452,0.6217441626231702
31
+ 0.045227288910538066,0.32533033076326434,0.3652254374411182
32
+ 0.388677289689482,0.2713490317738959,0.7738119236542511
33
+ 0.8287375091519293,0.3567533266935893,0.9351795294971191
34
+ 0.28093450968738076,0.5426960831582485,0.5020652261408302
35
+ 0.14092422497476265,0.8021969807540397,0.207643749220454
36
+ 0.07455064367977082,0.9868869366005173,0.07713243002937303
37
+ 0.7722447692966574,0.1987156815341724,0.9499377648466204
38
+ 0.005522117123602399,0.8154614284548342,0.01441365817791975
39
+ 0.7068573438476171,0.7290071680409873,0.7765363384012357
40
+ 0.7712703466859457,0.07404465173409036,0.9809531237946688
41
+ 0.3584657285442726,0.11586905952512971,0.8879208766651301
42
+ 0.8631034258755935,0.6232981268275579,0.9123218750417235
43
+ 0.3308980248526492,0.06355835028602363,0.9321215616843733
44
+ 0.3109823217156622,0.32518332202674705,0.683984263997003
45
+ 0.7296061783380641,0.6375574713552131,0.8179204225106446
46
+ 0.8872127425763265,0.4722149251619493,0.9450568572198838
47
+ 0.1195942459383017,0.713244787222995,0.21987794449897202
48
+ 0.7607850486168974,0.5612771975694962,0.8577387897180954
49
+ 0.770967179954561,0.49379559636439074,0.8794655258016348
50
+ 0.5227328293819941,0.42754101835854963,0.7577972913056583
51
+ 0.02541912674409519,0.10789142699330445,0.6728689751470142
52
+ 0.03142918568673425,0.6364104112637804,0.11058269210864166
53
+ 0.3143559810763267,0.5085706911647028,0.5551411500446592
54
+ 0.907566473926093,0.24929222914887494,0.9761114867199115
55
+ 0.41038292303562973,0.7555511385430487,0.5102050118761009
56
+ 0.22879816549162246,0.07697990982879299,0.8926695364249662
57
+ 0.289751452913768,0.16122128725400442,0.818968274689897
58
+ 0.9296976523425731,0.808120379564417,0.9427929152457777
59
+ 0.6334037565104235,0.8714605901877177,0.6716955774921506
60
+ 0.8036720768991145,0.18657005888603584,0.9600427248758505
61
+ 0.8925589984899778,0.5393422419156507,0.9405381426424225
62
+ 0.8074401551640625,0.8960912999234932,0.8255861281416367
63
+ 0.3180034749718639,0.11005192452767676,0.8815392814173838
64
+ 0.22793516254194168,0.4271077886262563,0.5317606737242714
65
+ 0.8180147659224931,0.8607305832563434,0.8412224374670049
66
+ 0.006952130531190703,0.5107473025775657,0.07904375205389284
67
+ 0.417411003148779,0.22210781047073025,0.8236150559828436
68
+ 0.1198653673336828,0.33761517140362796,0.48859950491607135
69
+ 0.9429097039125192,0.32320293202075523,0.9811799458729573
70
+ 0.5187906217433661,0.7030189588951778,0.6304259108490097
71
+ 0.363629602379294,0.9717820827209607,0.3741592725494348
72
+ 0.9624472949421112,0.25178229582536416,0.9904090768000238
73
+ 0.49724850589238545,0.30087830981676966,0.810411403704802
74
+ 0.2848404943774676,0.036886947354532795,0.954732975166944
75
+ 0.6095643339798968,0.5026790232288615,0.7797113146930209
76
+ 0.05147875124998935,0.27864646423661144,0.43752180145116215
77
+ 0.9082658859666537,0.23956189066697242,0.9772134321986058
78
+ 0.1448948720912231,0.489452760277563,0.3884857408258429
79
+ 0.9856504541106007,0.2420552715115004,0.9965075678190551
80
+ 0.6721355474058786,0.7616196153287176,0.7389035679220693
81
+ 0.23763754399239967,0.7282163486118596,0.351181104428453
82
+ 0.3677831327192532,0.6323058305935795,0.5312771863051331
83
+ 0.6335297107608947,0.5357746840747585,0.78305410395847
84
+ 0.0902897700544083,0.835302495589238,0.13416593821588932
85
+ 0.32078006497173583,0.18651851039985423,0.8089068896038694
86
+ 0.040775141554763916,0.5908929431882418,0.1509706391829457
87
+ 0.6775643618422824,0.016587828927856152,0.9935639759514822
88
+ 0.512093058299281,0.22649577519793795,0.8593473681905776
89
+ 0.6451727904094499,0.17436642900499144,0.9264327330143394
90
+ 0.690937738102466,0.3867353463005374,0.8667729564864897
91
+ 0.9367299887367345,0.13752094414599325,0.9910518779206805
92
+ 0.3410663510502585,0.11347352124058907,0.8850943779323204
93
+ 0.9246936182785628,0.877339353380981,0.9336166425910485
94
+ 0.2579416277151556,0.659984046034179,0.40889663483282684
95
+ 0.8172222002012158,0.5552008115994623,0.8939869586598578
96
+ 0.5296505783560065,0.24185229090045168,0.8575238670157391
97
+ 0.09310276780589921,0.8972157579533268,0.11883298066628799
98
+ 0.9004180571633305,0.6331014572732679,0.9357472382139566
99
+ 0.3390297910487007,0.3492095746126609,0.6854165136378787
100
+ 0.7259556788702394,0.8971102599525771,0.7502759576646942
101
+ 0.8870864242651173,0.7798755458576239,0.9107934601911546
1_data/benchmarks/nguyen/nguyen_11.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_11
2
+ equation: x**y
3
+ latex: x^y
4
+ n_vars: 2
5
+ range: (0, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_12.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,x_2,y
2
+ 0.3745401188473625,0.9507143064099162,-0.5316474979283755
3
+ 0.7319939418114051,0.5986584841970366,-0.5245780691122485
4
+ 0.15601864044243652,0.15599452033620265,-0.1470326281985484
5
+ 0.05808361216819946,0.8661761457749352,-0.49123016315900203
6
+ 0.6011150117432088,0.7080725777960455,-0.5440295831065014
7
+ 0.020584494295802447,0.9699098521619943,-0.4995558340525668
8
+ 0.8324426408004217,0.21233911067827616,-0.28645063725362985
9
+ 0.18182496720710062,0.18340450985343382,-0.17150410942706706
10
+ 0.3042422429595377,0.5247564316322378,-0.40666548191401125
11
+ 0.43194501864211576,0.2912291401980419,-0.2946019335150628
12
+ 0.6118528947223795,0.13949386065204183,-0.2186718894768038
13
+ 0.29214464853521815,0.3663618432936917,-0.3169010837400344
14
+ 0.45606998421703593,0.7851759613930136,-0.5285238661436937
15
+ 0.19967378215835974,0.5142344384136116,-0.38838724072182046
16
+ 0.5924145688620425,0.046450412719997725,-0.13011303294584414
17
+ 0.6075448519014384,0.17052412368729153,-0.24399347586867748
18
+ 0.06505159298527952,0.9488855372533332,-0.498951027941936
19
+ 0.9656320330745594,0.8083973481164611,-0.5125890941467149
20
+ 0.3046137691733707,0.09767211400638387,-0.11255726686131867
21
+ 0.6842330265121569,0.4401524937396013,-0.44443839313702255
22
+ 0.12203823484477883,0.4951769101112702,-0.3741725684537525
23
+ 0.034388521115218396,0.9093204020787821,-0.49592787363434615
24
+ 0.2587799816000169,0.662522284353982,-0.45589954550743944
25
+ 0.31171107608941095,0.5200680211778108,-0.405678875884419
26
+ 0.5467102793432796,0.18485445552552704,-0.24183976519509257
27
+ 0.9695846277645586,0.7751328233611146,-0.502441019579803
28
+ 0.9394989415641891,0.8948273504276488,-0.5446402529864546
29
+ 0.5978999788110851,0.9218742350231168,-0.5828930031608885
30
+ 0.0884925020519195,0.1959828624191452,-0.1774098758468889
31
+ 0.045227288910538066,0.32533033076326434,-0.2724987473704077
32
+ 0.388677289689482,0.2713490317738959,-0.2704292195788051
33
+ 0.8287375091519293,0.3567533266935893,-0.39059634750094807
34
+ 0.28093450968738076,0.5426960831582485,-0.41138006574851044
35
+ 0.14092422497476265,0.8021969807540397,-0.48284128157802086
36
+ 0.07455064367977082,0.9868869366005173,-0.5002974721138831
37
+ 0.7722447692966574,0.1987156815341724,-0.2838615274049386
38
+ 0.005522117123602399,0.8154614284548342,-0.4829729252663755
39
+ 0.7068573438476171,0.7290071680409873,-0.5668133801237885
40
+ 0.7712703466859457,0.07404465173409036,-0.17624366412809184
41
+ 0.3584657285442726,0.11586905952512971,-0.1387066006115091
42
+ 0.8631034258755935,0.6232981268275579,-0.517067796073964
43
+ 0.3308980248526492,0.06355835028602363,-0.08578087500153808
44
+ 0.3109823217156622,0.32518332202674705,-0.2930335023172791
45
+ 0.7296061783380641,0.6375574713552131,-0.5393353511837361
46
+ 0.8872127425763265,0.4722149251619493,-0.4394882839499687
47
+ 0.1195942459383017,0.713244787222995,-0.46039168497896
48
+ 0.7607850486168974,0.5612771975694962,-0.509096521876274
49
+ 0.770967179954561,0.49379559636439074,-0.4768340968434337
50
+ 0.5227328293819941,0.42754101835854963,-0.40431654954497565
51
+ 0.02541912674409519,0.10789142699330445,-0.10208715360872561
52
+ 0.03142918568673425,0.6364104112637804,-0.43393137529691794
53
+ 0.3143559810763267,0.5085706911647028,-0.4005478458408727
54
+ 0.907566473926093,0.24929222914887494,-0.28731682218657484
55
+ 0.41038292303562973,0.7555511385430487,-0.5108733418268507
56
+ 0.22879816549162246,0.07697990982879299,-0.08325384436032608
57
+ 0.289751452913768,0.16122128725400442,-0.16550288692705029
58
+ 0.9296976523425731,0.808120379564417,-0.5380841567189258
59
+ 0.6334037565104235,0.8714605901877177,-0.5848989033049367
60
+ 0.8036720768991145,0.18657005888603584,-0.27107631331792464
61
+ 0.8925589984899778,0.5393422419156507,-0.470295013851557
62
+ 0.8074401551640625,0.8960912999234932,-0.59596852950575
63
+ 0.3180034749718639,0.11005192452767676,-0.12592818733422592
64
+ 0.22793516254194168,0.4271077886262563,-0.34504023675612344
65
+ 0.8180147659224931,0.8607305832563434,-0.5899158316402314
66
+ 0.006952130531190703,0.5107473025775657,-0.38031623270764187
67
+ 0.417411003148779,0.22210781047073025,-0.23981143105710734
68
+ 0.1198653673336828,0.33761517140362796,-0.28213892883047403
69
+ 0.9429097039125192,0.32320293202075523,-0.318832855236626
70
+ 0.5187906217433661,0.7030189588951778,-0.5230920266442246
71
+ 0.363629602379294,0.9717820827209607,-0.5301994956717305
72
+ 0.9624472949421112,0.25178229582536416,-0.25356410409756097
73
+ 0.49724850589238545,0.30087830981676966,-0.3174265784042432
74
+ 0.2848404943774676,0.036886947354532795,-0.05273415977110583
75
+ 0.6095643339798968,0.5026790232288615,-0.46476765439268997
76
+ 0.05147875124998935,0.27864646423661144,-0.2399539372668816
77
+ 0.9082658859666537,0.23956189066697242,-0.2796006655774298
78
+ 0.1448948720912231,0.489452760277563,-0.3722719868332219
79
+ 0.9856504541106007,0.2420552715115004,-0.22650053348928928
80
+ 0.6721355474058786,0.7616196153287176,-0.5711428201443614
81
+ 0.23763754399239967,0.7282163486118596,-0.473297554430461
82
+ 0.3677831327192532,0.6323058305935795,-0.46385200894375067
83
+ 0.6335297107608947,0.5357746840747585,-0.4854310810016914
84
+ 0.0902897700544083,0.835302495589238,-0.48710697106906575
85
+ 0.32078006497173583,0.18651851039985423,-0.19154377448709786
86
+ 0.040775141554763916,0.5908929431882418,-0.4163807370007205
87
+ 0.6775643618422824,0.016587828927856152,-0.1167488120616192
88
+ 0.512093058299281,0.22649577519793795,-0.2663670817810153
89
+ 0.6451727904094499,0.17436642900499144,-0.25445410259051215
90
+ 0.690937738102466,0.3867353463005374,-0.41389747881703326
91
+ 0.9367299887367345,0.13752094414599325,-0.18006947009905816
92
+ 0.3410663510502585,0.11347352124058907,-0.13317857503984928
93
+ 0.9246936182785628,0.877339353380981,-0.552019449425134
94
+ 0.2579416277151556,0.659984046034179,-0.4549296760550325
95
+ 0.8172222002012158,0.5552008115994623,-0.5008339633920046
96
+ 0.5296505783560065,0.24185229090045168,-0.28249182975838943
97
+ 0.09310276780589921,0.8972157579533268,-0.49544958985988724
98
+ 0.9004180571633305,0.6331014572732679,-0.5053891761940527
99
+ 0.3390297910487007,0.3492095746126609,-0.31399292258823414
100
+ 0.7259556788702394,0.8971102599525771,-0.5995526723688602
101
+ 0.8870864242651173,0.7798755458576239,-0.5545939788269545
1_data/benchmarks/nguyen/nguyen_12.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_12
2
+ equation: x**4 - x**3 + y**2/2 - y
3
+ latex: x^4 - x^3 + \frac{y^2}{2} - y
4
+ n_vars: 2
5
+ range: (0, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_2.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ -0.250919762305275,-0.19979307271339486
3
+ 0.9014286128198323,3.106754963846846
4
+ 0.4639878836228102,0.8255096843833445
5
+ 0.1973169683940732,0.24544914576563198
6
+ -0.687962719115127,-0.316271768430889
7
+ -0.6880109593275947,-0.3162592952514309
8
+ -0.8838327756636011,-0.18287601110210117
9
+ 0.7323522915498704,1.9491423756178945
10
+ 0.2022300234864176,0.25307016676624217
11
+ 0.416145155592091,0.6913788293276577
12
+ -0.9588310114083951,-0.07576489223978267
13
+ 0.9398197043239886,3.4333370743437346
14
+ 0.6648852816008435,1.596313216676064
15
+ -0.5753217786434477,-0.3251975588470951
16
+ -0.6363500655857988,-0.32511560759302083
17
+ -0.6331909802931324,-0.3253804197057425
18
+ -0.39151551408092455,-0.27474822950833333
19
+ 0.04951286326447568,0.05209177881543897
20
+ -0.13610996271576847,-0.11976239352712116
21
+ -0.4175417196039162,-0.28560049468336546
22
+ 0.22370578944475894,0.2874496948760045
23
+ -0.7210122786959163,-0.30572500866496494
24
+ -0.4157107029295637,-0.28487136252946227
25
+ -0.2672763134126166,-0.2098298124197376
26
+ -0.08786003156592814,-0.08075928293478254
27
+ 0.5703519227860272,1.1870104156557124
28
+ -0.6006524356832805,-0.3264098596675816
29
+ 0.02846887682722321,0.029303084016308063
30
+ 0.18482913772408494,0.22647207506221967
31
+ -0.9070991745600046,-0.1536102701687616
32
+ 0.21508970380287673,0.2734444232481606
33
+ -0.6589517526254169,-0.32231790405084937
34
+ -0.869896814029441,-0.1988192051106482
35
+ 0.8977710745066665,3.0769856490294796
36
+ 0.9312640661491187,3.358285510834738
37
+ 0.6167946962329223,1.3766124336679
38
+ -0.39077246165325863,-0.27442327583918935
39
+ -0.8046557719872323,-0.25895748554130227
40
+ 0.3684660530243138,0.572691642793691
41
+ -0.1196950125207974,-0.10687771454758802
42
+ -0.7559235303104423,-0.28993200547233466
43
+ -0.00964617977745963,-0.009554019900385721
44
+ -0.9312229577695632,-0.11958658706487191
45
+ 0.8186408041575641,2.4865758792604673
46
+ -0.48244003679996617,-0.3078069764664333
47
+ 0.32504456870796394,0.4762035517253345
48
+ -0.3765778478211781,-0.2680594822320752
49
+ 0.040136042355621626,0.04181419448323982
50
+ 0.0934205586865593,0.10303944565358064
51
+ -0.6302910889489459,-0.3255970364031556
52
+ 0.9391692555291171,3.4275845586374323
53
+ 0.5502656467222291,1.1113572855576686
54
+ 0.8789978831283782,2.9277500581350866
55
+ 0.7896547008552977,2.2944222989511918
56
+ 0.19579995762217028,0.2431138594333487
57
+ 0.8437484700462337,2.663152129765188
58
+ -0.823014995896161,-0.24432553893716014
59
+ -0.6080342751617096,-0.326440030758018
60
+ -0.9095454221789239,-0.15033439380562408
61
+ -0.3493393384734713,-0.2550407630135437
62
+ -0.22264542062103598,-0.18165390734252046
63
+ -0.4573019364522082,-0.3000768795902401
64
+ 0.6574750183038587,1.560817671456092
65
+ -0.28649334661282144,-0.22119296531986027
66
+ -0.4381309806252385,-0.2934271384523741
67
+ 0.08539216631649693,0.09335982353659596
68
+ -0.7181515500504747,-0.3068011174024385
69
+ 0.6043939615080793,1.3239046275098028
70
+ -0.8508987126404584,-0.2187276720547493
71
+ 0.9737738732010346,3.7445271094357677
72
+ 0.5444895385933148,1.0902763712821586
73
+ -0.6025686369316552,-0.32643210128260247
74
+ -0.9889557657527952,-0.021604594541118627
75
+ 0.6309228569096683,1.4385886349785073
76
+ 0.41371468769523423,0.6849815632184505
77
+ 0.4580143360819746,0.8078788471365737
78
+ 0.5425406933718915,1.0832303299115513
79
+ -0.8519106965318193,-0.21771888700546604
80
+ -0.2830685429114548,-0.21920193818358347
81
+ -0.7682618809497406,-0.2831168381374246
82
+ 0.726206851751187,1.9146934506063242
83
+ 0.24659625365511584,0.32609923432705157
84
+ -0.3382039502947016,-0.24942323098709215
85
+ -0.8728832994279527,-0.19549978168020488
86
+ -0.3780353565686756,-0.2687264578518238
87
+ -0.3496333559465059,-0.2551868024860783
88
+ 0.45921235667612814,0.8113938873926984
89
+ 0.27511494271042625,0.37735484635995153
90
+ 0.7744254851526531,2.198292124261625
91
+ -0.05557014967610141,-0.05264417507086238
92
+ -0.7608115081233966,-0.2873120662846093
93
+ 0.42648957444599,0.7190437453871404
94
+ 0.5215700972337949,1.0094942165627014
95
+ 0.12255439513899247,0.13964027819697553
96
+ 0.541934359909122,1.081044947683308
97
+ -0.012408807271218514,-0.012256715753450735
98
+ 0.045465658763988115,0.04763104115236099
99
+ -0.14491796328290074,-0.12651914958498786
100
+ -0.9491617465118096,-0.0917258937919263
101
+ -0.7842171460133911,-0.27329070462795224
1_data/benchmarks/nguyen/nguyen_2.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_2
2
+ equation: x**4 + x**3 + x**2 + x
3
+ latex: x^4 + x^3 + x^2 + x
4
+ n_vars: 1
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_3.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ -0.250919762305275,-0.20078773198978944
3
+ 0.9014286128198323,3.701946430251423
4
+ 0.4639878836228102,0.8470143749899722
5
+ 0.1973169683940732,0.24574824973146267
6
+ -0.687962719115127,-0.47037953332606286
7
+ -0.6880109593275947,-0.4704210982053887
8
+ -0.8838327756636011,-0.7222009631689434
9
+ 0.7323522915498704,2.1598111768905937
10
+ 0.2022300234864176,0.25340840925526636
11
+ 0.416145155592091,0.7038591060957268
12
+ -0.9588310114083951,-0.8861852831528764
13
+ 0.9398197043239886,4.166537538378305
14
+ 0.6648852816008435,1.7262504441936566
15
+ -0.5753217786434477,-0.3882285406770297
16
+ -0.6363500655857988,-0.42946272737101315
17
+ -0.6331909802931324,-0.42716303337146244
18
+ -0.39151551408092455,-0.2839473197621456
19
+ 0.04951286326447568,0.05209207638616782
20
+ -0.13610996271576847,-0.11980910779804083
21
+ -0.4175417196039162,-0.29829159793409465
22
+ 0.22370578944475894,0.2880099503626506
23
+ -0.7210122786959163,-0.5005807935440612
24
+ -0.4157107029295637,-0.29728662856793836
25
+ -0.2672763134126166,-0.21119377470500825
26
+ -0.08786003156592814,-0.08076451841803642
27
+ 0.5703519227860272,1.2473655957223042
28
+ -0.6006524356832805,-0.40459355844290984
29
+ 0.02846887682722321,0.029303102716741258
30
+ 0.18482913772408494,0.22668777607641924
31
+ -0.9070991745600046,-0.7677594252859816
32
+ 0.21508970380287673,0.2739047838058721
33
+ -0.6589517526254169,-0.4465598048485588
34
+ -0.869896814029441,-0.696944620935822
35
+ 0.8977710745066665,3.660199786877455
36
+ 0.9312640661491187,4.058714686258748
37
+ 0.6167946962329223,1.4658819440875785
38
+ -0.39077246165325863,-0.28353540261862736
39
+ -0.8046557719872323,-0.5962841365471232
40
+ 0.3684660530243138,0.5794834822445153
41
+ -0.1196950125207974,-0.10690228313982963
42
+ -0.7559235303104423,-0.5367571051838087
43
+ -0.00964617977745963,-0.009554019983903081
44
+ -0.9312229577695632,-0.8198611824534459
45
+ 0.8186408041575641,2.854253281554156
46
+ -0.48244003679996617,-0.3339416277462138
47
+ 0.32504456870796394,0.4798319467957259
48
+ -0.3765778478211781,-0.27563258491416387
49
+ 0.040136042355621626,0.04181429863646714
50
+ 0.0934205586865593,0.10304656126627017
51
+ -0.6302910889489459,-0.4250701783158514
52
+ 0.9391692555291171,4.158251293727732
53
+ 0.5502656467222291,1.1618073821990806
54
+ 0.8789978831283782,3.4524839865581054
55
+ 0.7896547008552977,2.6014560549693257
56
+ 0.19579995762217028,0.24340164099658224
57
+ 0.8437484700462337,3.09077900503598
58
+ -0.823014995896161,-0.6219314134704668
59
+ -0.6080342751617096,-0.4095475476759919
60
+ -0.9095454221789239,-0.772809462496975
61
+ -0.3493393384734713,-0.26024356703855056
62
+ -0.22264542062103598,-0.1822010100133058
63
+ -0.4573019364522082,-0.3200761983310553
64
+ 0.6574750183038587,1.683673645413439
65
+ -0.28649334661282144,-0.22312303373112094
66
+ -0.4381309806252385,-0.3095714607130422
67
+ 0.08539216631649693,0.09336436389521274
68
+ -0.7181515500504747,-0.49782185203069584
69
+ 0.6043939615080793,1.4045539239876073
70
+ -0.8508987126404584,-0.6647836180702278
71
+ 0.9737738732010346,4.620096539862576
72
+ 0.5444895385933148,1.1381336169319312
73
+ -0.6025686369316552,-0.4058708906110614
74
+ -0.9889557657527952,-0.9675897774146045
75
+ 0.6309228569096683,1.538561308408088
76
+ 0.41371468769523423,0.6971016211991488
77
+ 0.4580143360819746,0.8280344298879034
78
+ 0.5425406933718915,1.1302372276435673
79
+ -0.8519106965318193,-0.6664336478548603
80
+ -0.2830685429114548,-0.22101936966646102
81
+ -0.7682618809497406,-0.5507540063537395
82
+ 0.726206851751187,2.1166703545846226
83
+ 0.24659625365511584,0.3270111031599685
84
+ -0.3382039502947016,-0.25384802827959924
85
+ -0.8728832994279527,-0.702234804957491
86
+ -0.3780353565686756,-0.2764472542552242
87
+ -0.3496333559465059,-0.2604115378000402
88
+ 0.45921235667612814,0.8318144558983341
89
+ 0.27511494271042625,0.378930899648246
90
+ 0.7744254851526531,2.476838929991218
91
+ -0.05557014967610141,-0.0526447049878387
92
+ -0.7608115081233966,-0.5422211816713537
93
+ 0.42648957444599,0.7331542354242023
94
+ 0.5215700972337949,1.0480920939233567
95
+ 0.12255439513899247,0.13966792497046343
96
+ 0.541934359909122,1.1277897616648658
97
+ -0.012408807271218514,-0.012256716047655835
98
+ 0.045465658763988115,0.047631235427594835
99
+ -0.14491796328290074,-0.12658306580875964
100
+ -0.9491617465118096,-0.8620990369599081
101
+ -0.7842171460133911,-0.5698978895980695
1_data/benchmarks/nguyen/nguyen_3.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_3
2
+ equation: x**5 + x**4 + x**3 + x**2 + x
3
+ latex: x^5 + x^4 + x^3 + x^2 + x
4
+ n_vars: 1
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_4.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ -0.250919762305275,-0.2005381523205818
3
+ 0.9014286128198323,4.238469048174702
4
+ 0.4639878836228102,0.8569922908725047
5
+ 0.1973169683940732,0.24580726801923503
6
+ -0.687962719115127,-0.3643591363520243
7
+ -0.6880109593275947,-0.36435608826336463
8
+ -0.8838327756636011,-0.2455278937990677
9
+ 0.7323522915498704,2.3140949562607194
10
+ 0.2022300234864176,0.25347681204176586
11
+ 0.416145155592091,0.7090527128132073
12
+ -0.9588310114083951,-0.10912908006768773
13
+ 0.9398197043239886,4.855613781697487
14
+ 0.6648852816008435,1.812643794302124
15
+ -0.5753217786434477,-0.3519654441009889
16
+ -0.6363500655857988,-0.36306143085659853
17
+ -0.6331909802931324,-0.3627152004476681
18
+ -0.39151551408092455,-0.28034573321234746
19
+ 0.04951286326447568,0.05209209111974663
20
+ -0.13610996271576847,-0.11980274952036765
21
+ -0.4175417196039162,-0.2929925328591143
22
+ 0.22370578944475894,0.2881352827585815
23
+ -0.7210122786959163,-0.3600873800713027
24
+ -0.4157107029295637,-0.2921254695960259
25
+ -0.2672763134126166,-0.21082921989376727
26
+ -0.08786003156592814,-0.08076405842831248
27
+ 0.5703519227860272,1.2817892887233815
28
+ -0.6006524356832805,-0.35763232934278105
29
+ 0.02846887682722321,0.02930310324912159
30
+ 0.18482913772408494,0.22672764390887995
31
+ -0.9070991745600046,-0.21066523362242717
32
+ 0.21508970380287673,0.27400380262187274
33
+ -0.6589517526254169,-0.36469038656839503
34
+ -0.869896814029441,-0.2636269087224129
35
+ 0.8977710745066665,4.183792570080711
36
+ 0.9312640661491187,4.710999208213584
37
+ 0.6167946962329223,1.520942904649746
38
+ -0.39077246165325863,-0.2799746344061298
39
+ -0.8046557719872323,-0.32485229977016666
40
+ 0.3684660530243138,0.5819860445197353
41
+ -0.1196950125207974,-0.10689934240187365
42
+ -0.7559235303104423,-0.3501762044406843
43
+ -0.00964617977745963,-0.009554019983097458
44
+ -0.9312229577695632,-0.16774940248481396
45
+ 0.8186408041575641,3.1552490058384244
46
+ -0.48244003679996617,-0.32133322562104216
47
+ 0.32504456870796394,0.48101133690648334
48
+ -0.3765778478211781,-0.27278072220481414
49
+ 0.040136042355621626,0.041814302816765486
50
+ 0.0934205586865593,0.10304722601078302
51
+ -0.6302910889489459,-0.36237314337852533
52
+ 0.9391692555291171,4.8444710273623794
53
+ 0.5502656467222291,1.1895683372546662
54
+ 0.8789978831283782,3.913723998847577
55
+ 0.7896547008552977,2.843906703730303
56
+ 0.19579995762217028,0.2434579886144678
57
+ 0.8437484700462337,3.451588526796362
58
+ -0.823014995896161,-0.3111561161910711
59
+ -0.6080342751617096,-0.3590153288662821
60
+ -0.9095454221789239,-0.2066401133482456
61
+ -0.3493393384734713,-0.2584260229222476
62
+ -0.22264542062103598,-0.18207920010904594
63
+ -0.4573019364522082,-0.31093047114315553
64
+ 0.6574750183038587,1.7644483791397838
65
+ -0.28649334661282144,-0.22257008197278716
66
+ -0.4381309806252385,-0.30249813296944583
67
+ 0.08539216631649693,0.09336475160627088
68
+ -0.7181515500504747,-0.36064001536563245
69
+ 0.6043939615080793,1.453297871778667
70
+ -0.8508987126404584,-0.28523518784003543
71
+ 0.9737738732010346,5.472703175385714
72
+ 0.5444895385933148,1.1641913865341225
73
+ -0.6025686369316552,-0.358003567605911
74
+ -0.9889557657527952,-0.03205227649515818
75
+ 0.6309228569096683,1.6016363531411766
76
+ 0.41371468769523423,0.7021158672014816
77
+ 0.4580143360819746,0.8372659757400991
78
+ 0.5425406933718915,1.155740382532357
79
+ -0.8519106965318193,-0.2841687433955442
80
+ -0.2830685429114548,-0.22050491198476146
81
+ -0.7682618809497406,-0.3451385720878113
82
+ 0.726206851751187,2.263347366149155
83
+ 0.24659625365511584,0.3272359665979907
84
+ -0.3382039502947016,-0.25235154435602003
85
+ -0.8728832994279527,-0.2599142659035131
86
+ -0.3780353565686756,-0.2735285202338706
87
+ -0.3496333559465059,-0.25858479605828744
88
+ 0.45921235667612814,0.8411918332864734
89
+ 0.27511494271042625,0.3793644954583637
90
+ 0.7744254851526531,2.6925526751560804
91
+ -0.05557014967610141,-0.05264467554027301
92
+ -0.7608115081233966,-0.34828339315956375
93
+ 0.42648957444599,0.7391722123153333
94
+ 0.5215700972337949,1.0682235925713717
95
+ 0.12255439513899247,0.13967131320406578
96
+ 0.541934359909122,1.1531223825090322
97
+ -0.012408807271218514,-0.0122567160440051
98
+ 0.045465658763988115,0.04763124426044632
99
+ -0.14491796328290074,-0.1265738031997899
100
+ -0.9491617465118096,-0.13089031892479408
101
+ -0.7842171460133911,-0.3372934495137382
1_data/benchmarks/nguyen/nguyen_4.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_4
2
+ equation: x**6 + x**5 + x**4 + x**3 + x**2 + x
3
+ latex: x^6 + x^5 + x^4 + x^3 + x^2 + x
4
+ n_vars: 1
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_5.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ -0.250919762305275,-0.9390512081168352
3
+ 0.9014286128198323,-0.5494873180553288
4
+ 0.4639878836228102,-0.8089599566948397
5
+ 0.1973169683940732,-0.9618311304971612
6
+ -0.687962719115127,-0.6478606522679959
7
+ -0.6880109593275947,-0.6478289752250018
8
+ -0.8838327756636011,-0.5534626739045931
9
+ 0.7323522915498704,-0.6200235532537778
10
+ 0.2022300234864176,-0.9599476201534037
11
+ 0.416145155592091,-0.8423936940912042
12
+ -0.9588310114083951,-0.5431689029988238
13
+ 0.9398197043239886,-0.5440918837654212
14
+ 0.6648852816008435,-0.6633160228877684
15
+ -0.5753217786434477,-0.7273325440876017
16
+ -0.6363500655857988,-0.6831457922480837
17
+ -0.6331909802931324,-0.6853819519705832
18
+ -0.39151551408092455,-0.8588685461295467
19
+ 0.04951286326447568,-0.9975514831941417
20
+ -0.13610996271576847,-0.9816464680708747
21
+ -0.4175417196039162,-0.8414428810237741
22
+ 0.22370578944475894,-0.9512230869588283
23
+ -0.7210122786959163,-0.6268666215487173
24
+ -0.4157107029295637,-0.8426891332699092
25
+ -0.2672763134126166,-0.9311584013772981
26
+ -0.08786003156592814,-0.9923104665144623
27
+ 0.5703519227860272,-0.7309939462300948
28
+ -0.6006524356832805,-0.7087806446206153
29
+ 0.02846887682722321,-0.99918985155515
30
+ 0.18482913772408494,-0.9664265762992899
31
+ -0.9070991745600046,-0.5484026596371498
32
+ 0.21508970380287673,-0.9548185804730333
33
+ -0.6589517526254169,-0.6673796194282944
34
+ -0.869896814029441,-0.5572462380695584
35
+ 0.8977710745066665,-0.550238281253637
36
+ 0.9312640661491187,-0.5448920303797724
37
+ 0.6167946962329223,-0.6970965407107139
38
+ -0.39077246165325863,-0.8593564758850973
39
+ -0.8046557719872323,-0.58178568969466
40
+ 0.3684660530243138,-0.8737340760964288
41
+ -0.1196950125207974,-0.9857760980914346
42
+ -0.7559235303104423,-0.6064732977627244
43
+ -0.00964617977745963,-0.9999069555448398
44
+ -0.9312229577695632,-0.5448964424582949
45
+ 0.8186408041575641,-0.5756409511274684
46
+ -0.48244003679996617,-0.7956727484647694
47
+ 0.32504456870796394,-0.900064608947692
48
+ -0.3765778478211781,-0.8685675799244211
49
+ 0.040136042355621626,-0.9983903961284725
50
+ 0.0934205586865593,-0.991310765594519
51
+ -0.6302910889489459,-0.687440925563
52
+ 0.9391692555291171,-0.5441444192356306
53
+ 0.5502656467222291,-0.7458300065420963
54
+ 0.8789978831283782,-0.5547129310790808
55
+ 0.7896547008552977,-0.5888643062279737
56
+ 0.19579995762217028,-0.9624041296078009
57
+ 0.8437484700462337,-0.5657849821252405
58
+ -0.823014995896161,-0.5738128862301527
59
+ -0.6080342751617096,-0.7034212209476314
60
+ -0.9095454221789239,-0.5479648926611902
61
+ -0.3493393384734713,-0.8856176517083943
62
+ -0.22264542062103598,-0.9516723889435259
63
+ -0.4573019364522082,-0.8137278342896282
64
+ 0.6574750183038587,-0.668396180415866
65
+ -0.28649334661282144,-0.9213553931307055
66
+ -0.4381309806252385,-0.8272380555912469
67
+ 0.08539216631649693,-0.9927348114985
68
+ -0.7181515500504747,-0.6286238526457859
69
+ 0.6043939615080793,-0.7060609958180059
70
+ -0.8508987126404584,-0.56326927462963
71
+ 0.9737738732010346,-0.5432897447861715
72
+ 0.5444895385933148,-0.7501016086636129
73
+ -0.6025686369316552,-0.7073869979373362
74
+ -0.9889557657527952,-0.5441925947478524
75
+ 0.6309228569096683,-0.6869918579399132
76
+ 0.41371468769523423,-0.8440443466281196
77
+ 0.4580143360819746,-0.8132211009838715
78
+ 0.5425406933718915,-0.7515428281198429
79
+ -0.8519106965318193,-0.5629240207718037
80
+ -0.2830685429114548,-0.9231433578147953
81
+ -0.7682618809497406,-0.5997746863584423
82
+ 0.726206851751187,-0.6237072834805513
83
+ 0.24659625365511584,-0.9410661846934961
84
+ -0.3382039502947016,-0.8923327139993233
85
+ -0.8728832994279527,-0.5563892940528554
86
+ -0.3780353565686756,-0.8676316179551925
87
+ -0.3496333559465059,-0.8854382694554389
88
+ 0.45921235667612814,-0.8123682362870449
89
+ 0.27511494271042625,-0.927227629948057
90
+ 0.7744254851526531,-0.5965370771413236
91
+ -0.05557014967610141,-0.9969167301387304
92
+ -0.7608115081233966,-0.603785604256095
93
+ 0.42648957444599,-0.8353116478277598
94
+ 0.5215700972337949,-0.7670333944975515
95
+ 0.12255439513899247,-0.9850936334783424
96
+ 0.541934359909122,-0.7519912093984299
97
+ -0.012408807271218514,-0.999846033357251
98
+ 0.045465658763988115,-0.997935011480979
99
+ -0.14491796328290074,-0.9792204513367526
100
+ -0.9491617465118096,-0.5434895142161542
101
+ -0.7842171460133911,-0.5915491851197053
1_data/benchmarks/nguyen/nguyen_5.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_5
2
+ equation: sin(x**2)*cos(x) - 1
3
+ latex: \sin(x^2) \cos(x) - 1
4
+ n_vars: 1
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_6.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ -0.250919762305275,-0.43514929053571705
3
+ 0.9014286128198323,1.7739777070144762
4
+ 0.4639878836228102,1.0757452027433314
5
+ 0.1973169683940732,0.4300984433827959
6
+ -0.687962719115127,-0.8479896478976541
7
+ -0.6880109593275947,-0.8480091939305666
8
+ -0.8838327756636011,-0.8756673756307944
9
+ 0.7323522915498704,1.623333209866147
10
+ 0.2022300234864176,0.4415932453735026
11
+ 0.416145155592091,0.9600350888652831
12
+ -0.9588310114083951,-0.8579844227187403
13
+ 0.9398197043239886,1.7757964620012405
14
+ 0.6648852816008435,1.511310612392346
15
+ -0.5753217786434477,-0.7860078048335528
16
+ -0.6363500655857988,-0.8236127344518008
17
+ -0.6331909802931324,-0.8218977378877982
18
+ -0.39151551408092455,-0.6175737622323656
19
+ 0.04951286326447568,0.10143363891812185
20
+ -0.13610996271576847,-0.2530033664640832
21
+ -0.4175417196039162,-0.646324893603273
22
+ 0.22370578944475894,0.4921883475197036
23
+ -0.7210122786959163,-0.859945147297601
24
+ -0.4157107029295637,-0.6443541593670672
25
+ -0.2672763134126166,-0.4586956789581246
26
+ -0.08786003156592814,-0.16780192738539715
27
+ 0.5703519227860272,1.3205458271668318
28
+ -0.6006524356832805,-0.8027562953942011
29
+ 0.02846887682722321,0.0577402019392117
30
+ 0.18482913772408494,0.40102336487359014
31
+ -0.9070991745600046,-0.8718906080781982
32
+ 0.21508970380287673,0.4718231825909881
33
+ -0.6589517526254169,-0.8351357987774414
34
+ -0.869896814029441,-0.8771972870866676
35
+ 0.8977710745066665,1.773112261855638
36
+ 0.9312640661491187,1.7765585415628578
37
+ 0.6167946962329223,1.4183947871732054
38
+ -0.39077246165325863,-0.616729624340086
39
+ -0.8046557719872323,-0.8771304085531529
40
+ 0.3684660530243138,0.8433211553872129
41
+ -0.1196950125207974,-0.22458265836212168
42
+ -0.7559235303104423,-0.8694190928201299
43
+ -0.00964617977745963,-0.019199015870959124
44
+ -0.9312229577695632,-0.8663534501230806
45
+ 0.8186408041575641,1.7268591780806577
46
+ -0.48244003679996617,-0.7110472800982361
47
+ 0.32504456870796394,0.7368566710350718
48
+ -0.3765778478211781,-0.6003566284265445
49
+ 0.040136042355621626,0.08186008651372706
50
+ 0.0934205586865593,0.19525514485076978
51
+ -0.6302910889489459,-0.8203010105715378
52
+ 0.9391692555291171,1.7758783124934705
53
+ 0.5502656467222291,1.2762087516369758
54
+ 0.8789978831283782,1.7668343124729904
55
+ 0.7896547008552977,1.6977190241718285
56
+ 0.19579995762217028,0.42655545364748826
57
+ 0.8437484700462337,1.7470252942741467
58
+ -0.823014995896161,-0.8783461682408447
59
+ -0.6080342751617096,-0.8073339482620789
60
+ -0.9095454221789239,-0.8714044274820212
61
+ -0.3493393384734713,-0.5676262516988537
62
+ -0.22264542062103598,-0.39302217898548675
63
+ -0.4573019364522082,-0.6871659884222171
64
+ 0.6574750183038587,1.497630717093317
65
+ -0.28649334661282144,-0.48558453002226076
66
+ -0.4381309806252385,-0.667941100272472
67
+ 0.08539216631649693,0.17783977455509073
68
+ -0.7181515500504747,-0.8590244862623294
69
+ 0.6043939615080793,1.3929716850978804
70
+ -0.8508987126404584,-0.8784032535275912
71
+ 0.9737738732010346,1.7659692545236667
72
+ 0.5444895385933148,1.2632639852448593
73
+ -0.6025686369316552,-0.8039577784468424
74
+ -0.9889557657527952,-0.8463746041245153
75
+ 0.6309228569096683,1.4466669850764435
76
+ 0.41371468769523423,0.9541081599557866
77
+ 0.4580143360819746,1.0614213526661842
78
+ 0.5425406933718915,1.2588779524345566
79
+ -0.8519106965318193,-0.8783645297620329
80
+ -0.2830685429114548,-0.48085397416735837
81
+ -0.7682618809497406,-0.8719829040159035
82
+ 0.726206851751187,1.614146688730815
83
+ 0.24659625365511584,0.5466918275185086
84
+ -0.3382039502947016,-0.5537512584147581
85
+ -0.8728832994279527,-0.8769154903633847
86
+ -0.3780353565686756,-0.602059443902563
87
+ -0.3496333559465059,-0.5679887252654134
88
+ 0.45921235667612814,1.06429743105736
89
+ 0.27511494271042625,0.6153097358245831
90
+ 0.7744254851526531,1.6800348672828789
91
+ -0.05557014967610141,-0.10799957252353272
92
+ -0.7608115081233966,-0.8704840611139395
93
+ 0.42648957444599,0.985218668558129
94
+ 0.5215700972337949,1.2111284238136855
95
+ 0.12255439513899247,0.25938825702662677
96
+ 0.541934359909122,1.2575114815069453
97
+ -0.012408807271218514,-0.024663010860991184
98
+ 0.045465658763988115,0.09296488443027026
99
+ -0.14491796328290074,-0.26801111295446844
100
+ -0.9491617465118096,-0.8611626229977559
101
+ -0.7842171460133911,-0.8746853295473812
1_data/benchmarks/nguyen/nguyen_6.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_6
2
+ equation: sin(x) + sin(x + x**2)
3
+ latex: \sin(x) + \sin(x + x^2)
4
+ n_vars: 1
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_7.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ 0.749080237694725,1.0044943539710771
3
+ 1.9014286128198323,2.5946084456041874
4
+ 1.4639878836228102,2.04704177271813
5
+ 1.1973169683940732,1.6765955177464844
6
+ 0.31203728088487304,0.3644950207193922
7
+ 0.3119890406724053,0.36443082009006955
8
+ 0.11616722433639892,0.12330527518118403
9
+ 1.7323522915498704,2.3917183259947485
10
+ 1.2022300234864176,1.683661630559711
11
+ 1.416145155592091,1.9826063566319803
12
+ 0.041168988591604894,0.042037560323179395
13
+ 1.9398197043239886,2.6392050827257325
14
+ 1.6648852816008435,2.307724752047451
15
+ 0.4246782213565523,0.5197583318072219
16
+ 0.36364993441420124,0.4343639775100977
17
+ 0.36680901970686763,0.43871392942129606
18
+ 0.6084844859190754,0.7902880866466716
19
+ 1.0495128632644757,1.4602426871151915
20
+ 0.8638900372842315,1.1801684194501334
21
+ 0.5824582803960838,0.7510949783115568
22
+ 1.223705789444759,1.7144476360772682
23
+ 0.27898772130408367,0.3210225319850415
24
+ 0.5842892970704363,0.7538452725717105
25
+ 0.7327236865873834,0.9794515463618694
26
+ 0.9121399684340719,1.2536309133704537
27
+ 1.5703519227860272,2.187045504912016
28
+ 0.39934756431671947,0.4839764178477628
29
+ 1.0284688768272232,1.4288935535974223
30
+ 1.184829137724085,1.6585968019236683
31
+ 0.09290082543999545,0.09742900194019119
32
+ 1.2150897038028767,1.702116152706024
33
+ 0.34104824737458306,0.4034836854209194
34
+ 0.13010318597055903,0.13909411999883536
35
+ 1.8977710745066665,2.5903318010915215
36
+ 1.9312640661491187,2.629312630527168
37
+ 1.6167946962329223,2.2467723049533905
38
+ 0.6092275383467414,0.7914100534079171
39
+ 0.19534422801276774,0.21588350693408828
40
+ 1.3684660530243138,1.917494634319942
41
+ 0.8803049874792026,1.2051988477172373
42
+ 0.24407646968955765,0.27625976490939214
43
+ 0.9903538202225404,1.3718135735019552
44
+ 0.06877704223043679,0.07123417253294272
45
+ 1.8186408041575641,2.4966018816848856
46
+ 0.5175599632000338,0.6544407561341439
47
+ 1.325044568707964,1.8574263454835047
48
+ 0.6234221521788219,0.8128721430860975
49
+ 1.0401360423556216,1.4462892592771355
50
+ 1.0934205586865593,1.5252403780918733
51
+ 0.3697089110510541,0.44271409262957717
52
+ 1.9391692555291171,2.6384539264232347
53
+ 1.550265646722229,2.160950005266661
54
+ 1.8789978831283782,2.5683039510702823
55
+ 1.7896547008552977,2.461684010396808
56
+ 1.1957999576221703,1.6744119782815021
57
+ 1.8437484700462337,2.5265931477776116
58
+ 0.176985004103839,0.19379920212879384
59
+ 0.3919657248382904,0.47363661139990393
60
+ 0.09045457782107613,0.09474339248156945
61
+ 0.6506606615265287,0.8541953881759485
62
+ 0.777354579378964,1.0478012138309305
63
+ 0.5426980635477918,0.6916737607221609
64
+ 1.6574750183038587,2.298391579491666
65
+ 0.7135066533871786,0.9500499213476303
66
+ 0.5618690193747615,0.7202496008381812
67
+ 1.085392166316497,1.513398916251658
68
+ 0.2818484499495253,0.3247441814248926
69
+ 1.6043939615080793,2.230907799621039
70
+ 0.14910128735954165,0.1609678309146809
71
+ 1.9737738732010346,2.6782060947078543
72
+ 1.5444895385933148,2.153416282610123
73
+ 0.3974313630683448,0.4812884691333153
74
+ 0.011044234247204798,0.011105659717543844
75
+ 1.6309228569096683,2.2647730643008543
76
+ 1.4137146876952342,1.9793088424715877
77
+ 1.4580143360819746,2.0390459508436884
78
+ 1.5425406933718915,2.150871440176907
79
+ 0.14808930346818072,0.15979251395522687
80
+ 0.7169314570885452,0.9552876826431883
81
+ 0.23173811905025943,0.26073648199843447
82
+ 1.726206851751187,2.3841402894128048
83
+ 1.2465962536551158,1.7470779213505572
84
+ 0.6617960497052984,0.8711341615273885
85
+ 0.12711670057204727,0.13569227346664953
86
+ 0.6219646434313244,0.8106659478651683
87
+ 0.6503666440534941,0.8537484672462254
88
+ 1.4592123566761281,2.0406506801223507
89
+ 1.2751149427104262,1.7874611691912827
90
+ 1.774425485152653,2.4432111990381067
91
+ 0.9444298503238986,1.302575647854838
92
+ 0.2391884918766034,0.27009115923794424
93
+ 1.42648957444599,1.996615009036411
94
+ 1.5215700972337949,2.1233923437150204
95
+ 1.1225543951389925,1.5680418793142075
96
+ 1.541934359909122,2.150079370436032
97
+ 0.9875911927287815,1.3676621443430883
98
+ 1.0454656587639881,1.454223169438932
99
+ 0.8550820367170993,1.1667236792399285
100
+ 0.05083825348819038,0.05216937619240417
101
+ 0.2157828539866089,0.24089892921884035
1_data/benchmarks/nguyen/nguyen_7.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_7
2
+ equation: log(x + 1) + log(x**2 + 1)
3
+ latex: \ln(x+1) + \ln(x^2+1)
4
+ n_vars: 1
5
+ range: (0, 2)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_8.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,y
2
+ 1.49816047538945,1.2239936582308955
3
+ 3.8028572256396647,1.9500915941667112
4
+ 2.9279757672456204,1.7111328900017146
5
+ 2.3946339367881464,1.547460479879259
6
+ 0.6240745617697461,0.7899838996902064
7
+ 0.6239780813448106,0.789922832525311
8
+ 0.23233444867279784,0.48201083875033124
9
+ 3.4647045830997407,1.8613716939665061
10
+ 2.404460046972835,1.5506321443117432
11
+ 2.832290311184182,1.6829409707961185
12
+ 0.08233797718320979,0.2869459481909612
13
+ 3.8796394086479773,1.9696800269708725
14
+ 3.329770563201687,1.8247658927110861
15
+ 0.8493564427131046,0.9216053616994123
16
+ 0.7272998688284025,0.8528187784215369
17
+ 0.7336180394137353,0.8565150549837027
18
+ 1.216968971838151,1.103163166461857
19
+ 2.0990257265289514,1.4488014793369557
20
+ 1.727780074568463,1.3144504838785154
21
+ 1.1649165607921677,1.0793130040874
22
+ 2.447411578889518,1.564420524951497
23
+ 0.5579754426081673,0.7469775382219785
24
+ 1.1685785941408726,1.0810081378698648
25
+ 1.4654473731747668,1.2105566377393364
26
+ 1.8242799368681437,1.3506590749956644
27
+ 3.1407038455720544,1.7722031050565437
28
+ 0.7986951286334389,0.893697448040129
29
+ 2.0569377536544464,1.4342028286314479
30
+ 2.36965827544817,1.5393694408582268
31
+ 0.1858016508799909,0.4310471562137847
32
+ 2.4301794076057535,1.5589032707662633
33
+ 0.6820964947491661,0.825891333499248
34
+ 0.26020637194111806,0.5101042755565944
35
+ 3.795542149013333,1.948215118772394
36
+ 3.8625281322982374,1.965331557854358
37
+ 3.2335893924658445,1.7982183939849588
38
+ 1.2184550766934827,1.1038365262544463
39
+ 0.3906884560255355,0.6250507627589422
40
+ 2.7369321060486276,1.6543675849244108
41
+ 1.7606099749584052,1.3268797891890602
42
+ 0.4881529393791153,0.698679425329754
43
+ 1.9807076404450807,1.4073761545674564
44
+ 0.13755408446087358,0.370882844657007
45
+ 3.6372816083151283,1.9071658575790225
46
+ 1.0351199264000677,1.017408436371582
47
+ 2.650089137415928,1.6279094377194108
48
+ 1.2468443043576438,1.1166218269215606
49
+ 2.0802720847112433,1.4423148355027218
50
+ 2.1868411173731186,1.4787971860174467
51
+ 0.7394178221021082,0.8598940760943223
52
+ 3.8783385110582342,1.9693497685932366
53
+ 3.100531293444458,1.7608325569015522
54
+ 3.7579957662567565,1.9385550717626663
55
+ 3.5793094017105953,1.8919062877718325
56
+ 2.3915999152443406,1.546479846375096
57
+ 3.6874969400924673,1.920285640234928
58
+ 0.353970008207678,0.5949537866151269
59
+ 0.7839314496765808,0.8853990341515969
60
+ 0.18090915564215226,0.4253341693799738
61
+ 1.3013213230530574,1.1407547164281449
62
+ 1.554709158757928,1.2468797691669906
63
+ 1.0853961270955836,1.0418234625384397
64
+ 3.3149500366077174,1.8207004247288232
65
+ 1.4270133067743571,1.194576622395716
66
+ 1.123738038749523,1.0600651106179861
67
+ 2.170784332632994,1.473358182056554
68
+ 0.5636968998990506,0.750797509252029
69
+ 3.2087879230161587,1.7913089970789962
70
+ 0.2982025747190833,0.5460792751232034
71
+ 3.947547746402069,1.9868436643083092
72
+ 3.0889790771866297,1.7575491677863893
73
+ 0.7948627261366896,0.8915507423229985
74
+ 0.022088468494409597,0.14862189776210502
75
+ 3.2618457138193366,1.806058059371109
76
+ 2.8274293753904685,1.6814961716847494
77
+ 2.9160286721639492,1.7076383317798736
78
+ 3.085081386743783,1.756439975274926
79
+ 0.29617860693636144,0.5442229386348589
80
+ 1.4338629141770904,1.197440150561643
81
+ 0.46347623810051886,0.6807908916110136
82
+ 3.452413703502374,1.8580671956370076
83
+ 2.4931925073102317,1.578984644418758
84
+ 1.3235920994105967,1.1504747278452476
85
+ 0.25423340114409454,0.5042156296110768
86
+ 1.2439292868626488,1.1153157789893626
87
+ 1.3007332881069882,1.1404969478727194
88
+ 2.9184247133522563,1.7083397534894094
89
+ 2.5502298854208525,1.5969439205622884
90
+ 3.548850970305306,1.8838394226433701
91
+ 1.8888597006477972,1.3743579230490859
92
+ 0.4783769837532068,0.6916480201324998
93
+ 2.85297914889198,1.689076418902348
94
+ 3.0431401944675898,1.7444598575110835
95
+ 2.245108790277985,1.4983687097233394
96
+ 3.083868719818244,1.7560947354337817
97
+ 1.975182385457563,1.4054118205912327
98
+ 2.0909313175279762,1.4460052965075807
99
+ 1.7101640734341985,1.3077324166029527
100
+ 0.10167650697638075,0.3188675382919697
101
+ 0.4315657079732178,0.6569366087935866
1_data/benchmarks/nguyen/nguyen_8.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_8
2
+ equation: sqrt(x)
3
+ latex: \sqrt{x}
4
+ n_vars: 1
5
+ range: (0, 4)
6
+ n_samples: 100
1_data/benchmarks/nguyen/nguyen_9.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_1,x_2,y
2
+ -0.250919762305275,0.9014286128198323,0.47776420484711724
3
+ 0.4639878836228102,0.1973169683940732,0.48644207085812474
4
+ -0.687962719115127,-0.6880109593275947,-0.17908604352731888
5
+ -0.8838327756636011,0.7323522915498704,-0.26218203193510137
6
+ 0.2022300234864176,0.416145155592091,0.37316689293318206
7
+ -0.9588310114083951,0.9398197043239886,-0.0457079947346013
8
+ 0.6648852816008435,-0.5753217786434477,0.9419531756947639
9
+ -0.6363500655857988,-0.6331909802931324,-0.2039883598325784
10
+ -0.39151551408092455,0.04951286326447568,-0.37913816951228
11
+ -0.13610996271576847,-0.4175417196039162,0.037769159921857864
12
+ 0.22370578944475894,-0.7210122786959163,0.718602102079344
13
+ -0.4157107029295637,-0.2672763134126166,-0.33246432027777695
14
+ -0.08786003156592814,0.5703519227860272,0.23184727590209103
15
+ -0.6006524356832805,0.02846887682722321,-0.5643703547272768
16
+ 0.18482913772408494,-0.9070991745600046,0.9168514297790041
17
+ 0.21508970380287673,-0.6589517526254169,0.6341356329828418
18
+ -0.869896814029441,0.8977710745066665,-0.04274391665004529
19
+ 0.9312640661491187,0.6167946962329223,1.1737000528120816
20
+ -0.39077246165325863,-0.8046557719872323,0.2222683475391934
21
+ 0.3684660530243138,-0.1196950125207974,0.37451127227879794
22
+ -0.7559235303104423,-0.00964617977745963,-0.6858679083778902
23
+ -0.9312229577695632,0.8186408041575641,-0.1812290712696445
24
+ -0.48244003679996617,0.32504456870796394,-0.35848458267243266
25
+ -0.3765778478211781,0.040136042355621626,-0.3661293707647741
26
+ 0.0934205586865593,-0.6302910889489459,0.4801842308956415
27
+ 0.9391692555291171,0.5502656467222291,1.1052544786389042
28
+ 0.8789978831283782,0.7896547008552977,1.3540244447960008
29
+ 0.19579995762217028,0.8437484700462337,0.8478334428521738
30
+ -0.823014995896161,-0.6080342751617096,-0.3718583841132946
31
+ -0.9095454221789239,-0.3493393384734713,-0.6674893880400039
32
+ -0.22264542062103598,-0.4573019364522082,-0.01320641283786203
33
+ 0.6574750183038587,-0.28649334661282144,0.6931064937612617
34
+ -0.4381309806252385,0.08539216631649693,-0.4169559662599447
35
+ -0.7181515500504747,0.6043939615080793,-0.30077177092335655
36
+ -0.8508987126404584,0.9737738732010346,0.060514655114857874
37
+ 0.5444895385933148,-0.6025686369316552,0.8731450114593164
38
+ -0.9889557657527952,0.6309228569096683,-0.44781844397033194
39
+ 0.41371468769523423,0.4580143360819746,0.6102553103342886
40
+ 0.5425406933718915,-0.8519106965318193,1.1800115078177271
41
+ -0.2830685429114548,-0.7682618809497406,0.27724568254954496
42
+ 0.726206851751187,0.24659625365511584,0.7248105293625019
43
+ -0.3382039502947016,-0.8728832994279527,0.3585223402710981
44
+ -0.3780353565686756,-0.3496333559465059,-0.24715600583220282
45
+ 0.45921235667612814,0.27511494271042625,0.5188581859931962
46
+ 0.7744254851526531,-0.05557014967610141,0.7023935510282614
47
+ -0.7608115081233966,0.42648957444599,-0.5086174096415731
48
+ 0.5215700972337949,0.12255439513899247,0.5132611003903497
49
+ 0.541934359909122,-0.012408807271218514,0.515948124519243
50
+ 0.045465658763988115,-0.14491796328290074,0.06644966885563074
51
+ -0.9491617465118096,-0.7842171460133911,-0.23597193139241113
52
+ -0.9371416286265315,0.2728208225275608,-0.7315064751214
53
+ -0.37128803784734665,0.01714138232940554,-0.362522177588772
54
+ 0.815132947852186,-0.5014155417022501,0.9765939645263036
55
+ -0.17923415392874054,0.5111022770860973,0.07998865766083815
56
+ -0.5424036690167551,-0.846040180342414,0.14001304604376585
57
+ -0.42049709417246395,-0.6775574254919912,0.034912901016997455
58
+ 0.8593953046851461,0.6162407591288339,1.1281386729129697
59
+ 0.26680751302084693,0.7429211803754354,0.7879864844362704
60
+ 0.6073441537982289,-0.6268598822279283,0.95360688133832
61
+ 0.7851179969799555,0.07868448383130144,0.7130998543064418
62
+ 0.6148803103281251,0.7921825998469865,1.164026731337986
63
+ -0.36399305005627225,-0.7798961509446465,0.21541385909859811
64
+ -0.5441296749161166,-0.14578442274748737,-0.496422157556886
65
+ 0.6360295318449862,0.7214611665126869,1.0913254238549446
66
+ -0.9860957389376186,0.021494605155131463,-0.8334153656803461
67
+ -0.16517799370244202,-0.5557843790585395,0.1395794327899027
68
+ -0.7602692653326344,-0.3247696571927441,-0.5838367239902751
69
+ 0.8858194078250383,-0.35359413595848954,0.8991369711397181
70
+ 0.037581243486732197,0.4060379177903557,0.2016933272194322
71
+ -0.27274079524141204,0.9435641654419213,0.5078969955272314
72
+ 0.9248945898842225,-0.4964354083492717,1.0425182593662667
73
+ -0.005502988215229099,-0.3982433803664607,0.15243079003847154
74
+ -0.4303190112450648,-0.9262261052909344,0.3393066204879884
75
+ 0.2191286679597937,0.005358046457722976,0.217407918292659
76
+ -0.8970424975000213,-0.44270707152677713,-0.586747836651068
77
+ 0.8165317719333074,-0.5208762186660552,0.9967710464625614
78
+ -0.7102102558175538,-0.02109447944487397,-0.6515482295602607
79
+ 0.9713009082212014,-0.5158894569769992,1.0886315849160677
80
+ 0.3442710948117571,0.5232392306574352,0.6078825279565068
81
+ -0.5247249120152007,0.45643269722371915,-0.29414785755880746
82
+ -0.26443373456149355,0.26461166118715895,-0.1914005964405623
83
+ 0.2670594215217894,0.07154936814951696,0.26901553332662537
84
+ -0.8194204598911834,0.6706049911784759,-0.2960449968694709
85
+ -0.35843987005652833,-0.6269629792002915,0.03222402345128389
86
+ -0.9184497168904722,0.18178588637648363,-0.7616213774179811
87
+ 0.35512872368456483,-0.9668243421442877,1.152161247302312
88
+ 0.024186116598561958,-0.5470084496041241,0.31895703499280215
89
+ 0.2903455808188997,-0.6512671419900171,0.6978285972621991
90
+ 0.3818754762049319,-0.22652930739892518,0.4239545131674068
91
+ 0.873459977473469,-0.7249581117080135,1.268256638618119
92
+ -0.317867297899483,-0.7730529575188219,0.2501276235695559
93
+ 0.8493872365571256,0.754678706761962,1.290120526044866
94
+ -0.4841167445696888,0.3199680920683581,-0.36322596002563673
95
+ 0.6344444004024317,0.11040162319892466,0.6049183419335217
96
+ 0.05930115671201297,-0.5162954181990966,0.32268182428446945
97
+ -0.8137944643882016,0.7944315159066535,-0.13684769483552173
98
+ 0.8008361143266609,0.2662029145465359,0.7887430634679737
99
+ -0.32194041790259864,-0.3015808507746782,-0.22558221167734632
100
+ 0.45191135774047875,0.7942205199051542,1.0264656897505602
101
+ 0.7741728485302346,0.5597510917152477,1.0073448206218134
1_data/benchmarks/nguyen/nguyen_9.meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: nguyen_9
2
+ equation: sin(x) + sin(y**2)
3
+ latex: \sin(x) + \sin(y^2)
4
+ n_vars: 2
5
+ range: (-1, 1)
6
+ n_samples: 100
1_data/processed/700K_prefix_converted/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e6c9ddae5a332af4ee05d1aebbb913757316ce685a179712212e6e8334ba5f1
3
+ size 7296336
1_data/processed/700K_prefix_converted/dataset_info.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builder_name": "csv",
3
+ "citation": "",
4
+ "config_name": "default",
5
+ "dataset_name": "sintetico_natural",
6
+ "dataset_size": 5704135,
7
+ "description": "",
8
+ "download_checksums": {
9
+ "hf://datasets/augustocsc/sintetico_natural@fe48ddc600c4674bcff395308dc8d7c32aed6135/mini_test/test_mini_test.csv": {
10
+ "num_bytes": 560871,
11
+ "checksum": null
12
+ },
13
+ "hf://datasets/augustocsc/sintetico_natural@fe48ddc600c4674bcff395308dc8d7c32aed6135/mini_test/train_mini_test.csv": {
14
+ "num_bytes": 4378438,
15
+ "checksum": null
16
+ },
17
+ "hf://datasets/augustocsc/sintetico_natural@fe48ddc600c4674bcff395308dc8d7c32aed6135/mini_test/val_mini_test.csv": {
18
+ "num_bytes": 545160,
19
+ "checksum": null
20
+ }
21
+ },
22
+ "download_size": 5484469,
23
+ "features": {
24
+ "infix_expr_n": {
25
+ "dtype": "string",
26
+ "_type": "Value"
27
+ },
28
+ "infix_expr_c": {
29
+ "dtype": "string",
30
+ "_type": "Value"
31
+ },
32
+ "expression_objects": {
33
+ "dtype": "string",
34
+ "_type": "Value"
35
+ },
36
+ "prefix_expr_c": {
37
+ "dtype": "string",
38
+ "_type": "Value"
39
+ },
40
+ "prefix_expr_n": {
41
+ "dtype": "string",
42
+ "_type": "Value"
43
+ },
44
+ "i_prompt_n": {
45
+ "dtype": "string",
46
+ "_type": "Value"
47
+ },
48
+ "p_prompt_n": {
49
+ "dtype": "string",
50
+ "_type": "Value"
51
+ },
52
+ "skeleton": {
53
+ "dtype": "string",
54
+ "_type": "Value"
55
+ },
56
+ "p_prompt_n_converted": {
57
+ "dtype": "string",
58
+ "_type": "Value"
59
+ },
60
+ "conversion_success": {
61
+ "dtype": "bool",
62
+ "_type": "Value"
63
+ }
64
+ },
65
+ "homepage": "",
66
+ "license": "",
67
+ "size_in_bytes": 11188604,
68
+ "splits": {
69
+ "test": {
70
+ "name": "test",
71
+ "num_bytes": 5704135,
72
+ "num_examples": 12221,
73
+ "dataset_name": "sintetico_natural"
74
+ }
75
+ },
76
+ "version": {
77
+ "version_str": "0.0.0",
78
+ "major": 0,
79
+ "minor": 0,
80
+ "patch": 0
81
+ }
82
+ }
1_data/processed/700K_prefix_converted/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "d4d2d5394689c8df",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": "test"
13
+ }
1_data/processed/PREFIX_CONVERSION_README.md ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Conversão Infix → Prefix Dataset
2
+
3
+ ## Objetivo
4
+
5
+ Converter o dataset `augustocsc/sintetico_natural` de notação **infix** para **prefix (Polish notation)**, mantendo as mesmas expressões matemáticas mas em formato prefixado.
6
+
7
+ ## Motivação
8
+
9
+ O dataset original contém:
10
+ - `i_prompt_n`: Prompts com expressões em notação infix (e.g., `x_1 + x_2`)
11
+ - `p_prompt_n`: Prompts com expressões em notação prefix (e.g., `+ x_1 x_2`)
12
+
13
+ **PROBLEMA**: As expressões em `i_prompt_n` e `p_prompt_n` são DIFERENTES! Não são a mesma expressão convertida.
14
+
15
+ **SOLUÇÃO**: Converter automaticamente as expressões de `i_prompt_n` para notação prefix, criando `p_prompt_n_converted`.
16
+
17
+ ## Script de Conversão
18
+
19
+ **Arquivo**: `scripts/data/convert_infix_to_prefix.py`
20
+
21
+ **Funcionalidade**:
22
+ 1. Lê expressões infix do campo `i_prompt_n`
23
+ 2. Parseia usando SymPy
24
+ 3. Converte para notação prefix (Polish notation)
25
+ 4. Mantém as mesmas variáveis e operadores do prompt original
26
+ 5. Salva em nova coluna `p_prompt_n_converted`
27
+
28
+ ## Exemplos de Conversão
29
+
30
+ ### Exemplo 1
31
+ **INFIX**:
32
+ ```
33
+ vars: x_1, x_2, x_3, x_4, x_5
34
+ oper: *, +, -, /, abs, asin, cos, exp, log, sin, sqrt, tan
35
+ cons: C
36
+ expr: x_2 - (x_5 - C)*(x_4 + exp(C*x_2) + C)
37
+ ```
38
+
39
+ **PREFIX**:
40
+ ```
41
+ vars: x_1, x_2, x_3, x_4, x_5
42
+ oper: *, +, -, /, abs, asin, cos, exp, log, sin, sqrt, tan
43
+ cons: C
44
+ expr: - x_2 * - x_5 C + + x_4 exp * C x_2 C
45
+ ```
46
+
47
+ ### Exemplo 2
48
+ **INFIX**:
49
+ ```
50
+ vars: x_1, x_2, x_3
51
+ oper: +, -, /, abs, cos, exp
52
+ cons: C
53
+ expr: (x_1 - C)/(x_2 + C)
54
+ ```
55
+
56
+ **PREFIX**:
57
+ ```
58
+ vars: x_1, x_2, x_3
59
+ oper: +, -, /, abs, cos, exp
60
+ cons: C
61
+ expr: / - x_1 C + x_2 C
62
+ ```
63
+
64
+ ### Exemplo 3
65
+ **INFIX**:
66
+ ```
67
+ vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10
68
+ oper: *, +, /, asin, sin, tan
69
+ cons: C
70
+ expr: (tan(x_7) + C)*(asin(x_5) + C)
71
+ ```
72
+
73
+ **PREFIX**:
74
+ ```
75
+ vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10
76
+ oper: *, +, /, asin, sin, tan
77
+ cons: C
78
+ expr: * + tan x_7 C + asin x_5 C
79
+ ```
80
+
81
+ ## Regras de Conversão
82
+
83
+ ### Operadores Binários
84
+ - Infix: `a + b` → Prefix: `+ a b`
85
+ - Infix: `a - b` → Prefix: `- a b`
86
+ - Infix: `a * b` → Prefix: `* a b`
87
+ - Infix: `a / b` → Prefix: `/ a b`
88
+ - Infix: `a ** b` → Prefix: `** a b`
89
+
90
+ ### Funções Unárias
91
+ - Infix: `sin(x)` → Prefix: `sin x`
92
+ - Infix: `exp(x)` → Prefix: `exp x`
93
+ - Infix: `log(x)` → Prefix: `log x`
94
+
95
+ ### Expressões Complexas
96
+ - Infix: `sin(x**2)` → Prefix: `sin ** x 2`
97
+ - Infix: `x*(y + z)` → Prefix: `* x + y z`
98
+ - Infix: `(a + b)*(c + d)` → Prefix: `* + a b + c d`
99
+
100
+ ### Casos Especiais
101
+ - **Negação**: `-x` → `* -1 x`
102
+ - **Múltiplas adições**: `a + b + c` → `+ + a b c` (aninhado à esquerda)
103
+ - **Divisão**: `x/y` → `/ x y` (SymPy representa como `x * y**-1`, conversão detecta e corrige)
104
+
105
+ ## Uso do Script
106
+
107
+ ### Teste (10 exemplos)
108
+ ```bash
109
+ python scripts/data/convert_infix_to_prefix.py --test_only
110
+ ```
111
+
112
+ ### Conversão Completa
113
+ ```bash
114
+ python scripts/data/convert_infix_to_prefix.py \
115
+ --split test \
116
+ --output_path ./1_data/processed/700K_prefix_converted
117
+ ```
118
+
119
+ ### Upload para HuggingFace
120
+ ```bash
121
+ python scripts/data/convert_infix_to_prefix.py \
122
+ --split test \
123
+ --output_path ./1_data/processed/700K_prefix_converted \
124
+ --upload \
125
+ --repo_id augustocsc/sintetico_natural_prefix_converted
126
+ ```
127
+
128
+ ## Output do Dataset
129
+
130
+ **Colunas adicionadas**:
131
+ - `p_prompt_n_converted`: Prompt prefixado convertido do infix
132
+ - `conversion_success`: Boolean indicando se conversão foi bem-sucedida
133
+
134
+ **Taxa de sucesso esperada**: ~99%+
135
+
136
+ ## Vantagens
137
+
138
+ 1. **Comparabilidade**: Agora é possível treinar modelos prefix com as MESMAS expressões dos modelos infix
139
+ 2. **Consistência**: Mantém vars/operators/constants do prompt original
140
+ 3. **Reprodutibilidade**: Conversão automática e determinística
141
+ 4. **Escalabilidade**: Fácil aplicar a novos datasets
142
+
143
+ ## Treinamento com Dataset Convertido
144
+
145
+ ### Usando o dataset local
146
+ ```bash
147
+ python 2_training/supervised/train.py \
148
+ --dataset_path ./1_data/processed/700K_prefix_converted \
149
+ --data_column p_prompt_n_converted \
150
+ --approach prefix \
151
+ --output_dir ./output/gpt2_prefix_converted
152
+ ```
153
+
154
+ ### Comparação: Infix vs Prefix (mesma expressão)
155
+ Agora você pode treinar dois modelos com a MESMA expressão:
156
+ - Modelo A: `--data_column i_prompt_n --approach infix`
157
+ - Modelo B: `--data_column p_prompt_n_converted --approach prefix`
158
+
159
+ E comparar diretamente qual notação o modelo aprende melhor!
160
+
161
+ ## Validação
162
+
163
+ Para verificar a correção da conversão:
164
+ 1. Parsear expressão prefix convertida
165
+ 2. Avaliar em pontos de teste
166
+ 3. Comparar com avaliação da expressão infix original
167
+ 4. R² score deve ser ~1.0 (mesma expressão, mesmos resultados)
168
+
169
+ ```python
170
+ from classes.expression import Expression
171
+
172
+ # Expressão infix original
173
+ expr_infix = Expression("x_1 + x_2", is_prefix=False)
174
+
175
+ # Expressão prefix convertida
176
+ expr_prefix = Expression("+ x_1 x_2", is_prefix=True)
177
+
178
+ # Testar
179
+ import numpy as np
180
+ x = np.array([[1.0, 2.0], [3.0, 4.0]])
181
+
182
+ result_infix = expr_infix.evaluate(x)
183
+ result_prefix = expr_prefix.evaluate(x)
184
+
185
+ # Devem ser idênticos
186
+ assert np.allclose(result_infix, result_prefix)
187
+ ```
188
+
189
+ ## Limitações Conhecidas
190
+
191
+ 1. **Expressões muito complexas**: Pode haver casos edge com aninhamento profundo
192
+ 2. **Operadores customizados**: Se houver operadores não mapeados no SymPy
193
+ 3. **Simplificação automática**: SymPy pode simplificar algumas expressões, alterando forma mas não valor
194
+
195
+ ## Status
196
+
197
+ - [x] Script de conversão implementado
198
+ - [x] Testes em 10 exemplos (100% sucesso)
199
+ - [ ] Conversão dataset completo (~12k exemplos)
200
+ - [ ] Upload para HuggingFace Hub
201
+ - [ ] Treinamento de modelo teste
202
+ - [ ] Comparação infix vs prefix
203
+
204
+ ## Referências
205
+
206
+ - **Polish Notation**: https://en.wikipedia.org/wiki/Polish_notation
207
+ - **SymPy Documentation**: https://docs.sympy.org/
208
+ - **Dataset Original**: https://huggingface.co/datasets/augustocsc/sintetico_natural
209
+
210
+ ---
211
+
212
+ **Data de Criação**: 2026-02-09
213
+ **Autor**: Claude Sonnet 4.5 (co-authored)
214
+ **Última Atualização**: 2026-02-09
2_training/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 2_training/ - Treinamento e Fine-tuning
2
+
3
+ Este diretório contém todos os scripts e configurações para treinamento de modelos LLM para symbolic regression.
4
+
5
+ ## Estrutura
6
+
7
+ ```
8
+ 2_training/
9
+ ├── supervised/ # Fine-tuning supervisionado
10
+ │ ├── train_with_json.py # Principal: treino com formato JSON
11
+ │ ├── train.py # Script base
12
+ │ ├── train_experiment.py # Experimentos controlados
13
+ │ └── iterative_sampling_sft.py # SFT iterativo
14
+
15
+ ├── reinforcement/ # Reinforcement Learning
16
+ │ ├── ppo_symbolic.py # Proximal Policy Optimization
17
+ │ ├── grpo_symbolic.py # Group Relative PO
18
+ │ ├── reinforce_*.py # REINFORCE algorithm
19
+ │ └── best_of_n_experiment.py # Best-of-N sampling
20
+
21
+ └── configs/ # Configurações
22
+ └── wandb_config.py # Wandb naming standards
23
+ ```
24
+
25
+ ## Métodos de Treinamento
26
+
27
+ ### 1. Supervised Fine-tuning (Recomendado)
28
+
29
+ Script principal: `supervised/train_with_json.py`
30
+
31
+ **Características**:
32
+ - LoRA fine-tuning (apenas 294K parâmetros treináveis)
33
+ - Formato JSON estruturado (80% valid rate)
34
+ - Early stopping automático
35
+ - Split train/validation 90/10
36
+ - Integração com Wandb
37
+
38
+ **Uso**:
39
+ ```bash
40
+ cd supervised
41
+ python train_with_json.py \
42
+ --model_size gpt2-medium \
43
+ --dataset_path ../../1_data/processed/700K \
44
+ --output_dir ../../models/gpt2/medium_test \
45
+ --num_train_epochs 3 \
46
+ --per_device_train_batch_size 4
47
+ ```
48
+
49
+ **Modelos suportados**:
50
+ - `gpt2` (124M params)
51
+ - `gpt2-medium` (355M params)
52
+ - `gpt2-large` (774M params)
53
+ - GPT-Neo, LLaMA, Phi (futuro)
54
+
55
+ ### 2. Reinforcement Learning
56
+
57
+ #### PPO (Proximal Policy Optimization)
58
+ Script: `reinforcement/ppo_symbolic.py`
59
+
60
+ **Quando usar**:
61
+ - Problemas complexos (Nguyen 4+)
62
+ - Otimização de R² score
63
+ - Após supervised fine-tuning
64
+
65
+ **Uso**:
66
+ ```bash
67
+ cd reinforcement
68
+ python ppo_symbolic.py \
69
+ --model_path ../../models/gpt2/base_700k_json \
70
+ --dataset ../../1_data/benchmarks/nguyen/nguyen_5.csv \
71
+ --epochs 20
72
+ ```
73
+
74
+ #### GRPO (Group Relative Policy Optimization)
75
+ Script: `reinforcement/grpo_symbolic.py`
76
+
77
+ **Vantagens**:
78
+ - Mais estável que PPO
79
+ - Melhor para multi-modal rewards
80
+ - Baseado em DeepSeek-R1
81
+
82
+ #### REINFORCE
83
+ Script: `reinforcement/reinforce_symbolic.py`
84
+
85
+ **Características**:
86
+ - Simples e eficaz
87
+ - EMA baseline
88
+ - Bom para benchmarks fáceis (Nguyen 1-3)
89
+
90
+ ## Configuração LoRA
91
+
92
+ Configuração padrão (todos os modelos):
93
+ ```python
94
+ {
95
+ "r": 8,
96
+ "lora_alpha": 32,
97
+ "target_modules": ["c_attn"],
98
+ "lora_dropout": 0.05
99
+ }
100
+ ```
101
+
102
+ **Resultado**: ~294K parâmetros treináveis (vs 124M-774M total)
103
+
104
+ ## Hiperparâmetros Recomendados
105
+
106
+ ### Por Tamanho de Modelo
107
+
108
+ | Modelo | Batch Size | Instance | VRAM | Tempo |
109
+ |--------|-----------|----------|------|-------|
110
+ | GPT-2 Base | 8 | g5.xlarge | 24GB | 2-3h |
111
+ | GPT-2 Medium | 4 | g5.xlarge | 24GB | 3-4h |
112
+ | GPT-2 Large | 2 | g5.2xlarge | 48GB | 4-5h |
113
+
114
+ ### Outros Hiperparâmetros
115
+ ```python
116
+ learning_rate = 5e-5
117
+ num_train_epochs = 3
118
+ gradient_accumulation_steps = 4
119
+ warmup_steps = 500
120
+ weight_decay = 0.01
121
+ early_stopping_patience = 3
122
+ seed = 42
123
+ ```
124
+
125
+ ## Formato de Dados
126
+
127
+ ### JSON Format (Recomendado)
128
+ ```json
129
+ {"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "sin(x_1 + C*x_2)"}
130
+ ```
131
+
132
+ **Vantagens**:
133
+ - 80% valid expression rate
134
+ - Structured boundaries
135
+ - Lower loss (0.343 vs 0.415)
136
+
137
+ ## Wandb Tracking
138
+
139
+ Naming standard: `seriguela-{type}-{model}-{dataset}-{timestamp}`
140
+
141
+ Exemplos:
142
+ ```python
143
+ # Supervised
144
+ seriguela-supervised-medium-700k-20260204-120000
145
+
146
+ # PPO
147
+ seriguela-ppo-large-nguyen5-20260204-120000
148
+ ```
149
+
150
+ ## Deploy AWS
151
+
152
+ Scripts disponíveis em: `../../scripts/aws/`
153
+
154
+ Lançar treinamento:
155
+ ```bash
156
+ # Medium model
157
+ bash ../../scripts/aws/launch_medium_training.sh \
158
+ --wandb-key YOUR_KEY \
159
+ --hf-token YOUR_TOKEN
160
+
161
+ # Large model
162
+ bash ../../scripts/aws/launch_large_training.sh \
163
+ --wandb-key YOUR_KEY \
164
+ --hf-token YOUR_TOKEN
165
+ ```
166
+
167
+ ## Troubleshooting
168
+
169
+ ### OOM (Out of Memory)
170
+ - Reduzir `per_device_train_batch_size`
171
+ - Usar `gradient_accumulation_steps` maior
172
+ - Usar instância maior (g5.2xlarge)
173
+
174
+ ### Low Valid Rate
175
+ - Verificar formato de dados (deve ser JSON)
176
+ - Aumentar `num_train_epochs`
177
+ - Verificar conversão de dados
178
+
179
+ ### Early Stopping Prematuro
180
+ - Aumentar `early_stopping_patience`
181
+ - Verificar validation loss
182
+
183
+ ## Próximos Modelos (Planejados)
184
+
185
+ ### GPT-Neo (EleutherAI)
186
+ - 125M, 1.3B, 2.7B params
187
+ - Similar ao GPT-2
188
+ - Compatível com mesma pipeline
189
+
190
+ ### LLaMA 2/3 (Meta)
191
+ - 7B, 13B, 70B params
192
+ - Melhor performance
193
+ - Requer mais VRAM
194
+
195
+ ### Phi-2/3 (Microsoft)
196
+ - 2.7B params
197
+ - Otimizado para reasoning
198
+ - Bom para symbolic tasks
199
+
200
+ ## Referências
201
+
202
+ - LoRA: https://arxiv.org/abs/2106.09685
203
+ - PPO: https://arxiv.org/abs/1707.06347
204
+ - GRPO: DeepSeek-R1 technical report
205
+ - Dataset: https://huggingface.co/datasets/augustocsc/sintetico_natural
2_training/configs/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Seriguela - Configuration utilities.
3
+
4
+ This module provides standardized configuration utilities for
5
+ experiment tracking and naming conventions.
6
+ """
7
+
8
+ from .wandb_config import (
9
+ generate_run_name,
10
+ get_wandb_project_name,
11
+ parse_run_name,
12
+ is_valid_run_name,
13
+ )
14
+
15
+ __all__ = [
16
+ 'generate_run_name',
17
+ 'get_wandb_project_name',
18
+ 'parse_run_name',
19
+ 'is_valid_run_name',
20
+ ]
21
+
22
+ __version__ = '1.0.0'
2_training/configs/eval_dataset_download.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_easy_dummy
2
+ git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_medium_dummy
3
+ git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_hard_dummy
4
+ git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_easy
5
+ git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_medium
6
+ git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_hard
2_training/configs/model_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
2_training/configs/peft_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
2_training/configs/training.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
2
+ --dataset_repo_id augustocsc/sintetico_natural \
3
+ --data_dir 500k \
4
+ --output_dir ./output \
5
+ --push_to_hub \
6
+ --hub_model_id augustocsc/Se124M500KInfPrompt_EOS \
7
+ --source_data_column i_prompt \
8
+ --report_to wandb \
9
+ --run_name Se124M500KInfPrompt_EOS \
10
+ --model_name_or_path gpt2 \
11
+ --bf16 \
12
+ --eval_strategy steps \
13
+ --num_train_epochs 3 \
14
+ --per_device_train_batch_size 16 \
15
+ --per_device_eval_batch_size 16 \
16
+ --gradient_accumulation_steps 4 \
17
+ --dataloader_num_workers 8 \
18
+ --learning_rate 5e-5 \
19
+ --warmup_ratio 0.03 \
20
+ --weight_decay 0.01 \
21
+ --max_grad_norm 1.0 \
22
+ --lr_scheduler_type cosine \
23
+ --optim adamw_torch_fused \
24
+ --logging_steps 20 \
25
+ --eval_steps 500 \
26
+ --save_steps 1000 \
27
+ --save_total_limit 3 \
28
+
29
+
30
+ # CUDA_VISIBLE_DEVICES=1 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
31
+ # --dataset_repo_id augustocsc/sintetico_final \
32
+ # --data_dir 100k \
33
+ # --output_dir ./output \
34
+ # --push_to_hub \
35
+ # --hub_model_id augustocsc/Se124M100KInfPrompt_NT \
36
+ # --source_data_column i_prompt \
37
+ # --report_to wandb \
38
+ # --run_name Se124M100KInfPrompt_NT \
39
+ # --bf16 \
40
+ # --eval_strategy steps \
41
+ # --num_train_epochs 3 \
42
+ # --per_device_train_batch_size 16 \
43
+ # --per_device_eval_batch_size 16 \
44
+ # --gradient_accumulation_steps 2 \
45
+ # --dataloader_num_workers 8 \
46
+ # --learning_rate 2e-5 \
47
+ # --warmup_ratio 0.03 \
48
+ # --weight_decay 0.01 \
49
+ # --max_grad_norm 1.0 \
50
+ # --lr_scheduler_type cosine \
51
+ # --optim adamw_torch_fused \
52
+ # --logging_steps 20 \
53
+ # --eval_steps 500 \
54
+ # --save_steps 1000 \
55
+ # --save_total_limit 3
56
+
57
+ # CUDA_VISIBLE_DEVICES=0 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
58
+ # --dataset_repo_id augustocsc/sintetico_final \
59
+ # --data_dir 100k \
60
+ # --output_dir ./output \
61
+ # --push_to_hub \
62
+ # --hub_model_id augustocsc/Se124M100KInfPrompt_WT \
63
+ # --source_data_column i_prompt \
64
+ # --report_to wandb \
65
+ # --run_name Se124M100KInfPrompt_WT \
66
+ # --bf16 \
67
+ # --eval_strategy steps \
68
+ # --num_train_epochs 3 \
69
+ # --per_device_train_batch_size 16 \
70
+ # --per_device_eval_batch_size 16 \
71
+ # --gradient_accumulation_steps 2 \
72
+ # --dataloader_num_workers 8 \
73
+ # --learning_rate 2e-5 \
74
+ # --warmup_ratio 0.03 \
75
+ # --weight_decay 0.01 \
76
+ # --max_grad_norm 1.0 \
77
+ # --lr_scheduler_type cosine \
78
+ # --optim adamw_torch_fused \
79
+ # --logging_steps 20 \
80
+ # --eval_steps 500 \
81
+ # --save_steps 1000 \
82
+ # --save_total_limit 3
2_training/configs/training_args.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "output_dir": "./output",
3
+ "overwrite_output_dir": true,
4
+ "num_train_epochs": 50,
5
+ "per_device_train_batch_size": 8,
6
+ "gradient_accumulation_steps": 1,
7
+ "learning_rate": 5e-5,
8
+ "weight_decay": 0.01,
9
+ "warmup_steps": 0,
10
+ "fp16": true,
11
+ "seed": 42,
12
+ "per_device_eval_batch_size": 8,
13
+ "eval_strategy": "epoch",
14
+ "metric_for_best_model": "eval_loss",
15
+ "greater_is_better": false,
16
+ "eval_steps": null,
17
+ "load_best_model_at_end": true,
18
+ "save_strategy": "epoch",
19
+ "save_steps": null,
20
+ "save_total_limit": 2,
21
+ "logging_dir": "./output/logs",
22
+ "logging_steps": 100,
23
+ "report_to": "wandb",
24
+ "run_name": "Se124M100K",
25
+ "push_to_hub": true,
26
+ "hub_model_id": "augustocsc/Se124M100K",
27
+ "hub_token": null
28
+
29
+ }
2_training/configs/training_large.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "model_name_or_path": "gpt2-large",
4
+ "model_size": "774M",
5
+ "description": "GPT-2 Large - 774M parameters"
6
+ },
7
+ "training_args": {
8
+ "num_train_epochs": 2,
9
+ "per_device_train_batch_size": 4,
10
+ "per_device_eval_batch_size": 4,
11
+ "gradient_accumulation_steps": 16,
12
+ "effective_batch_size": 64,
13
+ "learning_rate": 2e-5,
14
+ "weight_decay": 0.01,
15
+ "warmup_steps": 100,
16
+ "max_grad_norm": 1.0,
17
+ "lr_scheduler_type": "cosine",
18
+ "fp16": true,
19
+ "seed": 42,
20
+ "block_size": 128
21
+ },
22
+ "evaluation_args": {
23
+ "eval_strategy": "epoch",
24
+ "eval_steps": null,
25
+ "metric_for_best_model": "eval_loss",
26
+ "greater_is_better": false,
27
+ "load_best_model_at_end": true
28
+ },
29
+ "save_args": {
30
+ "save_strategy": "epoch",
31
+ "save_steps": null,
32
+ "save_total_limit": 2
33
+ },
34
+ "logging_args": {
35
+ "logging_dir": "./output/logs",
36
+ "logging_steps": 50,
37
+ "report_to": "wandb"
38
+ },
39
+ "lora_config": {
40
+ "r": 8,
41
+ "lora_alpha": 32,
42
+ "target_modules": ["c_attn", "c_proj"],
43
+ "lora_dropout": 0.05,
44
+ "bias": "none",
45
+ "task_type": "CAUSAL_LM"
46
+ },
47
+ "dataset_config": {
48
+ "dataset_repo_id": "augustocsc/sintetico_natural",
49
+ "data_dir": "700K",
50
+ "data_columns": {
51
+ "infix": "i_prompt_n",
52
+ "prefix": "p_prompt_n"
53
+ }
54
+ },
55
+ "hub_config": {
56
+ "push_to_hub": true,
57
+ "hub_model_id_template": "augustocsc/Se774M_700K_{format}",
58
+ "formats": ["infix", "prefix"]
59
+ },
60
+ "estimated_time": {
61
+ "per_epoch_minutes": 180,
62
+ "total_hours": 6,
63
+ "notes": "Estimated for AWS g5.xlarge with A10G GPU. May need gradient checkpointing for memory optimization."
64
+ }
65
+ }
2_training/configs/training_medium.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "model_name_or_path": "gpt2-medium",
4
+ "model_size": "355M",
5
+ "description": "GPT-2 Medium - 355M parameters"
6
+ },
7
+ "training_args": {
8
+ "num_train_epochs": 2,
9
+ "per_device_train_batch_size": 8,
10
+ "per_device_eval_batch_size": 8,
11
+ "gradient_accumulation_steps": 8,
12
+ "effective_batch_size": 64,
13
+ "learning_rate": 3e-5,
14
+ "weight_decay": 0.01,
15
+ "warmup_steps": 100,
16
+ "max_grad_norm": 1.0,
17
+ "lr_scheduler_type": "cosine",
18
+ "fp16": true,
19
+ "seed": 42,
20
+ "block_size": 128
21
+ },
22
+ "evaluation_args": {
23
+ "eval_strategy": "epoch",
24
+ "eval_steps": null,
25
+ "metric_for_best_model": "eval_loss",
26
+ "greater_is_better": false,
27
+ "load_best_model_at_end": true
28
+ },
29
+ "save_args": {
30
+ "save_strategy": "epoch",
31
+ "save_steps": null,
32
+ "save_total_limit": 2
33
+ },
34
+ "logging_args": {
35
+ "logging_dir": "./output/logs",
36
+ "logging_steps": 50,
37
+ "report_to": "wandb"
38
+ },
39
+ "lora_config": {
40
+ "r": 8,
41
+ "lora_alpha": 32,
42
+ "target_modules": ["c_attn", "c_proj"],
43
+ "lora_dropout": 0.05,
44
+ "bias": "none",
45
+ "task_type": "CAUSAL_LM"
46
+ },
47
+ "dataset_config": {
48
+ "dataset_repo_id": "augustocsc/sintetico_natural",
49
+ "data_dir": "700K",
50
+ "data_columns": {
51
+ "infix": "i_prompt_n",
52
+ "prefix": "p_prompt_n"
53
+ }
54
+ },
55
+ "hub_config": {
56
+ "push_to_hub": true,
57
+ "hub_model_id_template": "augustocsc/Se355M_700K_{format}",
58
+ "formats": ["infix", "prefix"]
59
+ },
60
+ "estimated_time": {
61
+ "per_epoch_minutes": 90,
62
+ "total_hours": 3,
63
+ "notes": "Estimated for AWS g5.xlarge with A10G GPU"
64
+ }
65
+ }
2_training/configs/training_small.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "model_name_or_path": "gpt2",
4
+ "model_size": "124M",
5
+ "description": "GPT-2 Small - 124M parameters"
6
+ },
7
+ "training_args": {
8
+ "num_train_epochs": 3,
9
+ "per_device_train_batch_size": 16,
10
+ "per_device_eval_batch_size": 16,
11
+ "gradient_accumulation_steps": 4,
12
+ "effective_batch_size": 64,
13
+ "learning_rate": 5e-5,
14
+ "weight_decay": 0.01,
15
+ "warmup_steps": 100,
16
+ "max_grad_norm": 1.0,
17
+ "lr_scheduler_type": "cosine",
18
+ "fp16": true,
19
+ "seed": 42,
20
+ "block_size": 128
21
+ },
22
+ "evaluation_args": {
23
+ "eval_strategy": "epoch",
24
+ "eval_steps": null,
25
+ "metric_for_best_model": "eval_loss",
26
+ "greater_is_better": false,
27
+ "load_best_model_at_end": true
28
+ },
29
+ "save_args": {
30
+ "save_strategy": "epoch",
31
+ "save_steps": null,
32
+ "save_total_limit": 2
33
+ },
34
+ "logging_args": {
35
+ "logging_dir": "./output/logs",
36
+ "logging_steps": 50,
37
+ "report_to": "wandb"
38
+ },
39
+ "lora_config": {
40
+ "r": 8,
41
+ "lora_alpha": 32,
42
+ "target_modules": ["c_attn", "c_proj"],
43
+ "lora_dropout": 0.05,
44
+ "bias": "none",
45
+ "task_type": "CAUSAL_LM"
46
+ },
47
+ "dataset_config": {
48
+ "dataset_repo_id": "augustocsc/sintetico_natural",
49
+ "data_dir": "700K",
50
+ "data_columns": {
51
+ "infix": "i_prompt_n",
52
+ "prefix": "p_prompt_n"
53
+ }
54
+ },
55
+ "hub_config": {
56
+ "push_to_hub": true,
57
+ "hub_model_id_template": "augustocsc/Se124M_700K_{format}",
58
+ "formats": ["infix", "prefix"]
59
+ },
60
+ "estimated_time": {
61
+ "per_epoch_minutes": 40,
62
+ "total_hours": 2,
63
+ "notes": "Estimated for AWS g5.xlarge with A10G GPU"
64
+ }
65
+ }
2_training/configs/training_v3.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "model_name_or_path": "gpt2",
4
+ "model_size": "124M",
5
+ "description": "GPT-2 Small (124M) - v3 with proper end markers"
6
+ },
7
+ "training_args": {
8
+ "num_train_epochs": 3,
9
+ "per_device_train_batch_size": 8,
10
+ "per_device_eval_batch_size": 8,
11
+ "gradient_accumulation_steps": 4,
12
+ "effective_batch_size": 32,
13
+ "learning_rate": 5e-5,
14
+ "weight_decay": 0.01,
15
+ "warmup_steps": 100,
16
+ "max_grad_norm": 1.0,
17
+ "lr_scheduler_type": "cosine",
18
+ "fp16": true,
19
+ "seed": 42,
20
+ "block_size": 128
21
+ },
22
+ "evaluation_args": {
23
+ "eval_strategy": "epoch",
24
+ "eval_steps": null,
25
+ "metric_for_best_model": "eval_loss",
26
+ "greater_is_better": false,
27
+ "load_best_model_at_end": true
28
+ },
29
+ "save_args": {
30
+ "save_strategy": "epoch",
31
+ "save_steps": null,
32
+ "save_total_limit": 2
33
+ },
34
+ "logging_args": {
35
+ "logging_dir": "./output/logs",
36
+ "logging_steps": 50,
37
+ "report_to": "wandb"
38
+ },
39
+ "lora_config": {
40
+ "r": 8,
41
+ "lora_alpha": 32,
42
+ "target_modules": ["c_attn"],
43
+ "lora_dropout": 0.05,
44
+ "bias": "none",
45
+ "task_type": "CAUSAL_LM"
46
+ },
47
+ "dataset_config": {
48
+ "use_local_csvs": true,
49
+ "train_file": "./data/processed/700K_fixed/train_700K.csv",
50
+ "validation_file": "./data/processed/700K_fixed/validation_700K.csv",
51
+ "test_file": "./data/processed/700K_fixed/test_700K.csv",
52
+ "data_column": "text"
53
+ },
54
+ "hub_config": {
55
+ "push_to_hub": true,
56
+ "hub_model_id": "augustocsc/Se124M_700K_infix_v3"
57
+ },
58
+ "special_tokens": {
59
+ "start_token": "<|startofex|>",
60
+ "end_token": "<|endofex|>",
61
+ "notes": "End token configured as EOS token for proper stopping"
62
+ },
63
+ "estimated_time": {
64
+ "per_epoch_minutes": 45,
65
+ "total_hours": 2.25,
66
+ "notes": "Estimated for AWS g5.xlarge with A10G GPU, GPT-2 Small, 3 epochs"
67
+ },
68
+ "version_info": {
69
+ "model_version": "v3",
70
+ "improvements": [
71
+ "Training data includes proper <|endofex|> markers",
72
+ "100% validation rate on prepared dataset",
73
+ "Addresses v1 non-stopping issue and v2 garbage generation",
74
+ "Uses local CSVs with validated end markers"
75
+ ],
76
+ "training_date": "2026-02-01"
77
+ }
78
+ }
2_training/configs/wandb_config.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wandb Configuration and Naming Standards for Seriguela Project
3
+
4
+ This module provides standardized naming conventions for Wandb experiment tracking.
5
+ """
6
+
7
+ import os
8
+ from datetime import datetime
9
+ from typing import Optional
10
+
11
+
12
+ # Default Wandb project name
13
+ DEFAULT_PROJECT = "seriguela"
14
+
15
+ # Alternative project name for experiments
16
+ EXPERIMENTS_PROJECT = "seriguela-experiments"
17
+
18
+
19
+ def get_wandb_project_name(use_experiments: bool = False) -> str:
20
+ """
21
+ Get the standard Wandb project name.
22
+
23
+ Args:
24
+ use_experiments: If True, use experiments project name
25
+
26
+ Returns:
27
+ Project name string
28
+ """
29
+ return EXPERIMENTS_PROJECT if use_experiments else DEFAULT_PROJECT
30
+
31
+
32
+ def generate_run_name(
33
+ experiment_type: str,
34
+ model_size: str = "base",
35
+ dataset: Optional[str] = None,
36
+ extra_info: Optional[str] = None,
37
+ include_timestamp: bool = True
38
+ ) -> str:
39
+ """
40
+ Generate a standardized Wandb run name.
41
+
42
+ Naming Convention: seriguela-{type}-{model}-{dataset}-{extra}-{timestamp}
43
+
44
+ Args:
45
+ experiment_type: Type of experiment (supervised, ppo, grpo, reinforce, iterative-sft)
46
+ model_size: Model size (base, medium, large) or full name (gpt2, gpt2-medium)
47
+ dataset: Dataset identifier (700K, nguyen5, nguyen7, etc)
48
+ extra_info: Additional information (optional)
49
+ include_timestamp: Whether to include timestamp suffix
50
+
51
+ Returns:
52
+ Formatted run name
53
+
54
+ Examples:
55
+ >>> generate_run_name("supervised", "medium", "700K")
56
+ 'seriguela-supervised-medium-700K-20260203-143022'
57
+
58
+ >>> generate_run_name("ppo", "base", "nguyen5", "lr3e5")
59
+ 'seriguela-ppo-base-nguyen5-lr3e5-20260203-143022'
60
+
61
+ >>> generate_run_name("grpo", "large", "nguyen7", include_timestamp=False)
62
+ 'seriguela-grpo-large-nguyen7'
63
+ """
64
+ # Normalize model size
65
+ model_map = {
66
+ "gpt2": "base",
67
+ "gpt2-base": "base",
68
+ "124m": "base",
69
+ "gpt2-medium": "medium",
70
+ "355m": "medium",
71
+ "gpt2-large": "large",
72
+ "774m": "large"
73
+ }
74
+ model_size = model_map.get(model_size.lower(), model_size.lower())
75
+
76
+ # Build run name parts
77
+ parts = ["seriguela", experiment_type.lower()]
78
+
79
+ # Add model size
80
+ parts.append(model_size)
81
+
82
+ # Add dataset if provided
83
+ if dataset:
84
+ parts.append(dataset.lower().replace("_", "").replace("-", ""))
85
+
86
+ # Add extra info if provided
87
+ if extra_info:
88
+ parts.append(extra_info.lower().replace("_", ""))
89
+
90
+ # Add timestamp if requested
91
+ if include_timestamp:
92
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
93
+ parts.append(timestamp)
94
+
95
+ return "-".join(parts)
96
+
97
+
98
+ def get_run_tags(
99
+ experiment_type: str,
100
+ model_size: str,
101
+ dataset: Optional[str] = None,
102
+ success: Optional[bool] = None
103
+ ) -> list:
104
+ """
105
+ Generate standardized tags for Wandb runs.
106
+
107
+ Args:
108
+ experiment_type: Type of experiment
109
+ model_size: Model size
110
+ dataset: Dataset name
111
+ success: Whether experiment was successful (optional)
112
+
113
+ Returns:
114
+ List of tags
115
+
116
+ Examples:
117
+ >>> get_run_tags("ppo", "medium", "nguyen5", True)
118
+ ['ppo', 'gpt2-medium', 'nguyen5', 'rl', 'success']
119
+ """
120
+ tags = [experiment_type.lower()]
121
+
122
+ # Add model size
123
+ if model_size.lower() in ["base", "124m", "gpt2"]:
124
+ tags.append("gpt2-base")
125
+ elif model_size.lower() in ["medium", "355m", "gpt2-medium"]:
126
+ tags.append("gpt2-medium")
127
+ elif model_size.lower() in ["large", "774m", "gpt2-large"]:
128
+ tags.append("gpt2-large")
129
+ else:
130
+ tags.append(model_size.lower())
131
+
132
+ # Add dataset
133
+ if dataset:
134
+ tags.append(dataset.lower())
135
+
136
+ # Add category based on experiment type
137
+ if experiment_type.lower() in ["ppo", "grpo", "reinforce"]:
138
+ tags.append("rl")
139
+ elif experiment_type.lower() in ["supervised", "sft"]:
140
+ tags.append("supervised")
141
+ elif experiment_type.lower() == "iterative-sft":
142
+ tags.append("iterative")
143
+
144
+ # Add success tag if provided
145
+ if success is not None:
146
+ tags.append("success" if success else "failed")
147
+
148
+ return tags
149
+
150
+
151
+ # Common experiment types
152
+ EXPERIMENT_TYPES = {
153
+ "SUPERVISED": "supervised",
154
+ "SFT": "sft",
155
+ "PPO": "ppo",
156
+ "GRPO": "grpo",
157
+ "REINFORCE": "reinforce",
158
+ "ITERATIVE_SFT": "iterative-sft",
159
+ "BEST_OF_N": "best-of-n",
160
+ "EVALUATION": "eval"
161
+ }
162
+
163
+ # Common datasets
164
+ DATASETS = {
165
+ "MAIN_700K": "700K",
166
+ "NGUYEN_1": "nguyen1",
167
+ "NGUYEN_5": "nguyen5",
168
+ "NGUYEN_7": "nguyen7",
169
+ "NGUYEN_10": "nguyen10",
170
+ "CUSTOM": "custom"
171
+ }
172
+
173
+
174
+ def setup_wandb_env():
175
+ """
176
+ Setup Wandb environment from credentials file.
177
+ Reads from ~/.tokens.txt if available.
178
+ """
179
+ tokens_file = os.path.expanduser("~/.tokens.txt")
180
+ if os.path.exists(tokens_file):
181
+ with open(tokens_file) as f:
182
+ for line in f:
183
+ if "=" in line and not line.strip().startswith("#"):
184
+ key, value = line.strip().split("=", 1)
185
+ key = key.strip()
186
+ value = value.strip()
187
+ if key.lower() == "wandb":
188
+ os.environ["WANDB_API_KEY"] = value
189
+ print(f"[OK] Wandb API key loaded from {tokens_file}")
190
+ return True
191
+
192
+ # Check if already in environment
193
+ if "WANDB_API_KEY" in os.environ:
194
+ print("[OK] Wandb API key found in environment")
195
+ return True
196
+
197
+ print("[WARN] Wandb API key not found. Run 'wandb login' or add to ~/.tokens.txt")
198
+ return False
199
+
200
+
201
+ if __name__ == "__main__":
202
+ # Example usage
203
+ print("Wandb Configuration Examples:\n")
204
+
205
+ print("1. Supervised training on 700K dataset:")
206
+ print(f" {generate_run_name('supervised', 'medium', '700K')}\n")
207
+
208
+ print("2. PPO on Nguyen-5 benchmark:")
209
+ print(f" {generate_run_name('ppo', 'base', 'nguyen5')}\n")
210
+
211
+ print("3. GRPO with custom learning rate:")
212
+ print(f" {generate_run_name('grpo', 'large', 'nguyen7', 'lr5e5')}\n")
213
+
214
+ print("4. Evaluation run (no timestamp):")
215
+ print(f" {generate_run_name('eval', 'medium', 'nguyen5', include_timestamp=False)}\n")
216
+
217
+ print("5. Tags example:")
218
+ print(f" {get_run_tags('ppo', 'medium', 'nguyen5', True)}\n")
219
+
220
+ print("6. Setup Wandb environment:")
221
+ setup_wandb_env()
2_training/reinforcement/best_of_n_experiment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Best-of-N Sampling Experiment for Symbolic Regression
4
+
5
+ Instead of PPO (which has API compatibility issues with TRL 0.16+),
6
+ this script tests if the base model can find correct expressions
7
+ through random sampling. If the model generates the correct expression
8
+ even occasionally, PPO should be able to learn to find it consistently.
9
+
10
+ This is a diagnostic experiment to understand model capabilities.
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import json
16
+ import argparse
17
+ import logging
18
+ import datetime
19
+ from pathlib import Path
20
+ from collections import defaultdict
21
+
22
+ import numpy as np
23
+ import torch
24
+ from tqdm import tqdm
25
+
26
+ # Add project root to path
27
+ PROJECT_ROOT = Path(__file__).parent.parent
28
+ sys.path.insert(0, str(PROJECT_ROOT))
29
+ sys.path.insert(0, str(PROJECT_ROOT / "classes"))
30
+
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM
32
+ from peft import PeftModel
33
+
34
+ from expression import Expression
35
+ from dataset import RegressionDataset
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(levelname)s - %(message)s'
41
+ )
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class BestOfNSampler:
46
+ """Generate N expressions and find the best one for a given dataset."""
47
+
48
+ def __init__(self, model_path: str, device: str = None):
49
+ self.model_path = model_path
50
+
51
+ # Device setup
52
+ if device:
53
+ self.device = torch.device(device)
54
+ else:
55
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ logger.info(f"Using device: {self.device}")
57
+
58
+ self._load_model()
59
+
60
+ def _load_model(self):
61
+ """Load the JSON format model with LoRA adapters."""
62
+ logger.info(f"Loading model from {self.model_path}")
63
+
64
+ # Load tokenizer from trained model (has special tokens)
65
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
66
+ self.tokenizer.pad_token = self.tokenizer.eos_token
67
+ logger.info(f"Tokenizer loaded with vocab size: {len(self.tokenizer)}")
68
+
69
+ # Load base GPT-2
70
+ base_model = AutoModelForCausalLM.from_pretrained(
71
+ "gpt2",
72
+ torch_dtype=torch.float16,
73
+ )
74
+
75
+ # Resize embeddings to match tokenizer (handles special tokens)
76
+ if len(self.tokenizer) != base_model.config.vocab_size:
77
+ logger.info(f"Resizing embeddings: {base_model.config.vocab_size} -> {len(self.tokenizer)}")
78
+ base_model.resize_token_embeddings(len(self.tokenizer))
79
+
80
+ # Load LoRA adapter
81
+ try:
82
+ model_with_lora = PeftModel.from_pretrained(base_model, self.model_path)
83
+ self.model = model_with_lora.merge_and_unload()
84
+ logger.info("LoRA adapter loaded and merged")
85
+ except Exception as e:
86
+ logger.warning(f"Could not load as PEFT model: {e}")
87
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_path)
88
+
89
+ self.model = self.model.to(self.device)
90
+ self.model.eval()
91
+ logger.info("Model loaded successfully")
92
+
93
+ def build_prompt(self, n_vars: int) -> str:
94
+ """Build JSON format prompt matching training data."""
95
+ vars_list = [f"x_{i+1}" for i in range(n_vars)]
96
+ ops_list = ["+", "-", "*", "sin", "cos"]
97
+
98
+ prompt = json.dumps({
99
+ "vars": vars_list,
100
+ "ops": ops_list,
101
+ "cons": None,
102
+ "expr": ""
103
+ })[:-3] # Remove trailing '"}'
104
+
105
+ return prompt
106
+
107
+ def extract_expression(self, generated_text: str) -> str:
108
+ """Extract expression from JSON format output.
109
+
110
+ Handles two formats:
111
+ 1. Standard JSON: "expr": "value"}
112
+ 2. Model output: "expr": value"} (no quotes around value)
113
+ """
114
+ try:
115
+ # Case 1: Standard JSON with quotes around expression value
116
+ if '"expr": "' in generated_text:
117
+ expr_start = generated_text.index('"expr": "') + len('"expr": "')
118
+ remaining = generated_text[expr_start:]
119
+ # Find closing "}
120
+ if '"}' in remaining:
121
+ return remaining[:remaining.index('"}')].strip()
122
+ # Fallback: find first quote
123
+ if '"' in remaining:
124
+ return remaining[:remaining.index('"')].strip()
125
+ return remaining.strip()
126
+
127
+ # Case 2: Model output WITHOUT quotes: "expr": value"}
128
+ # This is what the model actually generates
129
+ if '"expr": ' in generated_text:
130
+ expr_start = generated_text.index('"expr": ') + len('"expr": ')
131
+ remaining = generated_text[expr_start:]
132
+ # Find closing "} which ends the JSON object
133
+ if '"}' in remaining:
134
+ return remaining[:remaining.index('"}')].strip()
135
+ # Fallback: find "{ which starts next object
136
+ if '"{' in remaining:
137
+ return remaining[:remaining.index('"{')].strip().rstrip('}')
138
+ return remaining.strip()
139
+
140
+ # Case 3: Compact JSON without space
141
+ if '"expr":"' in generated_text:
142
+ expr_start = generated_text.index('"expr":"') + len('"expr":"')
143
+ remaining = generated_text[expr_start:]
144
+ if '"}' in remaining:
145
+ return remaining[:remaining.index('"}')].strip()
146
+ if '"' in remaining:
147
+ return remaining[:remaining.index('"')].strip()
148
+ return remaining.strip()
149
+
150
+ except (ValueError, IndexError):
151
+ pass
152
+
153
+ # Last resort: split on "expr" and clean up
154
+ fallback = generated_text.split('"expr"')[-1].strip(' ":}')
155
+ if '"}' in fallback:
156
+ fallback = fallback[:fallback.index('"}')]
157
+ return fallback.strip()
158
+
159
+ def compute_r2(self, expression_str: str, X: np.ndarray, y: np.ndarray) -> float:
160
+ """Compute R² score for an expression."""
161
+ if not expression_str or expression_str.isspace():
162
+ return -np.inf
163
+
164
+ # Replace constant placeholder C with 1
165
+ if 'C' in expression_str:
166
+ expression_str = expression_str.replace('C', '1')
167
+
168
+ try:
169
+ expr = Expression(expression_str, is_prefix=False)
170
+
171
+ if not expr.is_valid_on_dataset(X):
172
+ return -np.inf
173
+
174
+ y_pred = expr.evaluate(X)
175
+
176
+ if not np.all(np.isfinite(y_pred)):
177
+ return -np.inf
178
+
179
+ ss_res = np.sum((y - y_pred) ** 2)
180
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
181
+
182
+ if ss_tot == 0:
183
+ return 0.0
184
+
185
+ return 1 - (ss_res / ss_tot)
186
+ except Exception:
187
+ return -np.inf
188
+
189
+ def sample_expressions(self, n_vars: int, n_samples: int = 100,
190
+ temperature: float = 0.7) -> list:
191
+ """Generate N expression samples."""
192
+ prompt = self.build_prompt(n_vars)
193
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
194
+
195
+ expressions = []
196
+
197
+ debug_count = 0
198
+ for _ in tqdm(range(n_samples), desc="Sampling expressions"):
199
+ with torch.no_grad():
200
+ output = self.model.generate(
201
+ **inputs,
202
+ max_new_tokens=50,
203
+ do_sample=True,
204
+ top_k=50,
205
+ top_p=0.9,
206
+ temperature=temperature,
207
+ pad_token_id=self.tokenizer.pad_token_id,
208
+ )
209
+
210
+ text = self.tokenizer.decode(output[0], skip_special_tokens=True)
211
+ expr_str = self.extract_expression(text)
212
+
213
+ # Debug: print first 5 extractions
214
+ if debug_count < 5:
215
+ logger.info(f"DEBUG [{debug_count}] raw text (last 80 chars): ...{text[-80:]}")
216
+ logger.info(f"DEBUG [{debug_count}] extracted: '{expr_str}'")
217
+ debug_count += 1
218
+
219
+ expressions.append(expr_str)
220
+
221
+ return expressions
222
+
223
+ def find_best_expression(self, X: np.ndarray, y: np.ndarray,
224
+ n_samples: int = 500, temperature: float = 0.7):
225
+ """Sample N expressions and find the best one for the dataset."""
226
+ n_vars = X.shape[1]
227
+
228
+ logger.info(f"Sampling {n_samples} expressions for {n_vars}-variable dataset...")
229
+ expressions = self.sample_expressions(n_vars, n_samples, temperature)
230
+
231
+ # Compute R² for each
232
+ results = []
233
+ unique_expressions = set()
234
+
235
+ for expr_str in tqdm(expressions, desc="Computing R² scores"):
236
+ if expr_str in unique_expressions:
237
+ continue
238
+ unique_expressions.add(expr_str)
239
+
240
+ r2 = self.compute_r2(expr_str, X, y)
241
+ results.append({
242
+ "expression": expr_str,
243
+ "r2": float(r2) if np.isfinite(r2) else None,
244
+ "is_valid": bool(np.isfinite(r2) and r2 > -1),
245
+ })
246
+
247
+ # Sort by R²
248
+ results.sort(key=lambda x: x["r2"] if x["r2"] is not None else -np.inf, reverse=True)
249
+
250
+ # Statistics
251
+ valid_count = sum(1 for r in results if r["is_valid"])
252
+ valid_r2s = [r["r2"] for r in results if r["r2"] is not None and r["r2"] > -1]
253
+
254
+ return {
255
+ "n_samples": n_samples,
256
+ "unique_expressions": len(unique_expressions),
257
+ "valid_count": valid_count,
258
+ "valid_rate": valid_count / len(unique_expressions) if unique_expressions else 0,
259
+ "best_r2": results[0]["r2"] if results and results[0]["r2"] else None,
260
+ "best_expression": results[0]["expression"] if results else None,
261
+ "mean_r2": float(np.mean(valid_r2s)) if valid_r2s else None,
262
+ "median_r2": float(np.median(valid_r2s)) if valid_r2s else None,
263
+ "top_10": results[:10],
264
+ }
265
+
266
+
267
+ def run_experiment(model_path: str, datasets_dir: str, n_samples: int = 500,
268
+ output_dir: str = "./output/best_of_n"):
269
+ """Run Best-of-N experiment on multiple datasets."""
270
+
271
+ output_dir = Path(output_dir)
272
+ output_dir.mkdir(parents=True, exist_ok=True)
273
+
274
+ # Test datasets
275
+ test_datasets = {
276
+ "add_x1_x2": {"formula": "x_1 + x_2", "difficulty": "easy"},
277
+ "mul_x1_x2": {"formula": "x_1 * x_2", "difficulty": "easy"},
278
+ "sub_x1_x2": {"formula": "x_1 - x_2", "difficulty": "easy"},
279
+ "sin_x1": {"formula": "sin(x_1)", "difficulty": "medium"},
280
+ "cos_x1": {"formula": "cos(x_1)", "difficulty": "medium"},
281
+ "square_x1": {"formula": "x_1 * x_1", "difficulty": "medium"},
282
+ "sin_x1_plus_x2": {"formula": "sin(x_1) + x_2", "difficulty": "hard"},
283
+ "x1_mul_sin_x2": {"formula": "x_1 * sin(x_2)", "difficulty": "hard"},
284
+ }
285
+
286
+ # Initialize sampler
287
+ sampler = BestOfNSampler(model_path)
288
+
289
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
290
+
291
+ results = {
292
+ "timestamp": timestamp,
293
+ "model_path": model_path,
294
+ "n_samples": n_samples,
295
+ "datasets": {},
296
+ }
297
+
298
+ print("\n" + "=" * 70)
299
+ print("BEST-OF-N SAMPLING EXPERIMENT")
300
+ print("=" * 70)
301
+ print(f"Model: {model_path}")
302
+ print(f"Samples per dataset: {n_samples}")
303
+ print("=" * 70)
304
+
305
+ for dataset_name, info in test_datasets.items():
306
+ dataset_path = Path(datasets_dir) / f"{dataset_name}.csv"
307
+
308
+ if not dataset_path.exists():
309
+ logger.warning(f"Dataset not found: {dataset_path}")
310
+ continue
311
+
312
+ print(f"\n{'='*70}")
313
+ print(f"Dataset: {dataset_name}")
314
+ print(f"Ground truth: {info['formula']}")
315
+ print(f"Difficulty: {info['difficulty']}")
316
+ print(f"{'='*70}")
317
+
318
+ # Load dataset
319
+ reg = RegressionDataset(str(dataset_path.parent), dataset_path.name)
320
+ X, y = reg.get_numpy()
321
+
322
+ # Run Best-of-N
323
+ result = sampler.find_best_expression(X, y, n_samples)
324
+ result["ground_truth"] = info["formula"]
325
+ result["difficulty"] = info["difficulty"]
326
+
327
+ results["datasets"][dataset_name] = result
328
+
329
+ # Print results
330
+ print(f"\nResults:")
331
+ print(f" Valid expressions: {result['valid_count']}/{result['unique_expressions']} ({result['valid_rate']:.1%})")
332
+ print(f" Best R²: {result['best_r2']:.4f}" if result['best_r2'] else " Best R²: N/A")
333
+ print(f" Best expression: {result['best_expression']}")
334
+
335
+ if result['best_r2'] and result['best_r2'] > 0.99:
336
+ print(f" ✅ FOUND NEAR-PERFECT MATCH!")
337
+ elif result['best_r2'] and result['best_r2'] > 0.9:
338
+ print(f" ⚠️ Found good match (R² > 0.9)")
339
+ else:
340
+ print(f" ❌ No good match found")
341
+
342
+ print("\n Top 5 expressions:")
343
+ for i, expr in enumerate(result['top_10'][:5]):
344
+ r2_str = f"{expr['r2']:.4f}" if expr['r2'] else "N/A"
345
+ print(f" {i+1}. {expr['expression'][:40]:<40} R²={r2_str}")
346
+
347
+ # Save results
348
+ results_file = output_dir / f"best_of_n_results_{timestamp}.json"
349
+ with open(results_file, 'w') as f:
350
+ json.dump(results, f, indent=2)
351
+
352
+ print("\n" + "=" * 70)
353
+ print("SUMMARY")
354
+ print("=" * 70)
355
+
356
+ # Summary table
357
+ print(f"\n{'Dataset':<25} {'Difficulty':<10} {'Best R²':<10} {'Found?':<10}")
358
+ print("-" * 60)
359
+
360
+ success_count = 0
361
+ for name, res in results["datasets"].items():
362
+ r2 = res["best_r2"]
363
+ r2_str = f"{r2:.4f}" if r2 else "N/A"
364
+ found = "✅" if r2 and r2 > 0.99 else ("⚠️" if r2 and r2 > 0.9 else "❌")
365
+ if r2 and r2 > 0.99:
366
+ success_count += 1
367
+ print(f"{name:<25} {res['difficulty']:<10} {r2_str:<10} {found:<10}")
368
+
369
+ print("-" * 60)
370
+ print(f"Success rate (R² > 0.99): {success_count}/{len(results['datasets'])}")
371
+ print(f"\nResults saved to: {results_file}")
372
+
373
+ return results
374
+
375
+
376
+ def main():
377
+ parser = argparse.ArgumentParser(description="Best-of-N Sampling Experiment")
378
+ parser.add_argument("--model_path", type=str, default="./output/exp_a_json",
379
+ help="Path to trained model")
380
+ parser.add_argument("--datasets_dir", type=str, default="./data/ppo_test",
381
+ help="Directory containing test datasets")
382
+ parser.add_argument("--n_samples", type=int, default=500,
383
+ help="Number of samples per dataset")
384
+ parser.add_argument("--output_dir", type=str, default="./output/best_of_n",
385
+ help="Output directory for results")
386
+
387
+ args = parser.parse_args()
388
+
389
+ run_experiment(
390
+ model_path=args.model_path,
391
+ datasets_dir=args.datasets_dir,
392
+ n_samples=args.n_samples,
393
+ output_dir=args.output_dir,
394
+ )
395
+
396
+
397
+ if __name__ == "__main__":
398
+ main()
2_training/reinforcement/debug_reinforce.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Debug version of REINFORCE that saves ALL expressions (valid and invalid).
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import argparse
10
+ from pathlib import Path
11
+ from typing import List, Dict
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ # Add project root to path
18
+ PROJECT_ROOT = Path(__file__).parent.parent
19
+ sys.path.insert(0, str(PROJECT_ROOT))
20
+ sys.path.insert(0, str(PROJECT_ROOT / "classes"))
21
+
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM
23
+ from peft import PeftModel, LoraConfig, get_peft_model
24
+ from expression import Expression
25
+
26
+
27
+ class DebugREINFORCE:
28
+ """REINFORCE that logs all expressions."""
29
+
30
+ def __init__(self, model_path: str, X: np.ndarray, y: np.ndarray, device: str = None):
31
+ self.X = X
32
+ self.y = y
33
+ self.n_vars = X.shape[1]
34
+
35
+ if device:
36
+ self.device = torch.device(device)
37
+ else:
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ # Load model
41
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
42
+ self.tokenizer.pad_token = self.tokenizer.eos_token
43
+
44
+ try:
45
+ base_model = AutoModelForCausalLM.from_pretrained("gpt2")
46
+ if len(self.tokenizer) != base_model.config.vocab_size:
47
+ base_model.resize_token_embeddings(len(self.tokenizer))
48
+ model_with_lora = PeftModel.from_pretrained(base_model, model_path)
49
+ self.model = model_with_lora.merge_and_unload()
50
+ except:
51
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
52
+
53
+ # Add LoRA
54
+ lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], lora_dropout=0.05, bias="none")
55
+ self.model = get_peft_model(self.model, lora_config)
56
+ self.model = self.model.to(self.device)
57
+ self.model.train()
58
+
59
+ # Build prompt
60
+ vars_list = [f"x_{i+1}" for i in range(self.n_vars)]
61
+ ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
62
+ self.prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2]
63
+ self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device)
64
+
65
+ # Optimizer
66
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5)
67
+
68
+ # Baseline
69
+ self.baseline = 0.0
70
+ self.baseline_decay = 0.9
71
+
72
+ # ALL expressions log
73
+ self.all_expressions = []
74
+
75
+ def extract_expression(self, text: str) -> str:
76
+ """Extract expression from generated text."""
77
+ try:
78
+ if '"expr": "' in text:
79
+ start = text.index('"expr": "') + len('"expr": "')
80
+ remaining = text[start:]
81
+ for terminator in ['"}', '"']:
82
+ if terminator in remaining:
83
+ return remaining[:remaining.index(terminator)].strip()
84
+ except:
85
+ pass
86
+ return text.strip()
87
+
88
+ def compute_r2(self, expression_str: str) -> tuple:
89
+ """Compute R^2 and detailed error info."""
90
+ result = {
91
+ "expression": expression_str,
92
+ "r2": -1.0,
93
+ "is_valid": False,
94
+ "error_type": None,
95
+ "error_message": None,
96
+ }
97
+
98
+ if not expression_str or expression_str.isspace():
99
+ result["error_type"] = "empty"
100
+ return result
101
+
102
+ test_expr = expression_str.replace('C', '1')
103
+
104
+ try:
105
+ expr = Expression(test_expr, is_prefix=False)
106
+
107
+ if not expr.is_valid_on_dataset(self.X):
108
+ result["error_type"] = "invalid_on_dataset"
109
+ result["error_message"] = "NaN/Inf on dataset"
110
+ return result
111
+
112
+ y_pred = expr.evaluate(self.X)
113
+
114
+ if not np.all(np.isfinite(y_pred)):
115
+ result["error_type"] = "non_finite_output"
116
+ return result
117
+
118
+ ss_res = np.sum((self.y - y_pred) ** 2)
119
+ ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
120
+
121
+ if ss_tot == 0:
122
+ r2 = 0.0
123
+ else:
124
+ r2 = 1 - (ss_res / ss_tot)
125
+
126
+ result["r2"] = float(np.clip(r2, -1.0, 1.0))
127
+ result["is_valid"] = True
128
+
129
+ except Exception as e:
130
+ result["error_type"] = "parse_error"
131
+ result["error_message"] = str(e)[:100]
132
+
133
+ return result
134
+
135
+ def generate_batch(self, batch_size: int = 16, max_new_tokens: int = 50):
136
+ """Generate batch and evaluate."""
137
+ results = []
138
+
139
+ for _ in range(batch_size):
140
+ generated_ids = self.prompt_ids.clone()
141
+ generated_tokens = []
142
+
143
+ with torch.no_grad():
144
+ for _ in range(max_new_tokens):
145
+ outputs = self.model(generated_ids)
146
+ logits = outputs.logits[:, -1, :] / 0.7
147
+
148
+ probs = F.softmax(logits, dim=-1)
149
+ next_token = torch.multinomial(probs, num_samples=1)
150
+
151
+ generated_tokens.append(next_token.item())
152
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
153
+
154
+ if next_token.item() == self.tokenizer.eos_token_id:
155
+ break
156
+
157
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
158
+ if '"}' in text[len(self.prompt):]:
159
+ break
160
+
161
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
162
+ expr_str = self.extract_expression(text)
163
+
164
+ # Evaluate with detailed info
165
+ eval_result = self.compute_r2(expr_str)
166
+
167
+ # Compute log prob
168
+ if len(generated_tokens) > 0:
169
+ full_ids = torch.cat([self.prompt_ids, torch.tensor([generated_tokens], device=self.device)], dim=1)
170
+ outputs = self.model(full_ids[:, :-1])
171
+ logits = outputs.logits / 0.7
172
+ prompt_len = self.prompt_ids.shape[1]
173
+ gen_logits = logits[:, prompt_len-1:, :]
174
+ log_probs_all = F.log_softmax(gen_logits, dim=-1)
175
+ target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0)
176
+ selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1)
177
+ total_log_prob = selected_log_probs.sum()
178
+ else:
179
+ total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True)
180
+
181
+ eval_result["log_prob"] = total_log_prob
182
+ results.append(eval_result)
183
+
184
+ # Log ALL expressions
185
+ self.all_expressions.append(eval_result.copy())
186
+
187
+ return results
188
+
189
+ def train_step(self, batch_size: int = 16):
190
+ """One training step."""
191
+ results = self.generate_batch(batch_size)
192
+
193
+ # Compute rewards
194
+ rewards = [r["r2"] if r["is_valid"] else -0.1 for r in results]
195
+
196
+ # Update baseline
197
+ valid_rewards = [r for r in rewards if r > -0.1]
198
+ if valid_rewards:
199
+ mean_reward = np.mean(valid_rewards)
200
+ self.baseline = self.baseline_decay * self.baseline + (1 - self.baseline_decay) * mean_reward
201
+
202
+ # Advantages
203
+ advantages = [r - self.baseline for r in rewards]
204
+
205
+ # Update
206
+ self.optimizer.zero_grad()
207
+ policy_loss = torch.tensor(0.0, device=self.device)
208
+
209
+ for result, advantage in zip(results, advantages):
210
+ if result["is_valid"] or result["error_type"] == "parse_error":
211
+ policy_loss = policy_loss - result["log_prob"] * advantage
212
+
213
+ if len(results) > 0:
214
+ policy_loss = policy_loss / len(results)
215
+ policy_loss.backward()
216
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
217
+ self.optimizer.step()
218
+
219
+ # Stats
220
+ valid_count = sum(1 for r in results if r["is_valid"])
221
+ valid_r2 = [r["r2"] for r in results if r["is_valid"]]
222
+
223
+ return {
224
+ "valid_count": valid_count,
225
+ "total_count": len(results),
226
+ "mean_r2": np.mean(valid_r2) if valid_r2 else -1.0,
227
+ "max_r2": max(r["r2"] for r in results),
228
+ "baseline": self.baseline,
229
+ }
230
+
231
+ def run(self, epochs: int = 10):
232
+ """Run training."""
233
+ print(f"Running debug REINFORCE for {epochs} epochs...")
234
+ print()
235
+
236
+ for epoch in range(1, epochs + 1):
237
+ stats = self.train_step()
238
+ print(f"Epoch {epoch:2d} | Valid: {stats['valid_count']}/{stats['total_count']} | Mean R²: {stats['mean_r2']:.4f} | Max R²: {stats['max_r2']:.4f}")
239
+
240
+ # Save ALL expressions
241
+ output_file = "debug_expressions.json"
242
+ with open(output_file, "w") as f:
243
+ json.dump({"all_expressions": self.all_expressions}, f, indent=2, default=str)
244
+
245
+ print()
246
+ print(f"Saved {len(self.all_expressions)} expressions to {output_file}")
247
+
248
+ # Analyze
249
+ valid = [e for e in self.all_expressions if e["is_valid"]]
250
+ invalid = [e for e in self.all_expressions if not e["is_valid"]]
251
+
252
+ print()
253
+ print("SUMMARY:")
254
+ print(f" Total: {len(self.all_expressions)}")
255
+ print(f" Valid: {len(valid)} ({100*len(valid)/len(self.all_expressions):.1f}%)")
256
+ print(f" Invalid: {len(invalid)} ({100*len(invalid)/len(self.all_expressions):.1f}%)")
257
+
258
+ if invalid:
259
+ error_types = {}
260
+ for e in invalid:
261
+ et = e.get("error_type", "unknown")
262
+ error_types[et] = error_types.get(et, 0) + 1
263
+
264
+ print()
265
+ print("Invalid expression types:")
266
+ for et, count in sorted(error_types.items(), key=lambda x: -x[1]):
267
+ print(f" {et}: {count} ({100*count/len(invalid):.1f}%)")
268
+
269
+
270
+ def main():
271
+ parser = argparse.ArgumentParser()
272
+ parser.add_argument("--model_path", type=str, required=True)
273
+ parser.add_argument("--dataset", type=str, required=True)
274
+ parser.add_argument("--epochs", type=int, default=10)
275
+ args = parser.parse_args()
276
+
277
+ # Load dataset
278
+ import pandas as pd
279
+ df = pd.read_csv(args.dataset)
280
+ x_cols = [c for c in df.columns if c.startswith('x_')]
281
+ X = df[x_cols].values
282
+ y = df['y'].values
283
+
284
+ print(f"Dataset: {args.dataset}")
285
+ print(f" Samples: {len(df)}, Variables: {len(x_cols)}")
286
+ print()
287
+
288
+ # Run
289
+ reinforce = DebugREINFORCE(args.model_path, X, y)
290
+ reinforce.run(epochs=args.epochs)
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()
2_training/reinforcement/grpo_experiment.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO Experiment for Symbolic Regression
4
+
5
+ GRPO (Group Relative Policy Optimization) supports custom reward functions
6
+ via the reward_funcs parameter, making it ideal for symbolic regression
7
+ where we compute R^2 scores as rewards.
8
+
9
+ This is the recommended approach for TRL 0.27+ since PPO experimental
10
+ has compatibility issues.
11
+
12
+ Usage:
13
+ python scripts/grpo_experiment.py --dataset ./data/ppo_test/sin_x1.csv
14
+ """
15
+
16
+ import os
17
+ os.environ['TRL_EXPERIMENTAL_SILENCE'] = '1'
18
+
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import logging
23
+ import datetime
24
+ from pathlib import Path
25
+ from typing import List
26
+
27
+ import numpy as np
28
+ import torch
29
+
30
+ # Add project root to path
31
+ PROJECT_ROOT = Path(__file__).parent.parent
32
+ sys.path.insert(0, str(PROJECT_ROOT))
33
+ sys.path.insert(0, str(PROJECT_ROOT / "classes"))
34
+
35
+ from transformers import AutoTokenizer, AutoModelForCausalLM
36
+ from trl import GRPOConfig, GRPOTrainer
37
+ from datasets import Dataset
38
+ from peft import PeftModel
39
+
40
+ from expression import Expression
41
+ from dataset import RegressionDataset
42
+
43
+ # Configure logging
44
+ logging.basicConfig(
45
+ level=logging.INFO,
46
+ format='%(asctime)s - %(levelname)s - %(message)s',
47
+ )
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class SymbolicRegressionReward:
52
+ """
53
+ Reward function for symbolic regression.
54
+ Computes R^2 score for generated expressions.
55
+ """
56
+
57
+ def __init__(self, X: np.ndarray, y: np.ndarray, tokenizer):
58
+ self.X = X
59
+ self.y = y
60
+ self.tokenizer = tokenizer
61
+ self.n_vars = X.shape[1]
62
+ self.best_r2 = -np.inf
63
+ self.best_expression = None
64
+ self.history = []
65
+
66
+ def extract_expression(self, text: str) -> str:
67
+ """Extract expression from JSON format output."""
68
+ try:
69
+ # Case 1: Standard JSON with quotes
70
+ if '"expr": "' in text:
71
+ start = text.index('"expr": "') + len('"expr": "')
72
+ remaining = text[start:]
73
+ if '"}' in remaining:
74
+ return remaining[:remaining.index('"}')].strip()
75
+ if '"' in remaining:
76
+ return remaining[:remaining.index('"')].strip()
77
+ return remaining.strip()
78
+
79
+ # Case 2: Model output without quotes
80
+ if '"expr": ' in text:
81
+ start = text.index('"expr": ') + len('"expr": ')
82
+ remaining = text[start:]
83
+ if '"}' in remaining:
84
+ return remaining[:remaining.index('"}')].strip()
85
+ return remaining.strip()
86
+
87
+ except (ValueError, IndexError):
88
+ pass
89
+
90
+ return text.split('"expr"')[-1].strip(' ":}')
91
+
92
+ def compute_r2(self, expression_str: str) -> float:
93
+ """Compute R^2 score for an expression."""
94
+ if not expression_str or expression_str.isspace():
95
+ return -1.0
96
+
97
+ # Substitute C with 1
98
+ if 'C' in expression_str:
99
+ expression_str = expression_str.replace('C', '1')
100
+
101
+ try:
102
+ expr = Expression(expression_str, is_prefix=False)
103
+
104
+ if not expr.is_valid_on_dataset(self.X):
105
+ return -1.0
106
+
107
+ y_pred = expr.evaluate(self.X)
108
+
109
+ if not np.all(np.isfinite(y_pred)):
110
+ return -1.0
111
+
112
+ ss_res = np.sum((self.y - y_pred) ** 2)
113
+ ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
114
+
115
+ if ss_tot == 0:
116
+ return 0.0
117
+
118
+ r2 = 1 - (ss_res / ss_tot)
119
+ return float(np.clip(r2, -1.0, 1.0))
120
+ except Exception:
121
+ return -1.0
122
+
123
+ def __call__(self, completions: List[str], **kwargs) -> List[float]:
124
+ """
125
+ Compute rewards for a batch of completions.
126
+
127
+ Args:
128
+ completions: List of generated completion strings
129
+
130
+ Returns:
131
+ List of R^2 scores
132
+ """
133
+ rewards = []
134
+
135
+ for completion in completions:
136
+ # Extract expression from completion
137
+ expr_str = self.extract_expression(completion)
138
+
139
+ # Compute R^2
140
+ r2 = self.compute_r2(expr_str)
141
+ rewards.append(r2)
142
+
143
+ # Track best
144
+ if r2 > self.best_r2:
145
+ self.best_r2 = r2
146
+ self.best_expression = expr_str
147
+ logger.info(f"New best R^2: {r2:.4f} - {expr_str}")
148
+
149
+ # Log batch statistics
150
+ valid_rewards = [r for r in rewards if r > -1.0]
151
+ if valid_rewards:
152
+ self.history.append({
153
+ "mean_r2": np.mean(valid_rewards),
154
+ "max_r2": max(valid_rewards),
155
+ "valid_rate": len(valid_rewards) / len(rewards),
156
+ })
157
+
158
+ return rewards
159
+
160
+
161
+ def build_prompt(n_vars: int) -> str:
162
+ """Build JSON format prompt matching training data."""
163
+ vars_list = [f"x_{i+1}" for i in range(n_vars)]
164
+ ops_list = ["+", "-", "*", "sin", "cos"]
165
+
166
+ prompt = json.dumps({
167
+ "vars": vars_list,
168
+ "ops": ops_list,
169
+ "cons": None,
170
+ "expr": ""
171
+ })[:-3] # Remove trailing '"}' for model to complete
172
+
173
+ return prompt
174
+
175
+
176
+ def run_grpo_experiment(
177
+ model_path: str,
178
+ dataset_path: str,
179
+ output_dir: str = "./output/grpo_results",
180
+ num_episodes: int = 100,
181
+ batch_size: int = 4,
182
+ learning_rate: float = 1e-5,
183
+ use_cpu: bool = False,
184
+ ):
185
+ """Run GRPO experiment with custom R^2 reward function."""
186
+
187
+ output_dir = Path(output_dir)
188
+ output_dir.mkdir(parents=True, exist_ok=True)
189
+
190
+ # Device setup
191
+ device = "cpu" if use_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
192
+ logger.info(f"Using device: {device}")
193
+
194
+ # Load dataset
195
+ logger.info(f"Loading dataset from {dataset_path}")
196
+ dataset_path = Path(dataset_path)
197
+ reg = RegressionDataset(str(dataset_path.parent), dataset_path.name)
198
+ X, y = reg.get_numpy()
199
+ n_vars = X.shape[1]
200
+ logger.info(f"Dataset: {X.shape[0]} samples, {n_vars} variables")
201
+
202
+ # Load tokenizer and model
203
+ logger.info(f"Loading model from {model_path}")
204
+
205
+ # Check if model_path is a local path or HuggingFace model
206
+ if Path(model_path).exists():
207
+ # Load tokenizer from trained model
208
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
209
+ tokenizer.pad_token = tokenizer.eos_token
210
+
211
+ # Load base model and LoRA
212
+ base_model = AutoModelForCausalLM.from_pretrained("gpt2")
213
+ if len(tokenizer) != base_model.config.vocab_size:
214
+ base_model.resize_token_embeddings(len(tokenizer))
215
+
216
+ try:
217
+ model_with_lora = PeftModel.from_pretrained(base_model, model_path)
218
+ model = model_with_lora.merge_and_unload()
219
+ logger.info("LoRA adapter loaded and merged")
220
+ except Exception as e:
221
+ logger.warning(f"Could not load LoRA: {e}")
222
+ model = AutoModelForCausalLM.from_pretrained(model_path)
223
+ else:
224
+ # Load from HuggingFace Hub
225
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
226
+ tokenizer.pad_token = tokenizer.eos_token
227
+ model = AutoModelForCausalLM.from_pretrained(model_path)
228
+
229
+ logger.info("Model loaded successfully")
230
+
231
+ # Build prompt and create dataset
232
+ prompt = build_prompt(n_vars)
233
+ logger.info(f"Prompt: {prompt}...")
234
+
235
+ train_dataset = Dataset.from_dict({"prompt": [prompt] * num_episodes})
236
+
237
+ # Create reward function
238
+ reward_func = SymbolicRegressionReward(X, y, tokenizer)
239
+
240
+ # GRPO Config
241
+ grpo_config = GRPOConfig(
242
+ output_dir=str(output_dir),
243
+ learning_rate=learning_rate,
244
+ per_device_train_batch_size=batch_size,
245
+ num_generations=batch_size, # Generate batch_size samples per prompt
246
+ max_completion_length=50,
247
+ num_train_epochs=1,
248
+ report_to=[],
249
+ use_cpu=use_cpu or device == "cpu",
250
+ bf16=False if use_cpu or device == "cpu" else True,
251
+ logging_steps=10,
252
+ save_strategy="epoch",
253
+ )
254
+
255
+ # Create trainer
256
+ logger.info("Creating GRPO Trainer...")
257
+ trainer = GRPOTrainer(
258
+ model=model,
259
+ args=grpo_config,
260
+ processing_class=tokenizer,
261
+ train_dataset=train_dataset,
262
+ reward_funcs=reward_func,
263
+ )
264
+
265
+ # Train
266
+ logger.info("="*60)
267
+ logger.info("GRPO SYMBOLIC REGRESSION EXPERIMENT")
268
+ logger.info("="*60)
269
+ logger.info(f"Dataset: {dataset_path}")
270
+ logger.info(f"Model: {model_path}")
271
+ logger.info(f"Episodes: {num_episodes}")
272
+ logger.info("="*60)
273
+
274
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
275
+
276
+ try:
277
+ trainer.train()
278
+ logger.info("Training completed!")
279
+ except Exception as e:
280
+ logger.error(f"Training failed: {e}")
281
+ import traceback
282
+ traceback.print_exc()
283
+
284
+ # Results
285
+ logger.info("\n" + "="*60)
286
+ logger.info("RESULTS")
287
+ logger.info("="*60)
288
+ logger.info(f"Best R^2: {reward_func.best_r2:.4f}")
289
+ logger.info(f"Best expression: {reward_func.best_expression}")
290
+
291
+ # Save results
292
+ results = {
293
+ "timestamp": timestamp,
294
+ "model_path": model_path,
295
+ "dataset_path": str(dataset_path),
296
+ "best_r2": reward_func.best_r2,
297
+ "best_expression": reward_func.best_expression,
298
+ "history": reward_func.history,
299
+ }
300
+
301
+ results_file = output_dir / f"grpo_results_{timestamp}.json"
302
+ with open(results_file, 'w') as f:
303
+ json.dump(results, f, indent=2)
304
+
305
+ logger.info(f"Results saved to: {results_file}")
306
+
307
+ # Save model
308
+ trainer.save_model(str(output_dir / "final_model"))
309
+
310
+ return results
311
+
312
+
313
+ def main():
314
+ parser = argparse.ArgumentParser(description="GRPO Symbolic Regression")
315
+ parser.add_argument("--model_path", type=str, default="gpt2",
316
+ help="Path to model (local or HuggingFace)")
317
+ parser.add_argument("--dataset", type=str, default="./data/ppo_test/sin_x1.csv",
318
+ help="Path to test dataset CSV")
319
+ parser.add_argument("--output_dir", type=str, default="./output/grpo_results",
320
+ help="Output directory")
321
+ parser.add_argument("--num_episodes", type=int, default=100,
322
+ help="Number of training episodes")
323
+ parser.add_argument("--batch_size", type=int, default=4,
324
+ help="Batch size")
325
+ parser.add_argument("--lr", type=float, default=1e-5,
326
+ help="Learning rate")
327
+ parser.add_argument("--cpu", action="store_true",
328
+ help="Force CPU usage")
329
+
330
+ args = parser.parse_args()
331
+
332
+ run_grpo_experiment(
333
+ model_path=args.model_path,
334
+ dataset_path=args.dataset,
335
+ output_dir=args.output_dir,
336
+ num_episodes=args.num_episodes,
337
+ batch_size=args.batch_size,
338
+ learning_rate=args.lr,
339
+ use_cpu=args.cpu,
340
+ )
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
2_training/reinforcement/grpo_improved.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Improved GRPO (Group Relative Policy Optimization) for Symbolic Regression
4
+
5
+ Improvements over basic GRPO:
6
+ 1. Filter invalid expressions before computing group statistics
7
+ 2. Reward shaping with softer penalties
8
+ 3. Hybrid baseline: group stats + exponential moving average
9
+ 4. Entropy bonus for exploration
10
+ 5. Advantage clipping to prevent extreme updates
11
+ 6. Minimum valid ratio check before updates
12
+ 7. Temperature annealing for better exploration/exploitation
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import json
18
+ import argparse
19
+ import logging
20
+ import datetime
21
+ from pathlib import Path
22
+ from typing import List, Dict, Tuple
23
+ from collections import deque
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ # Add project root to path
30
+ PROJECT_ROOT = Path(__file__).parent.parent
31
+ sys.path.insert(0, str(PROJECT_ROOT))
32
+ sys.path.insert(0, str(PROJECT_ROOT / "classes"))
33
+
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM
35
+ from peft import PeftModel, LoraConfig, get_peft_model
36
+
37
+ from expression import Expression
38
+ from dataset import RegressionDataset
39
+
40
+ # Configure logging
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format='%(asctime)s - %(levelname)s - %(message)s',
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class ImprovedGRPO:
49
+ """Improved GRPO for symbolic regression."""
50
+
51
+ def __init__(
52
+ self,
53
+ model_path: str,
54
+ X: np.ndarray,
55
+ y: np.ndarray,
56
+ output_dir: str = "./output/grpo",
57
+ learning_rate: float = 5e-5,
58
+ device: str = None,
59
+ group_size: int = 16, # Larger groups for better statistics
60
+ entropy_coef: float = 0.01,
61
+ advantage_clip: float = 2.0, # Clip extreme advantages
62
+ min_valid_ratio: float = 0.2, # Minimum valid expressions to update
63
+ ):
64
+ self.X = X
65
+ self.y = y
66
+ self.n_vars = X.shape[1]
67
+ self.output_dir = Path(output_dir)
68
+ self.output_dir.mkdir(parents=True, exist_ok=True)
69
+ self.learning_rate = learning_rate
70
+ self.group_size = group_size
71
+ self.entropy_coef = entropy_coef
72
+ self.advantage_clip = advantage_clip
73
+ self.min_valid_ratio = min_valid_ratio
74
+
75
+ # Device
76
+ if device:
77
+ self.device = torch.device(device)
78
+ else:
79
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ logger.info(f"Using device: {self.device}")
81
+
82
+ # Load model
83
+ self._load_model(model_path)
84
+
85
+ # Build prompt
86
+ self.prompt = self._build_prompt()
87
+ self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device)
88
+
89
+ # Optimizer
90
+ self.optimizer = torch.optim.AdamW(
91
+ self.model.parameters(),
92
+ lr=learning_rate,
93
+ weight_decay=0.01
94
+ )
95
+
96
+ # Scheduler
97
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
98
+ self.optimizer, T_0=10, T_mult=2
99
+ )
100
+
101
+ # Tracking
102
+ self.best_r2 = -np.inf
103
+ self.best_expression = None
104
+ self.history = []
105
+ self.discovered_expressions: Dict[str, float] = {}
106
+
107
+ # Hybrid baseline: EMA of valid rewards
108
+ self.ema_baseline = 0.0
109
+ self.ema_decay = 0.9
110
+ self.reward_buffer = deque(maxlen=100)
111
+
112
+ # Temperature annealing
113
+ self.initial_temp = 0.8
114
+ self.min_temp = 0.5
115
+ self.current_temp = self.initial_temp
116
+
117
+ def _load_model(self, model_path: str):
118
+ """Load model and tokenizer."""
119
+ logger.info(f"Loading model from {model_path}")
120
+
121
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
122
+ self.tokenizer.pad_token = self.tokenizer.eos_token
123
+
124
+ try:
125
+ logger.info("Attempting to load as LoRA adapter...")
126
+ base_model = AutoModelForCausalLM.from_pretrained("gpt2")
127
+ if len(self.tokenizer) != base_model.config.vocab_size:
128
+ base_model.resize_token_embeddings(len(self.tokenizer))
129
+ logger.info(f"Resized embeddings to {len(self.tokenizer)}")
130
+
131
+ model_with_lora = PeftModel.from_pretrained(base_model, model_path)
132
+ self.model = model_with_lora.merge_and_unload()
133
+ logger.info("LoRA adapter loaded and merged successfully")
134
+ except Exception as e:
135
+ logger.info(f"LoRA load failed ({e}), loading as standalone model...")
136
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
137
+
138
+ # Add LoRA for training
139
+ lora_config = LoraConfig(
140
+ r=8,
141
+ lora_alpha=16,
142
+ target_modules=["c_attn"],
143
+ lora_dropout=0.05,
144
+ bias="none",
145
+ )
146
+ self.model = get_peft_model(self.model, lora_config)
147
+ self.model = self.model.to(self.device)
148
+ self.model.train()
149
+
150
+ trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
151
+ logger.info(f"Model loaded with {trainable} trainable params")
152
+
153
+ def _build_prompt(self, ops: list = None) -> str:
154
+ """Build JSON format prompt."""
155
+ vars_list = [f"x_{i+1}" for i in range(self.n_vars)]
156
+
157
+ if ops is None:
158
+ ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
159
+ else:
160
+ ops_list = ops
161
+
162
+ prompt = json.dumps({
163
+ "vars": vars_list,
164
+ "ops": ops_list,
165
+ "cons": "C",
166
+ "expr": ""
167
+ })
168
+ prompt = prompt[:-2]
169
+ return prompt
170
+
171
+ def extract_expression(self, text: str) -> str:
172
+ """Extract expression from generated text."""
173
+ try:
174
+ eos_token = "<|endoftext|>"
175
+ if eos_token in text:
176
+ text = text[:text.index(eos_token)]
177
+
178
+ if '"expr": "' in text:
179
+ start = text.index('"expr": "') + len('"expr": "')
180
+ remaining = text[start:]
181
+ for terminator in ['"}', '"']:
182
+ if terminator in remaining:
183
+ return remaining[:remaining.index(terminator)].strip()
184
+ return remaining.strip()
185
+
186
+ if '"expr": ' in text:
187
+ start = text.index('"expr": ') + len('"expr": ')
188
+ remaining = text[start:]
189
+ if '"}' in remaining:
190
+ return remaining[:remaining.index('"}')].strip()
191
+ return remaining.strip(' "')
192
+
193
+ except (ValueError, IndexError):
194
+ pass
195
+
196
+ if '"expr"' in text:
197
+ return text.split('"expr"')[-1].strip(' ":{}')
198
+ return text.strip()
199
+
200
+ def compute_r2(self, expression_str: str) -> Tuple[float, bool]:
201
+ """Compute R^2 score."""
202
+ if not expression_str or expression_str.isspace():
203
+ return -1.0, False
204
+
205
+ if 'C' in expression_str:
206
+ expression_str = expression_str.replace('C', '1')
207
+
208
+ try:
209
+ expr = Expression(expression_str, is_prefix=False)
210
+ if not expr.is_valid_on_dataset(self.X):
211
+ return -1.0, False
212
+
213
+ y_pred = expr.evaluate(self.X)
214
+ if not np.all(np.isfinite(y_pred)):
215
+ return -1.0, False
216
+
217
+ ss_res = np.sum((self.y - y_pred) ** 2)
218
+ ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
219
+
220
+ if ss_tot == 0:
221
+ return 0.0, True
222
+
223
+ r2 = 1 - (ss_res / ss_tot)
224
+ return float(np.clip(r2, -1.0, 1.0)), True
225
+ except Exception:
226
+ return -1.0, False
227
+
228
+ def shape_reward(self, r2: float, is_valid: bool) -> float:
229
+ """Shape reward for better learning signal."""
230
+ if not is_valid:
231
+ return -0.1 # Small penalty, not -1.0
232
+
233
+ # Bonus for high R²
234
+ if r2 >= 0.99:
235
+ return 2.0 # Big bonus for near-perfect
236
+ elif r2 >= 0.9:
237
+ return r2 * 1.5
238
+ elif r2 >= 0.5:
239
+ return r2 * 1.2
240
+ elif r2 >= 0:
241
+ return r2
242
+ else:
243
+ return r2 * 0.5 # Reduce negative penalty
244
+
245
+ def generate_group(self, max_new_tokens: int = 50) -> List[Dict]:
246
+ """Generate a group of expressions."""
247
+ results = []
248
+
249
+ for _ in range(self.group_size):
250
+ generated_ids = self.prompt_ids.clone()
251
+ generated_tokens = []
252
+
253
+ # Phase 1: Generate tokens
254
+ with torch.no_grad():
255
+ for _ in range(max_new_tokens):
256
+ outputs = self.model(generated_ids)
257
+ logits = outputs.logits[:, -1, :] / self.current_temp
258
+
259
+ probs = F.softmax(logits, dim=-1)
260
+ next_token = torch.multinomial(probs, num_samples=1)
261
+ generated_tokens.append(next_token.item())
262
+
263
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
264
+
265
+ if next_token.item() == self.tokenizer.eos_token_id:
266
+ break
267
+
268
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
269
+ if '"}' in text[len(self.prompt):]:
270
+ break
271
+
272
+ # Decode and evaluate
273
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
274
+ expr_str = self.extract_expression(text)
275
+ r2, is_valid = self.compute_r2(expr_str)
276
+ reward = self.shape_reward(r2, is_valid)
277
+
278
+ # Phase 2: Compute log probs with gradients
279
+ if len(generated_tokens) > 0:
280
+ full_ids = torch.cat([
281
+ self.prompt_ids,
282
+ torch.tensor([generated_tokens], device=self.device)
283
+ ], dim=1)
284
+
285
+ outputs = self.model(full_ids[:, :-1])
286
+ logits = outputs.logits / self.current_temp
287
+
288
+ prompt_len = self.prompt_ids.shape[1]
289
+ gen_logits = logits[:, prompt_len-1:, :]
290
+
291
+ log_probs_all = F.log_softmax(gen_logits, dim=-1)
292
+ probs_all = F.softmax(gen_logits, dim=-1)
293
+
294
+ target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0)
295
+ selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1)
296
+ total_log_prob = selected_log_probs.sum()
297
+
298
+ # Entropy for exploration
299
+ entropy_per_pos = -(probs_all * log_probs_all).sum(dim=-1)
300
+ total_entropy = entropy_per_pos.mean()
301
+ else:
302
+ total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True)
303
+ total_entropy = torch.tensor(0.0, device=self.device)
304
+
305
+ results.append({
306
+ "text": text,
307
+ "expression": expr_str,
308
+ "r2": r2,
309
+ "is_valid": is_valid,
310
+ "reward": reward,
311
+ "log_prob": total_log_prob,
312
+ "entropy": total_entropy,
313
+ })
314
+
315
+ # Track best
316
+ if is_valid:
317
+ self.discovered_expressions[expr_str] = max(
318
+ self.discovered_expressions.get(expr_str, -np.inf), r2
319
+ )
320
+ self.reward_buffer.append(reward)
321
+
322
+ if r2 > self.best_r2:
323
+ self.best_r2 = r2
324
+ self.best_expression = expr_str
325
+
326
+ if self.device.type == "cuda":
327
+ torch.cuda.empty_cache()
328
+
329
+ return results
330
+
331
+ def compute_advantages(self, results: List[Dict]) -> Tuple[List[float], dict]:
332
+ """
333
+ Compute improved GRPO advantages.
334
+
335
+ Key improvement: Only use VALID expressions for group statistics.
336
+ Invalid expressions get a fixed small negative advantage.
337
+ """
338
+ valid_results = [r for r in results if r["is_valid"]]
339
+ valid_rewards = [r["reward"] for r in valid_results]
340
+
341
+ stats = {
342
+ "valid_count": len(valid_results),
343
+ "total_count": len(results),
344
+ "valid_ratio": len(valid_results) / len(results),
345
+ }
346
+
347
+ # If too few valid expressions, use EMA baseline only
348
+ if len(valid_rewards) < 2:
349
+ advantages = []
350
+ for r in results:
351
+ if r["is_valid"]:
352
+ adv = r["reward"] - self.ema_baseline
353
+ else:
354
+ adv = -0.5 # Fixed penalty for invalid
355
+ advantages.append(adv)
356
+ stats["method"] = "ema_only"
357
+ return advantages, stats
358
+
359
+ # Compute group statistics from valid expressions only
360
+ group_mean = np.mean(valid_rewards)
361
+ group_std = np.std(valid_rewards)
362
+
363
+ # Update EMA baseline
364
+ self.ema_baseline = self.ema_decay * self.ema_baseline + (1 - self.ema_decay) * group_mean
365
+
366
+ # Hybrid baseline: combine group mean with EMA
367
+ hybrid_baseline = 0.7 * group_mean + 0.3 * self.ema_baseline
368
+
369
+ # Avoid division by zero
370
+ if group_std < 1e-8:
371
+ group_std = 1.0
372
+
373
+ # Compute advantages
374
+ advantages = []
375
+ for r in results:
376
+ if r["is_valid"]:
377
+ # Normalized advantage for valid expressions
378
+ adv = (r["reward"] - hybrid_baseline) / group_std
379
+ # Clip to prevent extreme updates
380
+ adv = np.clip(adv, -self.advantage_clip, self.advantage_clip)
381
+ else:
382
+ # Small fixed penalty for invalid (doesn't pollute group stats)
383
+ adv = -0.3
384
+ advantages.append(adv)
385
+
386
+ stats["method"] = "hybrid"
387
+ stats["group_mean"] = group_mean
388
+ stats["group_std"] = group_std
389
+ stats["ema_baseline"] = self.ema_baseline
390
+
391
+ return advantages, stats
392
+
393
+ def train_step(self, num_groups: int = 2) -> dict:
394
+ """Perform one training step."""
395
+ self.model.train()
396
+
397
+ all_results = []
398
+ all_advantages = []
399
+ total_policy_loss = 0.0
400
+ total_entropy_loss = 0.0
401
+ skipped_groups = 0
402
+
403
+ self.optimizer.zero_grad()
404
+
405
+ for _ in range(num_groups):
406
+ if self.device.type == "cuda":
407
+ torch.cuda.empty_cache()
408
+
409
+ # Generate group
410
+ group_results = self.generate_group()
411
+ all_results.extend(group_results)
412
+
413
+ # Compute advantages
414
+ advantages, adv_stats = self.compute_advantages(group_results)
415
+ all_advantages.extend(advantages)
416
+
417
+ # Skip update if too few valid expressions
418
+ if adv_stats["valid_ratio"] < self.min_valid_ratio:
419
+ skipped_groups += 1
420
+ continue
421
+
422
+ # Compute loss
423
+ policy_loss = torch.tensor(0.0, device=self.device)
424
+ entropy_loss = torch.tensor(0.0, device=self.device)
425
+ valid_count = 0
426
+
427
+ for result, advantage in zip(group_results, advantages):
428
+ if result["is_valid"] and advantage != 0:
429
+ policy_loss = policy_loss - result["log_prob"] * advantage
430
+ entropy_loss = entropy_loss - result["entropy"]
431
+ valid_count += 1
432
+
433
+ if valid_count > 0:
434
+ policy_loss = policy_loss / valid_count
435
+ entropy_loss = entropy_loss / valid_count
436
+
437
+ # Combined loss
438
+ loss = policy_loss + self.entropy_coef * entropy_loss
439
+ loss = loss / num_groups
440
+ loss.backward()
441
+
442
+ total_policy_loss += policy_loss.item()
443
+ total_entropy_loss += entropy_loss.item()
444
+
445
+ # Only update if we had valid groups
446
+ if skipped_groups < num_groups:
447
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
448
+ self.optimizer.step()
449
+ self.scheduler.step()
450
+
451
+ # Statistics
452
+ r2_values = [r["r2"] for r in all_results]
453
+ valid_mask = [r["is_valid"] for r in all_results]
454
+ valid_r2 = [r2 for r2, v in zip(r2_values, valid_mask) if v]
455
+
456
+ return {
457
+ "valid_count": int(sum(valid_mask)),
458
+ "total_count": len(all_results),
459
+ "valid_rate": sum(valid_mask) / len(all_results) if all_results else 0,
460
+ "mean_r2": float(np.mean(valid_r2)) if valid_r2 else 0.0,
461
+ "max_r2": float(max(r2_values)) if r2_values else 0.0,
462
+ "mean_advantage": float(np.mean(all_advantages)) if all_advantages else 0.0,
463
+ "ema_baseline": self.ema_baseline,
464
+ "policy_loss": total_policy_loss / max(num_groups - skipped_groups, 1),
465
+ "entropy_loss": total_entropy_loss / max(num_groups - skipped_groups, 1),
466
+ "lr": self.scheduler.get_last_lr()[0],
467
+ "temperature": self.current_temp,
468
+ "skipped_groups": skipped_groups,
469
+ }
470
+
471
+ def anneal_temperature(self, epoch: int, total_epochs: int):
472
+ """Anneal temperature from initial to minimum."""
473
+ progress = epoch / total_epochs
474
+ self.current_temp = self.initial_temp - progress * (self.initial_temp - self.min_temp)
475
+
476
+ def run(
477
+ self,
478
+ epochs: int = 50,
479
+ num_groups: int = 2,
480
+ target_r2: float = 0.99,
481
+ patience: int = 20,
482
+ ) -> dict:
483
+ """Run improved GRPO training."""
484
+ logger.info("=" * 60)
485
+ logger.info("IMPROVED GRPO SYMBOLIC REGRESSION")
486
+ logger.info("=" * 60)
487
+ logger.info(f"Epochs: {epochs}")
488
+ logger.info(f"Group size: {self.group_size}")
489
+ logger.info(f"Num groups: {num_groups}")
490
+ logger.info(f"Effective batch: {self.group_size * num_groups}")
491
+ logger.info(f"Entropy coef: {self.entropy_coef}")
492
+ logger.info(f"Advantage clip: {self.advantage_clip}")
493
+ logger.info(f"Min valid ratio: {self.min_valid_ratio}")
494
+ logger.info(f"Target R^2: {target_r2}")
495
+ logger.info("=" * 60)
496
+
497
+ no_improvement_count = 0
498
+ best_r2_at_start = self.best_r2
499
+
500
+ for epoch in range(1, epochs + 1):
501
+ # Anneal temperature
502
+ self.anneal_temperature(epoch, epochs)
503
+
504
+ stats = self.train_step(num_groups)
505
+ self.history.append({
506
+ "epoch": epoch,
507
+ **stats,
508
+ "best_r2": self.best_r2,
509
+ })
510
+
511
+ logger.info(
512
+ f"Epoch {epoch:3d} | "
513
+ f"Valid: {stats['valid_count']}/{stats['total_count']} | "
514
+ f"Mean R²: {stats['mean_r2']:.4f} | "
515
+ f"Best: {self.best_r2:.4f} | "
516
+ f"EMA: {stats['ema_baseline']:.3f} | "
517
+ f"Temp: {stats['temperature']:.2f} | "
518
+ f"LR: {stats['lr']:.2e}"
519
+ )
520
+
521
+ # Check for target
522
+ if self.best_r2 >= target_r2:
523
+ logger.info(f"Target R^2 {target_r2} reached at epoch {epoch}!")
524
+ break
525
+
526
+ # Early stopping
527
+ if self.best_r2 > best_r2_at_start:
528
+ best_r2_at_start = self.best_r2
529
+ no_improvement_count = 0
530
+ else:
531
+ no_improvement_count += 1
532
+
533
+ if no_improvement_count >= patience:
534
+ logger.info(f"No improvement for {patience} epochs. Early stopping.")
535
+ break
536
+
537
+ # Final results
538
+ logger.info("")
539
+ logger.info("=" * 60)
540
+ logger.info("FINAL RESULTS")
541
+ logger.info("=" * 60)
542
+ logger.info(f"Best R^2: {self.best_r2:.4f}")
543
+ logger.info(f"Best expression: {self.best_expression}")
544
+ logger.info(f"Unique expressions discovered: {len(self.discovered_expressions)}")
545
+
546
+ top_exprs = sorted(
547
+ self.discovered_expressions.items(),
548
+ key=lambda x: x[1],
549
+ reverse=True
550
+ )[:5]
551
+ logger.info("Top 5 expressions:")
552
+ for expr, r2 in top_exprs:
553
+ logger.info(f" R²={r2:.4f}: {expr}")
554
+
555
+ # Save results
556
+ results = {
557
+ "algorithm": "ImprovedGRPO",
558
+ "best_r2": self.best_r2,
559
+ "best_expression": self.best_expression,
560
+ "history": self.history,
561
+ "discovered_expressions": dict(list(self.discovered_expressions.items())[:100]),
562
+ "config": {
563
+ "group_size": self.group_size,
564
+ "num_groups": num_groups,
565
+ "learning_rate": self.learning_rate,
566
+ "entropy_coef": self.entropy_coef,
567
+ "advantage_clip": self.advantage_clip,
568
+ "min_valid_ratio": self.min_valid_ratio,
569
+ }
570
+ }
571
+
572
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
573
+ output_path = self.output_dir / f"results_grpo_improved_{timestamp}.json"
574
+ with open(output_path, "w") as f:
575
+ json.dump(results, f, indent=2)
576
+ logger.info(f"Results saved to: {output_path}")
577
+
578
+ return results
579
+
580
+
581
+ def main():
582
+ parser = argparse.ArgumentParser(description="Improved GRPO for Symbolic Regression")
583
+ parser.add_argument("--model_path", type=str, required=True)
584
+ parser.add_argument("--dataset", type=str, required=True)
585
+ parser.add_argument("--output_dir", type=str, default="./output/grpo")
586
+ parser.add_argument("--epochs", type=int, default=50)
587
+ parser.add_argument("--group_size", type=int, default=16)
588
+ parser.add_argument("--num_groups", type=int, default=2)
589
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
590
+ parser.add_argument("--target_r2", type=float, default=0.99)
591
+ parser.add_argument("--entropy_coef", type=float, default=0.01)
592
+ args = parser.parse_args()
593
+
594
+ # Load dataset
595
+ import pandas as pd
596
+ df = pd.read_csv(args.dataset)
597
+
598
+ x_cols = [c for c in df.columns if c.startswith('x_')]
599
+ X = df[x_cols].values
600
+ y = df['y'].values
601
+
602
+ logger.info(f"Loaded dataset: {args.dataset}")
603
+ logger.info(f" Samples: {len(df)}, Variables: {len(x_cols)}")
604
+
605
+ # Create trainer
606
+ grpo = ImprovedGRPO(
607
+ model_path=args.model_path,
608
+ X=X,
609
+ y=y,
610
+ output_dir=args.output_dir,
611
+ learning_rate=args.learning_rate,
612
+ group_size=args.group_size,
613
+ entropy_coef=args.entropy_coef,
614
+ )
615
+
616
+ # Run training
617
+ results = grpo.run(
618
+ epochs=args.epochs,
619
+ num_groups=args.num_groups,
620
+ target_r2=args.target_r2,
621
+ )
622
+
623
+
624
+ if __name__ == "__main__":
625
+ main()
2_training/reinforcement/grpo_symbolic.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO (Group Relative Policy Optimization) for Symbolic Regression
4
+
5
+ Based on DeepSeek-R1 approach:
6
+ - Generate a group of N samples
7
+ - Compute advantages relative to group mean/std
8
+ - No external baseline needed
9
+
10
+ Comparison with REINFORCE:
11
+ - REINFORCE: advantage = reward - moving_average_baseline
12
+ - GRPO: advantage = (reward - group_mean) / group_std
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import json
18
+ import argparse
19
+ import logging
20
+ import datetime
21
+ from pathlib import Path
22
+ from typing import List, Dict, Tuple
23
+ from copy import deepcopy
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ # Add project root to path
30
+ PROJECT_ROOT = Path(__file__).parent.parent
31
+ sys.path.insert(0, str(PROJECT_ROOT))
32
+ sys.path.insert(0, str(PROJECT_ROOT / "classes"))
33
+
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM
35
+ from peft import PeftModel, LoraConfig, get_peft_model
36
+
37
+ from expression import Expression
38
+ from dataset import RegressionDataset
39
+
40
+ # Configure logging
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format='%(asctime)s - %(levelname)s - %(message)s',
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class GRPO:
49
+ """Group Relative Policy Optimization for symbolic regression."""
50
+
51
+ def __init__(
52
+ self,
53
+ model_path: str,
54
+ X: np.ndarray,
55
+ y: np.ndarray,
56
+ output_dir: str = "./output/grpo",
57
+ learning_rate: float = 5e-5,
58
+ device: str = None,
59
+ group_size: int = 8, # Number of samples per group
60
+ kl_coef: float = 0.01, # KL penalty coefficient
61
+ clip_range: float = 0.2, # PPO-style clipping (optional)
62
+ ):
63
+ self.X = X
64
+ self.y = y
65
+ self.n_vars = X.shape[1]
66
+ self.output_dir = Path(output_dir)
67
+ self.output_dir.mkdir(parents=True, exist_ok=True)
68
+ self.learning_rate = learning_rate
69
+ self.group_size = group_size
70
+ self.kl_coef = kl_coef
71
+ self.clip_range = clip_range
72
+
73
+ # Device
74
+ if device:
75
+ self.device = torch.device(device)
76
+ else:
77
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+ logger.info(f"Using device: {self.device}")
79
+
80
+ # Load model
81
+ self._load_model(model_path)
82
+
83
+ # Keep reference model for KL penalty
84
+ self.ref_model = None # Will be set after first update
85
+
86
+ # Build prompt
87
+ self.prompt = self._build_prompt()
88
+ self.prompt_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"].to(self.device)
89
+
90
+ # Optimizer
91
+ self.optimizer = torch.optim.AdamW(
92
+ self.model.parameters(),
93
+ lr=learning_rate,
94
+ weight_decay=0.01
95
+ )
96
+
97
+ # Scheduler
98
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
99
+ self.optimizer, T_0=10, T_mult=2
100
+ )
101
+
102
+ # Tracking
103
+ self.best_r2 = -np.inf
104
+ self.best_expression = None
105
+ self.history = []
106
+ self.discovered_expressions: Dict[str, float] = {}
107
+
108
+ def _load_model(self, model_path: str):
109
+ """Load model and tokenizer."""
110
+ logger.info(f"Loading model from {model_path}")
111
+
112
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
113
+ self.tokenizer.pad_token = self.tokenizer.eos_token
114
+
115
+ try:
116
+ logger.info("Attempting to load as LoRA adapter...")
117
+ base_model = AutoModelForCausalLM.from_pretrained("gpt2")
118
+ if len(self.tokenizer) != base_model.config.vocab_size:
119
+ base_model.resize_token_embeddings(len(self.tokenizer))
120
+ logger.info(f"Resized embeddings to {len(self.tokenizer)}")
121
+
122
+ model_with_lora = PeftModel.from_pretrained(base_model, model_path)
123
+ self.model = model_with_lora.merge_and_unload()
124
+ logger.info("LoRA adapter loaded and merged successfully")
125
+ except Exception as e:
126
+ logger.info(f"LoRA load failed ({e}), loading as standalone model...")
127
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
128
+
129
+ # Add LoRA for training
130
+ lora_config = LoraConfig(
131
+ r=8,
132
+ lora_alpha=16,
133
+ target_modules=["c_attn"],
134
+ lora_dropout=0.05,
135
+ bias="none",
136
+ )
137
+ self.model = get_peft_model(self.model, lora_config)
138
+ self.model = self.model.to(self.device)
139
+ self.model.train()
140
+
141
+ trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
142
+ logger.info(f"Model loaded with {trainable} trainable params")
143
+
144
+ def _build_prompt(self, ops: list = None) -> str:
145
+ """Build JSON format prompt."""
146
+ vars_list = [f"x_{i+1}" for i in range(self.n_vars)]
147
+
148
+ if ops is None:
149
+ ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
150
+ else:
151
+ ops_list = ops
152
+
153
+ prompt = json.dumps({
154
+ "vars": vars_list,
155
+ "ops": ops_list,
156
+ "cons": "C",
157
+ "expr": ""
158
+ })
159
+ prompt = prompt[:-2]
160
+ return prompt
161
+
162
+ def extract_expression(self, text: str) -> str:
163
+ """Extract expression from generated text."""
164
+ try:
165
+ eos_token = "<|endoftext|>"
166
+ if eos_token in text:
167
+ text = text[:text.index(eos_token)]
168
+
169
+ if '"expr": "' in text:
170
+ start = text.index('"expr": "') + len('"expr": "')
171
+ remaining = text[start:]
172
+ for terminator in ['"}', '"']:
173
+ if terminator in remaining:
174
+ return remaining[:remaining.index(terminator)].strip()
175
+ return remaining.strip()
176
+
177
+ if '"expr": ' in text:
178
+ start = text.index('"expr": ') + len('"expr": ')
179
+ remaining = text[start:]
180
+ if '"}' in remaining:
181
+ return remaining[:remaining.index('"}')].strip()
182
+ return remaining.strip(' "')
183
+
184
+ except (ValueError, IndexError):
185
+ pass
186
+
187
+ if '"expr"' in text:
188
+ return text.split('"expr"')[-1].strip(' ":{}')
189
+ return text.strip()
190
+
191
+ def compute_r2(self, expression_str: str) -> Tuple[float, bool]:
192
+ """Compute R^2 score. Returns (score, is_valid)."""
193
+ if not expression_str or expression_str.isspace():
194
+ return -1.0, False
195
+
196
+ if 'C' in expression_str:
197
+ expression_str = expression_str.replace('C', '1')
198
+
199
+ try:
200
+ expr = Expression(expression_str, is_prefix=False)
201
+ if not expr.is_valid_on_dataset(self.X):
202
+ return -1.0, False
203
+
204
+ y_pred = expr.evaluate(self.X)
205
+ if not np.all(np.isfinite(y_pred)):
206
+ return -1.0, False
207
+
208
+ ss_res = np.sum((self.y - y_pred) ** 2)
209
+ ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
210
+
211
+ if ss_tot == 0:
212
+ return 0.0, True
213
+
214
+ r2 = 1 - (ss_res / ss_tot)
215
+ return float(np.clip(r2, -1.0, 1.0)), True
216
+ except Exception:
217
+ return -1.0, False
218
+
219
+ def generate_group(
220
+ self,
221
+ temperature: float = 0.7,
222
+ max_new_tokens: int = 50
223
+ ) -> List[Dict]:
224
+ """Generate a group of expressions."""
225
+ results = []
226
+
227
+ for _ in range(self.group_size):
228
+ generated_ids = self.prompt_ids.clone()
229
+ generated_tokens = []
230
+
231
+ # Phase 1: Generate tokens without gradients
232
+ with torch.no_grad():
233
+ for _ in range(max_new_tokens):
234
+ outputs = self.model(generated_ids)
235
+ logits = outputs.logits[:, -1, :] / temperature
236
+
237
+ probs = F.softmax(logits, dim=-1)
238
+ next_token = torch.multinomial(probs, num_samples=1)
239
+ generated_tokens.append(next_token.item())
240
+
241
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
242
+
243
+ if next_token.item() == self.tokenizer.eos_token_id:
244
+ break
245
+
246
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
247
+ if '"}' in text[len(self.prompt):]:
248
+ break
249
+
250
+ # Decode and extract expression
251
+ text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
252
+ expr_str = self.extract_expression(text)
253
+ r2, is_valid = self.compute_r2(expr_str)
254
+
255
+ # Phase 2: Efficient log prob computation
256
+ if len(generated_tokens) > 0:
257
+ full_ids = torch.cat([
258
+ self.prompt_ids,
259
+ torch.tensor([generated_tokens], device=self.device)
260
+ ], dim=1)
261
+
262
+ outputs = self.model(full_ids[:, :-1])
263
+ logits = outputs.logits / temperature
264
+
265
+ prompt_len = self.prompt_ids.shape[1]
266
+ gen_logits = logits[:, prompt_len-1:, :]
267
+
268
+ log_probs_all = F.log_softmax(gen_logits, dim=-1)
269
+
270
+ target_tokens = torch.tensor(generated_tokens, device=self.device).unsqueeze(0)
271
+ selected_log_probs = log_probs_all.gather(2, target_tokens.unsqueeze(-1)).squeeze(-1)
272
+ total_log_prob = selected_log_probs.sum()
273
+ else:
274
+ total_log_prob = torch.tensor(0.0, device=self.device, requires_grad=True)
275
+
276
+ results.append({
277
+ "text": text,
278
+ "expression": expr_str,
279
+ "r2": r2,
280
+ "is_valid": is_valid,
281
+ "log_prob": total_log_prob,
282
+ "generated_tokens": generated_tokens,
283
+ })
284
+
285
+ # Track best
286
+ if is_valid:
287
+ self.discovered_expressions[expr_str] = max(
288
+ self.discovered_expressions.get(expr_str, -np.inf), r2
289
+ )
290
+
291
+ if r2 > self.best_r2:
292
+ self.best_r2 = r2
293
+ self.best_expression = expr_str
294
+
295
+ # Clear cache
296
+ if self.device.type == "cuda":
297
+ torch.cuda.empty_cache()
298
+
299
+ return results
300
+
301
+ def compute_group_advantages(self, results: List[Dict]) -> List[float]:
302
+ """
303
+ Compute GRPO advantages: (reward - mean) / std
304
+
305
+ This is the key difference from REINFORCE:
306
+ - REINFORCE uses external moving average baseline
307
+ - GRPO uses within-group statistics
308
+ """
309
+ # Get rewards (R² values, with penalty for invalid)
310
+ rewards = []
311
+ for r in results:
312
+ if r["is_valid"]:
313
+ rewards.append(r["r2"])
314
+ else:
315
+ rewards.append(-0.1) # Small penalty for invalid
316
+
317
+ rewards = np.array(rewards)
318
+
319
+ # Compute group statistics
320
+ mean_reward = np.mean(rewards)
321
+ std_reward = np.std(rewards)
322
+
323
+ # Avoid division by zero
324
+ if std_reward < 1e-8:
325
+ std_reward = 1.0
326
+
327
+ # Compute normalized advantages
328
+ advantages = (rewards - mean_reward) / std_reward
329
+
330
+ return advantages.tolist(), mean_reward, std_reward
331
+
332
+ def train_step(self, num_groups: int = 4) -> dict:
333
+ """
334
+ Perform one GRPO training step.
335
+
336
+ Args:
337
+ num_groups: Number of groups to sample (effective batch = num_groups * group_size)
338
+ """
339
+ self.model.train()
340
+
341
+ all_results = []
342
+ all_advantages = []
343
+ total_loss = 0.0
344
+
345
+ self.optimizer.zero_grad()
346
+
347
+ # Generate multiple groups
348
+ for _ in range(num_groups):
349
+ if self.device.type == "cuda":
350
+ torch.cuda.empty_cache()
351
+
352
+ # Generate a group of samples
353
+ group_results = self.generate_group()
354
+ all_results.extend(group_results)
355
+
356
+ # Compute group-relative advantages
357
+ advantages, group_mean, group_std = self.compute_group_advantages(group_results)
358
+ all_advantages.extend(advantages)
359
+
360
+ # Compute loss for this group
361
+ group_loss = torch.tensor(0.0, device=self.device)
362
+ valid_count = 0
363
+
364
+ for result, advantage in zip(group_results, advantages):
365
+ if result["is_valid"]:
366
+ # Policy gradient loss with advantage
367
+ group_loss = group_loss - result["log_prob"] * advantage
368
+ valid_count += 1
369
+
370
+ if valid_count > 0:
371
+ group_loss = group_loss / valid_count
372
+ group_loss = group_loss / num_groups # Scale for accumulation
373
+ group_loss.backward()
374
+ total_loss += group_loss.item()
375
+
376
+ # Gradient clipping
377
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
378
+
379
+ # Update
380
+ self.optimizer.step()
381
+ self.scheduler.step()
382
+
383
+ # Statistics
384
+ r2_values = [r["r2"] for r in all_results]
385
+ valid_mask = [r["is_valid"] for r in all_results]
386
+ valid_r2 = [r2 for r2, v in zip(r2_values, valid_mask) if v]
387
+
388
+ return {
389
+ "valid_count": int(sum(valid_mask)),
390
+ "total_count": len(all_results),
391
+ "valid_rate": sum(valid_mask) / len(all_results),
392
+ "mean_r2": float(np.mean(valid_r2)) if valid_r2 else 0.0,
393
+ "max_r2": float(max(r2_values)),
394
+ "mean_advantage": float(np.mean(all_advantages)),
395
+ "std_advantage": float(np.std(all_advantages)),
396
+ "loss": total_loss,
397
+ "lr": self.scheduler.get_last_lr()[0],
398
+ }
399
+
400
+ def run(
401
+ self,
402
+ epochs: int = 50,
403
+ num_groups: int = 4,
404
+ target_r2: float = 0.99,
405
+ patience: int = 20,
406
+ ) -> dict:
407
+ """Run GRPO training."""
408
+ logger.info("=" * 60)
409
+ logger.info("GRPO SYMBOLIC REGRESSION")
410
+ logger.info("=" * 60)
411
+ logger.info(f"Epochs: {epochs}")
412
+ logger.info(f"Group size: {self.group_size}")
413
+ logger.info(f"Num groups: {num_groups}")
414
+ logger.info(f"Effective batch: {self.group_size * num_groups}")
415
+ logger.info(f"Target R^2: {target_r2}")
416
+ logger.info("=" * 60)
417
+
418
+ no_improvement_count = 0
419
+ best_r2_at_start = self.best_r2
420
+
421
+ for epoch in range(1, epochs + 1):
422
+ stats = self.train_step(num_groups)
423
+ self.history.append({
424
+ "epoch": epoch,
425
+ **stats,
426
+ "best_r2": self.best_r2,
427
+ })
428
+
429
+ logger.info(
430
+ f"Epoch {epoch:3d} | "
431
+ f"Valid: {stats['valid_count']}/{stats['total_count']} | "
432
+ f"Mean R²: {stats['mean_r2']:.4f} | "
433
+ f"Best: {self.best_r2:.4f} | "
434
+ f"Adv μ: {stats['mean_advantage']:.3f} σ: {stats['std_advantage']:.3f} | "
435
+ f"LR: {stats['lr']:.2e}"
436
+ )
437
+
438
+ # Check for target
439
+ if self.best_r2 >= target_r2:
440
+ logger.info(f"Target R^2 {target_r2} reached at epoch {epoch}!")
441
+ break
442
+
443
+ # Early stopping
444
+ if self.best_r2 > best_r2_at_start:
445
+ best_r2_at_start = self.best_r2
446
+ no_improvement_count = 0
447
+ else:
448
+ no_improvement_count += 1
449
+
450
+ if no_improvement_count >= patience:
451
+ logger.info(f"No improvement for {patience} epochs. Early stopping.")
452
+ break
453
+
454
+ # Final results
455
+ logger.info("")
456
+ logger.info("=" * 60)
457
+ logger.info("FINAL RESULTS")
458
+ logger.info("=" * 60)
459
+ logger.info(f"Best R^2: {self.best_r2:.4f}")
460
+ logger.info(f"Best expression: {self.best_expression}")
461
+ logger.info(f"Unique expressions discovered: {len(self.discovered_expressions)}")
462
+
463
+ # Top expressions
464
+ top_exprs = sorted(
465
+ self.discovered_expressions.items(),
466
+ key=lambda x: x[1],
467
+ reverse=True
468
+ )[:5]
469
+ logger.info("Top 5 expressions:")
470
+ for expr, r2 in top_exprs:
471
+ logger.info(f" R²={r2:.4f}: {expr}")
472
+
473
+ # Save results
474
+ results = {
475
+ "algorithm": "GRPO",
476
+ "best_r2": self.best_r2,
477
+ "best_expression": self.best_expression,
478
+ "history": self.history,
479
+ "discovered_expressions": dict(list(self.discovered_expressions.items())[:100]),
480
+ "config": {
481
+ "group_size": self.group_size,
482
+ "num_groups": num_groups,
483
+ "learning_rate": self.learning_rate,
484
+ "kl_coef": self.kl_coef,
485
+ }
486
+ }
487
+
488
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
489
+ output_path = self.output_dir / f"results_grpo_{timestamp}.json"
490
+ with open(output_path, "w") as f:
491
+ json.dump(results, f, indent=2)
492
+ logger.info(f"Results saved to: {output_path}")
493
+
494
+ return results
495
+
496
+
497
+ def main():
498
+ parser = argparse.ArgumentParser(description="GRPO for Symbolic Regression")
499
+ parser.add_argument("--model_path", type=str, required=True)
500
+ parser.add_argument("--dataset", type=str, required=True)
501
+ parser.add_argument("--output_dir", type=str, default="./output/grpo")
502
+ parser.add_argument("--epochs", type=int, default=50)
503
+ parser.add_argument("--group_size", type=int, default=8)
504
+ parser.add_argument("--num_groups", type=int, default=4)
505
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
506
+ parser.add_argument("--target_r2", type=float, default=0.99)
507
+ args = parser.parse_args()
508
+
509
+ # Load dataset
510
+ import pandas as pd
511
+ df = pd.read_csv(args.dataset)
512
+
513
+ x_cols = [c for c in df.columns if c.startswith('x_')]
514
+ X = df[x_cols].values
515
+ y = df['y'].values
516
+
517
+ logger.info(f"Loaded dataset: {args.dataset}")
518
+ logger.info(f" Samples: {len(df)}, Variables: {len(x_cols)}")
519
+
520
+ # Create GRPO trainer
521
+ grpo = GRPO(
522
+ model_path=args.model_path,
523
+ X=X,
524
+ y=y,
525
+ output_dir=args.output_dir,
526
+ learning_rate=args.learning_rate,
527
+ group_size=args.group_size,
528
+ )
529
+
530
+ # Run training
531
+ results = grpo.run(
532
+ epochs=args.epochs,
533
+ num_groups=args.num_groups,
534
+ target_r2=args.target_r2,
535
+ )
536
+
537
+
538
+ if __name__ == "__main__":
539
+ main()