GPT-2 Base trained on prefix dataset (682K)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +107 -0
- ANALYSIS_REPORT.md +283 -0
- EXPERIMENT_PLAN.md +195 -0
- README.md +106 -0
- classes/__init__.py +0 -0
- classes/dataset.py +48 -0
- classes/expression.py +403 -0
- configs/eval_dataset_download.sh +6 -0
- configs/model_config.json +1 -0
- configs/peft_config.json +1 -0
- configs/training.sh +82 -0
- configs/training_args.json +29 -0
- configs/training_large.json +65 -0
- configs/training_medium.json +65 -0
- configs/training_small.json +65 -0
- configs/training_v3.json +78 -0
- create_structure.sh +171 -0
- notebooks/.gitkeep +0 -0
- notebooks/01_data_exploration.ipynb +0 -0
- notebooks/02_finetuning_avaliation.ipynb +568 -0
- notebooks/03_RL.ipynb +338 -0
- notebooks/04_merging_model.ipynb +206 -0
- out.txt +7 -0
- out2.txt +0 -0
- requirements.txt +30 -0
- scripts/aws/analyze_model.sh +203 -0
- scripts/aws/evaluate_models.sh +62 -0
- scripts/aws/launch_evaluation_instance.sh +299 -0
- scripts/aws/launch_instance.sh +196 -0
- scripts/aws/launch_instance_fixed.sh +371 -0
- scripts/aws/monitor_evaluation.sh +116 -0
- scripts/aws/monitor_training_auto.sh +179 -0
- scripts/aws/run_all_training.sh +365 -0
- scripts/aws/setup_and_train_exp_a.sh +83 -0
- scripts/aws/setup_and_train_exp_b.sh +83 -0
- scripts/aws/setup_aws.sh +87 -0
- scripts/aws/train_exp_a.sh +57 -0
- scripts/aws/train_exp_b.sh +58 -0
- scripts/aws/train_fixed_model.sh +144 -0
- scripts/aws/train_v3_model.sh +144 -0
- scripts/aws/validate_setup.sh +285 -0
- scripts/compare_models.py +271 -0
- scripts/compare_v1_v2_simple.py +240 -0
- scripts/data/data_augmentation.py +63 -0
- scripts/data/data_cleaning.py +90 -0
- scripts/data/data_processing.py +108 -0
- scripts/data/parallel_utils.py +31 -0
- scripts/data/prepare_experiment_data.py +513 -0
- scripts/data/prepare_training_data_fixed.py +408 -0
- scripts/evaluate.py +432 -0
.gitignore
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Environments
|
| 55 |
+
.env
|
| 56 |
+
.venv
|
| 57 |
+
.seriguela
|
| 58 |
+
venv/
|
| 59 |
+
ENV/
|
| 60 |
+
env/
|
| 61 |
+
env.bak/
|
| 62 |
+
venv.bak/
|
| 63 |
+
|
| 64 |
+
# IDEs / Editors
|
| 65 |
+
.idea/
|
| 66 |
+
.vscode/
|
| 67 |
+
*.suo
|
| 68 |
+
*.ntvs*
|
| 69 |
+
*.njsproj
|
| 70 |
+
*.sln
|
| 71 |
+
*.sw?
|
| 72 |
+
|
| 73 |
+
# Jupyter Notebook
|
| 74 |
+
.ipynb_checkpoints
|
| 75 |
+
|
| 76 |
+
# Output folder (geralmente grande demais para Git)
|
| 77 |
+
output/*
|
| 78 |
+
!output/.gitkeep # Não ignore um .gitkeep se precisar manter a pasta
|
| 79 |
+
|
| 80 |
+
# Dados (podem ser grandes, usar Git LFS ou armazenar fora se necessário)
|
| 81 |
+
# Note: CSV files in data/processed/ can be 100MB+ and are excluded from git
|
| 82 |
+
# Run scripts/data/prepare_training_data_fixed.py on target system to generate them
|
| 83 |
+
data/*
|
| 84 |
+
data/raw/*
|
| 85 |
+
data/processed/*
|
| 86 |
+
!data/raw/.gitkeep
|
| 87 |
+
!data/processed/.gitkeep
|
| 88 |
+
|
| 89 |
+
# OS generated files
|
| 90 |
+
.DS_Store
|
| 91 |
+
.DS_Store?
|
| 92 |
+
._*
|
| 93 |
+
.Spotlight-V100
|
| 94 |
+
.Trashes
|
| 95 |
+
ehthumbs.db
|
| 96 |
+
Thumbs.db
|
| 97 |
+
.env
|
| 98 |
+
|
| 99 |
+
wandb
|
| 100 |
+
|
| 101 |
+
# AWS credentials and keys
|
| 102 |
+
aws/keys/*.pem
|
| 103 |
+
aws/keys/*.key
|
| 104 |
+
aws/.env
|
| 105 |
+
aws/credentials
|
| 106 |
+
*.pem
|
| 107 |
+
*.key
|
ANALYSIS_REPORT.md
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Seriguela - Relatório Consolidado de Análise
|
| 2 |
+
|
| 3 |
+
**Data:** 2026-02-01
|
| 4 |
+
**Status:** ⚠️ BLOCK 2 PRECISA RETREINO
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Resumo Executivo
|
| 9 |
+
|
| 10 |
+
Projeto Seriguela tem 3 blocos:
|
| 11 |
+
1. **Block 1 - Dados:** Preparação e análise ⚠️ **CAUSA RAIZ AQUI**
|
| 12 |
+
2. **Block 2 - Treino Supervisionado:** Treinar LLM para gerar expressões ❌ PROBLEMA
|
| 13 |
+
3. **Block 3 - PPO Finetuning:** Otimizar para symbolic regression ⛔ BLOQUEADO
|
| 14 |
+
|
| 15 |
+
**Causa raiz identificada:** Dados de treino **NÃO TÊM `<|endofex|>` markers**. 0% dos 758,255 exemplos têm o marker. Modelo nunca aprendeu a parar.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Investigação da Causa Raiz (2026-02-01)
|
| 20 |
+
|
| 21 |
+
### Descoberta 1: Validação Original Era Frouxa
|
| 22 |
+
|
| 23 |
+
Script `test_inference_configs.py` reporta **95% válidas**, mas aceita:
|
| 24 |
+
```
|
| 25 |
+
✅ VALID: C*x_1 + C*x_6 - tan(x_9) - Cainers: C9999(x
|
| 26 |
+
✅ VALID: C*x_1 + C*x_2 + C*x_1 + C Pressure, sin, sqrt, tan
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
Validação original só verifica:
|
| 30 |
+
- Tem operador? ✓
|
| 31 |
+
- Tem variável? ✓
|
| 32 |
+
- Não tem "Buyable"? ✓
|
| 33 |
+
|
| 34 |
+
**NÃO verifica:**
|
| 35 |
+
- Se usa variáveis do prompt
|
| 36 |
+
- Se pode ser parseada
|
| 37 |
+
- Se tem outros garbage tokens
|
| 38 |
+
|
| 39 |
+
### Descoberta 2: Dados de Treino SEM Markers
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
# Dataset: augustocsc/sintetico_natural (700K)
|
| 43 |
+
Total de exemplos: 758,255
|
| 44 |
+
Exemplos com <|endofex|>: 0 (0.0%)
|
| 45 |
+
Exemplos com <|startofex|>: 0 (0.0%)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**O modelo NUNCA viu `<|endofex|>` durante treino!**
|
| 49 |
+
|
| 50 |
+
### Descoberta 3: Origem do Garbage
|
| 51 |
+
|
| 52 |
+
Garbage tokens (Stockholm, Pressure, XP, etc.) vêm do **vocabulário GPT-2 base**.
|
| 53 |
+
Como modelo não sabe parar, eventualmente gera tokens aleatórios.
|
| 54 |
+
|
| 55 |
+
### Conclusão da Investigação
|
| 56 |
+
|
| 57 |
+
| Problema | Causa |
|
| 58 |
+
|----------|-------|
|
| 59 |
+
| Modelo não para | Dados sem `<|endofex|>` |
|
| 60 |
+
| Garbage tokens | GPT-2 base vaza sem stopping |
|
| 61 |
+
| Variáveis erradas | Dados têm x_1-x_10, modelo não aprende restrição |
|
| 62 |
+
| 95% vs 0% válidas | Validação original era frouxa |
|
| 63 |
+
|
| 64 |
+
### Solução Necessária
|
| 65 |
+
|
| 66 |
+
1. **Preparar dados** com `<|endofex|>` em 100% dos exemplos
|
| 67 |
+
2. **Retreinar modelo** com dados corrigidos
|
| 68 |
+
3. **Validação rigorosa** durante treino
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## Modelos Testados
|
| 73 |
+
|
| 74 |
+
| Modelo | HuggingFace Hub | Esperado | Real | Status |
|
| 75 |
+
|--------|-----------------|----------|------|--------|
|
| 76 |
+
| V1 | augustocsc/Se124M_700K_infix | 83.3% válidas | **0%** | ❌ Falha |
|
| 77 |
+
| V2 | augustocsc/Se124M_700K_infix_v2 | 90% válidas | **0%** | ❌ Falha |
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Testes Realizados
|
| 82 |
+
|
| 83 |
+
### Teste 1: Comparação V1 vs V2 (mesmo prompt)
|
| 84 |
+
|
| 85 |
+
**Prompt:**
|
| 86 |
+
```
|
| 87 |
+
vars: x_1, x_2
|
| 88 |
+
oper: *, +, -, sin, cos
|
| 89 |
+
cons: C
|
| 90 |
+
expr:
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
**Configurações ótimas usadas:**
|
| 94 |
+
- V1: temp=0.5, top_k=40, top_p=0.9, rep_penalty=1.15
|
| 95 |
+
- V2: temp=0.7, top_k=0, top_p=0.8, rep_penalty=1.0
|
| 96 |
+
|
| 97 |
+
**Resultados (20 gerações cada):**
|
| 98 |
+
|
| 99 |
+
| Métrica | V1 | V2 |
|
| 100 |
+
|---------|----|----|
|
| 101 |
+
| Expressões Válidas | 0% | 0% |
|
| 102 |
+
| Símbolos Corretos | 0% | 45% |
|
| 103 |
+
|
| 104 |
+
### Teste 2: PPO Evaluation
|
| 105 |
+
|
| 106 |
+
**Objetivo:** Verificar se modelo pode ser usado para PPO (symbolic regression)
|
| 107 |
+
|
| 108 |
+
**Resultados:**
|
| 109 |
+
- Valid Rate: 6.7% (muito baixo)
|
| 110 |
+
- Best R²: N/A (não conseguiu computar)
|
| 111 |
+
- **Conclusão:** PPO inviável com modelo atual
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## Problemas Identificados
|
| 116 |
+
|
| 117 |
+
### 1. Modelos Não Param Corretamente
|
| 118 |
+
|
| 119 |
+
**Sintoma:** Expressões continuam além do esperado
|
| 120 |
+
```
|
| 121 |
+
Esperado: C*x_1 + sin(x_2)<|endofex|>
|
| 122 |
+
Gerado: C*x_1 + sin(x_2) + C Stockholmvars: x_1, x_2, x_3...
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
**Causa:** Modelo não aprendeu a gerar `<|endofex|>`
|
| 126 |
+
|
| 127 |
+
### 2. Garbage Tokens na Saída
|
| 128 |
+
|
| 129 |
+
**Exemplos de lixo gerado:**
|
| 130 |
+
- "BuyableInstoreAndOnline"
|
| 131 |
+
- "Stockholm", "GREEN", "Muslims"
|
| 132 |
+
- "intuition", "records", "crash"
|
| 133 |
+
- "xstatics", "xid", "sinmod"
|
| 134 |
+
|
| 135 |
+
**Causa:** Dados de treino contaminados OU modelo não convergiu
|
| 136 |
+
|
| 137 |
+
### 3. Variáveis Erradas
|
| 138 |
+
|
| 139 |
+
**Sintoma:** Usa variáveis não permitidas
|
| 140 |
+
```
|
| 141 |
+
Prompt pede: x_1, x_2
|
| 142 |
+
Modelo gera: x_9, x_10, x_3, x_4
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
**Causa:** Modelo não aprendeu a respeitar o prompt
|
| 146 |
+
|
| 147 |
+
### 4. Discrepância com Documentação
|
| 148 |
+
|
| 149 |
+
**Documentação dizia:**
|
| 150 |
+
- V1: 83.3% válidas com config otimizada
|
| 151 |
+
- V2: 90% válidas com nucleus sampling
|
| 152 |
+
|
| 153 |
+
**Realidade:**
|
| 154 |
+
- V1: 0% válidas
|
| 155 |
+
- V2: 0% válidas
|
| 156 |
+
|
| 157 |
+
**Possíveis causas:**
|
| 158 |
+
1. Modelos no Hub não são os mesmos testados
|
| 159 |
+
2. Testes anteriores tinham bug
|
| 160 |
+
3. Forma de carregar modelo está errada
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
|
| 164 |
+
## Configurações de Inferência Testadas
|
| 165 |
+
|
| 166 |
+
### V1 Config Ótima (segundo docs)
|
| 167 |
+
```python
|
| 168 |
+
{
|
| 169 |
+
"temperature": 0.5,
|
| 170 |
+
"top_k": 40,
|
| 171 |
+
"top_p": 0.9,
|
| 172 |
+
"repetition_penalty": 1.15,
|
| 173 |
+
"max_new_tokens": 100,
|
| 174 |
+
"do_sample": True,
|
| 175 |
+
}
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### V2 Config Ótima (segundo docs)
|
| 179 |
+
```python
|
| 180 |
+
{
|
| 181 |
+
"temperature": 0.7,
|
| 182 |
+
"top_k": 0,
|
| 183 |
+
"top_p": 0.8,
|
| 184 |
+
"repetition_penalty": 1.0,
|
| 185 |
+
"max_new_tokens": 128,
|
| 186 |
+
"do_sample": True,
|
| 187 |
+
}
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
**Resultado:** Mesmo com configs ótimas, 0% válidas.
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## Forma de Carregar Modelos
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
# 1. Carregar base GPT-2
|
| 198 |
+
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16)
|
| 199 |
+
|
| 200 |
+
# 2. Configurar tokenizer com tokens especiais
|
| 201 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 202 |
+
tokenizer.add_special_tokens({
|
| 203 |
+
"additional_special_tokens": ["<|startofex|>", "<|endofex|>"]
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
# 3. Redimensionar embeddings
|
| 207 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 208 |
+
|
| 209 |
+
# 4. Carregar adapter LoRA
|
| 210 |
+
model = PeftModel.from_pretrained(model, "augustocsc/Se124M_700K_infix_v2")
|
| 211 |
+
|
| 212 |
+
# 5. Merge adapter no modelo base
|
| 213 |
+
model = model.merge_and_unload()
|
| 214 |
+
model.eval()
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
---
|
| 218 |
+
|
| 219 |
+
## Conclusões
|
| 220 |
+
|
| 221 |
+
### Block 2 (Treino) - PRECISA RETREINO
|
| 222 |
+
|
| 223 |
+
**Problemas no treino:**
|
| 224 |
+
1. Modelo não aprendeu `<|endofex|>` marker
|
| 225 |
+
2. Dados podem estar contaminados com garbage
|
| 226 |
+
3. Modelo não respeita variáveis do prompt
|
| 227 |
+
|
| 228 |
+
**Ações necessárias:**
|
| 229 |
+
1. Validar dados de treino (100% devem ter `<|endofex|>`)
|
| 230 |
+
2. Limpar garbage tokens dos dados
|
| 231 |
+
3. Monitorar valid rate durante treino
|
| 232 |
+
4. Só considerar treino bem-sucedido se valid rate > 80%
|
| 233 |
+
|
| 234 |
+
### Block 3 (PPO) - BLOQUEADO
|
| 235 |
+
|
| 236 |
+
**Pré-requisitos para PPO:**
|
| 237 |
+
- ✅ Base model gera >80% expressões válidas
|
| 238 |
+
- ✅ Expressões podem ser avaliadas (R² computável)
|
| 239 |
+
- ✅ Modelo para corretamente em boundaries
|
| 240 |
+
|
| 241 |
+
**Status atual:** ❌ Nenhum pré-requisito atendido
|
| 242 |
+
|
| 243 |
+
---
|
| 244 |
+
|
| 245 |
+
## Próximos Passos
|
| 246 |
+
|
| 247 |
+
1. **Investigar dados de treino**
|
| 248 |
+
- Verificar se `<|endofex|>` está presente
|
| 249 |
+
- Identificar fonte de garbage tokens
|
| 250 |
+
|
| 251 |
+
2. **Retreinar modelo (V3)**
|
| 252 |
+
- Usar dados validados
|
| 253 |
+
- Monitorar valid rate durante treino
|
| 254 |
+
- Validar antes de fazer push pro Hub
|
| 255 |
+
|
| 256 |
+
3. **Só então testar PPO**
|
| 257 |
+
- Após valid rate > 80%
|
| 258 |
+
- Com modelo que para corretamente
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
|
| 262 |
+
## Arquivos de Código Relevantes
|
| 263 |
+
|
| 264 |
+
- `scripts/train.py` - Script de treino
|
| 265 |
+
- `scripts/generate.py` - Geração com stopping criteria
|
| 266 |
+
- `scripts/evaluate.py` - Avaliação de modelo
|
| 267 |
+
- `scripts/compare_v1_v2_simple.py` - Comparação V1 vs V2
|
| 268 |
+
- `scripts/evaluate_ppo.py` - Avaliação para PPO
|
| 269 |
+
- `scripts/data/prepare_training_data_fixed.py` - Preparação de dados
|
| 270 |
+
- `classes/expression.py` - Parsing e validação de expressões
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
## Infraestrutura AWS
|
| 275 |
+
|
| 276 |
+
- **Instance:** g5.xlarge (NVIDIA A10G, 24GB)
|
| 277 |
+
- **Instance ID:** i-0377b6c8de3660a82
|
| 278 |
+
- **Custo:** ~$1/hora
|
| 279 |
+
- **Status atual:** Stopped (para economizar)
|
| 280 |
+
|
| 281 |
+
---
|
| 282 |
+
|
| 283 |
+
**Última atualização:** 2026-02-01
|
EXPERIMENT_PLAN.md
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Plano de Experimentos: Formatos de Treino
|
| 2 |
+
|
| 3 |
+
**Data:** 2026-02-01
|
| 4 |
+
**Objetivo:** Testar duas abordagens para resolver o problema de stopping
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Contexto
|
| 9 |
+
|
| 10 |
+
### Problema Identificado
|
| 11 |
+
- Dados de treino não têm marcador de fim (0% com qualquer marker)
|
| 12 |
+
- Modelo não aprende quando parar
|
| 13 |
+
- Gera garbage tokens do vocabulário GPT-2
|
| 14 |
+
|
| 15 |
+
### Experimentos Propostos
|
| 16 |
+
1. **EXP-A:** Formato estruturado (JSON-like)
|
| 17 |
+
2. **EXP-B:** Token EOS do GPT-2 (`<|endoftext|>`)
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## EXP-A: Formato Estruturado
|
| 22 |
+
|
| 23 |
+
### Formato dos Dados
|
| 24 |
+
```json
|
| 25 |
+
{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "expr": "C*sin(x_1) + x_2"}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Vantagens
|
| 29 |
+
- Estrutura clara e parseável
|
| 30 |
+
- Fácil validação (JSON válido = formato correto)
|
| 31 |
+
- Modelo aprende estrutura rígida
|
| 32 |
+
|
| 33 |
+
### Desvantagens
|
| 34 |
+
- Mais tokens por exemplo
|
| 35 |
+
- Pode ser mais difícil de aprender
|
| 36 |
+
|
| 37 |
+
### Preparação de Dados
|
| 38 |
+
```python
|
| 39 |
+
# Transformar de:
|
| 40 |
+
"vars: x_1, x_2\noper: *, +, sin\ncons: C\nexpr: C*sin(x_1) + x_2"
|
| 41 |
+
|
| 42 |
+
# Para:
|
| 43 |
+
'{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "C*sin(x_1) + x_2"}'
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### Inferência
|
| 47 |
+
```python
|
| 48 |
+
prompt = '{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "'
|
| 49 |
+
# Modelo completa com: C*sin(x_1) + x_2"}
|
| 50 |
+
# Extrair: tudo entre 'expr": "' e '"}'
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Critério de Sucesso
|
| 54 |
+
- JSON parseável em >90% dos casos
|
| 55 |
+
- Expressão extraída válida em >80% dos casos
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## EXP-B: Token EOS do GPT-2
|
| 60 |
+
|
| 61 |
+
### Formato dos Dados
|
| 62 |
+
```
|
| 63 |
+
vars: x_1, x_2
|
| 64 |
+
oper: *, +, sin
|
| 65 |
+
cons: C
|
| 66 |
+
expr: C*sin(x_1) + x_2<|endoftext|>
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Vantagens
|
| 70 |
+
- Token já existe no modelo (ID 50256)
|
| 71 |
+
- GPT-2 já entende como "fim de sequência"
|
| 72 |
+
- Não precisa resize de embeddings
|
| 73 |
+
- Formato similar ao atual
|
| 74 |
+
|
| 75 |
+
### Desvantagens
|
| 76 |
+
- Pode conflitar com outros usos do EOS
|
| 77 |
+
- Menos explícito que marker dedicado
|
| 78 |
+
|
| 79 |
+
### Preparação de Dados
|
| 80 |
+
```python
|
| 81 |
+
# Adicionar <|endoftext|> no final de cada expressão
|
| 82 |
+
text = original_text + "<|endoftext|>"
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Inferência
|
| 86 |
+
```python
|
| 87 |
+
# Usar eos_token_id como stopping criteria
|
| 88 |
+
output = model.generate(
|
| 89 |
+
**inputs,
|
| 90 |
+
eos_token_id=tokenizer.eos_token_id, # 50256
|
| 91 |
+
max_new_tokens=128
|
| 92 |
+
)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### Critério de Sucesso
|
| 96 |
+
- Modelo gera `<|endoftext|>` em >90% dos casos
|
| 97 |
+
- Expressão antes do EOS válida em >80% dos casos
|
| 98 |
+
|
| 99 |
+
---
|
| 100 |
+
|
| 101 |
+
## Plano de Execução
|
| 102 |
+
|
| 103 |
+
### Fase 1: Preparação de Dados (Local)
|
| 104 |
+
|
| 105 |
+
#### 1.1 Criar script de preparação
|
| 106 |
+
```
|
| 107 |
+
scripts/data/prepare_experiment_data.py
|
| 108 |
+
```
|
| 109 |
+
- Entrada: dataset augustocsc/sintetico_natural (700K)
|
| 110 |
+
- Saída A: data/exp_a_json/train.csv, validation.csv
|
| 111 |
+
- Saída B: data/exp_b_eos/train.csv, validation.csv
|
| 112 |
+
|
| 113 |
+
#### 1.2 Validar dados preparados
|
| 114 |
+
- Verificar formato correto em 100% dos exemplos
|
| 115 |
+
- Amostrar e inspecionar manualmente
|
| 116 |
+
|
| 117 |
+
### Fase 2: Treino (AWS)
|
| 118 |
+
|
| 119 |
+
#### 2.1 Treinar EXP-A (JSON)
|
| 120 |
+
```bash
|
| 121 |
+
python scripts/train.py \
|
| 122 |
+
--use_local_csvs \
|
| 123 |
+
--train_file ./data/exp_a_json/train.csv \
|
| 124 |
+
--output_dir ./output/exp_a_json \
|
| 125 |
+
--num_train_epochs 3
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
#### 2.2 Treinar EXP-B (EOS)
|
| 129 |
+
```bash
|
| 130 |
+
python scripts/train.py \
|
| 131 |
+
--use_local_csvs \
|
| 132 |
+
--train_file ./data/exp_b_eos/train.csv \
|
| 133 |
+
--output_dir ./output/exp_b_eos \
|
| 134 |
+
--num_train_epochs 3
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### Fase 3: Avaliação
|
| 138 |
+
|
| 139 |
+
#### 3.1 Métricas
|
| 140 |
+
- **Valid Rate:** % expressões parseáveis
|
| 141 |
+
- **Stopping Rate:** % que param corretamente (JSON fechado ou EOS)
|
| 142 |
+
- **Symbol Accuracy:** % que usam apenas símbolos do prompt
|
| 143 |
+
- **Garbage Rate:** % com tokens não-matemáticos
|
| 144 |
+
|
| 145 |
+
#### 3.2 Comparação
|
| 146 |
+
| Métrica | EXP-A (JSON) | EXP-B (EOS) |
|
| 147 |
+
|---------|--------------|-------------|
|
| 148 |
+
| Valid Rate | ? | ? |
|
| 149 |
+
| Stopping Rate | ? | ? |
|
| 150 |
+
| Symbol Accuracy | ? | ? |
|
| 151 |
+
| Garbage Rate | ? | ? |
|
| 152 |
+
|
| 153 |
+
### Fase 4: Decisão
|
| 154 |
+
|
| 155 |
+
- Se EXP-A melhor → usar formato JSON
|
| 156 |
+
- Se EXP-B melhor → usar EOS token
|
| 157 |
+
- Se ambos ruins → investigar outras opções
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## Estimativas
|
| 162 |
+
|
| 163 |
+
| Fase | Tempo | Custo AWS |
|
| 164 |
+
|------|-------|-----------|
|
| 165 |
+
| Preparação dados | 30 min | $0 |
|
| 166 |
+
| Treino EXP-A | 2-3h | ~$3 |
|
| 167 |
+
| Treino EXP-B | 2-3h | ~$3 |
|
| 168 |
+
| Avaliação | 30 min | ~$0.50 |
|
| 169 |
+
| **Total** | **6-7h** | **~$6.50** |
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
## Arquivos a Criar
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
scripts/data/prepare_experiment_data.py # Preparação
|
| 177 |
+
data/exp_a_json/train.csv # Dados JSON
|
| 178 |
+
data/exp_a_json/validation.csv
|
| 179 |
+
data/exp_b_eos/train.csv # Dados EOS
|
| 180 |
+
data/exp_b_eos/validation.csv
|
| 181 |
+
scripts/evaluate_experiments.py # Avaliação
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
+
## Critério de Sucesso Final
|
| 187 |
+
|
| 188 |
+
**Experimento bem-sucedido se:**
|
| 189 |
+
- Valid Rate > 80%
|
| 190 |
+
- Stopping Rate > 90%
|
| 191 |
+
- Garbage Rate < 5%
|
| 192 |
+
|
| 193 |
+
**Próximo passo após sucesso:**
|
| 194 |
+
- Usar formato vencedor para treinar modelo final
|
| 195 |
+
- Prosseguir para Block 3 (PPO)
|
README.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*# Nome do Seu Projeto de Fine-Tuning
|
| 2 |
+
|
| 3 |
+
(Breve descrição do objetivo do projeto)
|
| 4 |
+
|
| 5 |
+
## Estrutura de Pastas
|
| 6 |
+
|
| 7 |
+
Aqui está a organização das pastas e seus propósitos:
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
seu_projeto_finetuning/
|
| 11 |
+
│
|
| 12 |
+
├── data/ # Todos os dados relacionados ao projeto
|
| 13 |
+
│ ├── raw/ # Dados originais, não processados
|
| 14 |
+
│ └── processed/ # Dados limpos, formatados e divididos (train/val/test)
|
| 15 |
+
│
|
| 16 |
+
├── scripts/ # Scripts Python principais
|
| 17 |
+
│ ├── preprocess_data.py # (Opcional) Script para limpar e formatar dados
|
| 18 |
+
│ ├── train.py # Script principal para rodar o Trainer do HF
|
| 19 |
+
│ ├── evaluate.py # (Opcional) Script para avaliação customizada
|
| 20 |
+
│ └── generate.py # (Opcional) Script para gerar texto com modelo treinado
|
| 21 |
+
│
|
| 22 |
+
├── configs/ # Arquivos de configuração (JSON, YAML, etc.)
|
| 23 |
+
│ ├── training_args.json # Argumentos de treino (passados para TrainingArguments)
|
| 24 |
+
│ ├── peft_config.json # (Se usar PEFT) Configuração LoRA, Adapter, etc.
|
| 25 |
+
│ └── model_config.json # (Opcional) Nome do modelo base, caminhos, etc.
|
| 26 |
+
│
|
| 27 |
+
├── output/ # Todos os outputs gerados (modelos, logs, resultados)
|
| 28 |
+
│ └── {nome_experimento}/ # Subpasta para cada execução/experimento
|
| 29 |
+
│ ├── checkpoints/ # Checkpoints salvos pelo Trainer
|
| 30 |
+
│ ├── final_model/ # Modelo final treinado
|
| 31 |
+
│ ├── logs/ # Logs do TensorBoard ou outros
|
| 32 |
+
│ └── ... # Outros resultados (métricas, amostras)
|
| 33 |
+
│
|
| 34 |
+
├── notebooks/ # (Opcional) Jupyter notebooks para exploração e testes
|
| 35 |
+
│
|
| 36 |
+
├── .gitignore # Especifica arquivos/pastas a serem ignorados pelo Git
|
| 37 |
+
├── requirements.txt # Dependências Python do projeto
|
| 38 |
+
└── README.md # Documentação do projeto (este arquivo)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
* **`data/`**: Contém todos os dados.
|
| 42 |
+
* `raw/`: Armazena os dados originais, sem modificações.
|
| 43 |
+
* `processed/`: Guarda os dados após limpeza, formatação e divisão (treino, validação, teste), prontos para serem usados pelo script de treinamento.
|
| 44 |
+
* **`scripts/`**: Onde fica o código Python.
|
| 45 |
+
* `train.py`: O coração do projeto, responsável por carregar dados, modelo, configurações e executar o fine-tuning com o `Trainer`.
|
| 46 |
+
* Scripts auxiliares para pré-processamento, avaliação ou geração podem ser incluídos aqui.
|
| 47 |
+
* **`configs/`**: Centraliza as configurações do projeto, como hiperparâmetros de treinamento (`training_args.json`), configurações PEFT (`peft_config.json`) ou detalhes do modelo base. Isso facilita a alteração de parâmetros sem modificar o código principal.
|
| 48 |
+
* **`output/`**: Diretório para todos os artefatos gerados durante o treinamento. É **altamente recomendado** criar uma subpasta para cada experimento (identificada por nome ou timestamp) para manter os resultados organizados (checkpoints, modelo final, logs, métricas). O `output_dir` do `TrainingArguments` deve apontar para essa subpasta específica do experimento.
|
| 49 |
+
* **`notebooks/`**: Espaço para prototipagem, análise exploratória de dados e testes rápidos usando Jupyter Notebooks.
|
| 50 |
+
* **`.gitignore`**: Configura o Git para ignorar arquivos e pastas desnecessários (ambientes virtuais, caches, outputs grandes, dados brutos grandes, etc.).
|
| 51 |
+
* **`requirements.txt`**: Lista as bibliotecas Python necessárias para que o projeto funcione, permitindo recriar o ambiente facilmente (`pip install -r requirements.txt`).
|
| 52 |
+
* **`README.md`**: Documentação essencial explicando o projeto, como configurá-lo e executá-lo.
|
| 53 |
+
|
| 54 |
+
## Como Usar
|
| 55 |
+
|
| 56 |
+
1. **Setup:** Crie um ambiente virtual e instale as dependências:
|
| 57 |
+
```bash
|
| 58 |
+
python -m venv venv
|
| 59 |
+
source venv/bin/activate # Linux/macOS
|
| 60 |
+
# venv\Scripts\activate # Windows
|
| 61 |
+
pip install -r requirements.txt
|
| 62 |
+
```
|
| 63 |
+
2. **Dados:** Coloque seus dados brutos em `data/raw/` e execute (ou crie) o script `scripts/preprocess_data.py` para gerar os arquivos em `data/processed/`.
|
| 64 |
+
3. **Configuração:** Ajuste os arquivos em `configs/` (argumentos de treino, modelo base, PEFT se aplicável).
|
| 65 |
+
4. **Treinamento:** Execute o script principal:
|
| 66 |
+
```bash
|
| 67 |
+
python scripts/train.py --args_config configs/training_args.json --model_config configs/model_config.json
|
| 68 |
+
```
|
| 69 |
+
*(Adapte os argumentos conforme necessário)*
|
| 70 |
+
|
| 71 |
+
## Dependências
|
| 72 |
+
|
| 73 |
+
As dependências Python estão listadas no arquivo `requirements.txt`.
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Claro! Aqui está um bloco de instruções pronto para ser adicionado ao seu `README.md`, explicando como configurar o ambiente com `venv`, instalar as dependências e configurar o uso de GPU e Weights & Biases (W&B):
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
### 🚀 Setup do Ambiente (com suporte a GPU e W&B)
|
| 81 |
+
|
| 82 |
+
Siga os passos abaixo para configurar o ambiente de desenvolvimento com `venv`, `pip`, suporte a GPU (CUDA 11.8) e monitoramento com Weights & Biases:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
# 1. Crie o ambiente virtual
|
| 86 |
+
python -m venv .seriguela
|
| 87 |
+
|
| 88 |
+
# 2. Ative o ambiente virtual
|
| 89 |
+
# No Linux/macOS:
|
| 90 |
+
source .seriguela/bin/activate
|
| 91 |
+
# No Windows:
|
| 92 |
+
.seriguela\Scripts\activate
|
| 93 |
+
|
| 94 |
+
# 3. Instale as dependências principais
|
| 95 |
+
pip install -r requirements.txt
|
| 96 |
+
|
| 97 |
+
# 4. Instale PyTorch com suporte a CUDA 11.8 (para uso com GPU)
|
| 98 |
+
pip install torch==2.2.1+cu118 torchvision==0.17.1+cu118 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu118
|
| 99 |
+
|
| 100 |
+
# 5. (Opcional) Faça login no Weights & Biases para monitorar seus experimentos
|
| 101 |
+
wandb login
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
> ⚠️ Certifique-se de que sua GPU e drivers estão atualizados e compatíveis com CUDA 11.8.
|
| 105 |
+
> 💡 Para ambientes 100% reprodutíveis, use sempre o mesmo `requirements.txt` e registre os experimentos com `wandb`.
|
| 106 |
+
*
|
classes/__init__.py
ADDED
|
File without changes
|
classes/dataset.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class RegressionDataset:
|
| 5 |
+
def __init__(self, path: str, file_name: str = 'train.csv', delimiter: str = ',', header: int = 0,
|
| 6 |
+
encoding: str = 'utf-8', target_col: str = None):
|
| 7 |
+
"""
|
| 8 |
+
Initializes the RegressionDataset by loading data from a CSV file.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
path (str): Path to the directory containing the CSV file.
|
| 12 |
+
file_name (str): Name of the CSV file. Defaults to 'train.csv'.
|
| 13 |
+
delimiter (str): Delimiter used in the CSV file. Defaults to ','.
|
| 14 |
+
header (int): Row number to use as the column names. Defaults to 0.
|
| 15 |
+
encoding (str): Encoding of the CSV file. Defaults to 'utf-8'.
|
| 16 |
+
target_col (str): Name of the target column. If None, the last column is used.
|
| 17 |
+
"""
|
| 18 |
+
self.data = pd.read_csv(f"{path}/{file_name}", delimiter=delimiter, header=header, encoding=encoding)
|
| 19 |
+
|
| 20 |
+
if self.data.empty:
|
| 21 |
+
raise ValueError("CSV file is empty.")
|
| 22 |
+
|
| 23 |
+
if target_col is None:
|
| 24 |
+
target_col = self.data.columns[-1]
|
| 25 |
+
|
| 26 |
+
if target_col not in self.data.columns:
|
| 27 |
+
raise ValueError(f"CSV must contain a column named '{target_col}'.")
|
| 28 |
+
|
| 29 |
+
self.X = self.data.drop(columns=[target_col]).apply(pd.to_numeric, errors='coerce').values
|
| 30 |
+
|
| 31 |
+
self.y = pd.to_numeric(self.data[target_col], errors='coerce').values
|
| 32 |
+
|
| 33 |
+
def get_data(self):
|
| 34 |
+
"""
|
| 35 |
+
Returns the data as PyTorch tensors (X, y).
|
| 36 |
+
"""
|
| 37 |
+
X_tensor = torch.tensor(self.X, dtype=torch.float32)
|
| 38 |
+
y_tensor = torch.tensor(self.y, dtype=torch.float32)
|
| 39 |
+
return X_tensor, y_tensor
|
| 40 |
+
|
| 41 |
+
def get_numpy(self):
|
| 42 |
+
"""
|
| 43 |
+
Returns the data as NumPy arrays (useful for sympy and R² calculations).
|
| 44 |
+
"""
|
| 45 |
+
return self.X, self.y
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
classes/expression.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics import r2_score, mean_squared_error
|
| 4 |
+
from sklearn.metrics import mean_absolute_error
|
| 5 |
+
from scipy.optimize import minimize
|
| 6 |
+
import math
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Expression:
|
| 11 |
+
SAFE_FUNCTIONS = {
|
| 12 |
+
'sqrt': np.sqrt,
|
| 13 |
+
'log': np.log,
|
| 14 |
+
'exp': np.exp,
|
| 15 |
+
'sin': np.sin,
|
| 16 |
+
'cos': np.cos,
|
| 17 |
+
'tan': np.tan,
|
| 18 |
+
'asin': np.arcsin, # Corrected to np.arcsin
|
| 19 |
+
'abs': np.abs,
|
| 20 |
+
'pow': np.power, # Use np.power for vectorization and NaN handling
|
| 21 |
+
# '**' is handled by Python's eval; if operands are numpy arrays, np.power is used.
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
OPERATOR_ARITY = {
|
| 25 |
+
'+': 2,
|
| 26 |
+
'-': 2,
|
| 27 |
+
'*': 2,
|
| 28 |
+
'/': 2,
|
| 29 |
+
'**': 2, # Changed from '^' to '**'
|
| 30 |
+
'sin': 1,
|
| 31 |
+
'cos': 1,
|
| 32 |
+
'tan': 1,
|
| 33 |
+
'log': 1,
|
| 34 |
+
'sqrt': 1,
|
| 35 |
+
'exp': 1
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
OPERATOR_FUNCS = {
|
| 39 |
+
'+': sympy.Add,
|
| 40 |
+
'-': lambda x, y: x - y,
|
| 41 |
+
'*': sympy.Mul,
|
| 42 |
+
'/': lambda x, y: x / y,
|
| 43 |
+
'**': sympy.Pow, # Changed from '^' to '**', sympy.Pow handles both
|
| 44 |
+
'sin': sympy.sin,
|
| 45 |
+
'cos': sympy.cos,
|
| 46 |
+
'tan': sympy.tan,
|
| 47 |
+
'log': sympy.log,
|
| 48 |
+
'sqrt': sympy.sqrt,
|
| 49 |
+
'exp': sympy.exp
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def parse_prefix(self, tokens):
|
| 53 |
+
"""Parse prefix notation expression to SymPy.
|
| 54 |
+
|
| 55 |
+
Example: ['*', 'x_1', '+', 'x_2', 'C'] -> x_1*(x_2 + C)
|
| 56 |
+
"""
|
| 57 |
+
if not tokens:
|
| 58 |
+
raise ValueError("Empty token list")
|
| 59 |
+
|
| 60 |
+
# Define unary and binary operators
|
| 61 |
+
UNARY_OPS = {'sin', 'cos', 'tan', 'exp', 'log', 'sqrt', 'abs', 'asin'}
|
| 62 |
+
BINARY_OPS = {'+', '-', '*', '/', '**', '^'}
|
| 63 |
+
|
| 64 |
+
stack = []
|
| 65 |
+
|
| 66 |
+
# Process tokens in reverse order
|
| 67 |
+
for token in reversed(tokens):
|
| 68 |
+
if token in BINARY_OPS or token in UNARY_OPS:
|
| 69 |
+
# Operator: pop operands from stack
|
| 70 |
+
if token in UNARY_OPS:
|
| 71 |
+
if len(stack) < 1:
|
| 72 |
+
raise ValueError(f"Not enough operands for {token}")
|
| 73 |
+
arg = stack.pop()
|
| 74 |
+
if token in ['sin', 'cos', 'tan', 'exp', 'log', 'sqrt', 'abs', 'asin']:
|
| 75 |
+
stack.append(f"{token}({arg})")
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Unknown unary operator: {token}")
|
| 78 |
+
else: # Binary operator
|
| 79 |
+
if len(stack) < 2:
|
| 80 |
+
raise ValueError(f"Not enough operands for {token}")
|
| 81 |
+
right = stack.pop()
|
| 82 |
+
left = stack.pop()
|
| 83 |
+
|
| 84 |
+
# Handle operator mapping
|
| 85 |
+
op_map = {'+': '+', '-': '-', '*': '*', '/': '/', '**': '**', '^': '**'}
|
| 86 |
+
op = op_map.get(token, token)
|
| 87 |
+
|
| 88 |
+
if op in ['**', '^']:
|
| 89 |
+
stack.append(f"({left})**({right})")
|
| 90 |
+
elif op == '/':
|
| 91 |
+
stack.append(f"({left})/({right})")
|
| 92 |
+
else:
|
| 93 |
+
stack.append(f"({left}){op}({right})")
|
| 94 |
+
else:
|
| 95 |
+
# Operand: push to stack
|
| 96 |
+
stack.append(token)
|
| 97 |
+
|
| 98 |
+
if len(stack) != 1:
|
| 99 |
+
raise ValueError(f"Invalid prefix expression, {len(stack)} elements remaining")
|
| 100 |
+
|
| 101 |
+
return sympy.sympify(stack[0], evaluate=False)
|
| 102 |
+
|
| 103 |
+
def __init__(self, expression, is_prefix=False):
|
| 104 |
+
try:
|
| 105 |
+
self.original_expression = expression # Save original
|
| 106 |
+
|
| 107 |
+
if is_prefix:
|
| 108 |
+
# Ensure input prefix uses '**' if converting from external source
|
| 109 |
+
tokens = expression.replace('^', '**').split()
|
| 110 |
+
self.sympy_expression = self.parse_prefix(tokens)
|
| 111 |
+
else:
|
| 112 |
+
# Load the expression as a sympy expression without simplification
|
| 113 |
+
self.sympy_expression = sympy.sympify(expression, evaluate=False)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
raise ValueError(f"Failed to parse expression: {e}")
|
| 116 |
+
|
| 117 |
+
self.max_var = 0
|
| 118 |
+
for symbol in self.sympy_expression.free_symbols:
|
| 119 |
+
if symbol.name.startswith('x_'):
|
| 120 |
+
try:
|
| 121 |
+
index = int(symbol.name.split('_')[1])
|
| 122 |
+
self.max_var = max(self.max_var, index)
|
| 123 |
+
except ValueError:
|
| 124 |
+
# Handle symbols that look like x_ but aren't x_number
|
| 125 |
+
pass # Or raise ValueError(f"Invalid variable name: {symbol.name}") if strict
|
| 126 |
+
|
| 127 |
+
computable_expression = str(self.sympy_expression)
|
| 128 |
+
|
| 129 |
+
for i in range(1, self.max_var + 1):
|
| 130 |
+
# Use regex to match whole words to avoid issues with x_1 followed by x_11
|
| 131 |
+
computable_expression = re.sub(rf'\bx_{i}\b', f'x[{i-1}]', computable_expression)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
self.computable_expression = computable_expression.replace('**C', '**2')
|
| 135 |
+
|
| 136 |
+
self.constant_count = self.computable_expression.count('C')
|
| 137 |
+
self.best_constants = [1.0] * self.constant_count
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if self.constant_count > 0:
|
| 141 |
+
# Replace 'C' with indexable constants
|
| 142 |
+
split_expr = self.computable_expression.split('C')
|
| 143 |
+
new_expr = split_expr[0] # Start with first part
|
| 144 |
+
|
| 145 |
+
for i in range(1, len(split_expr)):
|
| 146 |
+
# Add constant reference
|
| 147 |
+
new_expr += f'constants[{i-1}]'
|
| 148 |
+
# Add next part
|
| 149 |
+
new_expr += split_expr[i]
|
| 150 |
+
|
| 151 |
+
self.computable_expression = new_expr
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def __str__(self):
|
| 158 |
+
return f"Expression: {self.original_expression}, Best constants: {self.best_constants}"
|
| 159 |
+
def sympy_str(self):
|
| 160 |
+
"""
|
| 161 |
+
Returns the string representation of the sympy expression.
|
| 162 |
+
"""
|
| 163 |
+
return str(self.sympy_expression)
|
| 164 |
+
|
| 165 |
+
def is_valid_on_dataset(self, X, test_constants_list=None):
|
| 166 |
+
"""
|
| 167 |
+
Checks if the expression evaluates to valid (finite) values for all rows in X,
|
| 168 |
+
across one or more sets of test constants.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
X (np.ndarray): Input data, shape (n_samples, n_features)
|
| 172 |
+
test_constants_list (list of lists): Optional. Defaults to [[1.0]*count].
|
| 173 |
+
Example: [[1.0]*n, [0.5]*n, [2.0]*n] to test more thoroughly.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
bool: True if no evaluation returns nan/inf or crashes. False otherwise.
|
| 177 |
+
"""
|
| 178 |
+
if test_constants_list is None:
|
| 179 |
+
test_constants_list = [[1.0] * self.constant_count]
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
for constants in test_constants_list:
|
| 183 |
+
results = self.evaluate(X, constants)
|
| 184 |
+
|
| 185 |
+
if not np.all(np.isfinite(results)):
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
return True
|
| 189 |
+
except Exception:
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
# Inside the Expression class
|
| 193 |
+
def evaluate(self, X, constants=None):
|
| 194 |
+
# with warnings.catch_warnings():
|
| 195 |
+
# warnings.simplefilter("ignore", category=RuntimeWarning) # Hide power/tan warnings
|
| 196 |
+
# np.seterr(invalid='ignore', divide='ignore')
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if constants is None:
|
| 201 |
+
# print("No constants provided, using best constants.") # Optional: uncomment for debugging
|
| 202 |
+
constants = self.best_constants
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
local_env = {
|
| 206 |
+
"constants": np.array(constants), # Ensure constants is a numpy array for broadcasting
|
| 207 |
+
**self.SAFE_FUNCTIONS,
|
| 208 |
+
"__builtins__": None
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if not isinstance(X, np.ndarray):
|
| 212 |
+
X = np.array(X) # Ensure X is a numpy array
|
| 213 |
+
|
| 214 |
+
# Ensure X is 2D, even if it has only one sample
|
| 215 |
+
if X.ndim == 1:
|
| 216 |
+
X = X.reshape(1, -1)
|
| 217 |
+
|
| 218 |
+
# x becomes a list of columns (1D arrays of shape (n_samples,))
|
| 219 |
+
x_cols = [X[:, i] for i in range(X.shape[1])]
|
| 220 |
+
local_env["x"] = x_cols
|
| 221 |
+
|
| 222 |
+
# The result will be a numpy array of shape (n_samples,)
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
y_pred_array = eval(self.computable_expression, local_env)
|
| 226 |
+
|
| 227 |
+
except FloatingPointError as e:
|
| 228 |
+
# print(f"FloatingPointError during eval: {e}")
|
| 229 |
+
# print(f"Expression: {self.computable_expression}")
|
| 230 |
+
# print(f"Constants: {constants}")
|
| 231 |
+
return np.full(X.shape[0], np.nan) # Return NaNs to be caught by loss
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
# print(f"General exception during eval: {e}")
|
| 235 |
+
return np.full(X.shape[0], np.nan)
|
| 236 |
+
|
| 237 |
+
finally:
|
| 238 |
+
np.seterr(all='warn') # 🔁 Reset to default behavior
|
| 239 |
+
|
| 240 |
+
# Ensure output is float to avoid issues with mixed types if some results are int
|
| 241 |
+
return np.asarray(y_pred_array, dtype=float)
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
# Return an array of NaNs of the expected shape to ensure loss calculation doesn't break
|
| 245 |
+
num_samples = X.shape[0] if X.ndim > 0 else 1
|
| 246 |
+
return np.full(num_samples, np.nan) # Return NaNs on error
|
| 247 |
+
|
| 248 |
+
def fit_constants(self, X, y):
|
| 249 |
+
X = np.array(X)
|
| 250 |
+
y = np.array(y)
|
| 251 |
+
|
| 252 |
+
if self.constant_count == 0:
|
| 253 |
+
try:
|
| 254 |
+
y_pred = self.evaluate(X) # Vectorized call
|
| 255 |
+
if not np.all(np.isfinite(y_pred)): # Check for NaNs/Infs
|
| 256 |
+
return -np.inf
|
| 257 |
+
if np.all(y_pred == y_pred[0]) and len(np.unique(y)) > 1: # Avoid R2 issues with constant prediction for non-constant y
|
| 258 |
+
return 0.0 # Or handle as per specific requirements
|
| 259 |
+
return r2_score(y, y_pred)
|
| 260 |
+
except Exception as e: # Broader catch for any eval issue
|
| 261 |
+
return -np.inf
|
| 262 |
+
|
| 263 |
+
def loss(current_constants):
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
y_pred = self.evaluate(X, current_constants)
|
| 267 |
+
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(f"Exception during evaluation: {e}")
|
| 270 |
+
return np.inf
|
| 271 |
+
|
| 272 |
+
if not np.all(np.isfinite(y_pred)):
|
| 273 |
+
return np.inf
|
| 274 |
+
|
| 275 |
+
# MSE calculation
|
| 276 |
+
mse = np.mean((y - y_pred) ** 2)
|
| 277 |
+
|
| 278 |
+
return mse
|
| 279 |
+
|
| 280 |
+
bounds = [(-2., 2.)] * self.constant_count
|
| 281 |
+
|
| 282 |
+
initial_guess = (
|
| 283 |
+
self.best_constants
|
| 284 |
+
if self.best_constants and len(self.best_constants) == self.constant_count
|
| 285 |
+
else [.0] * self.constant_count # Default to 1.0
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Ensure initial_guess is a flat numpy array
|
| 289 |
+
initial_guess = np.array(initial_guess, dtype=float).flatten()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# from scipy.optimize import differential_evolution
|
| 293 |
+
# # Step 1: Use Differential Evolution for global exploration
|
| 294 |
+
# print("\n--- Starting Differential Evolution ---")
|
| 295 |
+
# result_de = differential_evolution(loss, bounds,
|
| 296 |
+
# popsize=70, # Aumente para 50, 70, ou mais
|
| 297 |
+
# maxiter=10000, # Aumente para 5000, 10000, ou mais
|
| 298 |
+
# strategy='rand1bin', # Tente 'rand1exp' se rand1bin não funcionar
|
| 299 |
+
# tol=1e-7, # Tolerância mais apertada
|
| 300 |
+
# mutation=(0.8, 1.2), # Experimente valores mais altos
|
| 301 |
+
# recombination=0.5, # Experimente valores mais baixos
|
| 302 |
+
# seed=42, # Mantém a reproducibilidade
|
| 303 |
+
# disp=True, # Exibe o progresso
|
| 304 |
+
# polish=False)
|
| 305 |
+
|
| 306 |
+
# if result_de.success:
|
| 307 |
+
# print(f"\nDifferential Evolution finished successfully. Best raw constants: {result_de.x}, Best MSE: {result_de.fun}")
|
| 308 |
+
# # Use the result from DE as initial guess for local optimizer
|
| 309 |
+
# initial_guess_for_minimize = result_de.x
|
| 310 |
+
|
| 311 |
+
# # Step 2: (Optional but recommended) Refine with L-BFGS-B
|
| 312 |
+
# # L-BFGS-B will be applied to the "raw" (non-rounded) values,
|
| 313 |
+
# # but the loss function internally rounds for discrete ones.
|
| 314 |
+
# # It might still struggle if the function is too "stepped" from rounding.
|
| 315 |
+
# print("\n--- Starting L-BFGS-B refinement ---")
|
| 316 |
+
# result_min = minimize(loss,
|
| 317 |
+
# x0=initial_guess_for_minimize,
|
| 318 |
+
# method='L-BFGS-B',
|
| 319 |
+
# bounds=bounds,
|
| 320 |
+
# options={'maxiter': 500, 'ftol': 1e-9, 'disp': True} # More iterations, tighter tolerance
|
| 321 |
+
# )
|
| 322 |
+
|
| 323 |
+
# if result_min.success:
|
| 324 |
+
# print(f"\nL-BFGS-B refinement successful. Final raw constants: {result_min.x}, Final MSE: {result_min.fun}")
|
| 325 |
+
# self.best_constants = list(result_min.x)
|
| 326 |
+
# else:
|
| 327 |
+
# print(f"\nL-BFGS-B refinement failed: {result_min.message}. Using Differential Evolution's result.")
|
| 328 |
+
# self.best_constants = list(result_de.x)
|
| 329 |
+
# else:
|
| 330 |
+
# print(f"\nDifferential Evolution did not converge successfully: {result_de.message}. Cannot proceed with optimization.")
|
| 331 |
+
# return -np.inf # Indicate failure
|
| 332 |
+
|
| 333 |
+
# try:
|
| 334 |
+
# y_pred = self.evaluate(X)
|
| 335 |
+
# if not np.all(np.isfinite(y_pred)):
|
| 336 |
+
# print("Final evaluation produced non-finite values for R2 score.")
|
| 337 |
+
# return -np.inf
|
| 338 |
+
# if len(np.unique(y)) == 1:
|
| 339 |
+
# if np.allclose(y_pred, y[0]):
|
| 340 |
+
# return 1.0
|
| 341 |
+
# else:
|
| 342 |
+
# return 0.0
|
| 343 |
+
# return r2_score(y, y_pred)
|
| 344 |
+
# except Exception as e:
|
| 345 |
+
# print(f"Error calculating final R2: {e}")
|
| 346 |
+
# return -np.inf
|
| 347 |
+
|
| 348 |
+
result = minimize(loss,
|
| 349 |
+
x0=initial_guess,
|
| 350 |
+
method='L-BFGS-B',
|
| 351 |
+
bounds=bounds,
|
| 352 |
+
#options={'maxiter': 10, 'maxfun': 10, 'disp': True}
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if result.success:
|
| 356 |
+
self.best_constants = result.x.tolist()
|
| 357 |
+
# print(f"Optimization successful. Final loss: {result.fun}") # Optional
|
| 358 |
+
try:
|
| 359 |
+
y_pred = self.evaluate(X) # Uses self.best_constants (vectorized)
|
| 360 |
+
if not np.all(np.isfinite(y_pred)):
|
| 361 |
+
return -np.inf
|
| 362 |
+
# Refined R2 calculation for edge cases
|
| 363 |
+
if len(np.unique(y)) == 1: # If y is constant
|
| 364 |
+
if np.allclose(y_pred, y[0]):
|
| 365 |
+
return 1.0 # Perfect prediction of a constant
|
| 366 |
+
else:
|
| 367 |
+
return 0.0 # Or some other metric for imperfect constant prediction
|
| 368 |
+
#return mean_squared_error(y, y_pred) # Use MSE for optimization
|
| 369 |
+
#return mean_absolute_error(y, y_pred) # Use MAE for robustness
|
| 370 |
+
return r2_score(y, y_pred)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
return -np.inf
|
| 373 |
+
else:
|
| 374 |
+
return -np.inf
|
| 375 |
+
|
| 376 |
+
# from dataset import RegressionDataset
|
| 377 |
+
|
| 378 |
+
# import numpy as np
|
| 379 |
+
# import warnings
|
| 380 |
+
|
| 381 |
+
# with warnings.catch_warnings():
|
| 382 |
+
# warnings.simplefilter("ignore", category=RuntimeWarning)
|
| 383 |
+
# np.seterr(invalid='ignore')
|
| 384 |
+
|
| 385 |
+
# #reg = RegressionDataset('../data/evaluate/srsd-feynman_hard/train', 'feynman-bonus.12.txt', delimiter=' ')
|
| 386 |
+
# reg = RegressionDataset('./data/evaluate/srsd-feynman_easy/train', 'feynman-i.18.16.txt', delimiter=' ')
|
| 387 |
+
# X, y = reg.get_numpy()
|
| 388 |
+
|
| 389 |
+
# #x = np.array(X).T
|
| 390 |
+
# expression = "x_1*x_2*sin(x_4)"
|
| 391 |
+
# #expr = "0.5*x[0]*x[1]**2"
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# expr = Expression(expression)
|
| 395 |
+
# print("Expression:", expr)
|
| 396 |
+
|
| 397 |
+
# if expr.is_valid_on_dataset(X):
|
| 398 |
+
# print("Expression is valid on dataset.")
|
| 399 |
+
# score = expr.fit_constants(X, y)
|
| 400 |
+
# print("Fitted constants:", expr.best_constants)
|
| 401 |
+
# print("R2 score:", score)
|
| 402 |
+
# else:
|
| 403 |
+
# print("Expression is not valid on dataset.")
|
configs/eval_dataset_download.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_easy_dummy
|
| 2 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_medium_dummy
|
| 3 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_hard_dummy
|
| 4 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_easy
|
| 5 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_medium
|
| 6 |
+
git clone https://huggingface.co/datasets/yoshitomo-matsubara/srsd-feynman_hard
|
configs/model_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
configs/peft_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
configs/training.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
|
| 2 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 3 |
+
--data_dir 500k \
|
| 4 |
+
--output_dir ./output \
|
| 5 |
+
--push_to_hub \
|
| 6 |
+
--hub_model_id augustocsc/Se124M500KInfPrompt_EOS \
|
| 7 |
+
--source_data_column i_prompt \
|
| 8 |
+
--report_to wandb \
|
| 9 |
+
--run_name Se124M500KInfPrompt_EOS \
|
| 10 |
+
--model_name_or_path gpt2 \
|
| 11 |
+
--bf16 \
|
| 12 |
+
--eval_strategy steps \
|
| 13 |
+
--num_train_epochs 3 \
|
| 14 |
+
--per_device_train_batch_size 16 \
|
| 15 |
+
--per_device_eval_batch_size 16 \
|
| 16 |
+
--gradient_accumulation_steps 4 \
|
| 17 |
+
--dataloader_num_workers 8 \
|
| 18 |
+
--learning_rate 5e-5 \
|
| 19 |
+
--warmup_ratio 0.03 \
|
| 20 |
+
--weight_decay 0.01 \
|
| 21 |
+
--max_grad_norm 1.0 \
|
| 22 |
+
--lr_scheduler_type cosine \
|
| 23 |
+
--optim adamw_torch_fused \
|
| 24 |
+
--logging_steps 20 \
|
| 25 |
+
--eval_steps 500 \
|
| 26 |
+
--save_steps 1000 \
|
| 27 |
+
--save_total_limit 3 \
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# CUDA_VISIBLE_DEVICES=1 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
|
| 31 |
+
# --dataset_repo_id augustocsc/sintetico_final \
|
| 32 |
+
# --data_dir 100k \
|
| 33 |
+
# --output_dir ./output \
|
| 34 |
+
# --push_to_hub \
|
| 35 |
+
# --hub_model_id augustocsc/Se124M100KInfPrompt_NT \
|
| 36 |
+
# --source_data_column i_prompt \
|
| 37 |
+
# --report_to wandb \
|
| 38 |
+
# --run_name Se124M100KInfPrompt_NT \
|
| 39 |
+
# --bf16 \
|
| 40 |
+
# --eval_strategy steps \
|
| 41 |
+
# --num_train_epochs 3 \
|
| 42 |
+
# --per_device_train_batch_size 16 \
|
| 43 |
+
# --per_device_eval_batch_size 16 \
|
| 44 |
+
# --gradient_accumulation_steps 2 \
|
| 45 |
+
# --dataloader_num_workers 8 \
|
| 46 |
+
# --learning_rate 2e-5 \
|
| 47 |
+
# --warmup_ratio 0.03 \
|
| 48 |
+
# --weight_decay 0.01 \
|
| 49 |
+
# --max_grad_norm 1.0 \
|
| 50 |
+
# --lr_scheduler_type cosine \
|
| 51 |
+
# --optim adamw_torch_fused \
|
| 52 |
+
# --logging_steps 20 \
|
| 53 |
+
# --eval_steps 500 \
|
| 54 |
+
# --save_steps 1000 \
|
| 55 |
+
# --save_total_limit 3
|
| 56 |
+
|
| 57 |
+
# CUDA_VISIBLE_DEVICES=0 python /home/augusto/symbo_repos/seringuela/scripts/train_test.py \
|
| 58 |
+
# --dataset_repo_id augustocsc/sintetico_final \
|
| 59 |
+
# --data_dir 100k \
|
| 60 |
+
# --output_dir ./output \
|
| 61 |
+
# --push_to_hub \
|
| 62 |
+
# --hub_model_id augustocsc/Se124M100KInfPrompt_WT \
|
| 63 |
+
# --source_data_column i_prompt \
|
| 64 |
+
# --report_to wandb \
|
| 65 |
+
# --run_name Se124M100KInfPrompt_WT \
|
| 66 |
+
# --bf16 \
|
| 67 |
+
# --eval_strategy steps \
|
| 68 |
+
# --num_train_epochs 3 \
|
| 69 |
+
# --per_device_train_batch_size 16 \
|
| 70 |
+
# --per_device_eval_batch_size 16 \
|
| 71 |
+
# --gradient_accumulation_steps 2 \
|
| 72 |
+
# --dataloader_num_workers 8 \
|
| 73 |
+
# --learning_rate 2e-5 \
|
| 74 |
+
# --warmup_ratio 0.03 \
|
| 75 |
+
# --weight_decay 0.01 \
|
| 76 |
+
# --max_grad_norm 1.0 \
|
| 77 |
+
# --lr_scheduler_type cosine \
|
| 78 |
+
# --optim adamw_torch_fused \
|
| 79 |
+
# --logging_steps 20 \
|
| 80 |
+
# --eval_steps 500 \
|
| 81 |
+
# --save_steps 1000 \
|
| 82 |
+
# --save_total_limit 3
|
configs/training_args.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"output_dir": "./output",
|
| 3 |
+
"overwrite_output_dir": true,
|
| 4 |
+
"num_train_epochs": 50,
|
| 5 |
+
"per_device_train_batch_size": 8,
|
| 6 |
+
"gradient_accumulation_steps": 1,
|
| 7 |
+
"learning_rate": 5e-5,
|
| 8 |
+
"weight_decay": 0.01,
|
| 9 |
+
"warmup_steps": 0,
|
| 10 |
+
"fp16": true,
|
| 11 |
+
"seed": 42,
|
| 12 |
+
"per_device_eval_batch_size": 8,
|
| 13 |
+
"eval_strategy": "epoch",
|
| 14 |
+
"metric_for_best_model": "eval_loss",
|
| 15 |
+
"greater_is_better": false,
|
| 16 |
+
"eval_steps": null,
|
| 17 |
+
"load_best_model_at_end": true,
|
| 18 |
+
"save_strategy": "epoch",
|
| 19 |
+
"save_steps": null,
|
| 20 |
+
"save_total_limit": 2,
|
| 21 |
+
"logging_dir": "./output/logs",
|
| 22 |
+
"logging_steps": 100,
|
| 23 |
+
"report_to": "wandb",
|
| 24 |
+
"run_name": "Se124M100K",
|
| 25 |
+
"push_to_hub": true,
|
| 26 |
+
"hub_model_id": "augustocsc/Se124M100K",
|
| 27 |
+
"hub_token": null
|
| 28 |
+
|
| 29 |
+
}
|
configs/training_large.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2-large",
|
| 4 |
+
"model_size": "774M",
|
| 5 |
+
"description": "GPT-2 Large - 774M parameters"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 2,
|
| 9 |
+
"per_device_train_batch_size": 4,
|
| 10 |
+
"per_device_eval_batch_size": 4,
|
| 11 |
+
"gradient_accumulation_steps": 16,
|
| 12 |
+
"effective_batch_size": 64,
|
| 13 |
+
"learning_rate": 2e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn", "c_proj"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"dataset_repo_id": "augustocsc/sintetico_natural",
|
| 49 |
+
"data_dir": "700K",
|
| 50 |
+
"data_columns": {
|
| 51 |
+
"infix": "i_prompt_n",
|
| 52 |
+
"prefix": "p_prompt_n"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"hub_config": {
|
| 56 |
+
"push_to_hub": true,
|
| 57 |
+
"hub_model_id_template": "augustocsc/Se774M_700K_{format}",
|
| 58 |
+
"formats": ["infix", "prefix"]
|
| 59 |
+
},
|
| 60 |
+
"estimated_time": {
|
| 61 |
+
"per_epoch_minutes": 180,
|
| 62 |
+
"total_hours": 6,
|
| 63 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU. May need gradient checkpointing for memory optimization."
|
| 64 |
+
}
|
| 65 |
+
}
|
configs/training_medium.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2-medium",
|
| 4 |
+
"model_size": "355M",
|
| 5 |
+
"description": "GPT-2 Medium - 355M parameters"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 2,
|
| 9 |
+
"per_device_train_batch_size": 8,
|
| 10 |
+
"per_device_eval_batch_size": 8,
|
| 11 |
+
"gradient_accumulation_steps": 8,
|
| 12 |
+
"effective_batch_size": 64,
|
| 13 |
+
"learning_rate": 3e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn", "c_proj"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"dataset_repo_id": "augustocsc/sintetico_natural",
|
| 49 |
+
"data_dir": "700K",
|
| 50 |
+
"data_columns": {
|
| 51 |
+
"infix": "i_prompt_n",
|
| 52 |
+
"prefix": "p_prompt_n"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"hub_config": {
|
| 56 |
+
"push_to_hub": true,
|
| 57 |
+
"hub_model_id_template": "augustocsc/Se355M_700K_{format}",
|
| 58 |
+
"formats": ["infix", "prefix"]
|
| 59 |
+
},
|
| 60 |
+
"estimated_time": {
|
| 61 |
+
"per_epoch_minutes": 90,
|
| 62 |
+
"total_hours": 3,
|
| 63 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU"
|
| 64 |
+
}
|
| 65 |
+
}
|
configs/training_small.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2",
|
| 4 |
+
"model_size": "124M",
|
| 5 |
+
"description": "GPT-2 Small - 124M parameters"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 3,
|
| 9 |
+
"per_device_train_batch_size": 16,
|
| 10 |
+
"per_device_eval_batch_size": 16,
|
| 11 |
+
"gradient_accumulation_steps": 4,
|
| 12 |
+
"effective_batch_size": 64,
|
| 13 |
+
"learning_rate": 5e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn", "c_proj"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"dataset_repo_id": "augustocsc/sintetico_natural",
|
| 49 |
+
"data_dir": "700K",
|
| 50 |
+
"data_columns": {
|
| 51 |
+
"infix": "i_prompt_n",
|
| 52 |
+
"prefix": "p_prompt_n"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"hub_config": {
|
| 56 |
+
"push_to_hub": true,
|
| 57 |
+
"hub_model_id_template": "augustocsc/Se124M_700K_{format}",
|
| 58 |
+
"formats": ["infix", "prefix"]
|
| 59 |
+
},
|
| 60 |
+
"estimated_time": {
|
| 61 |
+
"per_epoch_minutes": 40,
|
| 62 |
+
"total_hours": 2,
|
| 63 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU"
|
| 64 |
+
}
|
| 65 |
+
}
|
configs/training_v3.json
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"model_name_or_path": "gpt2",
|
| 4 |
+
"model_size": "124M",
|
| 5 |
+
"description": "GPT-2 Small (124M) - v3 with proper end markers"
|
| 6 |
+
},
|
| 7 |
+
"training_args": {
|
| 8 |
+
"num_train_epochs": 3,
|
| 9 |
+
"per_device_train_batch_size": 8,
|
| 10 |
+
"per_device_eval_batch_size": 8,
|
| 11 |
+
"gradient_accumulation_steps": 4,
|
| 12 |
+
"effective_batch_size": 32,
|
| 13 |
+
"learning_rate": 5e-5,
|
| 14 |
+
"weight_decay": 0.01,
|
| 15 |
+
"warmup_steps": 100,
|
| 16 |
+
"max_grad_norm": 1.0,
|
| 17 |
+
"lr_scheduler_type": "cosine",
|
| 18 |
+
"fp16": true,
|
| 19 |
+
"seed": 42,
|
| 20 |
+
"block_size": 128
|
| 21 |
+
},
|
| 22 |
+
"evaluation_args": {
|
| 23 |
+
"eval_strategy": "epoch",
|
| 24 |
+
"eval_steps": null,
|
| 25 |
+
"metric_for_best_model": "eval_loss",
|
| 26 |
+
"greater_is_better": false,
|
| 27 |
+
"load_best_model_at_end": true
|
| 28 |
+
},
|
| 29 |
+
"save_args": {
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"save_steps": null,
|
| 32 |
+
"save_total_limit": 2
|
| 33 |
+
},
|
| 34 |
+
"logging_args": {
|
| 35 |
+
"logging_dir": "./output/logs",
|
| 36 |
+
"logging_steps": 50,
|
| 37 |
+
"report_to": "wandb"
|
| 38 |
+
},
|
| 39 |
+
"lora_config": {
|
| 40 |
+
"r": 8,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"target_modules": ["c_attn"],
|
| 43 |
+
"lora_dropout": 0.05,
|
| 44 |
+
"bias": "none",
|
| 45 |
+
"task_type": "CAUSAL_LM"
|
| 46 |
+
},
|
| 47 |
+
"dataset_config": {
|
| 48 |
+
"use_local_csvs": true,
|
| 49 |
+
"train_file": "./data/processed/700K_fixed/train_700K.csv",
|
| 50 |
+
"validation_file": "./data/processed/700K_fixed/validation_700K.csv",
|
| 51 |
+
"test_file": "./data/processed/700K_fixed/test_700K.csv",
|
| 52 |
+
"data_column": "text"
|
| 53 |
+
},
|
| 54 |
+
"hub_config": {
|
| 55 |
+
"push_to_hub": true,
|
| 56 |
+
"hub_model_id": "augustocsc/Se124M_700K_infix_v3"
|
| 57 |
+
},
|
| 58 |
+
"special_tokens": {
|
| 59 |
+
"start_token": "<|startofex|>",
|
| 60 |
+
"end_token": "<|endofex|>",
|
| 61 |
+
"notes": "End token configured as EOS token for proper stopping"
|
| 62 |
+
},
|
| 63 |
+
"estimated_time": {
|
| 64 |
+
"per_epoch_minutes": 45,
|
| 65 |
+
"total_hours": 2.25,
|
| 66 |
+
"notes": "Estimated for AWS g5.xlarge with A10G GPU, GPT-2 Small, 3 epochs"
|
| 67 |
+
},
|
| 68 |
+
"version_info": {
|
| 69 |
+
"model_version": "v3",
|
| 70 |
+
"improvements": [
|
| 71 |
+
"Training data includes proper <|endofex|> markers",
|
| 72 |
+
"100% validation rate on prepared dataset",
|
| 73 |
+
"Addresses v1 non-stopping issue and v2 garbage generation",
|
| 74 |
+
"Uses local CSVs with validated end markers"
|
| 75 |
+
],
|
| 76 |
+
"training_date": "2026-02-01"
|
| 77 |
+
}
|
| 78 |
+
}
|
create_structure.sh
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo "Criando estrutura de pastas para o projeto de fine-tuning..."
|
| 4 |
+
|
| 5 |
+
# Diretórios Principais
|
| 6 |
+
mkdir -p data/raw
|
| 7 |
+
mkdir -p data/processed
|
| 8 |
+
mkdir -p scripts
|
| 9 |
+
mkdir -p configs
|
| 10 |
+
mkdir -p output
|
| 11 |
+
mkdir -p notebooks
|
| 12 |
+
|
| 13 |
+
echo "Diretórios criados."
|
| 14 |
+
|
| 15 |
+
# Arquivos Placeholder e de Configuração Inicial
|
| 16 |
+
touch data/raw/.gitkeep # Mantém a pasta no Git mesmo vazia
|
| 17 |
+
touch data/processed/.gitkeep # Mantém a pasta no Git mesmo vazia
|
| 18 |
+
|
| 19 |
+
echo "# Script para pré-processar dados (raw -> processed)" > scripts/preprocess_data.py
|
| 20 |
+
echo "# Script principal de treinamento (usa Trainer)" > scripts/train.py
|
| 21 |
+
echo "# Script para avaliação customizada" > scripts/evaluate.py
|
| 22 |
+
echo "# Script para geração de texto com modelo treinado" > scripts/generate.py
|
| 23 |
+
|
| 24 |
+
echo "{}" > configs/training_args.json # Placeholder para argumentos do Trainer
|
| 25 |
+
echo "{}" > configs/peft_config.json # Placeholder para config PEFT (se usar)
|
| 26 |
+
echo "{}" > configs/model_config.json # Placeholder para config do modelo base
|
| 27 |
+
|
| 28 |
+
touch notebooks/01_data_exploration.ipynb
|
| 29 |
+
touch notebooks/.gitkeep # Mantém a pasta no Git mesmo vazia
|
| 30 |
+
|
| 31 |
+
touch requirements.txt
|
| 32 |
+
|
| 33 |
+
echo "Arquivos placeholder criados."
|
| 34 |
+
|
| 35 |
+
# Conteúdo Inicial para .gitignore
|
| 36 |
+
echo "Gerando .gitignore..."
|
| 37 |
+
cat << EOF > .gitignore
|
| 38 |
+
# Byte-compiled / optimized / DLL files
|
| 39 |
+
__pycache__/
|
| 40 |
+
*.py[cod]
|
| 41 |
+
*$py.class
|
| 42 |
+
|
| 43 |
+
# C extensions
|
| 44 |
+
*.so
|
| 45 |
+
|
| 46 |
+
# Distribution / packaging
|
| 47 |
+
.Python
|
| 48 |
+
build/
|
| 49 |
+
develop-eggs/
|
| 50 |
+
dist/
|
| 51 |
+
downloads/
|
| 52 |
+
eggs/
|
| 53 |
+
.eggs/
|
| 54 |
+
lib/
|
| 55 |
+
lib64/
|
| 56 |
+
parts/
|
| 57 |
+
sdist/
|
| 58 |
+
var/
|
| 59 |
+
wheels/
|
| 60 |
+
pip-wheel-metadata/
|
| 61 |
+
share/python-wheels/
|
| 62 |
+
*.egg-info/
|
| 63 |
+
.installed.cfg
|
| 64 |
+
*.egg
|
| 65 |
+
MANIFEST
|
| 66 |
+
|
| 67 |
+
# PyInstaller
|
| 68 |
+
# Usually these files are written by a python script from a template
|
| 69 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 70 |
+
*.manifest
|
| 71 |
+
*.spec
|
| 72 |
+
|
| 73 |
+
# Installer logs
|
| 74 |
+
pip-log.txt
|
| 75 |
+
pip-delete-this-directory.txt
|
| 76 |
+
|
| 77 |
+
# Unit test / coverage reports
|
| 78 |
+
htmlcov/
|
| 79 |
+
.tox/
|
| 80 |
+
.nox/
|
| 81 |
+
.coverage
|
| 82 |
+
.coverage.*
|
| 83 |
+
.cache
|
| 84 |
+
nosetests.xml
|
| 85 |
+
coverage.xml
|
| 86 |
+
*.cover
|
| 87 |
+
*.py,cover
|
| 88 |
+
.hypothesis/
|
| 89 |
+
.pytest_cache/
|
| 90 |
+
|
| 91 |
+
# Environments
|
| 92 |
+
.env
|
| 93 |
+
.venv
|
| 94 |
+
venv/
|
| 95 |
+
ENV/
|
| 96 |
+
env/
|
| 97 |
+
env.bak/
|
| 98 |
+
venv.bak/
|
| 99 |
+
|
| 100 |
+
# IDEs / Editors
|
| 101 |
+
.idea/
|
| 102 |
+
.vscode/
|
| 103 |
+
*.suo
|
| 104 |
+
*.ntvs*
|
| 105 |
+
*.njsproj
|
| 106 |
+
*.sln
|
| 107 |
+
*.sw?
|
| 108 |
+
|
| 109 |
+
# Jupyter Notebook
|
| 110 |
+
.ipynb_checkpoints
|
| 111 |
+
|
| 112 |
+
# Output folder (geralmente grande demais para Git)
|
| 113 |
+
output/*
|
| 114 |
+
!output/.gitkeep # Não ignore um .gitkeep se precisar manter a pasta
|
| 115 |
+
|
| 116 |
+
# Dados (podem ser grandes, usar Git LFS ou armazenar fora se necessário)
|
| 117 |
+
data/raw/*
|
| 118 |
+
data/processed/*
|
| 119 |
+
!data/raw/.gitkeep
|
| 120 |
+
!data/processed/.gitkeep
|
| 121 |
+
|
| 122 |
+
# OS generated files
|
| 123 |
+
.DS_Store
|
| 124 |
+
.DS_Store?
|
| 125 |
+
._*
|
| 126 |
+
.Spotlight-V100
|
| 127 |
+
.Trashes
|
| 128 |
+
ehthumbs.db
|
| 129 |
+
Thumbs.db
|
| 130 |
+
EOF
|
| 131 |
+
|
| 132 |
+
# Conteúdo Inicial para README.md (será preenchido com o texto gerado abaixo)
|
| 133 |
+
echo "Gerando README.md inicial..."
|
| 134 |
+
echo "# Nome do Seu Projeto de Fine-Tuning" > README.md
|
| 135 |
+
echo "" >> README.md
|
| 136 |
+
echo "(Breve descrição do objetivo do projeto)" >> README.md
|
| 137 |
+
echo "" >> README.md
|
| 138 |
+
echo "## Estrutura de Pastas" >> README.md
|
| 139 |
+
echo "" >> README.md
|
| 140 |
+
echo "**(COPIE E COLE A EXPLICAÇÃO DA ESTRUTURA GERADA NA PRÓXIMA SEÇÃO AQUI)**" >> README.md
|
| 141 |
+
echo "" >> README.md
|
| 142 |
+
echo "## Como Usar" >> README.md
|
| 143 |
+
echo "" >> README.md
|
| 144 |
+
echo "1. **Setup:** Crie um ambiente virtual e instale as dependências:" >> README.md
|
| 145 |
+
echo " \`\`\`bash" >> README.md
|
| 146 |
+
echo " python -m venv venv" >> README.md
|
| 147 |
+
echo " source venv/bin/activate # Linux/macOS" >> README.md
|
| 148 |
+
echo " # venv\\Scripts\\activate # Windows" >> README.md
|
| 149 |
+
echo " pip install -r requirements.txt" >> README.md
|
| 150 |
+
echo " \`\`\`" >> README.md
|
| 151 |
+
echo "2. **Dados:** Coloque seus dados brutos em \`data/raw/\` e execute (ou crie) o script \`scripts/preprocess_data.py\` para gerar os arquivos em \`data/processed/\`." >> README.md
|
| 152 |
+
echo "3. **Configuração:** Ajuste os arquivos em \`configs/\` (argumentos de treino, modelo base, PEFT se aplicável)." >> README.md
|
| 153 |
+
echo "4. **Treinamento:** Execute o script principal:" >> README.md
|
| 154 |
+
echo " \`\`\`bash" >> README.md
|
| 155 |
+
echo " python scripts/train.py --args_config configs/training_args.json --model_config configs/model_config.json" >> README.md
|
| 156 |
+
echo " \`\`\`" >> README.md
|
| 157 |
+
echo " *(Adapte os argumentos conforme necessário)*" >> README.md
|
| 158 |
+
echo "" >> README.md
|
| 159 |
+
echo "## Dependências" >> README.md
|
| 160 |
+
echo "" >> README.md
|
| 161 |
+
echo "As dependências Python estão listadas no arquivo \`requirements.txt\`." >> README.md
|
| 162 |
+
|
| 163 |
+
chmod +x create_structure.sh
|
| 164 |
+
|
| 165 |
+
echo "--------------------------------------------------"
|
| 166 |
+
echo "Estrutura criada com sucesso!"
|
| 167 |
+
echo "Para usar:"
|
| 168 |
+
echo "1. Torne o script executável: chmod +x create_structure.sh"
|
| 169 |
+
echo "2. Execute o script: ./create_structure.sh"
|
| 170 |
+
echo "3. Copie a explicação da estrutura (gerada na resposta anterior) para dentro do README.md onde indicado."
|
| 171 |
+
echo "--------------------------------------------------"
|
notebooks/.gitkeep
ADDED
|
File without changes
|
notebooks/01_data_exploration.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/02_finetuning_avaliation.ipynb
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "5c6de955",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import re\n",
|
| 11 |
+
"import json\n",
|
| 12 |
+
"from collections import Counter, defaultdict\n",
|
| 13 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 14 |
+
"from peft import PeftModel\n",
|
| 15 |
+
"import sympy as sp\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"# Configuration\n",
|
| 18 |
+
"TOKENIZER_REPO = \"augustocsc/Se124M500KInfPrompt_EOS\"\n",
|
| 19 |
+
"LORA_REPO = \"augustocsc/Se124M500KInfPrompt_EOS\"\n",
|
| 20 |
+
"BASE_MODEL = \"gpt2\"\n",
|
| 21 |
+
"PROMPT = \"\"\"\n",
|
| 22 |
+
"vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10\n",
|
| 23 |
+
"oper: *, **, +, -, /\n",
|
| 24 |
+
"cons: C\n",
|
| 25 |
+
"expr:\"\"\"\n",
|
| 26 |
+
"GENERATE_BATCH = 10\n",
|
| 27 |
+
"REPEAT_TIMES = 1\n",
|
| 28 |
+
"OUTPUT_EXPR_FILE = \"generated_expressions.json\"\n",
|
| 29 |
+
"OUTPUT_ANALYSIS_FILE = \"analysis_results.json\"\n"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 2,
|
| 35 |
+
"id": "e0b08244",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [
|
| 38 |
+
{
|
| 39 |
+
"name": "stdout",
|
| 40 |
+
"output_type": "stream",
|
| 41 |
+
"text": [
|
| 42 |
+
"Loading tokenizer and model...\n"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"data": {
|
| 47 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 48 |
+
"model_id": "8db99632228d4e599ab477a436f16c3e",
|
| 49 |
+
"version_major": 2,
|
| 50 |
+
"version_minor": 0
|
| 51 |
+
},
|
| 52 |
+
"text/plain": [
|
| 53 |
+
"tokenizer_config.json: 0%| | 0.00/1.09k [00:00<?, ?B/s]"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"output_type": "display_data"
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"data": {
|
| 61 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 62 |
+
"model_id": "b94dc7909bb74e1fbf6a3a3922c88a8b",
|
| 63 |
+
"version_major": 2,
|
| 64 |
+
"version_minor": 0
|
| 65 |
+
},
|
| 66 |
+
"text/plain": [
|
| 67 |
+
"vocab.json: 0%| | 0.00/798k [00:00<?, ?B/s]"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"output_type": "display_data"
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"data": {
|
| 75 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 76 |
+
"model_id": "a20a9a7738224785b2693a2239aa079d",
|
| 77 |
+
"version_major": 2,
|
| 78 |
+
"version_minor": 0
|
| 79 |
+
},
|
| 80 |
+
"text/plain": [
|
| 81 |
+
"merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"output_type": "display_data"
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"data": {
|
| 89 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 90 |
+
"model_id": "00c1fb75426b49c295e79ff1f7b92d4f",
|
| 91 |
+
"version_major": 2,
|
| 92 |
+
"version_minor": 0
|
| 93 |
+
},
|
| 94 |
+
"text/plain": [
|
| 95 |
+
"tokenizer.json: 0%| | 0.00/3.56M [00:00<?, ?B/s]"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"output_type": "display_data"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"data": {
|
| 103 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 104 |
+
"model_id": "1af7b45de60248b192b9c47b63a1c4c4",
|
| 105 |
+
"version_major": 2,
|
| 106 |
+
"version_minor": 0
|
| 107 |
+
},
|
| 108 |
+
"text/plain": [
|
| 109 |
+
"added_tokens.json: 0%| | 0.00/67.0 [00:00<?, ?B/s]"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"output_type": "display_data"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"data": {
|
| 117 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 118 |
+
"model_id": "b9cb7f8429d0486198aa7870aa45b381",
|
| 119 |
+
"version_major": 2,
|
| 120 |
+
"version_minor": 0
|
| 121 |
+
},
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"special_tokens_map.json: 0%| | 0.00/562 [00:00<?, ?B/s]"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "stderr",
|
| 131 |
+
"output_type": "stream",
|
| 132 |
+
"text": [
|
| 133 |
+
"The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"data": {
|
| 138 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 139 |
+
"model_id": "12fa3ee725a04c57a19737d18f70f340",
|
| 140 |
+
"version_major": 2,
|
| 141 |
+
"version_minor": 0
|
| 142 |
+
},
|
| 143 |
+
"text/plain": [
|
| 144 |
+
"adapter_config.json: 0%| | 0.00/744 [00:00<?, ?B/s]"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"output_type": "display_data"
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"data": {
|
| 152 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 153 |
+
"model_id": "3b8be3854aee41d8a28bf0810a860b42",
|
| 154 |
+
"version_major": 2,
|
| 155 |
+
"version_minor": 0
|
| 156 |
+
},
|
| 157 |
+
"text/plain": [
|
| 158 |
+
"adapter_model.safetensors: 0%| | 0.00/310M [00:00<?, ?B/s]"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
"metadata": {},
|
| 162 |
+
"output_type": "display_data"
|
| 163 |
+
}
|
| 164 |
+
],
|
| 165 |
+
"source": [
|
| 166 |
+
"# Load tokenizer and model with LoRA adapter\n",
|
| 167 |
+
"print(\"Loading tokenizer and model...\")\n",
|
| 168 |
+
"tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO)\n",
|
| 169 |
+
"model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)\n",
|
| 170 |
+
"model.resize_token_embeddings(len(tokenizer))\n",
|
| 171 |
+
"model = PeftModel.from_pretrained(model, LORA_REPO)\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"model.eval()\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"# Regex to extract expressions between tokens\n",
|
| 177 |
+
"pattern = re.compile(r\"<startofex>(.*?)<endofex>\", re.DOTALL)"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"execution_count": null,
|
| 183 |
+
"id": "c76ee26f",
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"outputs": [
|
| 186 |
+
{
|
| 187 |
+
"name": "stderr",
|
| 188 |
+
"output_type": "stream",
|
| 189 |
+
"text": [
|
| 190 |
+
"Some weights of the model checkpoint at augustocsc/Se124M100KInfPrompt_EOS_Merged were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
|
| 191 |
+
"- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 192 |
+
"- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 193 |
+
"Some weights of the model checkpoint at augustocsc/Se124M100KInfPrompt_EOS_Merged were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
|
| 194 |
+
"- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 195 |
+
"- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
| 196 |
+
]
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
"source": [
|
| 200 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 201 |
+
"from peft import PeftModel\n",
|
| 202 |
+
"from trl import AutoModelForCausalLMWithValueHead\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"# Carrega o modelo base\n",
|
| 205 |
+
"base_model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"# Carrega os pesos LoRA (checkpoint treinado)\n",
|
| 208 |
+
"peft_model = PeftModel.from_pretrained(base_model, \"augustocsc/Se124M100KInfPrompt_EOS\")\n"
|
| 209 |
+
]
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "code",
|
| 213 |
+
"execution_count": 3,
|
| 214 |
+
"id": "ffc6e072",
|
| 215 |
+
"metadata": {},
|
| 216 |
+
"outputs": [
|
| 217 |
+
{
|
| 218 |
+
"name": "stderr",
|
| 219 |
+
"output_type": "stream",
|
| 220 |
+
"text": [
|
| 221 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"name": "stdout",
|
| 226 |
+
"output_type": "stream",
|
| 227 |
+
"text": [
|
| 228 |
+
"Run 1/1: Generating 10 samples...\n"
|
| 229 |
+
]
|
| 230 |
+
}
|
| 231 |
+
],
|
| 232 |
+
"source": [
|
| 233 |
+
"all_expressions = []\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"# Generation loop\n",
|
| 236 |
+
"for run in range(REPEAT_TIMES):\n",
|
| 237 |
+
" print(f\"Run {run+1}/{REPEAT_TIMES}: Generating {GENERATE_BATCH} samples...\")\n",
|
| 238 |
+
" inputs = tokenizer([PROMPT] * GENERATE_BATCH, return_tensors=\"pt\", padding=True)\n",
|
| 239 |
+
" outputs = model.generate(\n",
|
| 240 |
+
" **inputs,\n",
|
| 241 |
+
" max_new_tokens=75,\n",
|
| 242 |
+
" do_sample=True,\n",
|
| 243 |
+
" top_p=0.9,\n",
|
| 244 |
+
" top_k=50,\n",
|
| 245 |
+
" temperature=0.7,\n",
|
| 246 |
+
" )\n"
|
| 247 |
+
]
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"cell_type": "code",
|
| 251 |
+
"execution_count": 71,
|
| 252 |
+
"id": "be3b4bcb",
|
| 253 |
+
"metadata": {},
|
| 254 |
+
"outputs": [
|
| 255 |
+
{
|
| 256 |
+
"name": "stdout",
|
| 257 |
+
"output_type": "stream",
|
| 258 |
+
"text": [
|
| 259 |
+
"Generated expressions:\n",
|
| 260 |
+
" a_1, b_2, c_1, c_2, c_3, c_4, c_5, c_6, c_7, c_8, c_9, c_10, c_\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"A function that evaluates to a string, and returns a string.\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"A string can be any character, and can be either a double, a string, a double, a singleton, a string with multiple elements, a string with\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"op: *,\n",
|
| 271 |
+
" *, +, +, -, /\n",
|
| 272 |
+
"cons: C\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"expr: *, +, +, -, /\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"cons: C\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"expr: *, +, +, -, /\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"cons: C\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" *\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"cons: c\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"type: Int\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"value: *\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"value: *\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"value: *\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"value: *\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"value: *\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"value: *\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"value: *\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"value:\n",
|
| 303 |
+
" *, **, -, /\n",
|
| 304 |
+
"oper: *, **, +, -, /\n",
|
| 305 |
+
"oper: *, **, +, -, /\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"op: [\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"op: [\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"op: [\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"op:\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"*, **, +, -, /\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_\n",
|
| 319 |
+
" *, **, *, **, *, **, *, *, **, *, **, *, **, *, **, *, **, *, *, **, *, **, *, *\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"oper\n",
|
| 322 |
+
" *, *, *, *, *, *, *\n",
|
| 323 |
+
"cons: C\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"expr: *, *, *, *, *, *, *\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"cons: C\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"expr: *, *, *, *\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"\n",
|
| 332 |
+
"vars: *, +, *, *, *, *, *, *, *, *, *, *, *, *, *, *\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"oper: *, **, +, *, *,\n"
|
| 335 |
+
]
|
| 336 |
+
}
|
| 337 |
+
],
|
| 338 |
+
"source": [
|
| 339 |
+
"# remove the prompt from the generated text and print the decoded text\n",
|
| 340 |
+
"generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
|
| 341 |
+
"generated_text = [text.replace(PROMPT, \"\") for text in generated_text]\n",
|
| 342 |
+
"all_expressions.extend(generated_text)\n",
|
| 343 |
+
"print(\"Generated expressions:\")\n",
|
| 344 |
+
"for text in generated_text:\n",
|
| 345 |
+
" print(text)\n",
|
| 346 |
+
" "
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"cell_type": "code",
|
| 351 |
+
"execution_count": 4,
|
| 352 |
+
"id": "5d8e569f",
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"outputs": [
|
| 355 |
+
{
|
| 356 |
+
"name": "stdout",
|
| 357 |
+
"output_type": "stream",
|
| 358 |
+
"text": [
|
| 359 |
+
"Valid Expressions:\n",
|
| 360 |
+
"x_6 - x_3 + C*x_6 + x_7 + C\n",
|
| 361 |
+
"x_2*(x_9 + x_2)**C\n",
|
| 362 |
+
"x_9 + x_2 + C*x_7**C\n",
|
| 363 |
+
"x_9**C + x_1**C + x_2\n",
|
| 364 |
+
"C*x_1 + x_8 + x_1 + C\n",
|
| 365 |
+
"x_1**C*(x_9 + x_4**C + C)\n",
|
| 366 |
+
"x_2*(x_9 - C)**C/x_7\n",
|
| 367 |
+
"x_1*(x_8 - C)/(x_1 + x_2)\n",
|
| 368 |
+
"x_8**C*(x_2 + x_7)\n",
|
| 369 |
+
"x_1**C + x_2**C + x_9\n",
|
| 370 |
+
"\n",
|
| 371 |
+
"Invalid Expressions:\n"
|
| 372 |
+
]
|
| 373 |
+
}
|
| 374 |
+
],
|
| 375 |
+
"source": [
|
| 376 |
+
"valid_expressions = []\n",
|
| 377 |
+
"invalid_expressions = []\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"for out in outputs:\n",
|
| 380 |
+
" text = tokenizer.decode(out)\n",
|
| 381 |
+
" expr = text.split(\"expr: \")[1].split(\"<|endoftext|>\")[0].strip() # Extract the expression between \"expr: \" and <|endoftext|>\n",
|
| 382 |
+
" try:\n",
|
| 383 |
+
" sympy_expr = sp.sympify(expr, evaluate=False) # Try to parse the expression with sympy\n",
|
| 384 |
+
" valid_expressions.append(expr)\n",
|
| 385 |
+
" except Exception as e:\n",
|
| 386 |
+
" invalid_expressions.append(expr)\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"# Print valid expressions\n",
|
| 389 |
+
"print(\"Valid Expressions:\")\n",
|
| 390 |
+
"for expr in valid_expressions:\n",
|
| 391 |
+
" print(expr)\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"# Print invalid expressions\n",
|
| 394 |
+
"print(\"\\nInvalid Expressions:\")\n",
|
| 395 |
+
"for expr in invalid_expressions:\n",
|
| 396 |
+
" print(expr)"
|
| 397 |
+
]
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"cell_type": "code",
|
| 401 |
+
"execution_count": null,
|
| 402 |
+
"id": "d05f1edd",
|
| 403 |
+
"metadata": {},
|
| 404 |
+
"outputs": [
|
| 405 |
+
{
|
| 406 |
+
"ename": "AttributeError",
|
| 407 |
+
"evalue": "'AutoModelForCausalLMWithValueHead' object has no attribute 'generation_config'",
|
| 408 |
+
"output_type": "error",
|
| 409 |
+
"traceback": [
|
| 410 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 411 |
+
"\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
|
| 412 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/peft/peft_model.py:793\u001b[39m, in \u001b[36mPeftModel.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 792\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m793\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__getattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# defer to nn.Module's logic\u001b[39;00m\n\u001b[32m 794\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n",
|
| 413 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/torch/nn/modules/module.py:1928\u001b[39m, in \u001b[36mModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 1927\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[32m-> \u001b[39m\u001b[32m1928\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[32m 1929\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m).\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m object has no attribute \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1930\u001b[39m )\n",
|
| 414 |
+
"\u001b[31mAttributeError\u001b[39m: 'PeftModelForCausalLM' object has no attribute 'generation_config'",
|
| 415 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
| 416 |
+
"\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
|
| 417 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/peft/tuners/lora/model.py:359\u001b[39m, in \u001b[36mLoraModel.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 358\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m359\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__getattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# defer to nn.Module's logic\u001b[39;00m\n\u001b[32m 360\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n",
|
| 418 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/torch/nn/modules/module.py:1928\u001b[39m, in \u001b[36mModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 1927\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[32m-> \u001b[39m\u001b[32m1928\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[32m 1929\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m).\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m object has no attribute \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1930\u001b[39m )\n",
|
| 419 |
+
"\u001b[31mAttributeError\u001b[39m: 'LoraModel' object has no attribute 'generation_config'",
|
| 420 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
| 421 |
+
"\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
|
| 422 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[41]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Generate with beam search and early stopping\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m output = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m.\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 5\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m#max_length=100,\u001b[39;49;00m\n\u001b[32m 6\u001b[39m \u001b[43m \u001b[49m\u001b[43mnum_beams\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Enable beam search\u001b[39;49;00m\n\u001b[32m 7\u001b[39m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Stop when all beams hit EOS\u001b[39;49;00m\n\u001b[32m 8\u001b[39m \n\u001b[32m 9\u001b[39m \u001b[43m)\u001b[49m\n\u001b[32m 11\u001b[39m decoded_output = tokenizer.decode(output[\u001b[32m0\u001b[39m], skip_special_tokens=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 12\u001b[39m \u001b[38;5;28mprint\u001b[39m(decoded_output)\n",
|
| 423 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/peft/peft_model.py:1867\u001b[39m, in \u001b[36mPeftModelForCausalLM.generate\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1865\u001b[39m \u001b[38;5;28mself\u001b[39m.base_model.prepare_inputs_for_generation = \u001b[38;5;28mself\u001b[39m.prepare_inputs_for_generation\n\u001b[32m 1866\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.base_model, \u001b[33m\"\u001b[39m\u001b[33mmodel\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m-> \u001b[39m\u001b[32m1867\u001b[39m \u001b[38;5;28mself\u001b[39m.base_model.model.generation_config = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgeneration_config\u001b[49m\n\u001b[32m 1868\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1869\u001b[39m \u001b[38;5;28mself\u001b[39m.base_model.generation_config = \u001b[38;5;28mself\u001b[39m.generation_config\n",
|
| 424 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/peft/peft_model.py:797\u001b[39m, in \u001b[36mPeftModel.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 795\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m name == \u001b[33m\"\u001b[39m\u001b[33mbase_model\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;66;03m# see #1892: prevent infinite recursion if class is not initialized\u001b[39;00m\n\u001b[32m 796\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m797\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.base_model, name)\n",
|
| 425 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/peft/tuners/lora/model.py:363\u001b[39m, in \u001b[36mLoraModel.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 361\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m name == \u001b[33m\"\u001b[39m\u001b[33mmodel\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;66;03m# see #1892: prevent infinite recursion if class is not initialized\u001b[39;00m\n\u001b[32m 362\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m363\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.model, name)\n",
|
| 426 |
+
"\u001b[36mFile \u001b[39m\u001b[32m~/symbo_repos/seringuela/.seriguela/lib/python3.11/site-packages/torch/nn/modules/module.py:1928\u001b[39m, in \u001b[36mModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 1926\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[32m 1927\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[32m-> \u001b[39m\u001b[32m1928\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[32m 1929\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m).\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m object has no attribute \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1930\u001b[39m )\n",
|
| 427 |
+
"\u001b[31mAttributeError\u001b[39m: 'AutoModelForCausalLMWithValueHead' object has no attribute 'generation_config'"
|
| 428 |
+
]
|
| 429 |
+
}
|
| 430 |
+
],
|
| 431 |
+
"source": [
|
| 432 |
+
"# Generate with beam search and early stopping\n",
|
| 433 |
+
"output = model.generate(\n",
|
| 434 |
+
" inputs.input_ids,\n",
|
| 435 |
+
" attention_mask=inputs.attention_mask,\n",
|
| 436 |
+
" #max_length=100,\n",
|
| 437 |
+
" num_beams=5, # Enable beam search\n",
|
| 438 |
+
" early_stopping=True, # Stop when all beams hit EOS\n",
|
| 439 |
+
"\n",
|
| 440 |
+
")\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"decoded_output = tokenizer.decode(output[0], skip_special_tokens=False)\n",
|
| 443 |
+
"print(decoded_output)"
|
| 444 |
+
]
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"cell_type": "code",
|
| 448 |
+
"execution_count": null,
|
| 449 |
+
"id": "7a9ade5c",
|
| 450 |
+
"metadata": {},
|
| 451 |
+
"outputs": [],
|
| 452 |
+
"source": [
|
| 453 |
+
"\n",
|
| 454 |
+
"# Save raw expressions\n",
|
| 455 |
+
"with open(OUTPUT_EXPR_FILE, 'w') as f:\n",
|
| 456 |
+
" json.dump(all_expressions, f, indent=2)\n",
|
| 457 |
+
"print(f\"Saved {len(all_expressions)} expressions to {OUTPUT_EXPR_FILE}\")\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"# Analysis\n",
|
| 460 |
+
"analysis = {\n",
|
| 461 |
+
" 'total_expressions': len(all_expressions),\n",
|
| 462 |
+
" 'syntactic_semantic': {\n",
|
| 463 |
+
" 'valid_equations': 0,\n",
|
| 464 |
+
" 'parse_errors': defaultdict(int),\n",
|
| 465 |
+
" },\n",
|
| 466 |
+
" 'diversity_redundancy': {},\n",
|
| 467 |
+
" 'statistical_distributions': {\n",
|
| 468 |
+
" 'variable_freq': Counter(),\n",
|
| 469 |
+
" 'operator_freq': Counter(),\n",
|
| 470 |
+
" 'avg_operators_per_eq': 0.0,\n",
|
| 471 |
+
" 'avg_variables_per_eq': 0.0,\n",
|
| 472 |
+
" }\n",
|
| 473 |
+
"}\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"# Helper to compute tree depth\n",
|
| 476 |
+
"def tree_depth(expr):\n",
|
| 477 |
+
" if not expr.args:\n",
|
| 478 |
+
" return 1\n",
|
| 479 |
+
" return 1 + max(tree_depth(arg) for arg in expr.args)\n",
|
| 480 |
+
"\n",
|
| 481 |
+
"# Operators list\n",
|
| 482 |
+
"operators = ['+', '-', '*', '/', '^', 'log', 'exp', 'cos', 'sqrt', 'asin', 'sin', 'pow', 'tan', 'abs']\n",
|
| 483 |
+
"\n",
|
| 484 |
+
"depths = []\n",
|
| 485 |
+
"operator_counts = []\n",
|
| 486 |
+
"variable_counts = []\n",
|
| 487 |
+
"unique_set = set()\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"for expr in all_expressions:\n",
|
| 490 |
+
" # Parse with sympy\n",
|
| 491 |
+
" try:\n",
|
| 492 |
+
" sympy_expr = sp.sympify(expr, evaluate=False)\n",
|
| 493 |
+
" analysis['syntactic_semantic']['valid_equations'] += 1\n",
|
| 494 |
+
" depths.append(tree_depth(sympy_expr))\n",
|
| 495 |
+
" except Exception as e:\n",
|
| 496 |
+
" err_msg = str(e)\n",
|
| 497 |
+
" if 'could not parse' in err_msg:\n",
|
| 498 |
+
" analysis['syntactic_semantic']['parse_errors']['parse_failure'] += 1\n",
|
| 499 |
+
" else:\n",
|
| 500 |
+
" analysis['syntactic_semantic']['parse_errors'][err_msg] += 1\n",
|
| 501 |
+
" continue\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" # Variables\n",
|
| 504 |
+
" vars_in_expr = [str(v) for v in sympy_expr.free_symbols]\n",
|
| 505 |
+
" for v in vars_in_expr:\n",
|
| 506 |
+
" analysis['statistical_distributions']['variable_freq'][v] += 1\n",
|
| 507 |
+
" variable_counts.append(len(vars_in_expr))\n",
|
| 508 |
+
"\n",
|
| 509 |
+
" # Operators\n",
|
| 510 |
+
" op_count = sum(expr.count(op) for op in operators)\n",
|
| 511 |
+
" analysis['statistical_distributions']['operator_freq'].update({op: expr.count(op) for op in operators})\n",
|
| 512 |
+
" operator_counts.append(op_count)\n",
|
| 513 |
+
"\n",
|
| 514 |
+
" # Diversity\n",
|
| 515 |
+
" unique_set.add(expr)\n",
|
| 516 |
+
"\n",
|
| 517 |
+
"# Populate diversity metrics\n",
|
| 518 |
+
"total = analysis['total_expressions']\n",
|
| 519 |
+
"unique_count = len(unique_set)\n",
|
| 520 |
+
"analysis['diversity_redundancy'] = {\n",
|
| 521 |
+
" 'unique_expressions': unique_count,\n",
|
| 522 |
+
" 'unique_proportion': unique_count / total if total else 0,\n",
|
| 523 |
+
" 'duplicate_counts': {expr: cnt for expr, cnt in Counter(all_expressions).items() if cnt > 1},\n",
|
| 524 |
+
" 'structural_diversity': {\n",
|
| 525 |
+
" 'avg_tree_depth': sum(depths) / len(depths) if depths else 0,\n",
|
| 526 |
+
" 'min_tree_depth': min(depths) if depths else 0,\n",
|
| 527 |
+
" 'max_tree_depth': max(depths) if depths else 0,\n",
|
| 528 |
+
" }\n",
|
| 529 |
+
"}\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"# Statistical distributions averages\n",
|
| 532 |
+
"analysis['statistical_distributions']['avg_operators_per_eq'] = sum(operator_counts) / len(operator_counts) if operator_counts else 0\n",
|
| 533 |
+
"analysis['statistical_distributions']['avg_variables_per_eq'] = sum(variable_counts) / len(variable_counts) if variable_counts else 0\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"# Convert Counters to dicts for JSON serialization\n",
|
| 536 |
+
"analysis['statistical_distributions']['variable_freq'] = dict(analysis['statistical_distributions']['variable_freq'])\n",
|
| 537 |
+
"analysis['statistical_distributions']['operator_freq'] = dict(analysis['statistical_distributions']['operator_freq'])\n",
|
| 538 |
+
"analysis['syntactic_semantic']['parse_errors'] = dict(analysis['syntactic_semantic']['parse_errors'])\n",
|
| 539 |
+
"\n",
|
| 540 |
+
"# Save analysis results\n",
|
| 541 |
+
"with open(OUTPUT_ANALYSIS_FILE, 'w') as f:\n",
|
| 542 |
+
" json.dump(analysis, f, indent=2)\n",
|
| 543 |
+
"print(f\"Saved analysis results to {OUTPUT_ANALYSIS_FILE}\")\n"
|
| 544 |
+
]
|
| 545 |
+
}
|
| 546 |
+
],
|
| 547 |
+
"metadata": {
|
| 548 |
+
"kernelspec": {
|
| 549 |
+
"display_name": ".seriguela",
|
| 550 |
+
"language": "python",
|
| 551 |
+
"name": "python3"
|
| 552 |
+
},
|
| 553 |
+
"language_info": {
|
| 554 |
+
"codemirror_mode": {
|
| 555 |
+
"name": "ipython",
|
| 556 |
+
"version": 3
|
| 557 |
+
},
|
| 558 |
+
"file_extension": ".py",
|
| 559 |
+
"mimetype": "text/x-python",
|
| 560 |
+
"name": "python",
|
| 561 |
+
"nbconvert_exporter": "python",
|
| 562 |
+
"pygments_lexer": "ipython3",
|
| 563 |
+
"version": "3.11.4"
|
| 564 |
+
}
|
| 565 |
+
},
|
| 566 |
+
"nbformat": 4,
|
| 567 |
+
"nbformat_minor": 5
|
| 568 |
+
}
|
notebooks/03_RL.ipynb
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "59d6d70b",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Some weights of the model checkpoint at augustocsc/Se124M100KInfPrompt_EOS_Merged were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
|
| 14 |
+
"- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 15 |
+
"- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 16 |
+
"WARNING:root:A <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> model is loaded from 'augustocsc/Se124M100KInfPrompt_EOS_Merged', and no v_head weight is found. This IS expected if you are not resuming PPO training.\n",
|
| 17 |
+
"Some weights of the model checkpoint at augustocsc/Se124M100KInfPrompt_EOS_Merged were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
|
| 18 |
+
"- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 19 |
+
"- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 20 |
+
"WARNING:root:A <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> model is loaded from 'augustocsc/Se124M100KInfPrompt_EOS_Merged', and no v_head weight is found. This IS expected if you are not resuming PPO training.\n"
|
| 21 |
+
]
|
| 22 |
+
}
|
| 23 |
+
],
|
| 24 |
+
"source": [
|
| 25 |
+
"import os\n",
|
| 26 |
+
"import torch\n",
|
| 27 |
+
"import numpy as np\n",
|
| 28 |
+
"from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead\n",
|
| 29 |
+
"from transformers import AutoTokenizer\n",
|
| 30 |
+
"from datasets import Dataset\n",
|
| 31 |
+
"from peft import PeftModel, AutoPeftModelForCausalLM\n",
|
| 32 |
+
"import sys\n",
|
| 33 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"# Add path for Expression class\n",
|
| 36 |
+
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../classes')))\n",
|
| 37 |
+
"from expression import Expression\n",
|
| 38 |
+
"from dataset import RegressionDataset\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"# === Reward function ===\n",
|
| 41 |
+
"def compute_reward(expression_str: str) -> float:\n",
|
| 42 |
+
" try:\n",
|
| 43 |
+
" expr = Expression(expression_str)\n",
|
| 44 |
+
" \n",
|
| 45 |
+
" # Check if the expression is valid and can be evaluated\n",
|
| 46 |
+
" if expr.is_valid_on_dataset(X):\n",
|
| 47 |
+
" score = expr.fit_constants(X, y)\n",
|
| 48 |
+
" return max(0.1 , (float(score) if np.isfinite(score) else -1.0))\n",
|
| 49 |
+
" else:\n",
|
| 50 |
+
" #print(f\"Expressão inválida: {expression_str}\")\n",
|
| 51 |
+
" return -1.0\n",
|
| 52 |
+
" except Exception as e:\n",
|
| 53 |
+
" #print(f\"Erro ao avaliar expressão: {expression_str} - {e}\")\n",
|
| 54 |
+
" return -1.0\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# === Helper to extract expression ===\n",
|
| 57 |
+
"def extract_expression(response: str) -> str:\n",
|
| 58 |
+
" return response.split(\"expr: \")[1].split(\"<|endoftext|>\")[0].strip()\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# === Load Data ===\n",
|
| 61 |
+
"#reg = RegressionDataset('../data/evaluate/srsd-feynman_hard/train', 'feynman-bonus.12.txt', delimiter=' ')\n",
|
| 62 |
+
"reg = RegressionDataset('../data/evaluate/srsd-feynman_easy/train', 'feynman-i.18.16.txt', delimiter=' ')\n",
|
| 63 |
+
"X, y = reg.get_numpy()\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# === Configs ===\n",
|
| 66 |
+
"BASE_MODEL = \"augustocsc/Se124M100KInfPrompt_EOS_Merged\"\n",
|
| 67 |
+
"LORA_REPO = \"augustocsc/Se124M100KInfPrompt_EOS_Merged\"\n",
|
| 68 |
+
"TOKENIZER_REPO = LORA_REPO\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"# ppo_config = PPOConfig(\n",
|
| 71 |
+
"# #model_name=BASE_MODEL,\n",
|
| 72 |
+
"# learning_rate=1e-5,\n",
|
| 73 |
+
"# batch_size=32,\n",
|
| 74 |
+
"# mini_batch_size=8,\n",
|
| 75 |
+
"# gradient_accumulation_steps=1,\n",
|
| 76 |
+
"# )\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"model = AutoModelForCausalLMWithValueHead.from_pretrained(BASE_MODEL)\n",
|
| 80 |
+
"ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(BASE_MODEL)\n",
|
| 81 |
+
"tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 84 |
+
"model = model.to(device)\n",
|
| 85 |
+
"ref_model = ref_model.to(device)\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"import os\n",
|
| 89 |
+
"os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"import numpy as np\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"def get_safe_functions(X, functions=['log', 'sqrt', 'asin', 'tan', 'abs', 'exp', 'sin', 'cos']):\n",
|
| 95 |
+
" \"\"\"\n",
|
| 96 |
+
" Returns a list of functions from `functions` that are safe to use on all columns of X.\n",
|
| 97 |
+
"\n",
|
| 98 |
+
" Parameters:\n",
|
| 99 |
+
" X: np.ndarray of shape (n_samples, n_features)\n",
|
| 100 |
+
" functions: list of function names to check\n",
|
| 101 |
+
"\n",
|
| 102 |
+
" Returns:\n",
|
| 103 |
+
" List of function names that are safe to use given the data\n",
|
| 104 |
+
" \"\"\"\n",
|
| 105 |
+
" safe_functions = []\n",
|
| 106 |
+
"\n",
|
| 107 |
+
" for fn in functions:\n",
|
| 108 |
+
" if fn in {'sin', 'cos', 'exp', 'abs'}:\n",
|
| 109 |
+
" # These are defined for all real values\n",
|
| 110 |
+
" safe_functions.append(fn)\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" elif fn == 'log':\n",
|
| 113 |
+
" if np.all(X > 0):\n",
|
| 114 |
+
" safe_functions.append(fn)\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" elif fn == 'sqrt':\n",
|
| 117 |
+
" if np.all(X >= 0):\n",
|
| 118 |
+
" safe_functions.append(fn)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" elif fn == 'asin':\n",
|
| 121 |
+
" if np.all((X >= -1) & (X <= 1)):\n",
|
| 122 |
+
" safe_functions.append(fn)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
" elif fn == 'tan':\n",
|
| 125 |
+
" # Check if cos(x) ≈ 0 anywhere → tan(x) will explode\n",
|
| 126 |
+
" # We use np.cos to simulate tan issues (e.g., near π/2, 3π/2, etc.)\n",
|
| 127 |
+
" cos_vals = np.cos(X)\n",
|
| 128 |
+
" if np.all(np.abs(cos_vals) > 1e-6): # adjustable tolerance\n",
|
| 129 |
+
" safe_functions.append(fn)\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" # else skip unknown functions\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" return safe_functions\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"safe_functions = get_safe_functions(X)\n"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": 2,
|
| 142 |
+
"id": "9e2f618a",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [
|
| 145 |
+
{
|
| 146 |
+
"name": "stdout",
|
| 147 |
+
"output_type": "stream",
|
| 148 |
+
"text": [
|
| 149 |
+
"log, sqrt, tan, abs, exp, sin, cos\n"
|
| 150 |
+
]
|
| 151 |
+
}
|
| 152 |
+
],
|
| 153 |
+
"source": [
|
| 154 |
+
"print(', '.join(safe_functions))"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"cell_type": "code",
|
| 159 |
+
"execution_count": 1,
|
| 160 |
+
"id": "dd922d70",
|
| 161 |
+
"metadata": {},
|
| 162 |
+
"outputs": [
|
| 163 |
+
{
|
| 164 |
+
"ename": "NameError",
|
| 165 |
+
"evalue": "name 'PPOConfig' is not defined",
|
| 166 |
+
"output_type": "error",
|
| 167 |
+
"traceback": [
|
| 168 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 169 |
+
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
|
| 170 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtqdm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m ppo_config = \u001b[43mPPOConfig\u001b[49m(\n\u001b[32m 4\u001b[39m model_name=\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;66;03m# definimos o modelo manualmente\u001b[39;00m\n\u001b[32m 5\u001b[39m learning_rate=\u001b[32m1e-5\u001b[39m,\n\u001b[32m 6\u001b[39m batch_size=\u001b[32m5\u001b[39m, \u001b[38;5;66;03m# total prompts/responses por step\u001b[39;00m\n\u001b[32m 7\u001b[39m mini_batch_size=\u001b[32m32\u001b[39m, \u001b[38;5;66;03m# 4 minibatches por batch\u001b[39;00m\n\u001b[32m 8\u001b[39m gradient_accumulation_steps=\u001b[32m1\u001b[39m,\n\u001b[32m 9\u001b[39m ppo_epochs=\u001b[32m4\u001b[39m, \u001b[38;5;66;03m# 4 passes por minibatch\u001b[39;00m\n\u001b[32m 10\u001b[39m log_with=\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;66;03m# ou \"wandb\"\u001b[39;00m\n\u001b[32m 11\u001b[39m optimize_cuda_cache=\u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;66;03m# 👍 melhora uso da A100\u001b[39;00m\n\u001b[32m 12\u001b[39m )\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# === PPO Trainer ===\u001b[39;00m\n\u001b[32m 15\u001b[39m ppo_trainer = PPOTrainer(\n\u001b[32m 16\u001b[39m config=ppo_config,\n\u001b[32m 17\u001b[39m tokenizer=tokenizer,\n\u001b[32m (...)\u001b[39m\u001b[32m 20\u001b[39m \n\u001b[32m 21\u001b[39m )\n",
|
| 171 |
+
"\u001b[31mNameError\u001b[39m: name 'PPOConfig' is not defined"
|
| 172 |
+
]
|
| 173 |
+
}
|
| 174 |
+
],
|
| 175 |
+
"source": [
|
| 176 |
+
"from tqdm import tqdm\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"ppo_config = PPOConfig(\n",
|
| 179 |
+
" model_name=None, # definimos o modelo manualmente\n",
|
| 180 |
+
" learning_rate=1e-5,\n",
|
| 181 |
+
" batch_size=5, # total prompts/responses por step\n",
|
| 182 |
+
" mini_batch_size=32, # 4 minibatches por batch\n",
|
| 183 |
+
" gradient_accumulation_steps=1,\n",
|
| 184 |
+
" ppo_epochs=4, # 4 passes por minibatch\n",
|
| 185 |
+
" log_with=None, # ou \"wandb\"\n",
|
| 186 |
+
" optimize_cuda_cache=True, # 👍 melhora uso da A100\n",
|
| 187 |
+
")\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"# === PPO Trainer ===\n",
|
| 190 |
+
"ppo_trainer = PPOTrainer(\n",
|
| 191 |
+
" config=ppo_config,\n",
|
| 192 |
+
" tokenizer=tokenizer,\n",
|
| 193 |
+
" model=model,\n",
|
| 194 |
+
" ref_model=ref_model,\n",
|
| 195 |
+
" \n",
|
| 196 |
+
")\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"# Define the prompt with the safe functions\n",
|
| 199 |
+
"PROMPT = f\"\"\"\n",
|
| 200 |
+
"vars: x_1, x_2, x_3\n",
|
| 201 |
+
"oper: * +, /, **, {', '.join(safe_functions)}\n",
|
| 202 |
+
"cons: C\n",
|
| 203 |
+
"expr:\"\"\"\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"# === Dummy dataset ===\n",
|
| 206 |
+
"dummy_dataset = Dataset.from_dict({\n",
|
| 207 |
+
" \"prompt\": [PROMPT] * 5\n",
|
| 208 |
+
"})\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"# Get the device of the model\n",
|
| 212 |
+
"device = next(model.parameters()).device\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"# === PPO Training Loop ===\n",
|
| 215 |
+
"# Tokenize the prompt and convert it to tensors\n",
|
| 216 |
+
"inputs = tokenizer([PROMPT] * ppo_config.batch_size, return_tensors=\"pt\", padding=True)\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"# Move inputs to the same device as the model\n",
|
| 219 |
+
"inputs = {key: value.to(device) for key, value in inputs.items()}\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"# Convert the batch tensor into a list of individual tensors\n",
|
| 222 |
+
"queries = [inputs[\"input_ids\"][i] for i in range(inputs[\"input_ids\"].size(0))]\n",
|
| 223 |
+
"all_rewards = []\n",
|
| 224 |
+
"all_responses = []\n",
|
| 225 |
+
"for epoch in tqdm(range(10), desc=\"Training Epochs\"): # adjust as needed\n",
|
| 226 |
+
" responses = []\n",
|
| 227 |
+
" constants = []\n",
|
| 228 |
+
" rewards = []\n",
|
| 229 |
+
" for i in tqdm(range(ppo_config.batch_size), desc=\"Batch Progress\", leave=False): # Nested progress bar\n",
|
| 230 |
+
" try:\n",
|
| 231 |
+
" input_ids = inputs[\"input_ids\"][i].unsqueeze(0)\n",
|
| 232 |
+
" attention_mask = inputs[\"attention_mask\"][i].unsqueeze(0)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
" # === VALIDATION PATCH ===\n",
|
| 235 |
+
" assert torch.all((input_ids >= 0) & (input_ids < model.config.vocab_size)), \\\n",
|
| 236 |
+
" f\"Token inválido detectado: max={input_ids.max().item()}, vocab_size={model.config.vocab_size}\"\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" # (opcional)\n",
|
| 239 |
+
" model.config.pad_token_id = tokenizer.pad_token_id\n",
|
| 240 |
+
" reward = -1\n",
|
| 241 |
+
" while reward < 0:\n",
|
| 242 |
+
" output = model.generate(\n",
|
| 243 |
+
" input_ids=input_ids,\n",
|
| 244 |
+
" attention_mask=attention_mask,\n",
|
| 245 |
+
" max_new_tokens=50,\n",
|
| 246 |
+
" do_sample=True,\n",
|
| 247 |
+
" top_k=50,\n",
|
| 248 |
+
" top_p=0.95,\n",
|
| 249 |
+
" temperature=0.7,\n",
|
| 250 |
+
" eos_token_id=tokenizer.eos_token_id,\n",
|
| 251 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 252 |
+
" return_dict_in_generate=True,\n",
|
| 253 |
+
" output_scores=False\n",
|
| 254 |
+
" )\n",
|
| 255 |
+
" response_ids = output.sequences[0][input_ids.shape[1]:]\n",
|
| 256 |
+
" response = tokenizer.decode(response_ids, skip_special_tokens=True)\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" reward = compute_reward(response)\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"\n",
|
| 261 |
+
" except Exception as e:\n",
|
| 262 |
+
" print(f\"Error at index {i}: {e}\")\n",
|
| 263 |
+
" print(f\"Input IDs: {input_ids}\")\n",
|
| 264 |
+
" print(f\"Token range: min={input_ids.min()}, max={input_ids.max()}, vocab_size={model.config.vocab_size}\")\n",
|
| 265 |
+
" raise e\n",
|
| 266 |
+
"\n",
|
| 267 |
+
" responses.append(response)\n",
|
| 268 |
+
" rewards.append(reward)\n",
|
| 269 |
+
" all_responses.extend(responses)\n",
|
| 270 |
+
" all_rewards.extend(rewards)\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" #if one reward is >= .9 break\n",
|
| 273 |
+
" if any(r >= 0.9 for r in rewards):\n",
|
| 274 |
+
" print(\"Reward >= 0.9 found, stopping training.\")\n",
|
| 275 |
+
" break\n",
|
| 276 |
+
" # Compute rewards with a progress bar\n",
|
| 277 |
+
" \n",
|
| 278 |
+
" import concurrent.futures\n",
|
| 279 |
+
"\n",
|
| 280 |
+
" # # Use process-based parallelism\n",
|
| 281 |
+
" # with concurrent.futures.ProcessPoolExecutor() as executor:\n",
|
| 282 |
+
" # rewards = list(tqdm(executor.map(compute_reward, responses), total=len(responses), desc=\"Computing Rewards\", leave=False))\n",
|
| 283 |
+
" \n",
|
| 284 |
+
" #rewards = [ compute_reward(response) for response in tqdm(responses, desc=\"Computing Rewards\", leave=False)]\n",
|
| 285 |
+
" \n",
|
| 286 |
+
"\n",
|
| 287 |
+
" # Convert rewards to a list of PyTorch tensors\n",
|
| 288 |
+
" rewards = [torch.tensor(reward, dtype=torch.float32, device=device) for reward in rewards]\n",
|
| 289 |
+
" \n",
|
| 290 |
+
" # Ensure responses are also tokenized and converted to tensors\n",
|
| 291 |
+
" responses = [tokenizer(response, return_tensors=\"pt\", padding=True)[\"input_ids\"].squeeze(0).to(device) for response in responses]\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" # Pass the tokenized tensors to ppo_trainer.step()\n",
|
| 294 |
+
" ppo_trainer.step(queries, responses, rewards)\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" # Log top expressions\n",
|
| 297 |
+
" top_k = 3\n",
|
| 298 |
+
" sorted_responses = sorted(zip(responses, rewards), key=lambda x: -x[1])\n",
|
| 299 |
+
" print(f\"\\nEpoch {epoch + 1} melhores expressões:\")\n",
|
| 300 |
+
" for i, (expr, score) in enumerate(sorted_responses[:top_k]):\n",
|
| 301 |
+
" print(f\"{i+1}. {tokenizer.decode(expr, skip_special_tokens=True)} -> R² = {score:.4f}\")\n",
|
| 302 |
+
" # Print average, median, and std of rewards\n",
|
| 303 |
+
" avg_reward = torch.mean(torch.stack(rewards)).item()\n",
|
| 304 |
+
" median_reward = torch.median(torch.stack(rewards)).item()\n",
|
| 305 |
+
" count_invalid = sum(1 for r in rewards if r == -1.0)\n",
|
| 306 |
+
" print(f\"Average Reward: {avg_reward:.4f}, Median Reward: {median_reward:.4f}, Invalid Count: {count_invalid}\")\n",
|
| 307 |
+
"\n"
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"cell_type": "markdown",
|
| 312 |
+
"id": "70a60613",
|
| 313 |
+
"metadata": {},
|
| 314 |
+
"source": []
|
| 315 |
+
}
|
| 316 |
+
],
|
| 317 |
+
"metadata": {
|
| 318 |
+
"kernelspec": {
|
| 319 |
+
"display_name": ".seriguela",
|
| 320 |
+
"language": "python",
|
| 321 |
+
"name": "python3"
|
| 322 |
+
},
|
| 323 |
+
"language_info": {
|
| 324 |
+
"codemirror_mode": {
|
| 325 |
+
"name": "ipython",
|
| 326 |
+
"version": 3
|
| 327 |
+
},
|
| 328 |
+
"file_extension": ".py",
|
| 329 |
+
"mimetype": "text/x-python",
|
| 330 |
+
"name": "python",
|
| 331 |
+
"nbconvert_exporter": "python",
|
| 332 |
+
"pygments_lexer": "ipython3",
|
| 333 |
+
"version": "3.11.4"
|
| 334 |
+
}
|
| 335 |
+
},
|
| 336 |
+
"nbformat": 4,
|
| 337 |
+
"nbformat_minor": 5
|
| 338 |
+
}
|
notebooks/04_merging_model.ipynb
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 4,
|
| 6 |
+
"id": "86149941",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"data": {
|
| 11 |
+
"text/plain": [
|
| 12 |
+
"('./modelo_final_para_ppo/tokenizer_config.json',\n",
|
| 13 |
+
" './modelo_final_para_ppo/special_tokens_map.json',\n",
|
| 14 |
+
" './modelo_final_para_ppo/vocab.json',\n",
|
| 15 |
+
" './modelo_final_para_ppo/merges.txt',\n",
|
| 16 |
+
" './modelo_final_para_ppo/added_tokens.json',\n",
|
| 17 |
+
" './modelo_final_para_ppo/tokenizer.json')"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
"execution_count": 4,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"output_type": "execute_result"
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"source": [
|
| 26 |
+
"# ===============================\n",
|
| 27 |
+
"# 🚀 LoRA Merge + ValueHead + Test\n",
|
| 28 |
+
"# ===============================\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"# ✅ Imports\n",
|
| 32 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 33 |
+
"from peft import PeftModel\n",
|
| 34 |
+
"from trl import AutoModelForCausalLMWithValueHead\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"# === Configurações ===\n",
|
| 37 |
+
"LORA_REPO = \"augustocsc/Se124M500KInfPrompt_EOS\"\n",
|
| 38 |
+
"BASE_MODEL = \"gpt2\"\n",
|
| 39 |
+
"OUTPUT_DIR = \"./modelo_final_para_ppo\"\n",
|
| 40 |
+
"MODEL_HUB = \"augustocsc/Se124M500KInfPrompt_EOS_Merged\"\n",
|
| 41 |
+
"# === Carregar o tokenizer correto ===\n",
|
| 42 |
+
"tokenizer = AutoTokenizer.from_pretrained(LORA_REPO)\n",
|
| 43 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# === Carregar modelo base e ajustar os embeddings ===\n",
|
| 46 |
+
"base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)\n",
|
| 47 |
+
"base_model.resize_token_embeddings(len(tokenizer)) # Corrige shape para 50258\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"# Load the PEFT model\n",
|
| 50 |
+
"peft_model = PeftModel.from_pretrained(base_model, LORA_REPO)\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"# === Merge das LoRA weights (corretamente) ===\n",
|
| 53 |
+
"merged_model = peft_model.merge_and_unload()\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"# === Adicionar Value Head ao modelo mergeado ===\n",
|
| 56 |
+
"model = AutoModelForCausalLMWithValueHead.from_pretrained(merged_model)\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"# === Salvar modelo final para PPO ===\n",
|
| 59 |
+
"model.save_pretrained(OUTPUT_DIR)\n",
|
| 60 |
+
"tokenizer.save_pretrained(OUTPUT_DIR)\n"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": 5,
|
| 66 |
+
"id": "e921394e",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [
|
| 69 |
+
{
|
| 70 |
+
"data": {
|
| 71 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 72 |
+
"model_id": "0d38506bf99e418eb92d977159c9550b",
|
| 73 |
+
"version_major": 2,
|
| 74 |
+
"version_minor": 0
|
| 75 |
+
},
|
| 76 |
+
"text/plain": [
|
| 77 |
+
"model.safetensors: 0%| | 0.00/498M [00:00<?, ?B/s]"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"output_type": "display_data"
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"data": {
|
| 85 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 86 |
+
"model_id": "45094d0c70c344acbb4f800968a9eb55",
|
| 87 |
+
"version_major": 2,
|
| 88 |
+
"version_minor": 0
|
| 89 |
+
},
|
| 90 |
+
"text/plain": [
|
| 91 |
+
"README.md: 0%| | 0.00/5.17k [00:00<?, ?B/s]"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"output_type": "display_data"
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"data": {
|
| 99 |
+
"text/plain": [
|
| 100 |
+
"CommitInfo(commit_url='https://huggingface.co/augustocsc/Se124M500KInfPrompt_EOS_Merged/commit/175b8a2750f170839ce04cb3dab9b1740fc83e92', commit_message='Upload tokenizer', commit_description='', oid='175b8a2750f170839ce04cb3dab9b1740fc83e92', pr_url=None, repo_url=RepoUrl('https://huggingface.co/augustocsc/Se124M500KInfPrompt_EOS_Merged', endpoint='https://huggingface.co', repo_type='model', repo_id='augustocsc/Se124M500KInfPrompt_EOS_Merged'), pr_revision=None, pr_num=None)"
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
"execution_count": 5,
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"output_type": "execute_result"
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"source": [
|
| 109 |
+
"model.push_to_hub(MODEL_HUB)\n",
|
| 110 |
+
"tokenizer.push_to_hub(MODEL_HUB)"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": 6,
|
| 116 |
+
"id": "34b6777d",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [
|
| 119 |
+
{
|
| 120 |
+
"name": "stderr",
|
| 121 |
+
"output_type": "stream",
|
| 122 |
+
"text": [
|
| 123 |
+
"Some weights of the model checkpoint at augustocsc/Se124M100KInfPrompt_EOS_Merged were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
|
| 124 |
+
"- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 125 |
+
"- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 126 |
+
"WARNING:root:A <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> model is loaded from 'augustocsc/Se124M100KInfPrompt_EOS_Merged', and no v_head weight is found. This IS expected if you are not resuming PPO training.\n",
|
| 127 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
| 128 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
|
| 129 |
+
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"name": "stdout",
|
| 134 |
+
"output_type": "stream",
|
| 135 |
+
"text": [
|
| 136 |
+
"🧪 Resposta do modelo:\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10\n",
|
| 140 |
+
"oper: *, **, +, -, /\n",
|
| 141 |
+
"cons: C\n",
|
| 142 |
+
"expr: x_1 + x_2 + C*x_8 + C*x_5**C<|endoftext|>\n"
|
| 143 |
+
]
|
| 144 |
+
}
|
| 145 |
+
],
|
| 146 |
+
"source": [
|
| 147 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 148 |
+
"from peft import PeftModel\n",
|
| 149 |
+
"from trl import AutoModelForCausalLMWithValueHead\n",
|
| 150 |
+
"# 🔁 Recarregar o modelo já mergeado + value head\n",
|
| 151 |
+
"from trl import AutoModelForCausalLMWithValueHead\n",
|
| 152 |
+
"MODEL_HUB = \"augustocsc/Se124M100KInfPrompt_EOS_Merged\"\n",
|
| 153 |
+
"#load model\n",
|
| 154 |
+
"model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_HUB)\n",
|
| 155 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_HUB)\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"# 🔁 Prompt de teste\n",
|
| 158 |
+
"PROMPT = \"\"\"\n",
|
| 159 |
+
"vars: x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8, x_9, x_10\n",
|
| 160 |
+
"oper: *, **, +, -, /\n",
|
| 161 |
+
"cons: C\n",
|
| 162 |
+
"expr:\"\"\"\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"device = model.pretrained_model.device # 👈 modelo base dentro do wrapper\n",
|
| 165 |
+
"input_ids = tokenizer(PROMPT, return_tensors=\"pt\").input_ids.to(device)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"# 🔮 Geração\n",
|
| 168 |
+
"gen_tokens = output = model.generate(\n",
|
| 169 |
+
" input_ids=input_ids,\n",
|
| 170 |
+
" max_new_tokens=50,\n",
|
| 171 |
+
" do_sample=True,\n",
|
| 172 |
+
" top_k=50,\n",
|
| 173 |
+
" top_p=0.95,\n",
|
| 174 |
+
" temperature=0.7,\n",
|
| 175 |
+
" \n",
|
| 176 |
+
" )\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"# Mostrar resposta\n",
|
| 179 |
+
"response = tokenizer.decode(gen_tokens[0], skip_special_tokens=False)\n",
|
| 180 |
+
"print(\"🧪 Resposta do modelo:\\n\")\n",
|
| 181 |
+
"print(response)\n"
|
| 182 |
+
]
|
| 183 |
+
}
|
| 184 |
+
],
|
| 185 |
+
"metadata": {
|
| 186 |
+
"kernelspec": {
|
| 187 |
+
"display_name": ".seriguela",
|
| 188 |
+
"language": "python",
|
| 189 |
+
"name": "python3"
|
| 190 |
+
},
|
| 191 |
+
"language_info": {
|
| 192 |
+
"codemirror_mode": {
|
| 193 |
+
"name": "ipython",
|
| 194 |
+
"version": 3
|
| 195 |
+
},
|
| 196 |
+
"file_extension": ".py",
|
| 197 |
+
"mimetype": "text/x-python",
|
| 198 |
+
"name": "python",
|
| 199 |
+
"nbconvert_exporter": "python",
|
| 200 |
+
"pygments_lexer": "ipython3",
|
| 201 |
+
"version": "3.11.4"
|
| 202 |
+
}
|
| 203 |
+
},
|
| 204 |
+
"nbformat": 4,
|
| 205 |
+
"nbformat_minor": 5
|
| 206 |
+
}
|
out.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Special constants found: [1]
|
| 2 |
+
Found 1 constants in the expression: tan(x_1**C + cos(x_1))
|
| 3 |
+
Testing expression validity with constants: [[1.0]]
|
| 4 |
+
Expression is valid on dataset.
|
| 5 |
+
Bounds for optimization: [(1, 3)]
|
| 6 |
+
Fitted constants: [1.0]
|
| 7 |
+
R2 score: -1.8028651105117532e-05
|
out2.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
# Core Hugging Face e Deep Learning
|
| 3 |
+
transformers==4.51.3
|
| 4 |
+
torch==2.5.1
|
| 5 |
+
torchvision==0.20.1
|
| 6 |
+
torchaudio==2.5.1
|
| 7 |
+
|
| 8 |
+
accelerate==1.6.0
|
| 9 |
+
python-dotenv==1.0.1
|
| 10 |
+
datasets==3.5.0
|
| 11 |
+
evaluate==0.4.1
|
| 12 |
+
huggingface-hub==0.30.2
|
| 13 |
+
|
| 14 |
+
# Parameter-Efficient Fine-Tuning (PEFT)
|
| 15 |
+
peft==0.15.1
|
| 16 |
+
|
| 17 |
+
# Avaliação e utilitários
|
| 18 |
+
scikit-learn==1.6.1
|
| 19 |
+
numpy==1.26.4
|
| 20 |
+
pandas==2.2.1
|
| 21 |
+
tqdm==4.67.1
|
| 22 |
+
sympy==1.13.1
|
| 23 |
+
regex==2024.11.6
|
| 24 |
+
|
| 25 |
+
# Logging e visualização
|
| 26 |
+
tensorboard==2.16.2
|
| 27 |
+
wandb>=0.24.1 # Versão atualizada para suportar novo formato de API key (wandb_v1_...)
|
| 28 |
+
|
| 29 |
+
# Fine-tuning avançado (SFT, DPO, etc.)
|
| 30 |
+
trl==0.16.1
|
scripts/aws/analyze_model.sh
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Automatic Model Analysis Script
|
| 3 |
+
# Runs evaluation and generation analysis after training
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Colors
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
YELLOW='\033[1;33m'
|
| 10 |
+
BLUE='\033[0;34m'
|
| 11 |
+
NC='\033[0m'
|
| 12 |
+
|
| 13 |
+
print_status() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
| 14 |
+
print_header() { echo -e "\n${BLUE}========================================\n$1\n========================================${NC}\n"; }
|
| 15 |
+
|
| 16 |
+
# Parameters
|
| 17 |
+
MODEL_PATH="${1:-./output/Se124M_700K_infix}"
|
| 18 |
+
DATA_COLUMN="${2:-i_prompt_n}"
|
| 19 |
+
DATASET_REPO="augustocsc/sintetico_natural"
|
| 20 |
+
DATA_DIR="700K"
|
| 21 |
+
NUM_SAMPLES=500
|
| 22 |
+
NUM_GENERATIONS=100
|
| 23 |
+
|
| 24 |
+
# Directories
|
| 25 |
+
PROJECT_DIR="/home/ubuntu/seriguela"
|
| 26 |
+
OUTPUT_DIR="$HOME/analysis_results_$(date +%Y%m%d_%H%M%S)"
|
| 27 |
+
mkdir -p "$OUTPUT_DIR"
|
| 28 |
+
|
| 29 |
+
cd "$PROJECT_DIR"
|
| 30 |
+
source venv/bin/activate
|
| 31 |
+
|
| 32 |
+
print_header "Automatic Model Analysis"
|
| 33 |
+
print_status "Model: $MODEL_PATH"
|
| 34 |
+
print_status "Output: $OUTPUT_DIR"
|
| 35 |
+
echo ""
|
| 36 |
+
|
| 37 |
+
# =============================================================================
|
| 38 |
+
# 1. EVALUATE MODEL
|
| 39 |
+
# =============================================================================
|
| 40 |
+
print_header "Step 1: Model Evaluation"
|
| 41 |
+
print_status "Running evaluation on $NUM_SAMPLES samples..."
|
| 42 |
+
|
| 43 |
+
python scripts/evaluate.py \
|
| 44 |
+
--model_path "$MODEL_PATH" \
|
| 45 |
+
--dataset_repo_id "$DATASET_REPO" \
|
| 46 |
+
--data_dir "$DATA_DIR" \
|
| 47 |
+
--data_column "$DATA_COLUMN" \
|
| 48 |
+
--num_samples "$NUM_SAMPLES" \
|
| 49 |
+
--output_dir "$OUTPUT_DIR/evaluation" \
|
| 50 |
+
--temperature 0.7 \
|
| 51 |
+
--seed 42 \
|
| 52 |
+
2>&1 | tee "$OUTPUT_DIR/evaluation.log"
|
| 53 |
+
|
| 54 |
+
if [ $? -eq 0 ]; then
|
| 55 |
+
print_status "✅ Evaluation completed"
|
| 56 |
+
else
|
| 57 |
+
print_status "⚠️ Evaluation had issues"
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
# =============================================================================
|
| 61 |
+
# 2. GENERATE SAMPLES
|
| 62 |
+
# =============================================================================
|
| 63 |
+
print_header "Step 2: Sample Generation & Validation"
|
| 64 |
+
print_status "Generating $NUM_GENERATIONS samples with validation..."
|
| 65 |
+
|
| 66 |
+
python scripts/generate.py \
|
| 67 |
+
--model_path "$MODEL_PATH" \
|
| 68 |
+
--num_generations "$NUM_GENERATIONS" \
|
| 69 |
+
--validate \
|
| 70 |
+
--output_file "$OUTPUT_DIR/generations.txt" \
|
| 71 |
+
--temperature 0.8 \
|
| 72 |
+
--top_p 0.95 \
|
| 73 |
+
--seed 42 \
|
| 74 |
+
2>&1 | tee "$OUTPUT_DIR/generation.log"
|
| 75 |
+
|
| 76 |
+
if [ $? -eq 0 ]; then
|
| 77 |
+
print_status "✅ Generation completed"
|
| 78 |
+
else
|
| 79 |
+
print_status "⚠️ Generation had issues"
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
# =============================================================================
|
| 83 |
+
# 3. ANALYZE TRAINING LOGS
|
| 84 |
+
# =============================================================================
|
| 85 |
+
print_header "Step 3: Training Log Analysis"
|
| 86 |
+
print_status "Extracting training metrics..."
|
| 87 |
+
|
| 88 |
+
TRAINING_LOG="$HOME/training_success.log"
|
| 89 |
+
|
| 90 |
+
if [ -f "$TRAINING_LOG" ]; then
|
| 91 |
+
# Extract loss values
|
| 92 |
+
grep -E "'loss':|train_loss|eval_loss" "$TRAINING_LOG" > "$OUTPUT_DIR/training_metrics.txt" 2>/dev/null || true
|
| 93 |
+
|
| 94 |
+
# Extract epoch summaries
|
| 95 |
+
grep -E "epoch.*loss" "$TRAINING_LOG" | tail -20 > "$OUTPUT_DIR/epoch_summary.txt" 2>/dev/null || true
|
| 96 |
+
|
| 97 |
+
# Count total steps
|
| 98 |
+
TOTAL_STEPS=$(grep -E "[0-9]+/21882" "$TRAINING_LOG" | tail -1 | sed 's/.*\([0-9]\+\)\/21882.*/\1/' || echo "0")
|
| 99 |
+
|
| 100 |
+
print_status "Total training steps: $TOTAL_STEPS"
|
| 101 |
+
fi
|
| 102 |
+
|
| 103 |
+
# =============================================================================
|
| 104 |
+
# 4. CREATE SUMMARY REPORT
|
| 105 |
+
# =============================================================================
|
| 106 |
+
print_header "Step 4: Creating Analysis Report"
|
| 107 |
+
|
| 108 |
+
cat > "$OUTPUT_DIR/ANALYSIS_REPORT.md" << 'EOFREPORT'
|
| 109 |
+
# Training Analysis Report
|
| 110 |
+
**Generated:** $(date)
|
| 111 |
+
|
| 112 |
+
## 📊 Model Information
|
| 113 |
+
- **Architecture:** GPT-2 Small (124M parameters)
|
| 114 |
+
- **Training Method:** LoRA (294K trainable parameters, 0.24%)
|
| 115 |
+
- **Dataset:** 700K samples (infix notation)
|
| 116 |
+
- **Training Duration:** $(grep "Training Duration:" $HOME/training_notification.txt 2>/dev/null | head -1 || echo "N/A")
|
| 117 |
+
|
| 118 |
+
## 📈 Training Metrics
|
| 119 |
+
|
| 120 |
+
### Loss Progression
|
| 121 |
+
```
|
| 122 |
+
$(tail -20 $OUTPUT_DIR/training_metrics.txt 2>/dev/null || echo "No metrics available")
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### Epoch Summary
|
| 126 |
+
```
|
| 127 |
+
$(cat $OUTPUT_DIR/epoch_summary.txt 2>/dev/null || echo "No epoch data available")
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## 🎯 Evaluation Results
|
| 131 |
+
|
| 132 |
+
### Performance Metrics
|
| 133 |
+
```
|
| 134 |
+
$(grep -E "Accuracy|Loss|Perplexity" $OUTPUT_DIR/evaluation.log 2>/dev/null || echo "Check evaluation.log for details")
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### Sample Predictions
|
| 138 |
+
```
|
| 139 |
+
$(head -50 $OUTPUT_DIR/evaluation/*.txt 2>/dev/null | head -20 || echo "No evaluation samples found")
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## 🔮 Generation Quality
|
| 143 |
+
|
| 144 |
+
### Validation Results
|
| 145 |
+
```
|
| 146 |
+
$(grep -E "Valid:|Success|Failed" $OUTPUT_DIR/generation.log | head -20 || echo "Check generation.log")
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### Sample Generations
|
| 150 |
+
```
|
| 151 |
+
$(head -30 $OUTPUT_DIR/generations.txt 2>/dev/null || echo "No generations file found")
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## 📁 Output Files
|
| 155 |
+
- Evaluation results: `evaluation/`
|
| 156 |
+
- Generated samples: `generations.txt`
|
| 157 |
+
- Full logs: `evaluation.log`, `generation.log`
|
| 158 |
+
- Training metrics: `training_metrics.txt`
|
| 159 |
+
|
| 160 |
+
## 🔗 Resources
|
| 161 |
+
- **Wandb Dashboard:** https://wandb.ai/symbolic-gression/seriguela_700K_test
|
| 162 |
+
- **HuggingFace Model:** https://huggingface.co/augustocsc/Se124M_700K_infix
|
| 163 |
+
- **Analysis Directory:** $OUTPUT_DIR
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
*Generated automatically by analyze_model.sh*
|
| 167 |
+
EOFREPORT
|
| 168 |
+
|
| 169 |
+
# Evaluate the report with actual values
|
| 170 |
+
eval "cat > \"$OUTPUT_DIR/ANALYSIS_REPORT.md\" << 'EOFREPORT'
|
| 171 |
+
$(cat "$OUTPUT_DIR/ANALYSIS_REPORT.md")
|
| 172 |
+
EOFREPORT"
|
| 173 |
+
|
| 174 |
+
print_status "Report created: $OUTPUT_DIR/ANALYSIS_REPORT.md"
|
| 175 |
+
|
| 176 |
+
# =============================================================================
|
| 177 |
+
# 5. FINAL SUMMARY
|
| 178 |
+
# =============================================================================
|
| 179 |
+
print_header "Analysis Complete!"
|
| 180 |
+
echo ""
|
| 181 |
+
print_status "All results saved to: $OUTPUT_DIR"
|
| 182 |
+
print_status "Main report: $OUTPUT_DIR/ANALYSIS_REPORT.md"
|
| 183 |
+
echo ""
|
| 184 |
+
print_status "Key files:"
|
| 185 |
+
echo " - Evaluation: $OUTPUT_DIR/evaluation.log"
|
| 186 |
+
echo " - Generation: $OUTPUT_DIR/generation.log"
|
| 187 |
+
echo " - Metrics: $OUTPUT_DIR/training_metrics.txt"
|
| 188 |
+
echo " - Report: $OUTPUT_DIR/ANALYSIS_REPORT.md"
|
| 189 |
+
echo ""
|
| 190 |
+
print_status "View the full report with:"
|
| 191 |
+
echo " cat $OUTPUT_DIR/ANALYSIS_REPORT.md"
|
| 192 |
+
echo ""
|
| 193 |
+
|
| 194 |
+
# Create a quick summary
|
| 195 |
+
EVAL_SUCCESS=$(grep -c "✅" "$OUTPUT_DIR/evaluation.log" 2>/dev/null || echo "0")
|
| 196 |
+
GEN_SUCCESS=$(grep -c "Valid" "$OUTPUT_DIR/generation.log" 2>/dev/null || echo "0")
|
| 197 |
+
|
| 198 |
+
print_header "Quick Summary"
|
| 199 |
+
echo "Evaluation samples processed: $NUM_SAMPLES"
|
| 200 |
+
echo "Generations created: $NUM_GENERATIONS"
|
| 201 |
+
echo "Check logs for detailed metrics and quality assessment"
|
| 202 |
+
echo ""
|
| 203 |
+
print_status "Done!"
|
scripts/aws/evaluate_models.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to evaluate two models on AWS and compare results
|
| 3 |
+
# This script compares the original model (without end token) with the v2 model (with end token)
|
| 4 |
+
# Usage: bash scripts/aws/evaluate_models.sh
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
echo "=========================================="
|
| 9 |
+
echo "Model Comparison: v1 vs v2"
|
| 10 |
+
echo "=========================================="
|
| 11 |
+
echo "Model 1: augustocsc/Se124M_700K_infix (original)"
|
| 12 |
+
echo "Model 2: augustocsc/Se124M_700K_infix_v2 (with <|endofex|> token)"
|
| 13 |
+
echo "=========================================="
|
| 14 |
+
echo ""
|
| 15 |
+
|
| 16 |
+
# Activate virtual environment
|
| 17 |
+
source ~/seriguela/venv/bin/activate
|
| 18 |
+
cd ~/seriguela
|
| 19 |
+
|
| 20 |
+
# Set up logging
|
| 21 |
+
LOG_FILE="evaluation_$(date +%Y%m%d_%H%M%S).log"
|
| 22 |
+
exec > >(tee -a "$LOG_FILE") 2>&1
|
| 23 |
+
|
| 24 |
+
echo "[$(date)] Starting evaluation..."
|
| 25 |
+
echo ""
|
| 26 |
+
|
| 27 |
+
# Check GPU availability
|
| 28 |
+
echo "Checking GPU..."
|
| 29 |
+
if nvidia-smi &> /dev/null; then
|
| 30 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader
|
| 31 |
+
echo ""
|
| 32 |
+
else
|
| 33 |
+
echo "WARNING: No GPU detected. Evaluation will be slow."
|
| 34 |
+
echo ""
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
# Run comparison
|
| 38 |
+
echo "Running model comparison..."
|
| 39 |
+
echo "This will evaluate both models on 500 samples from the test set."
|
| 40 |
+
echo ""
|
| 41 |
+
|
| 42 |
+
python scripts/compare_models.py \
|
| 43 |
+
--model1 augustocsc/Se124M_700K_infix \
|
| 44 |
+
--model2 augustocsc/Se124M_700K_infix_v2 \
|
| 45 |
+
--model1_name "Original (no end token)" \
|
| 46 |
+
--model2_name "V2 (with <|endofex|>)" \
|
| 47 |
+
--num_samples 500 \
|
| 48 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 49 |
+
--data_dir 700K \
|
| 50 |
+
--data_column i_prompt_n \
|
| 51 |
+
--output_dir ./evaluation_results/comparison
|
| 52 |
+
|
| 53 |
+
echo ""
|
| 54 |
+
echo "=========================================="
|
| 55 |
+
echo "Evaluation Complete!"
|
| 56 |
+
echo "=========================================="
|
| 57 |
+
echo "Results saved to: ./evaluation_results/comparison"
|
| 58 |
+
echo "Log file: $LOG_FILE"
|
| 59 |
+
echo ""
|
| 60 |
+
echo "To view results:"
|
| 61 |
+
echo " cat ./evaluation_results/comparison/comparison_*.json | jq"
|
| 62 |
+
echo ""
|
scripts/aws/launch_evaluation_instance.sh
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to launch AWS instance for model evaluation
|
| 3 |
+
# Evaluates two models: original (Se124M_700K_infix) vs v2 (with end token)
|
| 4 |
+
# Usage: ./launch_evaluation_instance.sh [--hf-token TOKEN]
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
# Colors
|
| 9 |
+
GREEN='\033[0;32m'
|
| 10 |
+
YELLOW='\033[1;33m'
|
| 11 |
+
RED='\033[0;31m'
|
| 12 |
+
BLUE='\033[0;34m'
|
| 13 |
+
NC='\033[0m'
|
| 14 |
+
|
| 15 |
+
print_status() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
| 16 |
+
print_warning() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
| 17 |
+
print_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
| 18 |
+
|
| 19 |
+
# Default configuration
|
| 20 |
+
INSTANCE_TYPE="g5.xlarge"
|
| 21 |
+
AMI_ID=""
|
| 22 |
+
KEY_NAME=""
|
| 23 |
+
SECURITY_GROUP=""
|
| 24 |
+
REGION=$(aws configure get region 2>/dev/null || echo "us-east-1")
|
| 25 |
+
VOLUME_SIZE=80
|
| 26 |
+
INSTANCE_NAME="seriguela-evaluation"
|
| 27 |
+
HF_TOKEN=""
|
| 28 |
+
|
| 29 |
+
# Parse arguments
|
| 30 |
+
while [[ $# -gt 0 ]]; do
|
| 31 |
+
case $1 in
|
| 32 |
+
--hf-token) HF_TOKEN="$2"; shift 2;;
|
| 33 |
+
--instance-type) INSTANCE_TYPE="$2"; shift 2;;
|
| 34 |
+
--key-name) KEY_NAME="$2"; shift 2;;
|
| 35 |
+
--help)
|
| 36 |
+
echo "Usage: $0 [OPTIONS]"
|
| 37 |
+
echo "Options:"
|
| 38 |
+
echo " --hf-token TOKEN HuggingFace token (optional, for accessing models)"
|
| 39 |
+
echo " --instance-type TYPE Instance type (default: g5.xlarge)"
|
| 40 |
+
echo " --key-name NAME SSH key pair name"
|
| 41 |
+
echo ""
|
| 42 |
+
echo "Example:"
|
| 43 |
+
echo " $0 --hf-token hf_xxx"
|
| 44 |
+
exit 0;;
|
| 45 |
+
*) echo "Unknown option: $1"; exit 1;;
|
| 46 |
+
esac
|
| 47 |
+
done
|
| 48 |
+
|
| 49 |
+
if [ -z "$HF_TOKEN" ]; then
|
| 50 |
+
print_warning "HuggingFace token not provided. Public models will still work."
|
| 51 |
+
print_warning "Get your token from: https://huggingface.co/settings/tokens"
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
print_status "Launching Seriguela evaluation instance..."
|
| 55 |
+
|
| 56 |
+
# Find Deep Learning AMI
|
| 57 |
+
print_status "Finding Deep Learning AMI..."
|
| 58 |
+
AMI_ID=$(aws ec2 describe-images \
|
| 59 |
+
--owners amazon \
|
| 60 |
+
--filters "Name=name,Values=*Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04)*" \
|
| 61 |
+
--query "Images | sort_by(@, &CreationDate) | [-1].ImageId" \
|
| 62 |
+
--output text)
|
| 63 |
+
|
| 64 |
+
if [ -z "$AMI_ID" ] || [ "$AMI_ID" == "None" ]; then
|
| 65 |
+
print_error "Could not find Deep Learning AMI"
|
| 66 |
+
exit 1
|
| 67 |
+
fi
|
| 68 |
+
print_status "Using AMI: $AMI_ID"
|
| 69 |
+
|
| 70 |
+
# Find or select key pair
|
| 71 |
+
if [ -z "$KEY_NAME" ]; then
|
| 72 |
+
KEY_NAME=$(aws ec2 describe-key-pairs --query "KeyPairs[0].KeyName" --output text 2>/dev/null)
|
| 73 |
+
fi
|
| 74 |
+
if [ -z "$KEY_NAME" ] || [ "$KEY_NAME" == "None" ]; then
|
| 75 |
+
print_error "No SSH key pair found. Create one first or specify with --key-name"
|
| 76 |
+
exit 1
|
| 77 |
+
fi
|
| 78 |
+
print_status "Using key pair: $KEY_NAME"
|
| 79 |
+
|
| 80 |
+
# Find or create security group
|
| 81 |
+
SECURITY_GROUP=$(aws ec2 describe-security-groups \
|
| 82 |
+
--filters "Name=group-name,Values=seriguela-sg" \
|
| 83 |
+
--query "SecurityGroups[0].GroupId" \
|
| 84 |
+
--output text 2>/dev/null)
|
| 85 |
+
|
| 86 |
+
if [ -z "$SECURITY_GROUP" ] || [ "$SECURITY_GROUP" == "None" ]; then
|
| 87 |
+
print_status "Creating security group..."
|
| 88 |
+
SECURITY_GROUP=$(aws ec2 create-security-group \
|
| 89 |
+
--group-name seriguela-sg \
|
| 90 |
+
--description "Security group for Seriguela" \
|
| 91 |
+
--query "GroupId" --output text)
|
| 92 |
+
|
| 93 |
+
# Get current IP and add SSH rule
|
| 94 |
+
MY_IP=$(curl -s ifconfig.me)
|
| 95 |
+
aws ec2 authorize-security-group-ingress \
|
| 96 |
+
--group-id "$SECURITY_GROUP" \
|
| 97 |
+
--protocol tcp --port 22 \
|
| 98 |
+
--cidr "${MY_IP}/32"
|
| 99 |
+
print_status "Created security group with SSH access from $MY_IP"
|
| 100 |
+
else
|
| 101 |
+
# Update security group with current IP
|
| 102 |
+
MY_IP=$(curl -s ifconfig.me)
|
| 103 |
+
aws ec2 authorize-security-group-ingress \
|
| 104 |
+
--group-id "$SECURITY_GROUP" \
|
| 105 |
+
--protocol tcp --port 22 \
|
| 106 |
+
--cidr "${MY_IP}/32" 2>/dev/null || true
|
| 107 |
+
fi
|
| 108 |
+
print_status "Using security group: $SECURITY_GROUP"
|
| 109 |
+
|
| 110 |
+
# Create user-data script for automatic setup
|
| 111 |
+
USER_DATA=$(cat << 'USERDATA'
|
| 112 |
+
#!/bin/bash
|
| 113 |
+
exec > /var/log/user-data.log 2>&1
|
| 114 |
+
set -x
|
| 115 |
+
|
| 116 |
+
echo "=========================================="
|
| 117 |
+
echo "Seriguela Evaluation Instance Setup"
|
| 118 |
+
echo "Started: $(date)"
|
| 119 |
+
echo "=========================================="
|
| 120 |
+
|
| 121 |
+
# Wait for cloud-init to complete
|
| 122 |
+
cloud-init status --wait
|
| 123 |
+
|
| 124 |
+
# Setup as ubuntu user
|
| 125 |
+
sudo -u ubuntu bash << 'UBUNTUSETUP'
|
| 126 |
+
cd /home/ubuntu
|
| 127 |
+
|
| 128 |
+
echo "[1/7] Installing system dependencies..."
|
| 129 |
+
sudo apt-get update -qq
|
| 130 |
+
sudo apt-get install -y -qq python3-venv python3-pip git jq
|
| 131 |
+
|
| 132 |
+
echo "[2/7] Cloning repository..."
|
| 133 |
+
git clone https://github.com/augustocsc/seriguela.git
|
| 134 |
+
cd seriguela
|
| 135 |
+
|
| 136 |
+
echo "[3/7] Creating virtual environment..."
|
| 137 |
+
python3 -m venv venv
|
| 138 |
+
source venv/bin/activate
|
| 139 |
+
|
| 140 |
+
echo "[4/7] Upgrading pip..."
|
| 141 |
+
pip install --upgrade pip -q
|
| 142 |
+
|
| 143 |
+
echo "[5/7] Installing requirements..."
|
| 144 |
+
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 -q
|
| 145 |
+
|
| 146 |
+
echo "[6/7] Testing setup..."
|
| 147 |
+
python3 << 'PYCHECK'
|
| 148 |
+
import sys
|
| 149 |
+
print("Testing imports...")
|
| 150 |
+
try:
|
| 151 |
+
import transformers
|
| 152 |
+
print(f"✅ transformers {transformers.__version__}")
|
| 153 |
+
import torch
|
| 154 |
+
print(f"✅ torch {torch.__version__}")
|
| 155 |
+
print(f"✅ CUDA available: {torch.cuda.is_available()}")
|
| 156 |
+
import peft
|
| 157 |
+
print(f"✅ peft {peft.__version__}")
|
| 158 |
+
import datasets
|
| 159 |
+
print(f"✅ datasets {datasets.__version__}")
|
| 160 |
+
except ImportError as e:
|
| 161 |
+
print(f"❌ Import failed: {e}")
|
| 162 |
+
sys.exit(1)
|
| 163 |
+
PYCHECK
|
| 164 |
+
|
| 165 |
+
if [ $? -ne 0 ]; then
|
| 166 |
+
echo "❌ Package validation failed"
|
| 167 |
+
exit 1
|
| 168 |
+
fi
|
| 169 |
+
|
| 170 |
+
echo "[7/7] Checking GPU..."
|
| 171 |
+
if nvidia-smi &> /dev/null; then
|
| 172 |
+
echo "✅ GPU detected:"
|
| 173 |
+
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
| 174 |
+
else
|
| 175 |
+
echo "⚠️ No GPU detected (will be slower)"
|
| 176 |
+
fi
|
| 177 |
+
|
| 178 |
+
# Configure HuggingFace token if provided
|
| 179 |
+
if [ -n "$HF_TOKEN" ]; then
|
| 180 |
+
echo "Configuring HuggingFace authentication..."
|
| 181 |
+
mkdir -p ~/.cache/huggingface
|
| 182 |
+
echo "$HF_TOKEN" > ~/.cache/huggingface/token
|
| 183 |
+
echo "✅ HuggingFace token configured"
|
| 184 |
+
fi
|
| 185 |
+
|
| 186 |
+
# Make evaluation script executable
|
| 187 |
+
chmod +x ~/seriguela/scripts/aws/evaluate_models.sh
|
| 188 |
+
|
| 189 |
+
# Create completion marker
|
| 190 |
+
touch /home/ubuntu/.setup_complete
|
| 191 |
+
|
| 192 |
+
# Create info file
|
| 193 |
+
cat > /home/ubuntu/setup_info.txt << 'INFOFILE'
|
| 194 |
+
Seriguela Evaluation Instance - Ready!
|
| 195 |
+
|
| 196 |
+
Setup completed successfully:
|
| 197 |
+
- Python packages installed
|
| 198 |
+
- GPU available (if supported)
|
| 199 |
+
- Repository cloned and configured
|
| 200 |
+
|
| 201 |
+
To run the evaluation:
|
| 202 |
+
cd ~/seriguela
|
| 203 |
+
source venv/bin/activate
|
| 204 |
+
bash scripts/aws/evaluate_models.sh
|
| 205 |
+
|
| 206 |
+
This will compare:
|
| 207 |
+
- Model 1: augustocsc/Se124M_700K_infix (original)
|
| 208 |
+
- Model 2: augustocsc/Se124M_700K_infix_v2 (with <|endofex|> token)
|
| 209 |
+
|
| 210 |
+
On 500 test samples to evaluate if the ending token improves generation stopping.
|
| 211 |
+
INFOFILE
|
| 212 |
+
|
| 213 |
+
echo ""
|
| 214 |
+
echo "=========================================="
|
| 215 |
+
echo "✅ Setup Complete!"
|
| 216 |
+
echo "Finished: $(date)"
|
| 217 |
+
echo "=========================================="
|
| 218 |
+
cat ~/setup_info.txt
|
| 219 |
+
|
| 220 |
+
UBUNTUSETUP
|
| 221 |
+
|
| 222 |
+
echo "User-data script completed"
|
| 223 |
+
USERDATA
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Replace HF_TOKEN placeholder
|
| 227 |
+
USER_DATA="${USER_DATA//\$HF_TOKEN/$HF_TOKEN}"
|
| 228 |
+
|
| 229 |
+
# Launch instance
|
| 230 |
+
print_status "Launching instance..."
|
| 231 |
+
INSTANCE_ID=$(aws ec2 run-instances \
|
| 232 |
+
--image-id "$AMI_ID" \
|
| 233 |
+
--instance-type "$INSTANCE_TYPE" \
|
| 234 |
+
--key-name "$KEY_NAME" \
|
| 235 |
+
--security-group-ids "$SECURITY_GROUP" \
|
| 236 |
+
--block-device-mappings "[{\"DeviceName\":\"/dev/sda1\",\"Ebs\":{\"VolumeSize\":$VOLUME_SIZE,\"VolumeType\":\"gp3\"}}]" \
|
| 237 |
+
--tag-specifications "ResourceType=instance,Tags=[{Key=Name,Value=$INSTANCE_NAME},{Key=Project,Value=seriguela},{Key=Purpose,Value=evaluation}]" \
|
| 238 |
+
--user-data "$USER_DATA" \
|
| 239 |
+
--query "Instances[0].InstanceId" \
|
| 240 |
+
--output text)
|
| 241 |
+
|
| 242 |
+
print_status "Instance launched: $INSTANCE_ID"
|
| 243 |
+
|
| 244 |
+
# Wait for instance to be running
|
| 245 |
+
print_status "Waiting for instance to start..."
|
| 246 |
+
aws ec2 wait instance-running --instance-ids "$INSTANCE_ID"
|
| 247 |
+
|
| 248 |
+
# Get public IP
|
| 249 |
+
PUBLIC_IP=$(aws ec2 describe-instances \
|
| 250 |
+
--instance-ids "$INSTANCE_ID" \
|
| 251 |
+
--query "Reservations[0].Instances[0].PublicIpAddress" \
|
| 252 |
+
--output text)
|
| 253 |
+
|
| 254 |
+
echo ""
|
| 255 |
+
echo "=========================================="
|
| 256 |
+
echo -e "${GREEN}Instance Ready!${NC}"
|
| 257 |
+
echo "=========================================="
|
| 258 |
+
echo "Instance ID: $INSTANCE_ID"
|
| 259 |
+
echo "Public IP: $PUBLIC_IP"
|
| 260 |
+
echo "Key Pair: $KEY_NAME"
|
| 261 |
+
echo ""
|
| 262 |
+
echo -e "${BLUE}Connect with:${NC}"
|
| 263 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP}"
|
| 264 |
+
echo ""
|
| 265 |
+
echo -e "${BLUE}Check setup progress:${NC}"
|
| 266 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP} 'tail -f /var/log/user-data.log'"
|
| 267 |
+
echo ""
|
| 268 |
+
echo -e "${BLUE}Wait for setup to complete (takes ~5-10 minutes):${NC}"
|
| 269 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP} 'while [ ! -f ~/.setup_complete ]; do sleep 10; echo \"Setup in progress...\"; done; echo \"✅ Setup complete!\"; cat ~/setup_info.txt'"
|
| 270 |
+
echo ""
|
| 271 |
+
echo -e "${BLUE}Then run evaluation:${NC}"
|
| 272 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP} 'cd seriguela && source venv/bin/activate && bash scripts/aws/evaluate_models.sh'"
|
| 273 |
+
echo ""
|
| 274 |
+
echo -e "${BLUE}Or run in one command:${NC}"
|
| 275 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP} 'cd seriguela && source venv/bin/activate && nohup bash scripts/aws/evaluate_models.sh > evaluation.log 2>&1 &'"
|
| 276 |
+
echo ""
|
| 277 |
+
echo -e "${YELLOW}IMPORTANT:${NC} Remember to stop the instance when done:"
|
| 278 |
+
echo " aws ec2 stop-instances --instance-ids $INSTANCE_ID"
|
| 279 |
+
echo ""
|
| 280 |
+
|
| 281 |
+
# Save instance info
|
| 282 |
+
INFO_DIR="${HOME}/.seriguela"
|
| 283 |
+
mkdir -p "$INFO_DIR"
|
| 284 |
+
echo "$INSTANCE_ID" > "$INFO_DIR/last_evaluation_instance_id.txt"
|
| 285 |
+
echo "$PUBLIC_IP" > "$INFO_DIR/last_evaluation_instance_ip.txt"
|
| 286 |
+
echo "$KEY_NAME" > "$INFO_DIR/last_evaluation_key_name.txt"
|
| 287 |
+
|
| 288 |
+
cat > "$INFO_DIR/last_evaluation_instance_info.txt" << INFOEND
|
| 289 |
+
Instance ID: $INSTANCE_ID
|
| 290 |
+
Public IP: $PUBLIC_IP
|
| 291 |
+
Key Name: $KEY_NAME
|
| 292 |
+
Instance Type: $INSTANCE_TYPE
|
| 293 |
+
Region: $REGION
|
| 294 |
+
Launched: $(date)
|
| 295 |
+
Purpose: Model Evaluation (v1 vs v2)
|
| 296 |
+
INFOEND
|
| 297 |
+
|
| 298 |
+
print_status "Instance info saved to: $INFO_DIR/"
|
| 299 |
+
echo ""
|
scripts/aws/launch_instance.sh
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to launch and configure AWS g5.xlarge instance for Seriguela training
|
| 3 |
+
# Usage: ./launch_instance.sh [--hf-token TOKEN] [--wandb-key KEY]
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Colors
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
YELLOW='\033[1;33m'
|
| 10 |
+
RED='\033[0;31m'
|
| 11 |
+
NC='\033[0m'
|
| 12 |
+
|
| 13 |
+
print_status() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
| 14 |
+
print_warning() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
| 15 |
+
print_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
| 16 |
+
|
| 17 |
+
# Default configuration
|
| 18 |
+
INSTANCE_TYPE="g5.xlarge"
|
| 19 |
+
AMI_ID="" # Will be auto-detected
|
| 20 |
+
KEY_NAME="" # Will be auto-detected
|
| 21 |
+
SECURITY_GROUP="" # Will be auto-detected or created
|
| 22 |
+
REGION=$(aws configure get region 2>/dev/null || echo "us-east-1")
|
| 23 |
+
VOLUME_SIZE=100
|
| 24 |
+
INSTANCE_NAME="seriguela-training"
|
| 25 |
+
HF_TOKEN=""
|
| 26 |
+
WANDB_KEY=""
|
| 27 |
+
|
| 28 |
+
# Parse arguments
|
| 29 |
+
while [[ $# -gt 0 ]]; do
|
| 30 |
+
case $1 in
|
| 31 |
+
--hf-token) HF_TOKEN="$2"; shift 2;;
|
| 32 |
+
--wandb-key) WANDB_KEY="$2"; shift 2;;
|
| 33 |
+
--instance-type) INSTANCE_TYPE="$2"; shift 2;;
|
| 34 |
+
--key-name) KEY_NAME="$2"; shift 2;;
|
| 35 |
+
--help)
|
| 36 |
+
echo "Usage: $0 [OPTIONS]"
|
| 37 |
+
echo "Options:"
|
| 38 |
+
echo " --hf-token TOKEN HuggingFace token"
|
| 39 |
+
echo " --wandb-key KEY Wandb API key"
|
| 40 |
+
echo " --instance-type TYPE Instance type (default: g5.xlarge)"
|
| 41 |
+
echo " --key-name NAME SSH key pair name"
|
| 42 |
+
exit 0;;
|
| 43 |
+
*) echo "Unknown option: $1"; exit 1;;
|
| 44 |
+
esac
|
| 45 |
+
done
|
| 46 |
+
|
| 47 |
+
print_status "Launching Seriguela training instance..."
|
| 48 |
+
|
| 49 |
+
# Find Deep Learning AMI
|
| 50 |
+
print_status "Finding Deep Learning AMI..."
|
| 51 |
+
AMI_ID=$(aws ec2 describe-images \
|
| 52 |
+
--owners amazon \
|
| 53 |
+
--filters "Name=name,Values=*Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04)*" \
|
| 54 |
+
--query "Images | sort_by(@, &CreationDate) | [-1].ImageId" \
|
| 55 |
+
--output text)
|
| 56 |
+
|
| 57 |
+
if [ -z "$AMI_ID" ] || [ "$AMI_ID" == "None" ]; then
|
| 58 |
+
print_error "Could not find Deep Learning AMI"
|
| 59 |
+
exit 1
|
| 60 |
+
fi
|
| 61 |
+
print_status "Using AMI: $AMI_ID"
|
| 62 |
+
|
| 63 |
+
# Find or select key pair
|
| 64 |
+
if [ -z "$KEY_NAME" ]; then
|
| 65 |
+
KEY_NAME=$(aws ec2 describe-key-pairs --query "KeyPairs[0].KeyName" --output text 2>/dev/null)
|
| 66 |
+
fi
|
| 67 |
+
if [ -z "$KEY_NAME" ] || [ "$KEY_NAME" == "None" ]; then
|
| 68 |
+
print_error "No SSH key pair found. Create one first or specify with --key-name"
|
| 69 |
+
exit 1
|
| 70 |
+
fi
|
| 71 |
+
print_status "Using key pair: $KEY_NAME"
|
| 72 |
+
|
| 73 |
+
# Find or create security group
|
| 74 |
+
SECURITY_GROUP=$(aws ec2 describe-security-groups \
|
| 75 |
+
--filters "Name=group-name,Values=seriguela-sg" \
|
| 76 |
+
--query "SecurityGroups[0].GroupId" \
|
| 77 |
+
--output text 2>/dev/null)
|
| 78 |
+
|
| 79 |
+
if [ -z "$SECURITY_GROUP" ] || [ "$SECURITY_GROUP" == "None" ]; then
|
| 80 |
+
print_status "Creating security group..."
|
| 81 |
+
SECURITY_GROUP=$(aws ec2 create-security-group \
|
| 82 |
+
--group-name seriguela-sg \
|
| 83 |
+
--description "Security group for Seriguela training" \
|
| 84 |
+
--query "GroupId" --output text)
|
| 85 |
+
|
| 86 |
+
# Get current IP and add SSH rule
|
| 87 |
+
MY_IP=$(curl -s ifconfig.me)
|
| 88 |
+
aws ec2 authorize-security-group-ingress \
|
| 89 |
+
--group-id "$SECURITY_GROUP" \
|
| 90 |
+
--protocol tcp --port 22 \
|
| 91 |
+
--cidr "${MY_IP}/32"
|
| 92 |
+
print_status "Created security group with SSH access from $MY_IP"
|
| 93 |
+
else
|
| 94 |
+
# Update security group with current IP
|
| 95 |
+
MY_IP=$(curl -s ifconfig.me)
|
| 96 |
+
aws ec2 authorize-security-group-ingress \
|
| 97 |
+
--group-id "$SECURITY_GROUP" \
|
| 98 |
+
--protocol tcp --port 22 \
|
| 99 |
+
--cidr "${MY_IP}/32" 2>/dev/null || true
|
| 100 |
+
fi
|
| 101 |
+
print_status "Using security group: $SECURITY_GROUP"
|
| 102 |
+
|
| 103 |
+
# Create user-data script for automatic setup
|
| 104 |
+
USER_DATA=$(cat << 'USERDATA'
|
| 105 |
+
#!/bin/bash
|
| 106 |
+
exec > /var/log/user-data.log 2>&1
|
| 107 |
+
set -x
|
| 108 |
+
|
| 109 |
+
# Wait for cloud-init to complete
|
| 110 |
+
cloud-init status --wait
|
| 111 |
+
|
| 112 |
+
# Setup as ubuntu user
|
| 113 |
+
sudo -u ubuntu bash << 'UBUNTUSETUP'
|
| 114 |
+
cd /home/ubuntu
|
| 115 |
+
|
| 116 |
+
# Install dependencies
|
| 117 |
+
sudo apt-get update -qq
|
| 118 |
+
sudo apt-get install -y -qq python3-venv python3-pip git
|
| 119 |
+
|
| 120 |
+
# Clone repository
|
| 121 |
+
git clone https://github.com/augustocsc/seriguela.git
|
| 122 |
+
cd seriguela
|
| 123 |
+
|
| 124 |
+
# Create virtual environment
|
| 125 |
+
python3 -m venv venv
|
| 126 |
+
source venv/bin/activate
|
| 127 |
+
|
| 128 |
+
# Install requirements
|
| 129 |
+
pip install --upgrade pip -q
|
| 130 |
+
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 -q
|
| 131 |
+
|
| 132 |
+
# Create marker file to indicate setup complete
|
| 133 |
+
touch /home/ubuntu/.setup_complete
|
| 134 |
+
UBUNTUSETUP
|
| 135 |
+
USERDATA
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Add tokens to user-data if provided
|
| 139 |
+
if [ -n "$HF_TOKEN" ] || [ -n "$WANDB_KEY" ]; then
|
| 140 |
+
TOKEN_SETUP="
|
| 141 |
+
# Configure tokens
|
| 142 |
+
cd /home/ubuntu/seriguela
|
| 143 |
+
echo 'HF_TOKEN=$HF_TOKEN' > .env
|
| 144 |
+
echo 'WANDB_API_KEY=$WANDB_KEY' >> .env
|
| 145 |
+
"
|
| 146 |
+
USER_DATA="${USER_DATA}${TOKEN_SETUP}"
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
# Launch instance
|
| 150 |
+
print_status "Launching instance..."
|
| 151 |
+
INSTANCE_ID=$(aws ec2 run-instances \
|
| 152 |
+
--image-id "$AMI_ID" \
|
| 153 |
+
--instance-type "$INSTANCE_TYPE" \
|
| 154 |
+
--key-name "$KEY_NAME" \
|
| 155 |
+
--security-group-ids "$SECURITY_GROUP" \
|
| 156 |
+
--block-device-mappings "[{\"DeviceName\":\"/dev/sda1\",\"Ebs\":{\"VolumeSize\":$VOLUME_SIZE,\"VolumeType\":\"gp3\"}}]" \
|
| 157 |
+
--tag-specifications "ResourceType=instance,Tags=[{Key=Name,Value=$INSTANCE_NAME}]" \
|
| 158 |
+
--user-data "$USER_DATA" \
|
| 159 |
+
--query "Instances[0].InstanceId" \
|
| 160 |
+
--output text)
|
| 161 |
+
|
| 162 |
+
print_status "Instance launched: $INSTANCE_ID"
|
| 163 |
+
|
| 164 |
+
# Wait for instance to be running
|
| 165 |
+
print_status "Waiting for instance to start..."
|
| 166 |
+
aws ec2 wait instance-running --instance-ids "$INSTANCE_ID"
|
| 167 |
+
|
| 168 |
+
# Get public IP
|
| 169 |
+
PUBLIC_IP=$(aws ec2 describe-instances \
|
| 170 |
+
--instance-ids "$INSTANCE_ID" \
|
| 171 |
+
--query "Reservations[0].Instances[0].PublicIpAddress" \
|
| 172 |
+
--output text)
|
| 173 |
+
|
| 174 |
+
echo ""
|
| 175 |
+
echo "=========================================="
|
| 176 |
+
echo -e "${GREEN}Instance Ready!${NC}"
|
| 177 |
+
echo "=========================================="
|
| 178 |
+
echo "Instance ID: $INSTANCE_ID"
|
| 179 |
+
echo "Public IP: $PUBLIC_IP"
|
| 180 |
+
echo ""
|
| 181 |
+
echo "Connect with:"
|
| 182 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP}"
|
| 183 |
+
echo ""
|
| 184 |
+
echo "Check setup progress:"
|
| 185 |
+
echo " ssh ubuntu@${PUBLIC_IP} 'tail -f /var/log/user-data.log'"
|
| 186 |
+
echo ""
|
| 187 |
+
echo "Wait for setup to complete (check for .setup_complete):"
|
| 188 |
+
echo " ssh ubuntu@${PUBLIC_IP} 'while [ ! -f ~/.setup_complete ]; do sleep 10; done; echo Done!'"
|
| 189 |
+
echo ""
|
| 190 |
+
echo "Then run training:"
|
| 191 |
+
echo " ssh ubuntu@${PUBLIC_IP} 'cd seriguela && source venv/bin/activate && bash scripts/aws/run_all_training.sh'"
|
| 192 |
+
echo ""
|
| 193 |
+
|
| 194 |
+
# Save instance info
|
| 195 |
+
echo "$INSTANCE_ID" > /tmp/seriguela_instance_id.txt
|
| 196 |
+
echo "$PUBLIC_IP" > /tmp/seriguela_instance_ip.txt
|
scripts/aws/launch_instance_fixed.sh
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to launch and configure AWS g5.xlarge instance for Seriguela training
|
| 3 |
+
# FIXED VERSION - Includes Wandb validation and proper setup
|
| 4 |
+
# Usage: ./launch_instance_fixed.sh [--hf-token TOKEN] [--wandb-key KEY]
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
# Colors
|
| 9 |
+
GREEN='\033[0;32m'
|
| 10 |
+
YELLOW='\033[1;33m'
|
| 11 |
+
RED='\033[0;31m'
|
| 12 |
+
BLUE='\033[0;34m'
|
| 13 |
+
NC='\033[0m'
|
| 14 |
+
|
| 15 |
+
print_status() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
| 16 |
+
print_warning() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
| 17 |
+
print_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
| 18 |
+
|
| 19 |
+
# Default configuration
|
| 20 |
+
INSTANCE_TYPE="g5.xlarge"
|
| 21 |
+
AMI_ID="" # Will be auto-detected
|
| 22 |
+
KEY_NAME="" # Will be auto-detected
|
| 23 |
+
SECURITY_GROUP="" # Will be auto-detected or created
|
| 24 |
+
REGION=$(aws configure get region 2>/dev/null || echo "us-east-1")
|
| 25 |
+
VOLUME_SIZE=100
|
| 26 |
+
INSTANCE_NAME="seriguela-training"
|
| 27 |
+
HF_TOKEN=""
|
| 28 |
+
WANDB_KEY=""
|
| 29 |
+
|
| 30 |
+
# Parse arguments
|
| 31 |
+
while [[ $# -gt 0 ]]; do
|
| 32 |
+
case $1 in
|
| 33 |
+
--hf-token) HF_TOKEN="$2"; shift 2;;
|
| 34 |
+
--wandb-key) WANDB_KEY="$2"; shift 2;;
|
| 35 |
+
--instance-type) INSTANCE_TYPE="$2"; shift 2;;
|
| 36 |
+
--key-name) KEY_NAME="$2"; shift 2;;
|
| 37 |
+
--help)
|
| 38 |
+
echo "Usage: $0 [OPTIONS]"
|
| 39 |
+
echo "Options:"
|
| 40 |
+
echo " --hf-token TOKEN HuggingFace token (required for push to hub)"
|
| 41 |
+
echo " --wandb-key KEY Wandb API key (required for logging)"
|
| 42 |
+
echo " --instance-type TYPE Instance type (default: g5.xlarge)"
|
| 43 |
+
echo " --key-name NAME SSH key pair name"
|
| 44 |
+
echo ""
|
| 45 |
+
echo "Example:"
|
| 46 |
+
echo " $0 --hf-token hf_xxx --wandb-key wandb_v1_xxx"
|
| 47 |
+
exit 0;;
|
| 48 |
+
*) echo "Unknown option: $1"; exit 1;;
|
| 49 |
+
esac
|
| 50 |
+
done
|
| 51 |
+
|
| 52 |
+
# Validate required tokens
|
| 53 |
+
if [ -z "$WANDB_KEY" ]; then
|
| 54 |
+
print_error "Wandb API key is required! Use --wandb-key"
|
| 55 |
+
print_warning "Get your key from: https://wandb.ai/authorize"
|
| 56 |
+
exit 1
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
if [ -z "$HF_TOKEN" ]; then
|
| 60 |
+
print_warning "HuggingFace token not provided. Model won't be pushed to Hub."
|
| 61 |
+
print_warning "Get your token from: https://huggingface.co/settings/tokens"
|
| 62 |
+
fi
|
| 63 |
+
|
| 64 |
+
print_status "Launching Seriguela training instance with validated setup..."
|
| 65 |
+
|
| 66 |
+
# Find Deep Learning AMI
|
| 67 |
+
print_status "Finding Deep Learning AMI..."
|
| 68 |
+
AMI_ID=$(aws ec2 describe-images \
|
| 69 |
+
--owners amazon \
|
| 70 |
+
--filters "Name=name,Values=*Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04)*" \
|
| 71 |
+
--query "Images | sort_by(@, &CreationDate) | [-1].ImageId" \
|
| 72 |
+
--output text)
|
| 73 |
+
|
| 74 |
+
if [ -z "$AMI_ID" ] || [ "$AMI_ID" == "None" ]; then
|
| 75 |
+
print_error "Could not find Deep Learning AMI"
|
| 76 |
+
exit 1
|
| 77 |
+
fi
|
| 78 |
+
print_status "Using AMI: $AMI_ID"
|
| 79 |
+
|
| 80 |
+
# Find or select key pair
|
| 81 |
+
if [ -z "$KEY_NAME" ]; then
|
| 82 |
+
KEY_NAME=$(aws ec2 describe-key-pairs --query "KeyPairs[0].KeyName" --output text 2>/dev/null)
|
| 83 |
+
fi
|
| 84 |
+
if [ -z "$KEY_NAME" ] || [ "$KEY_NAME" == "None" ]; then
|
| 85 |
+
print_error "No SSH key pair found. Create one first or specify with --key-name"
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
print_status "Using key pair: $KEY_NAME"
|
| 89 |
+
|
| 90 |
+
# Find or create security group
|
| 91 |
+
SECURITY_GROUP=$(aws ec2 describe-security-groups \
|
| 92 |
+
--filters "Name=group-name,Values=seriguela-sg" \
|
| 93 |
+
--query "SecurityGroups[0].GroupId" \
|
| 94 |
+
--output text 2>/dev/null)
|
| 95 |
+
|
| 96 |
+
if [ -z "$SECURITY_GROUP" ] || [ "$SECURITY_GROUP" == "None" ]; then
|
| 97 |
+
print_status "Creating security group..."
|
| 98 |
+
SECURITY_GROUP=$(aws ec2 create-security-group \
|
| 99 |
+
--group-name seriguela-sg \
|
| 100 |
+
--description "Security group for Seriguela training" \
|
| 101 |
+
--query "GroupId" --output text)
|
| 102 |
+
|
| 103 |
+
# Get current IP and add SSH rule
|
| 104 |
+
MY_IP=$(curl -s ifconfig.me)
|
| 105 |
+
aws ec2 authorize-security-group-ingress \
|
| 106 |
+
--group-id "$SECURITY_GROUP" \
|
| 107 |
+
--protocol tcp --port 22 \
|
| 108 |
+
--cidr "${MY_IP}/32"
|
| 109 |
+
print_status "Created security group with SSH access from $MY_IP"
|
| 110 |
+
else
|
| 111 |
+
# Update security group with current IP
|
| 112 |
+
MY_IP=$(curl -s ifconfig.me)
|
| 113 |
+
aws ec2 authorize-security-group-ingress \
|
| 114 |
+
--group-id "$SECURITY_GROUP" \
|
| 115 |
+
--protocol tcp --port 22 \
|
| 116 |
+
--cidr "${MY_IP}/32" 2>/dev/null || true
|
| 117 |
+
fi
|
| 118 |
+
print_status "Using security group: $SECURITY_GROUP"
|
| 119 |
+
|
| 120 |
+
# Create user-data script for automatic setup with validation
|
| 121 |
+
USER_DATA=$(cat << USERDATA
|
| 122 |
+
#!/bin/bash
|
| 123 |
+
exec > /var/log/user-data.log 2>&1
|
| 124 |
+
set -x
|
| 125 |
+
|
| 126 |
+
echo "=========================================="
|
| 127 |
+
echo "Seriguela Instance Setup - VALIDATED"
|
| 128 |
+
echo "Started: \$(date)"
|
| 129 |
+
echo "=========================================="
|
| 130 |
+
|
| 131 |
+
# Wait for cloud-init to complete
|
| 132 |
+
cloud-init status --wait
|
| 133 |
+
|
| 134 |
+
# Setup as ubuntu user
|
| 135 |
+
sudo -u ubuntu bash << 'UBUNTUSETUP'
|
| 136 |
+
cd /home/ubuntu
|
| 137 |
+
|
| 138 |
+
echo "[1/8] Installing system dependencies..."
|
| 139 |
+
sudo apt-get update -qq
|
| 140 |
+
sudo apt-get install -y -qq python3-venv python3-pip git dos2unix
|
| 141 |
+
|
| 142 |
+
echo "[2/8] Cloning repository..."
|
| 143 |
+
git clone https://github.com/augustocsc/seriguela.git
|
| 144 |
+
cd seriguela
|
| 145 |
+
|
| 146 |
+
echo "[3/8] Creating virtual environment..."
|
| 147 |
+
python3 -m venv venv
|
| 148 |
+
source venv/bin/activate
|
| 149 |
+
|
| 150 |
+
echo "[4/8] Upgrading pip..."
|
| 151 |
+
pip install --upgrade pip -q
|
| 152 |
+
|
| 153 |
+
echo "[5/8] Installing requirements..."
|
| 154 |
+
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 -q
|
| 155 |
+
|
| 156 |
+
echo "[6/8] Upgrading Wandb to latest version..."
|
| 157 |
+
pip install --upgrade 'wandb>=0.24.1' -q
|
| 158 |
+
|
| 159 |
+
echo "[7/8] Configuring environment..."
|
| 160 |
+
# Create .env file
|
| 161 |
+
cat > .env << 'ENVFILE'
|
| 162 |
+
HF_TOKEN=$HF_TOKEN
|
| 163 |
+
WANDB_API_KEY=$WANDB_KEY
|
| 164 |
+
ENVFILE
|
| 165 |
+
|
| 166 |
+
echo "[8/8] Validating setup..."
|
| 167 |
+
|
| 168 |
+
# Validate Python packages
|
| 169 |
+
python3 << 'PYCHECK'
|
| 170 |
+
import sys
|
| 171 |
+
print("Testing imports...")
|
| 172 |
+
try:
|
| 173 |
+
import transformers
|
| 174 |
+
print(f"✅ transformers {transformers.__version__}")
|
| 175 |
+
import torch
|
| 176 |
+
print(f"✅ torch {torch.__version__}")
|
| 177 |
+
import wandb
|
| 178 |
+
print(f"✅ wandb {wandb.__version__}")
|
| 179 |
+
import peft
|
| 180 |
+
print(f"✅ peft {peft.__version__}")
|
| 181 |
+
except ImportError as e:
|
| 182 |
+
print(f"❌ Import failed: {e}")
|
| 183 |
+
sys.exit(1)
|
| 184 |
+
PYCHECK
|
| 185 |
+
|
| 186 |
+
if [ \$? -ne 0 ]; then
|
| 187 |
+
echo "❌ Package validation failed"
|
| 188 |
+
exit 1
|
| 189 |
+
fi
|
| 190 |
+
|
| 191 |
+
# Validate GPU
|
| 192 |
+
echo "Checking GPU..."
|
| 193 |
+
if nvidia-smi &> /dev/null; then
|
| 194 |
+
echo "✅ GPU detected:"
|
| 195 |
+
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
| 196 |
+
else
|
| 197 |
+
echo "❌ No GPU detected"
|
| 198 |
+
exit 1
|
| 199 |
+
fi
|
| 200 |
+
|
| 201 |
+
# Validate Wandb authentication
|
| 202 |
+
if [ -n "$WANDB_KEY" ]; then
|
| 203 |
+
echo "Validating Wandb authentication..."
|
| 204 |
+
python3 << PYVALIDATE
|
| 205 |
+
import wandb
|
| 206 |
+
import os
|
| 207 |
+
try:
|
| 208 |
+
result = wandb.login(key='$WANDB_KEY')
|
| 209 |
+
if result:
|
| 210 |
+
print("✅ Wandb authentication successful")
|
| 211 |
+
# Get user info
|
| 212 |
+
import requests
|
| 213 |
+
response = requests.get('https://api.wandb.ai/graphql',
|
| 214 |
+
headers={'Authorization': f'Bearer $WANDB_KEY'},
|
| 215 |
+
json={'query': '{viewer{entity}}'})
|
| 216 |
+
if response.status_code == 200:
|
| 217 |
+
print(f" Logged in to Wandb")
|
| 218 |
+
else:
|
| 219 |
+
print("❌ Wandb authentication failed")
|
| 220 |
+
exit(1)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
print(f"❌ Wandb validation error: {e}")
|
| 223 |
+
exit(1)
|
| 224 |
+
PYVALIDATE
|
| 225 |
+
|
| 226 |
+
if [ \$? -ne 0 ]; then
|
| 227 |
+
echo "❌ Wandb authentication failed"
|
| 228 |
+
exit 1
|
| 229 |
+
fi
|
| 230 |
+
else
|
| 231 |
+
echo "⚠️ No Wandb key provided - skipping validation"
|
| 232 |
+
fi
|
| 233 |
+
|
| 234 |
+
# Validate HuggingFace token
|
| 235 |
+
if [ -n "$HF_TOKEN" ]; then
|
| 236 |
+
echo "Validating HuggingFace authentication..."
|
| 237 |
+
python3 << PYVALIDATE
|
| 238 |
+
from huggingface_hub import HfApi
|
| 239 |
+
try:
|
| 240 |
+
api = HfApi(token='$HF_TOKEN')
|
| 241 |
+
user = api.whoami()
|
| 242 |
+
print(f"✅ HuggingFace authentication successful")
|
| 243 |
+
print(f" Logged in as: {user.get('name', 'unknown')}")
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"❌ HuggingFace validation error: {e}")
|
| 246 |
+
exit(1)
|
| 247 |
+
PYVALIDATE
|
| 248 |
+
|
| 249 |
+
if [ \$? -ne 0 ]; then
|
| 250 |
+
echo "❌ HuggingFace authentication failed"
|
| 251 |
+
exit 1
|
| 252 |
+
fi
|
| 253 |
+
else
|
| 254 |
+
echo "⚠️ No HuggingFace token provided - model won't be pushed to Hub"
|
| 255 |
+
fi
|
| 256 |
+
|
| 257 |
+
# All validations passed
|
| 258 |
+
echo ""
|
| 259 |
+
echo "=========================================="
|
| 260 |
+
echo "✅ Setup Complete and Validated!"
|
| 261 |
+
echo "Finished: \$(date)"
|
| 262 |
+
echo "=========================================="
|
| 263 |
+
|
| 264 |
+
# Create completion markers
|
| 265 |
+
touch /home/ubuntu/.setup_complete
|
| 266 |
+
touch /home/ubuntu/.setup_validated
|
| 267 |
+
|
| 268 |
+
# Create info file
|
| 269 |
+
cat > /home/ubuntu/setup_info.txt << 'INFOFILE'
|
| 270 |
+
Setup completed successfully!
|
| 271 |
+
|
| 272 |
+
Validated:
|
| 273 |
+
- Python packages installed
|
| 274 |
+
- GPU detected
|
| 275 |
+
- Wandb authenticated
|
| 276 |
+
- HuggingFace authenticated (if token provided)
|
| 277 |
+
|
| 278 |
+
Ready to train!
|
| 279 |
+
|
| 280 |
+
Quick commands:
|
| 281 |
+
cd ~/seriguela
|
| 282 |
+
source venv/bin/activate
|
| 283 |
+
python scripts/train.py --help
|
| 284 |
+
|
| 285 |
+
Monitor scripts:
|
| 286 |
+
bash scripts/aws/monitor_training_auto.sh
|
| 287 |
+
INFOFILE
|
| 288 |
+
|
| 289 |
+
echo "Setup info saved to ~/setup_info.txt"
|
| 290 |
+
UBUNTUSETUP
|
| 291 |
+
|
| 292 |
+
# End of setup
|
| 293 |
+
echo "User-data script completed"
|
| 294 |
+
USERDATA
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Replace placeholder tokens in user-data
|
| 298 |
+
USER_DATA="${USER_DATA//\$HF_TOKEN/$HF_TOKEN}"
|
| 299 |
+
USER_DATA="${USER_DATA//\$WANDB_KEY/$WANDB_KEY}"
|
| 300 |
+
|
| 301 |
+
# Launch instance
|
| 302 |
+
print_status "Launching instance..."
|
| 303 |
+
INSTANCE_ID=$(aws ec2 run-instances \
|
| 304 |
+
--image-id "$AMI_ID" \
|
| 305 |
+
--instance-type "$INSTANCE_TYPE" \
|
| 306 |
+
--key-name "$KEY_NAME" \
|
| 307 |
+
--security-group-ids "$SECURITY_GROUP" \
|
| 308 |
+
--block-device-mappings "[{\"DeviceName\":\"/dev/sda1\",\"Ebs\":{\"VolumeSize\":$VOLUME_SIZE,\"VolumeType\":\"gp3\"}}]" \
|
| 309 |
+
--tag-specifications "ResourceType=instance,Tags=[{Key=Name,Value=$INSTANCE_NAME},{Key=Project,Value=seriguela},{Key=AutoSetup,Value=validated}]" \
|
| 310 |
+
--user-data "$USER_DATA" \
|
| 311 |
+
--query "Instances[0].InstanceId" \
|
| 312 |
+
--output text)
|
| 313 |
+
|
| 314 |
+
print_status "Instance launched: $INSTANCE_ID"
|
| 315 |
+
|
| 316 |
+
# Wait for instance to be running
|
| 317 |
+
print_status "Waiting for instance to start..."
|
| 318 |
+
aws ec2 wait instance-running --instance-ids "$INSTANCE_ID"
|
| 319 |
+
|
| 320 |
+
# Get public IP
|
| 321 |
+
PUBLIC_IP=$(aws ec2 describe-instances \
|
| 322 |
+
--instance-ids "$INSTANCE_ID" \
|
| 323 |
+
--query "Reservations[0].Instances[0].PublicIpAddress" \
|
| 324 |
+
--output text)
|
| 325 |
+
|
| 326 |
+
echo ""
|
| 327 |
+
echo "=========================================="
|
| 328 |
+
echo -e "${GREEN}Instance Ready!${NC}"
|
| 329 |
+
echo "=========================================="
|
| 330 |
+
echo "Instance ID: $INSTANCE_ID"
|
| 331 |
+
echo "Public IP: $PUBLIC_IP"
|
| 332 |
+
echo "Key Pair: $KEY_NAME"
|
| 333 |
+
echo ""
|
| 334 |
+
echo -e "${BLUE}Connect with:${NC}"
|
| 335 |
+
echo " ssh -i ~/.ssh/${KEY_NAME}.pem ubuntu@${PUBLIC_IP}"
|
| 336 |
+
echo ""
|
| 337 |
+
echo -e "${BLUE}Check setup progress:${NC}"
|
| 338 |
+
echo " ssh ubuntu@${PUBLIC_IP} 'tail -f /var/log/user-data.log'"
|
| 339 |
+
echo ""
|
| 340 |
+
echo -e "${BLUE}Wait for VALIDATED setup to complete:${NC}"
|
| 341 |
+
echo " ssh ubuntu@${PUBLIC_IP} 'while [ ! -f ~/.setup_validated ]; do sleep 10; echo \"Setup in progress...\"; done; echo \"✅ Setup validated!\"; cat ~/setup_info.txt'"
|
| 342 |
+
echo ""
|
| 343 |
+
echo -e "${BLUE}Then run training:${NC}"
|
| 344 |
+
echo " ssh ubuntu@${PUBLIC_IP} 'cd seriguela && source venv/bin/activate && bash scripts/aws/run_all_training.sh'"
|
| 345 |
+
echo ""
|
| 346 |
+
echo -e "${YELLOW}Setup includes:${NC}"
|
| 347 |
+
echo " ✅ Wandb 0.24.1+ with authentication test"
|
| 348 |
+
echo " ✅ HuggingFace authentication test"
|
| 349 |
+
echo " ✅ GPU validation"
|
| 350 |
+
echo " ✅ All packages validated"
|
| 351 |
+
echo ""
|
| 352 |
+
|
| 353 |
+
# Save instance info
|
| 354 |
+
INFO_DIR="${HOME}/.seriguela"
|
| 355 |
+
mkdir -p "$INFO_DIR"
|
| 356 |
+
echo "$INSTANCE_ID" > "$INFO_DIR/last_instance_id.txt"
|
| 357 |
+
echo "$PUBLIC_IP" > "$INFO_DIR/last_instance_ip.txt"
|
| 358 |
+
echo "$KEY_NAME" > "$INFO_DIR/last_key_name.txt"
|
| 359 |
+
|
| 360 |
+
cat > "$INFO_DIR/last_instance_info.txt" << INFOEND
|
| 361 |
+
Instance ID: $INSTANCE_ID
|
| 362 |
+
Public IP: $PUBLIC_IP
|
| 363 |
+
Key Name: $KEY_NAME
|
| 364 |
+
Instance Type: $INSTANCE_TYPE
|
| 365 |
+
Region: $REGION
|
| 366 |
+
Launched: $(date)
|
| 367 |
+
Setup: Validated (Wandb + HF + GPU)
|
| 368 |
+
INFOEND
|
| 369 |
+
|
| 370 |
+
print_status "Instance info saved to: $INFO_DIR/"
|
| 371 |
+
echo ""
|
scripts/aws/monitor_evaluation.sh
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to monitor evaluation progress and download results
|
| 3 |
+
# Usage: bash scripts/aws/monitor_evaluation.sh [PUBLIC_IP]
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Colors
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
YELLOW='\033[1;33m'
|
| 10 |
+
BLUE='\033[0;34m'
|
| 11 |
+
NC='\033[0m'
|
| 12 |
+
|
| 13 |
+
print_status() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
| 14 |
+
print_warning() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
| 15 |
+
|
| 16 |
+
# Get IP from argument or saved info
|
| 17 |
+
if [ -n "$1" ]; then
|
| 18 |
+
PUBLIC_IP="$1"
|
| 19 |
+
else
|
| 20 |
+
INFO_DIR="${HOME}/.seriguela"
|
| 21 |
+
if [ -f "$INFO_DIR/last_evaluation_instance_ip.txt" ]; then
|
| 22 |
+
PUBLIC_IP=$(cat "$INFO_DIR/last_evaluation_instance_ip.txt")
|
| 23 |
+
print_status "Using saved IP: $PUBLIC_IP"
|
| 24 |
+
else
|
| 25 |
+
echo "Error: No IP provided and no saved IP found."
|
| 26 |
+
echo "Usage: $0 <PUBLIC_IP>"
|
| 27 |
+
exit 1
|
| 28 |
+
fi
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
# Get key name
|
| 32 |
+
INFO_DIR="${HOME}/.seriguela"
|
| 33 |
+
if [ -f "$INFO_DIR/last_evaluation_key_name.txt" ]; then
|
| 34 |
+
KEY_NAME=$(cat "$INFO_DIR/last_evaluation_key_name.txt")
|
| 35 |
+
else
|
| 36 |
+
KEY_NAME=$(aws ec2 describe-key-pairs --query "KeyPairs[0].KeyName" --output text 2>/dev/null)
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
SSH_CMD="ssh -i ~/.ssh/${KEY_NAME}.pem -o StrictHostKeyChecking=no ubuntu@${PUBLIC_IP}"
|
| 40 |
+
|
| 41 |
+
echo "=========================================="
|
| 42 |
+
echo "Monitoring Evaluation"
|
| 43 |
+
echo "=========================================="
|
| 44 |
+
echo "Instance: $PUBLIC_IP"
|
| 45 |
+
echo "Key: $KEY_NAME"
|
| 46 |
+
echo ""
|
| 47 |
+
|
| 48 |
+
# Check if setup is complete
|
| 49 |
+
print_status "Checking setup status..."
|
| 50 |
+
if $SSH_CMD 'test -f ~/.setup_complete'; then
|
| 51 |
+
print_status "✅ Setup complete"
|
| 52 |
+
else
|
| 53 |
+
print_warning "Setup still in progress. Waiting..."
|
| 54 |
+
$SSH_CMD 'while [ ! -f ~/.setup_complete ]; do sleep 5; done; echo "Setup complete!"'
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
echo ""
|
| 58 |
+
echo "=========================================="
|
| 59 |
+
echo "Evaluation Progress"
|
| 60 |
+
echo "=========================================="
|
| 61 |
+
echo "Press Ctrl+C to stop monitoring (evaluation will continue)"
|
| 62 |
+
echo ""
|
| 63 |
+
|
| 64 |
+
# Check if evaluation has started
|
| 65 |
+
if $SSH_CMD 'test -f ~/seriguela/evaluation_*.log'; then
|
| 66 |
+
print_status "Evaluation in progress. Showing logs..."
|
| 67 |
+
echo ""
|
| 68 |
+
$SSH_CMD 'tail -f ~/seriguela/evaluation_*.log' || true
|
| 69 |
+
else
|
| 70 |
+
print_warning "Evaluation hasn't started yet."
|
| 71 |
+
echo ""
|
| 72 |
+
echo "To start evaluation, run:"
|
| 73 |
+
echo " $SSH_CMD 'cd seriguela && source venv/bin/activate && bash scripts/aws/evaluate_models.sh'"
|
| 74 |
+
echo ""
|
| 75 |
+
echo "Or run in background:"
|
| 76 |
+
echo " $SSH_CMD 'cd seriguela && source venv/bin/activate && nohup bash scripts/aws/evaluate_models.sh > evaluation.log 2>&1 &'"
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
echo ""
|
| 80 |
+
echo "=========================================="
|
| 81 |
+
echo "Download Results"
|
| 82 |
+
echo "=========================================="
|
| 83 |
+
echo ""
|
| 84 |
+
|
| 85 |
+
# Download results if available
|
| 86 |
+
if $SSH_CMD 'test -d ~/seriguela/evaluation_results/comparison'; then
|
| 87 |
+
print_status "Downloading results..."
|
| 88 |
+
|
| 89 |
+
# Create local directory
|
| 90 |
+
mkdir -p ./evaluation_results/comparison
|
| 91 |
+
|
| 92 |
+
# Download results
|
| 93 |
+
scp -i ~/.ssh/${KEY_NAME}.pem -o StrictHostKeyChecking=no -r \
|
| 94 |
+
ubuntu@${PUBLIC_IP}:~/seriguela/evaluation_results/comparison/* \
|
| 95 |
+
./evaluation_results/comparison/ 2>/dev/null || true
|
| 96 |
+
|
| 97 |
+
# Download log files
|
| 98 |
+
scp -i ~/.ssh/${KEY_NAME}.pem -o StrictHostKeyChecking=no \
|
| 99 |
+
ubuntu@${PUBLIC_IP}:~/seriguela/evaluation_*.log \
|
| 100 |
+
./evaluation_results/ 2>/dev/null || true
|
| 101 |
+
|
| 102 |
+
print_status "Results downloaded to: ./evaluation_results/"
|
| 103 |
+
echo ""
|
| 104 |
+
|
| 105 |
+
# Show latest comparison
|
| 106 |
+
LATEST_COMPARISON=$(ls -t ./evaluation_results/comparison/comparison_*.json 2>/dev/null | head -1)
|
| 107 |
+
if [ -n "$LATEST_COMPARISON" ]; then
|
| 108 |
+
echo "Latest comparison results:"
|
| 109 |
+
echo ""
|
| 110 |
+
cat "$LATEST_COMPARISON" | jq '.comparison' 2>/dev/null || cat "$LATEST_COMPARISON"
|
| 111 |
+
fi
|
| 112 |
+
else
|
| 113 |
+
print_warning "No results available yet."
|
| 114 |
+
fi
|
| 115 |
+
|
| 116 |
+
echo ""
|
scripts/aws/monitor_training_auto.sh
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Automatic Training Monitor and Notifier
|
| 3 |
+
# Monitors training process and runs analysis when complete
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Colors
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
YELLOW='\033[1;33m'
|
| 10 |
+
RED='\033[0;31m'
|
| 11 |
+
BLUE='\033[0;34m'
|
| 12 |
+
NC='\033[0m'
|
| 13 |
+
|
| 14 |
+
print_status() { echo -e "${GREEN}[$(date '+%H:%M:%S')]${NC} $1"; }
|
| 15 |
+
print_warning() { echo -e "${YELLOW}[$(date '+%H:%M:%S')]${NC} $1"; }
|
| 16 |
+
print_error() { echo -e "${RED}[$(date '+%H:%M:%S')]${NC} $1"; }
|
| 17 |
+
print_header() { echo -e "\n${BLUE}========================================\n$1\n========================================${NC}\n"; }
|
| 18 |
+
|
| 19 |
+
# Configuration
|
| 20 |
+
PROJECT_DIR="/home/ubuntu/seriguela"
|
| 21 |
+
LOG_FILE="$HOME/training_success.log"
|
| 22 |
+
MONITOR_LOG="$HOME/monitor_output.log"
|
| 23 |
+
TRAINING_PID=""
|
| 24 |
+
CHECK_INTERVAL=60 # Check every 60 seconds
|
| 25 |
+
MODEL_PATH="./output/Se124M_700K_infix"
|
| 26 |
+
DATASET_REPO="augustocsc/sintetico_natural"
|
| 27 |
+
DATA_DIR="700K"
|
| 28 |
+
DATA_COLUMN="i_prompt_n"
|
| 29 |
+
|
| 30 |
+
cd "$PROJECT_DIR"
|
| 31 |
+
source venv/bin/activate
|
| 32 |
+
|
| 33 |
+
# Get training PID
|
| 34 |
+
get_training_pid() {
|
| 35 |
+
TRAINING_PID=$(ps aux | grep "python scripts/train.py" | grep -v grep | awk '{print $2}')
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Check if training is running
|
| 39 |
+
is_training_running() {
|
| 40 |
+
get_training_pid
|
| 41 |
+
if [ -z "$TRAINING_PID" ]; then
|
| 42 |
+
return 1
|
| 43 |
+
else
|
| 44 |
+
return 0
|
| 45 |
+
fi
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Get training progress from log
|
| 49 |
+
get_progress() {
|
| 50 |
+
if [ -f "$LOG_FILE" ]; then
|
| 51 |
+
# Get last progress line
|
| 52 |
+
tail -100 "$LOG_FILE" | grep -E "([0-9]+)%\|" | tail -1 | sed 's/.*\([0-9]\+\)%|.*/\1/' || echo "0"
|
| 53 |
+
else
|
| 54 |
+
echo "0"
|
| 55 |
+
fi
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Get current epoch and step
|
| 59 |
+
get_training_stats() {
|
| 60 |
+
if [ -f "$LOG_FILE" ]; then
|
| 61 |
+
local last_line=$(tail -100 "$LOG_FILE" | grep -E "[0-9]+/21882" | tail -1)
|
| 62 |
+
echo "$last_line"
|
| 63 |
+
fi
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Send notification (multiple methods)
|
| 67 |
+
send_notification() {
|
| 68 |
+
local title="$1"
|
| 69 |
+
local message="$2"
|
| 70 |
+
|
| 71 |
+
print_header "$title"
|
| 72 |
+
echo "$message"
|
| 73 |
+
|
| 74 |
+
# Save to notification file
|
| 75 |
+
cat > "$HOME/training_notification.txt" << EOF
|
| 76 |
+
================================================================================
|
| 77 |
+
$title
|
| 78 |
+
$(date '+%Y-%m-%d %H:%M:%S')
|
| 79 |
+
================================================================================
|
| 80 |
+
|
| 81 |
+
$message
|
| 82 |
+
|
| 83 |
+
================================================================================
|
| 84 |
+
EOF
|
| 85 |
+
|
| 86 |
+
print_status "Notification saved to: $HOME/training_notification.txt"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Monitor training
|
| 90 |
+
print_header "Training Monitor Started"
|
| 91 |
+
print_status "Monitoring training process..."
|
| 92 |
+
print_status "Log file: $LOG_FILE"
|
| 93 |
+
print_status "Check interval: ${CHECK_INTERVAL}s"
|
| 94 |
+
|
| 95 |
+
START_TIME=$(date +%s)
|
| 96 |
+
LAST_PROGRESS=0
|
| 97 |
+
|
| 98 |
+
while true; do
|
| 99 |
+
if is_training_running; then
|
| 100 |
+
CURRENT_PROGRESS=$(get_progress)
|
| 101 |
+
TRAINING_STATS=$(get_training_stats)
|
| 102 |
+
|
| 103 |
+
# Show progress every check
|
| 104 |
+
print_status "Training running (PID: $TRAINING_PID) - Progress: ${CURRENT_PROGRESS}%"
|
| 105 |
+
|
| 106 |
+
if [ ! -z "$TRAINING_STATS" ]; then
|
| 107 |
+
echo " $TRAINING_STATS"
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
# Check GPU
|
| 111 |
+
GPU_INFO=$(nvidia-smi --query-gpu=utilization.gpu,memory.used --format=csv,noheader,nounits)
|
| 112 |
+
echo " GPU: $GPU_INFO"
|
| 113 |
+
|
| 114 |
+
LAST_PROGRESS=$CURRENT_PROGRESS
|
| 115 |
+
sleep $CHECK_INTERVAL
|
| 116 |
+
else
|
| 117 |
+
# Training finished or crashed
|
| 118 |
+
END_TIME=$(date +%s)
|
| 119 |
+
DURATION=$((END_TIME - START_TIME))
|
| 120 |
+
HOURS=$((DURATION / 3600))
|
| 121 |
+
MINUTES=$(((DURATION % 3600) / 60))
|
| 122 |
+
|
| 123 |
+
print_header "Training Process Ended"
|
| 124 |
+
|
| 125 |
+
# Check if training completed successfully
|
| 126 |
+
if grep -q "Training finished" "$LOG_FILE" 2>/dev/null || \
|
| 127 |
+
grep -q "100%|" "$LOG_FILE" 2>/dev/null; then
|
| 128 |
+
|
| 129 |
+
# SUCCESS - Training completed
|
| 130 |
+
print_status "Training completed successfully!"
|
| 131 |
+
print_status "Total time: ${HOURS}h ${MINUTES}m"
|
| 132 |
+
|
| 133 |
+
# Extract final metrics
|
| 134 |
+
FINAL_METRICS=$(tail -200 "$LOG_FILE" | grep -E "(train_loss|eval_loss)" | tail -5)
|
| 135 |
+
|
| 136 |
+
send_notification "✅ Training Completed Successfully" \
|
| 137 |
+
"Training Duration: ${HOURS}h ${MINUTES}m
|
| 138 |
+
Model: GPT-2 (124M) with LoRA
|
| 139 |
+
Dataset: 700K infix
|
| 140 |
+
Output: $MODEL_PATH
|
| 141 |
+
|
| 142 |
+
Final Metrics:
|
| 143 |
+
$FINAL_METRICS
|
| 144 |
+
|
| 145 |
+
Wandb Dashboard:
|
| 146 |
+
https://wandb.ai/symbolic-gression/seriguela_700K_test
|
| 147 |
+
|
| 148 |
+
Starting automatic analysis...
|
| 149 |
+
"
|
| 150 |
+
|
| 151 |
+
# Run automatic analysis
|
| 152 |
+
print_header "Starting Automatic Analysis"
|
| 153 |
+
bash "$PROJECT_DIR/scripts/aws/analyze_model.sh" "$MODEL_PATH" "$DATA_COLUMN" 2>&1 | tee "$HOME/analysis_output.log"
|
| 154 |
+
|
| 155 |
+
print_status "Analysis complete! Check: $HOME/analysis_output.log"
|
| 156 |
+
|
| 157 |
+
else
|
| 158 |
+
# FAILED - Training crashed or was killed
|
| 159 |
+
print_error "Training ended unexpectedly!"
|
| 160 |
+
|
| 161 |
+
# Get last errors
|
| 162 |
+
ERRORS=$(tail -50 "$LOG_FILE" | grep -E "(Error|Exception|Traceback)" | head -10)
|
| 163 |
+
|
| 164 |
+
send_notification "❌ Training Failed or Interrupted" \
|
| 165 |
+
"Training Duration: ${HOURS}h ${MINUTES}m
|
| 166 |
+
Last Progress: ${LAST_PROGRESS}%
|
| 167 |
+
|
| 168 |
+
Possible Errors:
|
| 169 |
+
$ERRORS
|
| 170 |
+
|
| 171 |
+
Check full log: $LOG_FILE
|
| 172 |
+
"
|
| 173 |
+
fi
|
| 174 |
+
|
| 175 |
+
break
|
| 176 |
+
fi
|
| 177 |
+
done
|
| 178 |
+
|
| 179 |
+
print_status "Monitor finished. Check notification file: $HOME/training_notification.txt"
|
scripts/aws/run_all_training.sh
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Workflow completo de treinamento para AWS g5.xlarge
|
| 3 |
+
# Projeto Seriguela - Treinar 6 modelos GPT-2 (3 tamanhos x 2 formatos)
|
| 4 |
+
|
| 5 |
+
set -e # Exit on error
|
| 6 |
+
|
| 7 |
+
echo "=========================================="
|
| 8 |
+
echo "Seriguela - Full Training Workflow"
|
| 9 |
+
echo "=========================================="
|
| 10 |
+
|
| 11 |
+
# Colors for output
|
| 12 |
+
RED='\033[0;31m'
|
| 13 |
+
GREEN='\033[0;32m'
|
| 14 |
+
YELLOW='\033[1;33m'
|
| 15 |
+
BLUE='\033[0;34m'
|
| 16 |
+
NC='\033[0m' # No Color
|
| 17 |
+
|
| 18 |
+
print_status() {
|
| 19 |
+
echo -e "${GREEN}[INFO]${NC} $1"
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
print_warning() {
|
| 23 |
+
echo -e "${YELLOW}[WARNING]${NC} $1"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
print_error() {
|
| 27 |
+
echo -e "${RED}[ERROR]${NC} $1"
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
print_header() {
|
| 31 |
+
echo ""
|
| 32 |
+
echo -e "${BLUE}=========================================="
|
| 33 |
+
echo "$1"
|
| 34 |
+
echo -e "==========================================${NC}"
|
| 35 |
+
echo ""
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Configuration
|
| 39 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 40 |
+
PROJECT_DIR="$(dirname "$(dirname "$SCRIPT_DIR")")"
|
| 41 |
+
cd "$PROJECT_DIR"
|
| 42 |
+
|
| 43 |
+
# Check if virtual environment is activated
|
| 44 |
+
if [ -z "$VIRTUAL_ENV" ]; then
|
| 45 |
+
print_warning "Virtual environment not activated. Activating..."
|
| 46 |
+
source venv/bin/activate 2>/dev/null || {
|
| 47 |
+
print_error "Could not activate virtual environment. Please run setup_aws.sh first."
|
| 48 |
+
exit 1
|
| 49 |
+
}
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
# Check environment variables
|
| 53 |
+
if [ -z "$HF_TOKEN" ]; then
|
| 54 |
+
print_error "HF_TOKEN not set. Please export HF_TOKEN='your_token'"
|
| 55 |
+
exit 1
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
# Check GPU
|
| 59 |
+
print_status "Checking GPU..."
|
| 60 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv || {
|
| 61 |
+
print_error "GPU not available!"
|
| 62 |
+
exit 1
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Dataset configuration
|
| 66 |
+
DATASET_REPO="augustocsc/sintetico_natural"
|
| 67 |
+
DATA_DIR="700K"
|
| 68 |
+
HF_USER="augustocsc"
|
| 69 |
+
|
| 70 |
+
# Common training parameters
|
| 71 |
+
WANDB_PROJECT="seriguela_700K"
|
| 72 |
+
SEED=42
|
| 73 |
+
BLOCK_SIZE=128
|
| 74 |
+
|
| 75 |
+
# Output directories
|
| 76 |
+
OUTPUT_BASE="./output"
|
| 77 |
+
EVAL_OUTPUT="./evaluation_results"
|
| 78 |
+
mkdir -p "$OUTPUT_BASE" "$EVAL_OUTPUT"
|
| 79 |
+
|
| 80 |
+
# Training configurations
|
| 81 |
+
# Format: "model_name|epochs|batch_size|grad_accum|learning_rate|run_suffix"
|
| 82 |
+
declare -a CONFIGS=(
|
| 83 |
+
# GPT-2 Small (124M)
|
| 84 |
+
"gpt2|3|16|4|5e-5|Se124M"
|
| 85 |
+
# GPT-2 Medium (355M)
|
| 86 |
+
"gpt2-medium|2|8|8|3e-5|Se355M"
|
| 87 |
+
# GPT-2 Large (774M)
|
| 88 |
+
"gpt2-large|2|4|16|2e-5|Se774M"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Data columns for formats
|
| 92 |
+
declare -a DATA_COLUMNS=(
|
| 93 |
+
"i_prompt_n|infix"
|
| 94 |
+
"p_prompt_n|prefix"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Function to run training
|
| 98 |
+
run_training() {
|
| 99 |
+
local model_name=$1
|
| 100 |
+
local epochs=$2
|
| 101 |
+
local batch_size=$3
|
| 102 |
+
local grad_accum=$4
|
| 103 |
+
local lr=$5
|
| 104 |
+
local run_suffix=$6
|
| 105 |
+
local data_column=$7
|
| 106 |
+
local format=$8
|
| 107 |
+
|
| 108 |
+
local run_name="${run_suffix}_${DATA_DIR}_${format}"
|
| 109 |
+
local output_dir="${OUTPUT_BASE}/${run_name}"
|
| 110 |
+
local hub_model_id="${HF_USER}/${run_name}"
|
| 111 |
+
|
| 112 |
+
print_header "Training: $run_name"
|
| 113 |
+
echo "Model: $model_name"
|
| 114 |
+
echo "Epochs: $epochs"
|
| 115 |
+
echo "Batch size: $batch_size"
|
| 116 |
+
echo "Gradient accumulation: $grad_accum"
|
| 117 |
+
echo "Effective batch size: $((batch_size * grad_accum))"
|
| 118 |
+
echo "Learning rate: $lr"
|
| 119 |
+
echo "Data column: $data_column"
|
| 120 |
+
echo "Output: $output_dir"
|
| 121 |
+
echo "Hub ID: $hub_model_id"
|
| 122 |
+
echo ""
|
| 123 |
+
|
| 124 |
+
# Run training
|
| 125 |
+
python scripts/train.py \
|
| 126 |
+
--model_name_or_path "$model_name" \
|
| 127 |
+
--dataset_repo_id "$DATASET_REPO" \
|
| 128 |
+
--data_dir "$DATA_DIR" \
|
| 129 |
+
--data_column "$data_column" \
|
| 130 |
+
--approach "$format" \
|
| 131 |
+
--output_dir "$output_dir" \
|
| 132 |
+
--num_train_epochs "$epochs" \
|
| 133 |
+
--per_device_train_batch_size "$batch_size" \
|
| 134 |
+
--per_device_eval_batch_size "$batch_size" \
|
| 135 |
+
--gradient_accumulation_steps "$grad_accum" \
|
| 136 |
+
--learning_rate "$lr" \
|
| 137 |
+
--weight_decay 0.01 \
|
| 138 |
+
--warmup_steps 100 \
|
| 139 |
+
--block_size "$BLOCK_SIZE" \
|
| 140 |
+
--logging_steps 50 \
|
| 141 |
+
--eval_strategy epoch \
|
| 142 |
+
--save_strategy epoch \
|
| 143 |
+
--save_total_limit 2 \
|
| 144 |
+
--load_best_model_at_end \
|
| 145 |
+
--fp16 \
|
| 146 |
+
--seed "$SEED" \
|
| 147 |
+
--wandb_project "$WANDB_PROJECT" \
|
| 148 |
+
--wandb_run_name "$run_name" \
|
| 149 |
+
--push_to_hub \
|
| 150 |
+
--hub_model_id "$hub_model_id"
|
| 151 |
+
|
| 152 |
+
# Check if training was successful
|
| 153 |
+
if [ $? -eq 0 ]; then
|
| 154 |
+
print_status "Training completed successfully: $run_name"
|
| 155 |
+
return 0
|
| 156 |
+
else
|
| 157 |
+
print_error "Training failed: $run_name"
|
| 158 |
+
return 1
|
| 159 |
+
fi
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Function to run evaluation
|
| 163 |
+
run_evaluation() {
|
| 164 |
+
local model_path=$1
|
| 165 |
+
local data_column=$2
|
| 166 |
+
local num_samples=${3:-500}
|
| 167 |
+
|
| 168 |
+
print_status "Evaluating model: $model_path"
|
| 169 |
+
|
| 170 |
+
python scripts/evaluate.py \
|
| 171 |
+
--model_path "$model_path" \
|
| 172 |
+
--dataset_repo_id "$DATASET_REPO" \
|
| 173 |
+
--data_dir "$DATA_DIR" \
|
| 174 |
+
--data_column "$data_column" \
|
| 175 |
+
--num_samples "$num_samples" \
|
| 176 |
+
--output_dir "$EVAL_OUTPUT" \
|
| 177 |
+
--temperature 0.7 \
|
| 178 |
+
--seed "$SEED"
|
| 179 |
+
|
| 180 |
+
if [ $? -eq 0 ]; then
|
| 181 |
+
print_status "Evaluation completed: $model_path"
|
| 182 |
+
else
|
| 183 |
+
print_warning "Evaluation had issues: $model_path"
|
| 184 |
+
fi
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# Parse command line arguments
|
| 188 |
+
RUN_TEST=false
|
| 189 |
+
RUN_TRAINING=true
|
| 190 |
+
RUN_EVAL=true
|
| 191 |
+
SPECIFIC_MODEL=""
|
| 192 |
+
|
| 193 |
+
while [[ $# -gt 0 ]]; do
|
| 194 |
+
case $1 in
|
| 195 |
+
--test-only)
|
| 196 |
+
RUN_TEST=true
|
| 197 |
+
RUN_TRAINING=false
|
| 198 |
+
RUN_EVAL=false
|
| 199 |
+
shift
|
| 200 |
+
;;
|
| 201 |
+
--no-eval)
|
| 202 |
+
RUN_EVAL=false
|
| 203 |
+
shift
|
| 204 |
+
;;
|
| 205 |
+
--eval-only)
|
| 206 |
+
RUN_TRAINING=false
|
| 207 |
+
RUN_EVAL=true
|
| 208 |
+
shift
|
| 209 |
+
;;
|
| 210 |
+
--model)
|
| 211 |
+
SPECIFIC_MODEL="$2"
|
| 212 |
+
shift 2
|
| 213 |
+
;;
|
| 214 |
+
--help)
|
| 215 |
+
echo "Usage: $0 [OPTIONS]"
|
| 216 |
+
echo ""
|
| 217 |
+
echo "Options:"
|
| 218 |
+
echo " --test-only Run only the test training (1 epoch)"
|
| 219 |
+
echo " --no-eval Skip evaluation after training"
|
| 220 |
+
echo " --eval-only Run only evaluation (skip training)"
|
| 221 |
+
echo " --model NAME Train only specific model (gpt2, gpt2-medium, gpt2-large)"
|
| 222 |
+
echo " --help Show this help message"
|
| 223 |
+
exit 0
|
| 224 |
+
;;
|
| 225 |
+
*)
|
| 226 |
+
print_error "Unknown option: $1"
|
| 227 |
+
exit 1
|
| 228 |
+
;;
|
| 229 |
+
esac
|
| 230 |
+
done
|
| 231 |
+
|
| 232 |
+
# Test run
|
| 233 |
+
if [ "$RUN_TEST" = true ]; then
|
| 234 |
+
print_header "Running Test Training (1 epoch with gpt2)"
|
| 235 |
+
|
| 236 |
+
python scripts/train.py \
|
| 237 |
+
--model_name_or_path gpt2 \
|
| 238 |
+
--dataset_repo_id "$DATASET_REPO" \
|
| 239 |
+
--data_dir "$DATA_DIR" \
|
| 240 |
+
--data_column "i_prompt_n" \
|
| 241 |
+
--approach "infix" \
|
| 242 |
+
--output_dir "${OUTPUT_BASE}/test_run" \
|
| 243 |
+
--num_train_epochs 1 \
|
| 244 |
+
--per_device_train_batch_size 16 \
|
| 245 |
+
--gradient_accumulation_steps 4 \
|
| 246 |
+
--learning_rate 5e-5 \
|
| 247 |
+
--block_size "$BLOCK_SIZE" \
|
| 248 |
+
--logging_steps 20 \
|
| 249 |
+
--eval_strategy epoch \
|
| 250 |
+
--save_strategy epoch \
|
| 251 |
+
--fp16 \
|
| 252 |
+
--seed "$SEED" \
|
| 253 |
+
--wandb_project "${WANDB_PROJECT}_test"
|
| 254 |
+
|
| 255 |
+
print_status "Test training completed!"
|
| 256 |
+
print_status "Checklist:"
|
| 257 |
+
echo " [ ] GPU detected and functioning"
|
| 258 |
+
echo " [ ] Dataset loaded correctly"
|
| 259 |
+
echo " [ ] Training completed without errors"
|
| 260 |
+
echo " [ ] Wandb received metrics"
|
| 261 |
+
echo " [ ] Model saved locally"
|
| 262 |
+
echo ""
|
| 263 |
+
echo "Now test evaluate.py and generate.py:"
|
| 264 |
+
echo " python scripts/evaluate.py --model_path ./output/test_run --num_samples 50"
|
| 265 |
+
echo " python scripts/generate.py --model_path ./output/test_run --num_generations 5 --validate"
|
| 266 |
+
exit 0
|
| 267 |
+
fi
|
| 268 |
+
|
| 269 |
+
# Track completed trainings
|
| 270 |
+
declare -a COMPLETED_MODELS=()
|
| 271 |
+
declare -a FAILED_MODELS=()
|
| 272 |
+
|
| 273 |
+
# Main training loop
|
| 274 |
+
if [ "$RUN_TRAINING" = true ]; then
|
| 275 |
+
print_header "Starting Full Training Workflow"
|
| 276 |
+
|
| 277 |
+
START_TIME=$(date +%s)
|
| 278 |
+
|
| 279 |
+
for config in "${CONFIGS[@]}"; do
|
| 280 |
+
IFS='|' read -r model_name epochs batch_size grad_accum lr run_suffix <<< "$config"
|
| 281 |
+
|
| 282 |
+
# Skip if specific model requested and this is not it
|
| 283 |
+
if [ -n "$SPECIFIC_MODEL" ] && [ "$model_name" != "$SPECIFIC_MODEL" ]; then
|
| 284 |
+
continue
|
| 285 |
+
fi
|
| 286 |
+
|
| 287 |
+
for data_config in "${DATA_COLUMNS[@]}"; do
|
| 288 |
+
IFS='|' read -r data_column format <<< "$data_config"
|
| 289 |
+
|
| 290 |
+
run_name="${run_suffix}_${DATA_DIR}_${format}"
|
| 291 |
+
|
| 292 |
+
print_status "Starting training: $run_name"
|
| 293 |
+
|
| 294 |
+
if run_training "$model_name" "$epochs" "$batch_size" "$grad_accum" "$lr" "$run_suffix" "$data_column" "$format"; then
|
| 295 |
+
COMPLETED_MODELS+=("${HF_USER}/${run_name}|${data_column}")
|
| 296 |
+
else
|
| 297 |
+
FAILED_MODELS+=("$run_name")
|
| 298 |
+
fi
|
| 299 |
+
|
| 300 |
+
# Small delay between trainings
|
| 301 |
+
sleep 10
|
| 302 |
+
done
|
| 303 |
+
done
|
| 304 |
+
|
| 305 |
+
END_TIME=$(date +%s)
|
| 306 |
+
DURATION=$((END_TIME - START_TIME))
|
| 307 |
+
HOURS=$((DURATION / 3600))
|
| 308 |
+
MINUTES=$(((DURATION % 3600) / 60))
|
| 309 |
+
|
| 310 |
+
print_header "Training Summary"
|
| 311 |
+
echo "Total time: ${HOURS}h ${MINUTES}m"
|
| 312 |
+
echo ""
|
| 313 |
+
echo "Completed models (${#COMPLETED_MODELS[@]}):"
|
| 314 |
+
for model in "${COMPLETED_MODELS[@]}"; do
|
| 315 |
+
echo " - ${model%|*}"
|
| 316 |
+
done
|
| 317 |
+
echo ""
|
| 318 |
+
if [ ${#FAILED_MODELS[@]} -gt 0 ]; then
|
| 319 |
+
echo "Failed models (${#FAILED_MODELS[@]}):"
|
| 320 |
+
for model in "${FAILED_MODELS[@]}"; do
|
| 321 |
+
echo " - $model"
|
| 322 |
+
done
|
| 323 |
+
fi
|
| 324 |
+
fi
|
| 325 |
+
|
| 326 |
+
# Evaluation
|
| 327 |
+
if [ "$RUN_EVAL" = true ]; then
|
| 328 |
+
print_header "Running Evaluations"
|
| 329 |
+
|
| 330 |
+
# If we just trained, use those models
|
| 331 |
+
if [ ${#COMPLETED_MODELS[@]} -gt 0 ]; then
|
| 332 |
+
for model_info in "${COMPLETED_MODELS[@]}"; do
|
| 333 |
+
IFS='|' read -r model_path data_column <<< "$model_info"
|
| 334 |
+
run_evaluation "$model_path" "$data_column" 500
|
| 335 |
+
done
|
| 336 |
+
else
|
| 337 |
+
# Otherwise, evaluate all expected models
|
| 338 |
+
for config in "${CONFIGS[@]}"; do
|
| 339 |
+
IFS='|' read -r model_name epochs batch_size grad_accum lr run_suffix <<< "$config"
|
| 340 |
+
|
| 341 |
+
for data_config in "${DATA_COLUMNS[@]}"; do
|
| 342 |
+
IFS='|' read -r data_column format <<< "$data_config"
|
| 343 |
+
|
| 344 |
+
run_name="${run_suffix}_${DATA_DIR}_${format}"
|
| 345 |
+
model_path="${HF_USER}/${run_name}"
|
| 346 |
+
|
| 347 |
+
run_evaluation "$model_path" "$data_column" 500
|
| 348 |
+
done
|
| 349 |
+
done
|
| 350 |
+
fi
|
| 351 |
+
|
| 352 |
+
print_header "Evaluation Complete"
|
| 353 |
+
echo "Results saved to: $EVAL_OUTPUT"
|
| 354 |
+
fi
|
| 355 |
+
|
| 356 |
+
print_header "Workflow Complete!"
|
| 357 |
+
echo ""
|
| 358 |
+
echo "Next steps:"
|
| 359 |
+
echo "1. Check training results on wandb: https://wandb.ai/${WANDB_PROJECT}"
|
| 360 |
+
echo "2. Check models on HuggingFace Hub: https://huggingface.co/${HF_USER}"
|
| 361 |
+
echo "3. Review evaluation results in: $EVAL_OUTPUT"
|
| 362 |
+
echo ""
|
| 363 |
+
echo "To test a model interactively:"
|
| 364 |
+
echo " python scripts/generate.py --model_path ${HF_USER}/Se124M_700K_infix --interactive --validate"
|
| 365 |
+
echo ""
|
scripts/aws/setup_and_train_exp_a.sh
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Complete setup and training script for EXP-A (JSON format)
|
| 3 |
+
# Run this on a fresh AWS instance
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
echo "=============================================="
|
| 8 |
+
echo "EXP-A: Complete Setup and Training"
|
| 9 |
+
echo "JSON Format with <|endofex|> marker"
|
| 10 |
+
echo "=============================================="
|
| 11 |
+
echo "Started: $(date)"
|
| 12 |
+
echo ""
|
| 13 |
+
|
| 14 |
+
cd /home/ubuntu/seriguela
|
| 15 |
+
|
| 16 |
+
# Activate environment
|
| 17 |
+
source venv/bin/activate
|
| 18 |
+
|
| 19 |
+
# Step 1: Prepare data
|
| 20 |
+
echo "[1/3] Preparing training data..."
|
| 21 |
+
echo "This will download from HuggingFace Hub and convert to JSON format"
|
| 22 |
+
echo ""
|
| 23 |
+
|
| 24 |
+
mkdir -p data/experiments
|
| 25 |
+
|
| 26 |
+
python scripts/data/prepare_experiment_data.py \
|
| 27 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 28 |
+
--data_dir 700K \
|
| 29 |
+
--data_column i_prompt_n \
|
| 30 |
+
--output_base_dir ./data/experiments
|
| 31 |
+
|
| 32 |
+
# Verify data
|
| 33 |
+
if [ ! -f "./data/experiments/exp_a_json/train.csv" ]; then
|
| 34 |
+
echo "ERROR: Data preparation failed!"
|
| 35 |
+
exit 1
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
TRAIN_COUNT=$(wc -l < ./data/experiments/exp_a_json/train.csv)
|
| 39 |
+
echo "Training samples: $TRAIN_COUNT"
|
| 40 |
+
|
| 41 |
+
# Step 2: Run training
|
| 42 |
+
echo ""
|
| 43 |
+
echo "[2/3] Starting training..."
|
| 44 |
+
echo "Output: ./output/exp_a_json"
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
python scripts/train_experiment.py \
|
| 48 |
+
--experiment_name "exp_a_json" \
|
| 49 |
+
--train_file ./data/experiments/exp_a_json/train.csv \
|
| 50 |
+
--validation_file ./data/experiments/exp_a_json/validation.csv \
|
| 51 |
+
--output_dir ./output/exp_a_json \
|
| 52 |
+
--json_format \
|
| 53 |
+
--end_marker '"}' \
|
| 54 |
+
--num_train_epochs 3 \
|
| 55 |
+
--per_device_train_batch_size 8 \
|
| 56 |
+
--gradient_accumulation_steps 4 \
|
| 57 |
+
--learning_rate 5e-5 \
|
| 58 |
+
--block_size 256 \
|
| 59 |
+
--fp16 \
|
| 60 |
+
--wandb_project seriguela_experiments \
|
| 61 |
+
--wandb_run_name "exp_a_json_$(date +%Y%m%d_%H%M%S)"
|
| 62 |
+
|
| 63 |
+
# Step 3: Evaluate
|
| 64 |
+
echo ""
|
| 65 |
+
echo "[3/3] Evaluating model..."
|
| 66 |
+
echo ""
|
| 67 |
+
|
| 68 |
+
python scripts/evaluate_experiments.py \
|
| 69 |
+
--model_path ./output/exp_a_json \
|
| 70 |
+
--experiment_type json \
|
| 71 |
+
--num_samples 200 \
|
| 72 |
+
--output_file ./output/exp_a_json/evaluation_results.json
|
| 73 |
+
|
| 74 |
+
echo ""
|
| 75 |
+
echo "=============================================="
|
| 76 |
+
echo "EXP-A Complete!"
|
| 77 |
+
echo "=============================================="
|
| 78 |
+
echo "Finished: $(date)"
|
| 79 |
+
echo "Model: ./output/exp_a_json"
|
| 80 |
+
echo "Results: ./output/exp_a_json/evaluation_results.json"
|
| 81 |
+
|
| 82 |
+
# Create completion marker
|
| 83 |
+
touch /home/ubuntu/.exp_a_complete
|
scripts/aws/setup_and_train_exp_b.sh
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Complete setup and training script for EXP-B (EOS format)
|
| 3 |
+
# Run this on a fresh AWS instance
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
echo "=============================================="
|
| 8 |
+
echo "EXP-B: Complete Setup and Training"
|
| 9 |
+
echo "EOS Format with <|endoftext|> marker"
|
| 10 |
+
echo "=============================================="
|
| 11 |
+
echo "Started: $(date)"
|
| 12 |
+
echo ""
|
| 13 |
+
|
| 14 |
+
cd /home/ubuntu/seriguela
|
| 15 |
+
|
| 16 |
+
# Activate environment
|
| 17 |
+
source venv/bin/activate
|
| 18 |
+
|
| 19 |
+
# Step 1: Prepare data
|
| 20 |
+
echo "[1/3] Preparing training data..."
|
| 21 |
+
echo "This will download from HuggingFace Hub and convert to EOS format"
|
| 22 |
+
echo ""
|
| 23 |
+
|
| 24 |
+
mkdir -p data/experiments
|
| 25 |
+
|
| 26 |
+
python scripts/data/prepare_experiment_data.py \
|
| 27 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 28 |
+
--data_dir 700K \
|
| 29 |
+
--data_column i_prompt_n \
|
| 30 |
+
--output_base_dir ./data/experiments
|
| 31 |
+
|
| 32 |
+
# Verify data
|
| 33 |
+
if [ ! -f "./data/experiments/exp_b_eos/train.csv" ]; then
|
| 34 |
+
echo "ERROR: Data preparation failed!"
|
| 35 |
+
exit 1
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
TRAIN_COUNT=$(wc -l < ./data/experiments/exp_b_eos/train.csv)
|
| 39 |
+
echo "Training samples: $TRAIN_COUNT"
|
| 40 |
+
|
| 41 |
+
# Step 2: Run training
|
| 42 |
+
echo ""
|
| 43 |
+
echo "[2/3] Starting training..."
|
| 44 |
+
echo "Output: ./output/exp_b_eos"
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
python scripts/train_experiment.py \
|
| 48 |
+
--experiment_name "exp_b_eos" \
|
| 49 |
+
--train_file ./data/experiments/exp_b_eos/train.csv \
|
| 50 |
+
--validation_file ./data/experiments/exp_b_eos/validation.csv \
|
| 51 |
+
--output_dir ./output/exp_b_eos \
|
| 52 |
+
--end_marker "<|endoftext|>" \
|
| 53 |
+
--use_native_eos \
|
| 54 |
+
--num_train_epochs 3 \
|
| 55 |
+
--per_device_train_batch_size 8 \
|
| 56 |
+
--gradient_accumulation_steps 4 \
|
| 57 |
+
--learning_rate 5e-5 \
|
| 58 |
+
--block_size 128 \
|
| 59 |
+
--fp16 \
|
| 60 |
+
--wandb_project seriguela_experiments \
|
| 61 |
+
--wandb_run_name "exp_b_eos_$(date +%Y%m%d_%H%M%S)"
|
| 62 |
+
|
| 63 |
+
# Step 3: Evaluate
|
| 64 |
+
echo ""
|
| 65 |
+
echo "[3/3] Evaluating model..."
|
| 66 |
+
echo ""
|
| 67 |
+
|
| 68 |
+
python scripts/evaluate_experiments.py \
|
| 69 |
+
--model_path ./output/exp_b_eos \
|
| 70 |
+
--experiment_type eos \
|
| 71 |
+
--num_samples 200 \
|
| 72 |
+
--output_file ./output/exp_b_eos/evaluation_results.json
|
| 73 |
+
|
| 74 |
+
echo ""
|
| 75 |
+
echo "=============================================="
|
| 76 |
+
echo "EXP-B Complete!"
|
| 77 |
+
echo "=============================================="
|
| 78 |
+
echo "Finished: $(date)"
|
| 79 |
+
echo "Model: ./output/exp_b_eos"
|
| 80 |
+
echo "Results: ./output/exp_b_eos/evaluation_results.json"
|
| 81 |
+
|
| 82 |
+
# Create completion marker
|
| 83 |
+
touch /home/ubuntu/.exp_b_complete
|
scripts/aws/setup_aws.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Setup script for AWS g5.xlarge instance (Deep Learning AMI Ubuntu)
|
| 3 |
+
# Project: Seriguela - GPT-2 Fine-tuning for Symbolic Regression
|
| 4 |
+
# Optimized for faster setup
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
echo "=========================================="
|
| 9 |
+
echo "Seriguela AWS Setup Script (Optimized)"
|
| 10 |
+
echo "=========================================="
|
| 11 |
+
|
| 12 |
+
# Colors
|
| 13 |
+
GREEN='\033[0;32m'
|
| 14 |
+
YELLOW='\033[1;33m'
|
| 15 |
+
RED='\033[0;31m'
|
| 16 |
+
NC='\033[0m'
|
| 17 |
+
|
| 18 |
+
print_status() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
| 19 |
+
print_warning() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
| 20 |
+
print_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
| 21 |
+
|
| 22 |
+
# Configuration
|
| 23 |
+
REPO_URL="https://github.com/augustocsc/seriguela.git"
|
| 24 |
+
REPO_DIR="$HOME/seriguela"
|
| 25 |
+
PYTHON_VERSION="python3"
|
| 26 |
+
|
| 27 |
+
# Check GPU
|
| 28 |
+
print_status "Checking GPU..."
|
| 29 |
+
if ! nvidia-smi &>/dev/null; then
|
| 30 |
+
print_error "GPU not detected!"
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
| 34 |
+
|
| 35 |
+
# Install system dependencies (minimal)
|
| 36 |
+
print_status "Installing system dependencies..."
|
| 37 |
+
sudo apt-get update -qq
|
| 38 |
+
sudo apt-get install -y -qq python3-venv python3-pip git htop
|
| 39 |
+
|
| 40 |
+
# Clone or update repository
|
| 41 |
+
if [ -d "$REPO_DIR" ]; then
|
| 42 |
+
print_status "Updating repository..."
|
| 43 |
+
cd "$REPO_DIR" && git pull
|
| 44 |
+
else
|
| 45 |
+
print_status "Cloning repository..."
|
| 46 |
+
git clone "$REPO_URL" "$REPO_DIR"
|
| 47 |
+
fi
|
| 48 |
+
cd "$REPO_DIR"
|
| 49 |
+
|
| 50 |
+
# Setup virtual environment
|
| 51 |
+
print_status "Setting up virtual environment..."
|
| 52 |
+
$PYTHON_VERSION -m venv venv
|
| 53 |
+
source venv/bin/activate
|
| 54 |
+
|
| 55 |
+
# Upgrade pip and install dependencies in one step
|
| 56 |
+
print_status "Installing all dependencies (this may take a few minutes)..."
|
| 57 |
+
pip install --upgrade pip -q
|
| 58 |
+
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 -q
|
| 59 |
+
|
| 60 |
+
# Verify installation
|
| 61 |
+
print_status "Verifying installation..."
|
| 62 |
+
python -c "
|
| 63 |
+
import torch
|
| 64 |
+
import transformers
|
| 65 |
+
import peft
|
| 66 |
+
print(f'PyTorch: {torch.__version__}')
|
| 67 |
+
print(f'CUDA available: {torch.cuda.is_available()}')
|
| 68 |
+
if torch.cuda.is_available():
|
| 69 |
+
print(f'GPU: {torch.cuda.get_device_name(0)}')
|
| 70 |
+
print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
|
| 71 |
+
print(f'Transformers: {transformers.__version__}')
|
| 72 |
+
print(f'PEFT: {peft.__version__}')
|
| 73 |
+
"
|
| 74 |
+
|
| 75 |
+
echo ""
|
| 76 |
+
echo "=========================================="
|
| 77 |
+
echo -e "${GREEN}Setup Complete!${NC}"
|
| 78 |
+
echo "=========================================="
|
| 79 |
+
echo ""
|
| 80 |
+
echo "Next: Configure tokens in .env file:"
|
| 81 |
+
echo " echo 'HF_TOKEN=your_token' > .env"
|
| 82 |
+
echo " echo 'WANDB_API_KEY=your_key' >> .env"
|
| 83 |
+
echo ""
|
| 84 |
+
echo "Then run training:"
|
| 85 |
+
echo " source venv/bin/activate"
|
| 86 |
+
echo " bash scripts/aws/run_all_training.sh --test-only"
|
| 87 |
+
echo ""
|
scripts/aws/train_exp_a.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# EXP-A: Training with JSON structured format
|
| 3 |
+
# Uses <|endofex|> as end marker
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
echo "=============================================="
|
| 8 |
+
echo "EXP-A: JSON Format Training"
|
| 9 |
+
echo "=============================================="
|
| 10 |
+
|
| 11 |
+
cd ~/seriguela
|
| 12 |
+
|
| 13 |
+
# Activate virtual environment
|
| 14 |
+
source venv/bin/activate
|
| 15 |
+
|
| 16 |
+
# Check data exists
|
| 17 |
+
if [ ! -f "./data/experiments/exp_a_json/train.csv" ]; then
|
| 18 |
+
echo "ERROR: Training data not found!"
|
| 19 |
+
echo "Expected: ./data/experiments/exp_a_json/train.csv"
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Count samples
|
| 24 |
+
TRAIN_COUNT=$(wc -l < ./data/experiments/exp_a_json/train.csv)
|
| 25 |
+
echo "Training samples: $TRAIN_COUNT"
|
| 26 |
+
|
| 27 |
+
# Training configuration
|
| 28 |
+
export WANDB_PROJECT="seriguela_experiments"
|
| 29 |
+
export HF_TOKEN="${HF_TOKEN:-}"
|
| 30 |
+
export WANDB_API_KEY="${WANDB_API_KEY:-}"
|
| 31 |
+
|
| 32 |
+
# Run training
|
| 33 |
+
echo ""
|
| 34 |
+
echo "Starting training..."
|
| 35 |
+
echo "Output: ./output/exp_a_json"
|
| 36 |
+
echo ""
|
| 37 |
+
|
| 38 |
+
python scripts/train_experiment.py \
|
| 39 |
+
--experiment_name "exp_a_json" \
|
| 40 |
+
--train_file ./data/experiments/exp_a_json/train.csv \
|
| 41 |
+
--validation_file ./data/experiments/exp_a_json/validation.csv \
|
| 42 |
+
--output_dir ./output/exp_a_json \
|
| 43 |
+
--end_marker "<|endofex|>" \
|
| 44 |
+
--num_train_epochs 3 \
|
| 45 |
+
--per_device_train_batch_size 8 \
|
| 46 |
+
--gradient_accumulation_steps 4 \
|
| 47 |
+
--learning_rate 5e-5 \
|
| 48 |
+
--block_size 256 \
|
| 49 |
+
--fp16 \
|
| 50 |
+
--wandb_project seriguela_experiments \
|
| 51 |
+
--wandb_run_name "exp_a_json_$(date +%Y%m%d_%H%M%S)"
|
| 52 |
+
|
| 53 |
+
echo ""
|
| 54 |
+
echo "=============================================="
|
| 55 |
+
echo "EXP-A Training Complete!"
|
| 56 |
+
echo "=============================================="
|
| 57 |
+
echo "Model saved to: ./output/exp_a_json"
|
scripts/aws/train_exp_b.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# EXP-B: Training with GPT-2 EOS token (<|endoftext|>)
|
| 3 |
+
# Uses native GPT-2 EOS token (ID 50256)
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
echo "=============================================="
|
| 8 |
+
echo "EXP-B: EOS Token Format Training"
|
| 9 |
+
echo "=============================================="
|
| 10 |
+
|
| 11 |
+
cd ~/seriguela
|
| 12 |
+
|
| 13 |
+
# Activate virtual environment
|
| 14 |
+
source venv/bin/activate
|
| 15 |
+
|
| 16 |
+
# Check data exists
|
| 17 |
+
if [ ! -f "./data/experiments/exp_b_eos/train.csv" ]; then
|
| 18 |
+
echo "ERROR: Training data not found!"
|
| 19 |
+
echo "Expected: ./data/experiments/exp_b_eos/train.csv"
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Count samples
|
| 24 |
+
TRAIN_COUNT=$(wc -l < ./data/experiments/exp_b_eos/train.csv)
|
| 25 |
+
echo "Training samples: $TRAIN_COUNT"
|
| 26 |
+
|
| 27 |
+
# Training configuration
|
| 28 |
+
export WANDB_PROJECT="seriguela_experiments"
|
| 29 |
+
export HF_TOKEN="${HF_TOKEN:-}"
|
| 30 |
+
export WANDB_API_KEY="${WANDB_API_KEY:-}"
|
| 31 |
+
|
| 32 |
+
# Run training
|
| 33 |
+
echo ""
|
| 34 |
+
echo "Starting training..."
|
| 35 |
+
echo "Output: ./output/exp_b_eos"
|
| 36 |
+
echo ""
|
| 37 |
+
|
| 38 |
+
python scripts/train_experiment.py \
|
| 39 |
+
--experiment_name "exp_b_eos" \
|
| 40 |
+
--train_file ./data/experiments/exp_b_eos/train.csv \
|
| 41 |
+
--validation_file ./data/experiments/exp_b_eos/validation.csv \
|
| 42 |
+
--output_dir ./output/exp_b_eos \
|
| 43 |
+
--end_marker "<|endoftext|>" \
|
| 44 |
+
--use_native_eos \
|
| 45 |
+
--num_train_epochs 3 \
|
| 46 |
+
--per_device_train_batch_size 8 \
|
| 47 |
+
--gradient_accumulation_steps 4 \
|
| 48 |
+
--learning_rate 5e-5 \
|
| 49 |
+
--block_size 128 \
|
| 50 |
+
--fp16 \
|
| 51 |
+
--wandb_project seriguela_experiments \
|
| 52 |
+
--wandb_run_name "exp_b_eos_$(date +%Y%m%d_%H%M%S)"
|
| 53 |
+
|
| 54 |
+
echo ""
|
| 55 |
+
echo "=============================================="
|
| 56 |
+
echo "EXP-B Training Complete!"
|
| 57 |
+
echo "=============================================="
|
| 58 |
+
echo "Model saved to: ./output/exp_b_eos"
|
scripts/aws/train_fixed_model.sh
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Train model with proper end-of-expression markers
|
| 3 |
+
# This script retrains the Seriguela model with <|endofex|> markers in the training data
|
| 4 |
+
# so the model learns to stop generation correctly.
|
| 5 |
+
|
| 6 |
+
set -e # Exit on error
|
| 7 |
+
|
| 8 |
+
echo "================================================================"
|
| 9 |
+
echo "SERIGUELA - Training Model with Proper End Markers"
|
| 10 |
+
echo "================================================================"
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
MODEL_NAME="gpt2"
|
| 14 |
+
DATASET_REPO="augustocsc/sintetico_natural"
|
| 15 |
+
DATA_DIR="700K"
|
| 16 |
+
DATA_COLUMN="i_prompt_n" # or p_prompt_n for prefix
|
| 17 |
+
OUTPUT_DIR="./output/Se124M_700K_infix_v2"
|
| 18 |
+
HUB_MODEL_ID="augustocsc/Se124M_700K_infix_v2" # NEW REPO NAME
|
| 19 |
+
|
| 20 |
+
# Hyperparameters
|
| 21 |
+
EPOCHS=3
|
| 22 |
+
BATCH_SIZE=8
|
| 23 |
+
LEARNING_RATE=5e-5
|
| 24 |
+
BLOCK_SIZE=128
|
| 25 |
+
LORA_R=8
|
| 26 |
+
LORA_ALPHA=32
|
| 27 |
+
LORA_DROPOUT=0.05
|
| 28 |
+
|
| 29 |
+
echo ""
|
| 30 |
+
echo "Configuration:"
|
| 31 |
+
echo " Model: $MODEL_NAME"
|
| 32 |
+
echo " Dataset: $DATASET_REPO/$DATA_DIR"
|
| 33 |
+
echo " Data Column: $DATA_COLUMN"
|
| 34 |
+
echo " Output: $OUTPUT_DIR"
|
| 35 |
+
echo " Hub Model: $HUB_MODEL_ID"
|
| 36 |
+
echo ""
|
| 37 |
+
echo "Hyperparameters:"
|
| 38 |
+
echo " Epochs: $EPOCHS"
|
| 39 |
+
echo " Batch Size: $BATCH_SIZE"
|
| 40 |
+
echo " Learning Rate: $LEARNING_RATE"
|
| 41 |
+
echo " Block Size: $BLOCK_SIZE"
|
| 42 |
+
echo " LoRA r: $LORA_R"
|
| 43 |
+
echo " LoRA alpha: $LORA_ALPHA"
|
| 44 |
+
echo " LoRA dropout: $LORA_DROPOUT"
|
| 45 |
+
echo "================================================================"
|
| 46 |
+
|
| 47 |
+
# Check if data preparation is needed
|
| 48 |
+
echo ""
|
| 49 |
+
echo "[Step 1/3] Checking data preparation..."
|
| 50 |
+
if [ ! -f "./data/processed/700K_fixed/train_700K.csv" ]; then
|
| 51 |
+
echo "Training data not found. Preparing data with end markers..."
|
| 52 |
+
|
| 53 |
+
python scripts/data/prepare_training_data_fixed.py \
|
| 54 |
+
--dataset_repo_id $DATASET_REPO \
|
| 55 |
+
--data_dir $DATA_DIR \
|
| 56 |
+
--data_column $DATA_COLUMN \
|
| 57 |
+
--output_dir ./data/processed/700K_fixed \
|
| 58 |
+
--validate
|
| 59 |
+
|
| 60 |
+
if [ $? -ne 0 ]; then
|
| 61 |
+
echo "❌ Data preparation failed!"
|
| 62 |
+
exit 1
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
echo "✅ Data preparation complete!"
|
| 66 |
+
else
|
| 67 |
+
echo "✅ Training data already prepared (./data/processed/700K_fixed/)"
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
# Optional: Show sample of prepared data
|
| 71 |
+
echo ""
|
| 72 |
+
echo "Sample of prepared data:"
|
| 73 |
+
head -n 2 ./data/processed/700K_fixed/train_700K.csv
|
| 74 |
+
echo ""
|
| 75 |
+
|
| 76 |
+
# Start training
|
| 77 |
+
echo ""
|
| 78 |
+
echo "[Step 2/3] Starting training..."
|
| 79 |
+
echo "================================================================"
|
| 80 |
+
echo ""
|
| 81 |
+
|
| 82 |
+
python scripts/train.py \
|
| 83 |
+
--model_name_or_path $MODEL_NAME \
|
| 84 |
+
--dataset_repo_id $DATASET_REPO \
|
| 85 |
+
--data_dir $DATA_DIR \
|
| 86 |
+
--data_column $DATA_COLUMN \
|
| 87 |
+
--output_dir $OUTPUT_DIR \
|
| 88 |
+
--num_train_epochs $EPOCHS \
|
| 89 |
+
--per_device_train_batch_size $BATCH_SIZE \
|
| 90 |
+
--learning_rate $LEARNING_RATE \
|
| 91 |
+
--block_size $BLOCK_SIZE \
|
| 92 |
+
--eval_strategy epoch \
|
| 93 |
+
--save_strategy epoch \
|
| 94 |
+
--save_total_limit 2 \
|
| 95 |
+
--load_best_model_at_end \
|
| 96 |
+
--lora_r $LORA_R \
|
| 97 |
+
--lora_alpha $LORA_ALPHA \
|
| 98 |
+
--lora_dropout $LORA_DROPOUT \
|
| 99 |
+
--push_to_hub \
|
| 100 |
+
--hub_model_id $HUB_MODEL_ID \
|
| 101 |
+
--logging_steps 100 \
|
| 102 |
+
--seed 42
|
| 103 |
+
|
| 104 |
+
if [ $? -ne 0 ]; then
|
| 105 |
+
echo ""
|
| 106 |
+
echo "❌ Training failed!"
|
| 107 |
+
exit 1
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
echo ""
|
| 111 |
+
echo "✅ Training complete!"
|
| 112 |
+
|
| 113 |
+
# Quick test generation
|
| 114 |
+
echo ""
|
| 115 |
+
echo "[Step 3/3] Testing model generation..."
|
| 116 |
+
echo "================================================================"
|
| 117 |
+
echo ""
|
| 118 |
+
|
| 119 |
+
python scripts/generate.py \
|
| 120 |
+
--model_path $OUTPUT_DIR \
|
| 121 |
+
--num_generations 5 \
|
| 122 |
+
--validate
|
| 123 |
+
|
| 124 |
+
if [ $? -ne 0 ]; then
|
| 125 |
+
echo ""
|
| 126 |
+
echo "⚠️ Generation test failed, but model was trained successfully"
|
| 127 |
+
else
|
| 128 |
+
echo ""
|
| 129 |
+
echo "✅ Generation test passed!"
|
| 130 |
+
fi
|
| 131 |
+
|
| 132 |
+
# Summary
|
| 133 |
+
echo ""
|
| 134 |
+
echo "================================================================"
|
| 135 |
+
echo "TRAINING COMPLETE"
|
| 136 |
+
echo "================================================================"
|
| 137 |
+
echo "Model saved to: $OUTPUT_DIR"
|
| 138 |
+
echo "Model pushed to: $HUB_MODEL_ID"
|
| 139 |
+
echo ""
|
| 140 |
+
echo "Next steps:"
|
| 141 |
+
echo " 1. Evaluate the model: python scripts/evaluate.py --model_path $OUTPUT_DIR"
|
| 142 |
+
echo " 2. Compare with old model: python scripts/compare_models.py --model1 ./output/Se124M_700K_infix --model2 $OUTPUT_DIR"
|
| 143 |
+
echo " 3. Generate more samples: python scripts/generate.py --model_path $OUTPUT_DIR --num_generations 20"
|
| 144 |
+
echo "================================================================"
|
scripts/aws/train_v3_model.sh
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Training script for v3 model with proper end markers
|
| 3 |
+
# This script is designed to be run on AWS EC2 instances with GPU
|
| 4 |
+
|
| 5 |
+
set -e # Exit on error
|
| 6 |
+
|
| 7 |
+
echo "=================================================="
|
| 8 |
+
echo "Seriguela v3 Model Training"
|
| 9 |
+
echo "=================================================="
|
| 10 |
+
echo "Start time: $(date)"
|
| 11 |
+
echo ""
|
| 12 |
+
|
| 13 |
+
# Configuration
|
| 14 |
+
PROJECT_DIR="${HOME}/seriguela"
|
| 15 |
+
OUTPUT_DIR="${PROJECT_DIR}/output/Se124M_700K_infix_v3"
|
| 16 |
+
CONFIG_FILE="${PROJECT_DIR}/configs/training_v3.json"
|
| 17 |
+
DATA_DIR="${PROJECT_DIR}/data/processed/700K_fixed"
|
| 18 |
+
|
| 19 |
+
# Check if running in project directory
|
| 20 |
+
if [ ! -d "$PROJECT_DIR" ]; then
|
| 21 |
+
echo "ERROR: Project directory not found: $PROJECT_DIR"
|
| 22 |
+
exit 1
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
cd "$PROJECT_DIR"
|
| 26 |
+
|
| 27 |
+
# Activate virtual environment
|
| 28 |
+
echo "Activating virtual environment..."
|
| 29 |
+
if [ -d "venv" ]; then
|
| 30 |
+
source venv/bin/activate
|
| 31 |
+
elif [ -d ".seriguela" ]; then
|
| 32 |
+
source .seriguela/bin/activate
|
| 33 |
+
else
|
| 34 |
+
echo "ERROR: Virtual environment not found!"
|
| 35 |
+
exit 1
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
# Verify GPU availability
|
| 39 |
+
echo ""
|
| 40 |
+
echo "Checking GPU availability..."
|
| 41 |
+
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'GPU count: {torch.cuda.device_count()}'); print(f'GPU name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')"
|
| 42 |
+
|
| 43 |
+
if ! python -c "import torch; exit(0 if torch.cuda.is_available() else 1)"; then
|
| 44 |
+
echo "WARNING: GPU not detected! Training will be slow on CPU."
|
| 45 |
+
read -p "Continue anyway? (y/n) " -n 1 -r
|
| 46 |
+
echo
|
| 47 |
+
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
| 48 |
+
exit 1
|
| 49 |
+
fi
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
# Verify data files exist
|
| 53 |
+
echo ""
|
| 54 |
+
echo "Verifying training data..."
|
| 55 |
+
if [ ! -f "$DATA_DIR/train_700K.csv" ]; then
|
| 56 |
+
echo "ERROR: Training data not found: $DATA_DIR/train_700K.csv"
|
| 57 |
+
echo "Please ensure data preparation step was completed."
|
| 58 |
+
exit 1
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
if [ ! -f "$DATA_DIR/validation_700K.csv" ]; then
|
| 62 |
+
echo "ERROR: Validation data not found: $DATA_DIR/validation_700K.csv"
|
| 63 |
+
exit 1
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
# Check for end markers in data
|
| 67 |
+
echo "Checking for end markers in training data..."
|
| 68 |
+
MARKER_COUNT=$(head -100 "$DATA_DIR/train_700K.csv" | grep -c "<|endofex|>" || true)
|
| 69 |
+
if [ "$MARKER_COUNT" -eq 0 ]; then
|
| 70 |
+
echo "ERROR: No <|endofex|> markers found in training data!"
|
| 71 |
+
echo "Please run data preparation script first."
|
| 72 |
+
exit 1
|
| 73 |
+
else
|
| 74 |
+
echo "✓ End markers detected in training data"
|
| 75 |
+
fi
|
| 76 |
+
|
| 77 |
+
# Verify config file exists
|
| 78 |
+
if [ ! -f "$CONFIG_FILE" ]; then
|
| 79 |
+
echo "ERROR: Config file not found: $CONFIG_FILE"
|
| 80 |
+
exit 1
|
| 81 |
+
fi
|
| 82 |
+
|
| 83 |
+
echo ""
|
| 84 |
+
echo "Configuration:"
|
| 85 |
+
echo " Config file: $CONFIG_FILE"
|
| 86 |
+
echo " Output directory: $OUTPUT_DIR"
|
| 87 |
+
echo " Training data: $DATA_DIR/train_700K.csv"
|
| 88 |
+
echo " Validation data: $DATA_DIR/validation_700K.csv"
|
| 89 |
+
echo ""
|
| 90 |
+
|
| 91 |
+
# Create output directory
|
| 92 |
+
mkdir -p "$OUTPUT_DIR"
|
| 93 |
+
|
| 94 |
+
# Set environment variables
|
| 95 |
+
export WANDB_PROJECT="seriguela_v3"
|
| 96 |
+
export WANDB_RUN_NAME="v3_proper_markers_$(date +%Y%m%d_%H%M%S)"
|
| 97 |
+
|
| 98 |
+
# Check if wandb is configured
|
| 99 |
+
if ! python -c "import wandb; wandb.api.api_key" 2>/dev/null; then
|
| 100 |
+
echo "WARNING: Weights & Biases not configured. Training will proceed without W&B logging."
|
| 101 |
+
echo "To enable W&B: wandb login"
|
| 102 |
+
fi
|
| 103 |
+
|
| 104 |
+
# Start training
|
| 105 |
+
echo ""
|
| 106 |
+
echo "=================================================="
|
| 107 |
+
echo "Starting training..."
|
| 108 |
+
echo "=================================================="
|
| 109 |
+
echo ""
|
| 110 |
+
|
| 111 |
+
# Run training with config file
|
| 112 |
+
python scripts/train.py \
|
| 113 |
+
--config "$CONFIG_FILE" \
|
| 114 |
+
--output_dir "$OUTPUT_DIR" \
|
| 115 |
+
--use_local_csvs \
|
| 116 |
+
--train_file "$DATA_DIR/train_700K.csv" \
|
| 117 |
+
--validation_file "$DATA_DIR/validation_700K.csv" \
|
| 118 |
+
--wandb_project seriguela_v3 \
|
| 119 |
+
--wandb_run_name "$WANDB_RUN_NAME"
|
| 120 |
+
|
| 121 |
+
TRAIN_EXIT_CODE=$?
|
| 122 |
+
|
| 123 |
+
echo ""
|
| 124 |
+
echo "=================================================="
|
| 125 |
+
echo "Training completed"
|
| 126 |
+
echo "=================================================="
|
| 127 |
+
echo "End time: $(date)"
|
| 128 |
+
echo "Exit code: $TRAIN_EXIT_CODE"
|
| 129 |
+
echo ""
|
| 130 |
+
|
| 131 |
+
if [ $TRAIN_EXIT_CODE -eq 0 ]; then
|
| 132 |
+
echo "✓ Training completed successfully!"
|
| 133 |
+
echo ""
|
| 134 |
+
echo "Model saved to: $OUTPUT_DIR"
|
| 135 |
+
echo ""
|
| 136 |
+
echo "Next steps:"
|
| 137 |
+
echo "1. Run evaluation: python scripts/evaluate.py --model_path $OUTPUT_DIR"
|
| 138 |
+
echo "2. Test generation: python scripts/generate.py --model_path $OUTPUT_DIR --num_generations 50 --validate"
|
| 139 |
+
echo "3. Push to Hub (if configured): huggingface-cli upload augustocsc/Se124M_700K_infix_v3 $OUTPUT_DIR"
|
| 140 |
+
else
|
| 141 |
+
echo "✗ Training failed with exit code $TRAIN_EXIT_CODE"
|
| 142 |
+
echo "Check logs above for error details."
|
| 143 |
+
exit $TRAIN_EXIT_CODE
|
| 144 |
+
fi
|
scripts/aws/validate_setup.sh
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Validate Seriguela Training Setup
|
| 3 |
+
# This script validates that everything is configured correctly before training
|
| 4 |
+
# Usage: ./validate_setup.sh
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
RED='\033[0;31m'
|
| 10 |
+
YELLOW='\033[1;33m'
|
| 11 |
+
BLUE='\033[0;34m'
|
| 12 |
+
NC='\033[0m'
|
| 13 |
+
|
| 14 |
+
print_success() { echo -e "${GREEN}✅${NC} $1"; }
|
| 15 |
+
print_error() { echo -e "${RED}❌${NC} $1"; }
|
| 16 |
+
print_warning() { echo -e "${YELLOW}⚠️${NC} $1"; }
|
| 17 |
+
print_header() { echo -e "\n${BLUE}========== $1 ==========${NC}"; }
|
| 18 |
+
|
| 19 |
+
ERRORS=0
|
| 20 |
+
|
| 21 |
+
print_header "Seriguela Setup Validation"
|
| 22 |
+
|
| 23 |
+
# Change to project directory
|
| 24 |
+
if [ -d "/home/ubuntu/seriguela" ]; then
|
| 25 |
+
cd /home/ubuntu/seriguela
|
| 26 |
+
elif [ -d "$(pwd)/seriguela" ]; then
|
| 27 |
+
cd seriguela
|
| 28 |
+
else
|
| 29 |
+
cd .
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
print_header "1. Python Environment"
|
| 33 |
+
|
| 34 |
+
# Check Python version
|
| 35 |
+
if python3 --version &> /dev/null; then
|
| 36 |
+
PYTHON_VERSION=$(python3 --version)
|
| 37 |
+
print_success "Python installed: $PYTHON_VERSION"
|
| 38 |
+
else
|
| 39 |
+
print_error "Python not found"
|
| 40 |
+
ERRORS=$((ERRORS + 1))
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
# Check venv
|
| 44 |
+
if [ -d "venv" ]; then
|
| 45 |
+
print_success "Virtual environment exists"
|
| 46 |
+
source venv/bin/activate
|
| 47 |
+
else
|
| 48 |
+
print_error "Virtual environment not found"
|
| 49 |
+
ERRORS=$((ERRORS + 1))
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
# Check pip
|
| 53 |
+
if pip --version &> /dev/null; then
|
| 54 |
+
PIP_VERSION=$(pip --version | cut -d' ' -f2)
|
| 55 |
+
print_success "pip version: $PIP_VERSION"
|
| 56 |
+
else
|
| 57 |
+
print_error "pip not found"
|
| 58 |
+
ERRORS=$((ERRORS + 1))
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
print_header "2. Python Packages"
|
| 62 |
+
|
| 63 |
+
# Check critical packages
|
| 64 |
+
PACKAGES=(
|
| 65 |
+
"transformers:Hugging Face Transformers"
|
| 66 |
+
"torch:PyTorch"
|
| 67 |
+
"wandb:Weights & Biases"
|
| 68 |
+
"peft:Parameter-Efficient Fine-Tuning"
|
| 69 |
+
"datasets:Hugging Face Datasets"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
for pkg_info in "${PACKAGES[@]}"; do
|
| 73 |
+
IFS=':' read -r pkg_name pkg_desc <<< "$pkg_info"
|
| 74 |
+
|
| 75 |
+
if python3 -c "import $pkg_name" &> /dev/null; then
|
| 76 |
+
VERSION=$(python3 -c "import $pkg_name; print($pkg_name.__version__)" 2>/dev/null || echo "unknown")
|
| 77 |
+
print_success "$pkg_desc ($pkg_name) - version $VERSION"
|
| 78 |
+
else
|
| 79 |
+
print_error "$pkg_desc ($pkg_name) not installed"
|
| 80 |
+
ERRORS=$((ERRORS + 1))
|
| 81 |
+
fi
|
| 82 |
+
done
|
| 83 |
+
|
| 84 |
+
# Check Wandb version specifically
|
| 85 |
+
WANDB_VERSION=$(python3 -c "import wandb; print(wandb.__version__)" 2>/dev/null || echo "0.0.0")
|
| 86 |
+
REQUIRED_VERSION="0.24.0"
|
| 87 |
+
|
| 88 |
+
if python3 << VERSIONCHECK
|
| 89 |
+
import sys
|
| 90 |
+
from packaging import version
|
| 91 |
+
current = version.parse("$WANDB_VERSION")
|
| 92 |
+
required = version.parse("$REQUIRED_VERSION")
|
| 93 |
+
sys.exit(0 if current >= required else 1)
|
| 94 |
+
VERSIONCHECK
|
| 95 |
+
then
|
| 96 |
+
print_success "Wandb version $WANDB_VERSION (>= $REQUIRED_VERSION required)"
|
| 97 |
+
else
|
| 98 |
+
print_warning "Wandb version $WANDB_VERSION is older than recommended $REQUIRED_VERSION"
|
| 99 |
+
print_warning "New API key format (wandb_v1_...) requires Wandb >= 0.24.0"
|
| 100 |
+
fi
|
| 101 |
+
|
| 102 |
+
print_header "3. Environment Variables"
|
| 103 |
+
|
| 104 |
+
# Load .env if exists
|
| 105 |
+
if [ -f ".env" ]; then
|
| 106 |
+
source <(grep -v '^#' .env | sed 's/^/export /')
|
| 107 |
+
print_success ".env file loaded"
|
| 108 |
+
else
|
| 109 |
+
print_warning ".env file not found"
|
| 110 |
+
fi
|
| 111 |
+
|
| 112 |
+
# Check HF_TOKEN
|
| 113 |
+
if [ -n "$HF_TOKEN" ]; then
|
| 114 |
+
TOKEN_LEN=${#HF_TOKEN}
|
| 115 |
+
print_success "HF_TOKEN set ($TOKEN_LEN characters)"
|
| 116 |
+
else
|
| 117 |
+
print_warning "HF_TOKEN not set (model won't be pushed to Hub)"
|
| 118 |
+
fi
|
| 119 |
+
|
| 120 |
+
# Check WANDB_API_KEY
|
| 121 |
+
if [ -n "$WANDB_API_KEY" ]; then
|
| 122 |
+
KEY_LEN=${#WANDB_API_KEY}
|
| 123 |
+
print_success "WANDB_API_KEY set ($KEY_LEN characters)"
|
| 124 |
+
else
|
| 125 |
+
print_error "WANDB_API_KEY not set"
|
| 126 |
+
ERRORS=$((ERRORS + 1))
|
| 127 |
+
fi
|
| 128 |
+
|
| 129 |
+
print_header "4. GPU / CUDA"
|
| 130 |
+
|
| 131 |
+
# Check nvidia-smi
|
| 132 |
+
if nvidia-smi &> /dev/null; then
|
| 133 |
+
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
|
| 134 |
+
GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -1)
|
| 135 |
+
print_success "GPU detected: $GPU_NAME ($GPU_MEMORY)"
|
| 136 |
+
else
|
| 137 |
+
print_error "GPU not detected (nvidia-smi failed)"
|
| 138 |
+
ERRORS=$((ERRORS + 1))
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
# Check CUDA
|
| 142 |
+
if python3 -c "import torch; assert torch.cuda.is_available()" &> /dev/null; then
|
| 143 |
+
CUDA_VERSION=$(python3 -c "import torch; print(torch.version.cuda)")
|
| 144 |
+
GPU_COUNT=$(python3 -c "import torch; print(torch.cuda.device_count())")
|
| 145 |
+
print_success "CUDA available: version $CUDA_VERSION ($GPU_COUNT GPU(s))"
|
| 146 |
+
else
|
| 147 |
+
print_error "CUDA not available in PyTorch"
|
| 148 |
+
ERRORS=$((ERRORS + 1))
|
| 149 |
+
fi
|
| 150 |
+
|
| 151 |
+
print_header "5. Wandb Authentication"
|
| 152 |
+
|
| 153 |
+
if [ -n "$WANDB_API_KEY" ]; then
|
| 154 |
+
if python3 << WANDBCHECK
|
| 155 |
+
import wandb
|
| 156 |
+
import sys
|
| 157 |
+
try:
|
| 158 |
+
result = wandb.login(key="$WANDB_API_KEY", relogin=True)
|
| 159 |
+
if result:
|
| 160 |
+
print("Login successful")
|
| 161 |
+
sys.exit(0)
|
| 162 |
+
else:
|
| 163 |
+
print("Login failed")
|
| 164 |
+
sys.exit(1)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"Error: {e}")
|
| 167 |
+
sys.exit(1)
|
| 168 |
+
WANDBCHECK
|
| 169 |
+
then
|
| 170 |
+
print_success "Wandb authentication successful"
|
| 171 |
+
|
| 172 |
+
# Get user info
|
| 173 |
+
WANDB_USER=$(python3 << 'GETUSER'
|
| 174 |
+
import wandb
|
| 175 |
+
try:
|
| 176 |
+
api = wandb.Api()
|
| 177 |
+
print(api.viewer.get("username", "unknown"))
|
| 178 |
+
except:
|
| 179 |
+
print("unknown")
|
| 180 |
+
GETUSER
|
| 181 |
+
)
|
| 182 |
+
print_success "Logged in as: $WANDB_USER"
|
| 183 |
+
else
|
| 184 |
+
print_error "Wandb authentication failed"
|
| 185 |
+
ERRORS=$((ERRORS + 1))
|
| 186 |
+
fi
|
| 187 |
+
else
|
| 188 |
+
print_warning "Skipping Wandb auth (no API key)"
|
| 189 |
+
fi
|
| 190 |
+
|
| 191 |
+
print_header "6. HuggingFace Authentication"
|
| 192 |
+
|
| 193 |
+
if [ -n "$HF_TOKEN" ]; then
|
| 194 |
+
if python3 << HFCHECK
|
| 195 |
+
from huggingface_hub import HfApi
|
| 196 |
+
import sys
|
| 197 |
+
try:
|
| 198 |
+
api = HfApi(token="$HF_TOKEN")
|
| 199 |
+
user = api.whoami()
|
| 200 |
+
print(f"Login successful: {user.get('name', 'unknown')}")
|
| 201 |
+
sys.exit(0)
|
| 202 |
+
except Exception as e:
|
| 203 |
+
print(f"Error: {e}")
|
| 204 |
+
sys.exit(1)
|
| 205 |
+
HFCHECK
|
| 206 |
+
then
|
| 207 |
+
print_success "HuggingFace authentication successful"
|
| 208 |
+
else
|
| 209 |
+
print_error "HuggingFace authentication failed"
|
| 210 |
+
ERRORS=$((ERRORS + 1))
|
| 211 |
+
fi
|
| 212 |
+
else
|
| 213 |
+
print_warning "Skipping HF auth (no token)"
|
| 214 |
+
fi
|
| 215 |
+
|
| 216 |
+
print_header "7. Dataset Access"
|
| 217 |
+
|
| 218 |
+
# Test dataset loading
|
| 219 |
+
if python3 << DATASETCHECK
|
| 220 |
+
from datasets import load_dataset
|
| 221 |
+
import sys
|
| 222 |
+
try:
|
| 223 |
+
# Quick test load (just get info, don't download)
|
| 224 |
+
ds = load_dataset("augustocsc/sintetico_natural", split="train", streaming=True)
|
| 225 |
+
print("Dataset accessible")
|
| 226 |
+
sys.exit(0)
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"Error: {e}")
|
| 229 |
+
sys.exit(1)
|
| 230 |
+
DATASETCHECK
|
| 231 |
+
then
|
| 232 |
+
print_success "Dataset accessible: augustocsc/sintetico_natural"
|
| 233 |
+
else
|
| 234 |
+
print_warning "Could not verify dataset access (may require authentication)"
|
| 235 |
+
fi
|
| 236 |
+
|
| 237 |
+
print_header "8. Scripts"
|
| 238 |
+
|
| 239 |
+
SCRIPTS=(
|
| 240 |
+
"scripts/train.py"
|
| 241 |
+
"scripts/evaluate.py"
|
| 242 |
+
"scripts/generate.py"
|
| 243 |
+
"scripts/aws/monitor_training_auto.sh"
|
| 244 |
+
"scripts/aws/analyze_model.sh"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
for script in "${SCRIPTS[@]}"; do
|
| 248 |
+
if [ -f "$script" ]; then
|
| 249 |
+
print_success "$script exists"
|
| 250 |
+
else
|
| 251 |
+
print_warning "$script not found"
|
| 252 |
+
fi
|
| 253 |
+
done
|
| 254 |
+
|
| 255 |
+
# Final summary
|
| 256 |
+
print_header "Validation Summary"
|
| 257 |
+
echo ""
|
| 258 |
+
|
| 259 |
+
if [ $ERRORS -eq 0 ]; then
|
| 260 |
+
echo -e "${GREEN}╔══════════════════════════════════════╗${NC}"
|
| 261 |
+
echo -e "${GREEN}║ ║${NC}"
|
| 262 |
+
echo -e "${GREEN}║ ✅ ALL VALIDATIONS PASSED ✅ ║${NC}"
|
| 263 |
+
echo -e "${GREEN}║ ║${NC}"
|
| 264 |
+
echo -e "${GREEN}║ Ready for training! 🚀 ║${NC}"
|
| 265 |
+
echo -e "${GREEN}║ ║${NC}"
|
| 266 |
+
echo -e "${GREEN}╚══════════════════════════════════════╝${NC}"
|
| 267 |
+
echo ""
|
| 268 |
+
echo "You can now run:"
|
| 269 |
+
echo " python scripts/train.py --help"
|
| 270 |
+
echo " bash scripts/aws/run_all_training.sh"
|
| 271 |
+
echo ""
|
| 272 |
+
exit 0
|
| 273 |
+
else
|
| 274 |
+
echo -e "${RED}╔══════════════════════════════════════╗${NC}"
|
| 275 |
+
echo -e "${RED}║ ║${NC}"
|
| 276 |
+
echo -e "${RED}║ ❌ VALIDATION FAILED ❌ ║${NC}"
|
| 277 |
+
echo -e "${RED}║ ║${NC}"
|
| 278 |
+
echo -e "${RED}║ $ERRORS error(s) found ║${NC}"
|
| 279 |
+
echo -e "${RED}║ ║${NC}"
|
| 280 |
+
echo -e "${RED}╚══════════════════════════════════════╝${NC}"
|
| 281 |
+
echo ""
|
| 282 |
+
echo "Please fix the errors above before training."
|
| 283 |
+
echo ""
|
| 284 |
+
exit 1
|
| 285 |
+
fi
|
scripts/compare_models.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compare two models: band-aided vs properly trained.
|
| 3 |
+
Evaluates both on same test set and reports metrics.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/compare_models.py \
|
| 7 |
+
--model1 ./output/Se124M_700K_infix \
|
| 8 |
+
--model2 ./output/Se124M_700K_infix_v2 \
|
| 9 |
+
--num_samples 500
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
# Import evaluate_model from evaluate.py
|
| 19 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 20 |
+
from evaluate import evaluate_model
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def format_metric(value, metric_type):
|
| 24 |
+
"""Format metric value for display."""
|
| 25 |
+
if metric_type == "rate":
|
| 26 |
+
return f"{value * 100:5.1f}%"
|
| 27 |
+
elif metric_type == "float":
|
| 28 |
+
return f"{value:7.2f}"
|
| 29 |
+
elif metric_type == "int":
|
| 30 |
+
return f"{int(value):7d}"
|
| 31 |
+
else:
|
| 32 |
+
return f"{value:7}"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def print_comparison_table(metrics1, metrics2, model1_name, model2_name):
|
| 36 |
+
"""Print formatted comparison table."""
|
| 37 |
+
print("\n" + "=" * 80)
|
| 38 |
+
print("COMPARISON RESULTS")
|
| 39 |
+
print("=" * 80)
|
| 40 |
+
|
| 41 |
+
# Header
|
| 42 |
+
print(f"{'Metric':<35} {model1_name:>20} {model2_name:>20}")
|
| 43 |
+
print("-" * 80)
|
| 44 |
+
|
| 45 |
+
# Define metrics to compare
|
| 46 |
+
comparison_metrics = [
|
| 47 |
+
("valid_rate", "Valid Rate", "rate"),
|
| 48 |
+
("parseable_rate", "Parseable Rate", "rate"),
|
| 49 |
+
("constraints_met_rate", "Constraints Met", "rate"),
|
| 50 |
+
("diversity_rate", "Diversity", "rate"),
|
| 51 |
+
("avg_expression_length", "Avg Expression Length", "float"),
|
| 52 |
+
("total_samples", "Total Samples", "int"),
|
| 53 |
+
("total_valid", "Total Valid", "int"),
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
improvements = []
|
| 57 |
+
|
| 58 |
+
for key, label, metric_type in comparison_metrics:
|
| 59 |
+
val1 = metrics1.get(key, 0)
|
| 60 |
+
val2 = metrics2.get(key, 0)
|
| 61 |
+
|
| 62 |
+
formatted_val1 = format_metric(val1, metric_type)
|
| 63 |
+
formatted_val2 = format_metric(val2, metric_type)
|
| 64 |
+
|
| 65 |
+
print(f"{label:<35} {formatted_val1:>20} {formatted_val2:>20}")
|
| 66 |
+
|
| 67 |
+
# Calculate improvement for rate metrics
|
| 68 |
+
if metric_type == "rate" and val1 > 0:
|
| 69 |
+
improvement = ((val2 - val1) / val1) * 100
|
| 70 |
+
improvements.append((label, improvement, val2 - val1))
|
| 71 |
+
|
| 72 |
+
print("=" * 80)
|
| 73 |
+
|
| 74 |
+
# Show improvements
|
| 75 |
+
print("\nIMPROVEMENTS (Model 2 vs Model 1):")
|
| 76 |
+
print("-" * 80)
|
| 77 |
+
|
| 78 |
+
for label, improvement, absolute_diff in improvements:
|
| 79 |
+
sign = "+" if improvement > 0 else ""
|
| 80 |
+
abs_sign = "+" if absolute_diff > 0 else ""
|
| 81 |
+
print(f"{label:<35} {sign}{improvement:>6.1f}% ({abs_sign}{absolute_diff * 100:>5.1f} pp)")
|
| 82 |
+
|
| 83 |
+
print("-" * 80)
|
| 84 |
+
|
| 85 |
+
# Determine winner
|
| 86 |
+
valid_rate_improvement = metrics2.get("valid_rate", 0) - metrics1.get("valid_rate", 0)
|
| 87 |
+
|
| 88 |
+
print("\n" + "=" * 80)
|
| 89 |
+
if valid_rate_improvement > 0.20: # >20% improvement
|
| 90 |
+
print(f"🎯 SIGNIFICANT IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points")
|
| 91 |
+
print(" The properly trained model significantly outperforms the band-aided version!")
|
| 92 |
+
elif valid_rate_improvement > 0.05: # >5% improvement
|
| 93 |
+
print(f"✅ IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points")
|
| 94 |
+
print(" The properly trained model shows clear improvement.")
|
| 95 |
+
elif valid_rate_improvement > 0: # Any improvement
|
| 96 |
+
print(f"📈 SLIGHT IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points")
|
| 97 |
+
print(" The properly trained model shows modest improvement.")
|
| 98 |
+
elif valid_rate_improvement == 0:
|
| 99 |
+
print("⚖️ TIE: Both models perform equally")
|
| 100 |
+
print(" No significant difference between models.")
|
| 101 |
+
else:
|
| 102 |
+
print(f"⚠️ REGRESSION: Model 1 wins by {-valid_rate_improvement * 100:.1f} percentage points")
|
| 103 |
+
print(" The band-aided model performs better - retraining may need adjustment.")
|
| 104 |
+
|
| 105 |
+
print("=" * 80)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_comparison_report(metrics1, metrics2, model1_name, model2_name, output_dir):
|
| 109 |
+
"""Save detailed comparison report to JSON."""
|
| 110 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 113 |
+
report_file = os.path.join(output_dir, f"comparison_{timestamp}.json")
|
| 114 |
+
|
| 115 |
+
report = {
|
| 116 |
+
"timestamp": timestamp,
|
| 117 |
+
"model1": {
|
| 118 |
+
"name": model1_name,
|
| 119 |
+
"metrics": metrics1
|
| 120 |
+
},
|
| 121 |
+
"model2": {
|
| 122 |
+
"name": model2_name,
|
| 123 |
+
"metrics": metrics2
|
| 124 |
+
},
|
| 125 |
+
"comparison": {
|
| 126 |
+
"valid_rate_diff": metrics2.get("valid_rate", 0) - metrics1.get("valid_rate", 0),
|
| 127 |
+
"parseable_rate_diff": metrics2.get("parseable_rate", 0) - metrics1.get("parseable_rate", 0),
|
| 128 |
+
"constraints_met_diff": metrics2.get("constraints_met_rate", 0) - metrics1.get("constraints_met_rate", 0),
|
| 129 |
+
"diversity_diff": metrics2.get("diversity_rate", 0) - metrics1.get("diversity_rate", 0),
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
with open(report_file, "w") as f:
|
| 134 |
+
json.dump(report, f, indent=2)
|
| 135 |
+
|
| 136 |
+
print(f"\n📄 Detailed comparison report saved to: {report_file}")
|
| 137 |
+
return report_file
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def compare_models(model1_path, model2_path, model1_name, model2_name,
|
| 141 |
+
num_samples=500, dataset_repo_id="augustocsc/sintetico_natural",
|
| 142 |
+
data_dir="700K", data_column="i_prompt_n", output_dir="./evaluation_results/comparison"):
|
| 143 |
+
"""Compare two models on same test set."""
|
| 144 |
+
|
| 145 |
+
print("=" * 80)
|
| 146 |
+
print("MODEL COMPARISON")
|
| 147 |
+
print("=" * 80)
|
| 148 |
+
print(f"Model 1 ({model1_name}): {model1_path}")
|
| 149 |
+
print(f"Model 2 ({model2_name}): {model2_path}")
|
| 150 |
+
print(f"Samples: {num_samples}")
|
| 151 |
+
print(f"Dataset: {dataset_repo_id}/{data_dir}")
|
| 152 |
+
print("=" * 80)
|
| 153 |
+
|
| 154 |
+
# Create output directory
|
| 155 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 156 |
+
|
| 157 |
+
# Evaluate Model 1 (band-aided)
|
| 158 |
+
print(f"\n[1/2] Evaluating Model 1: {model1_name}")
|
| 159 |
+
print("-" * 80)
|
| 160 |
+
|
| 161 |
+
args1 = argparse.Namespace(
|
| 162 |
+
model_path=model1_path,
|
| 163 |
+
base_model=None,
|
| 164 |
+
dataset_repo_id=dataset_repo_id,
|
| 165 |
+
data_dir=data_dir,
|
| 166 |
+
data_column=data_column,
|
| 167 |
+
num_samples=num_samples,
|
| 168 |
+
num_generations=1,
|
| 169 |
+
max_new_tokens=128,
|
| 170 |
+
temperature=0.7,
|
| 171 |
+
top_p=0.9,
|
| 172 |
+
output_dir=os.path.join(output_dir, "model1"),
|
| 173 |
+
seed=42,
|
| 174 |
+
device="auto"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
metrics1 = evaluate_model(args1)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"\n❌ Error evaluating Model 1: {e}")
|
| 181 |
+
import traceback
|
| 182 |
+
traceback.print_exc()
|
| 183 |
+
sys.exit(1)
|
| 184 |
+
|
| 185 |
+
# Evaluate Model 2 (properly trained)
|
| 186 |
+
print(f"\n[2/2] Evaluating Model 2: {model2_name}")
|
| 187 |
+
print("-" * 80)
|
| 188 |
+
|
| 189 |
+
args2 = argparse.Namespace(
|
| 190 |
+
model_path=model2_path,
|
| 191 |
+
base_model=None,
|
| 192 |
+
dataset_repo_id=dataset_repo_id,
|
| 193 |
+
data_dir=data_dir,
|
| 194 |
+
data_column=data_column,
|
| 195 |
+
num_samples=num_samples,
|
| 196 |
+
num_generations=1,
|
| 197 |
+
max_new_tokens=128,
|
| 198 |
+
temperature=0.7,
|
| 199 |
+
top_p=0.9,
|
| 200 |
+
output_dir=os.path.join(output_dir, "model2"),
|
| 201 |
+
seed=42,
|
| 202 |
+
device="auto"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
metrics2 = evaluate_model(args2)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"\n❌ Error evaluating Model 2: {e}")
|
| 209 |
+
import traceback
|
| 210 |
+
traceback.print_exc()
|
| 211 |
+
sys.exit(1)
|
| 212 |
+
|
| 213 |
+
# Print comparison
|
| 214 |
+
print_comparison_table(metrics1, metrics2, model1_name, model2_name)
|
| 215 |
+
|
| 216 |
+
# Save report
|
| 217 |
+
save_comparison_report(metrics1, metrics2, model1_name, model2_name, output_dir)
|
| 218 |
+
|
| 219 |
+
return metrics1, metrics2
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def main():
|
| 223 |
+
parser = argparse.ArgumentParser(
|
| 224 |
+
description="Compare two models on the same test set"
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument("--model1", type=str, required=True,
|
| 227 |
+
help="Path to first model (band-aided)")
|
| 228 |
+
parser.add_argument("--model2", type=str, required=True,
|
| 229 |
+
help="Path to second model (properly trained)")
|
| 230 |
+
parser.add_argument("--model1_name", type=str, default="Band-Aided",
|
| 231 |
+
help="Display name for model 1")
|
| 232 |
+
parser.add_argument("--model2_name", type=str, default="Proper",
|
| 233 |
+
help="Display name for model 2")
|
| 234 |
+
parser.add_argument("--num_samples", type=int, default=500,
|
| 235 |
+
help="Number of samples to evaluate")
|
| 236 |
+
parser.add_argument("--dataset_repo_id", type=str, default="augustocsc/sintetico_natural",
|
| 237 |
+
help="HuggingFace dataset repository")
|
| 238 |
+
parser.add_argument("--data_dir", type=str, default="700K",
|
| 239 |
+
help="Data directory within dataset")
|
| 240 |
+
parser.add_argument("--data_column", type=str, default="i_prompt_n",
|
| 241 |
+
help="Column name for prompts")
|
| 242 |
+
parser.add_argument("--output_dir", type=str, default="./evaluation_results/comparison",
|
| 243 |
+
help="Directory to save comparison results")
|
| 244 |
+
|
| 245 |
+
args = parser.parse_args()
|
| 246 |
+
|
| 247 |
+
# Run comparison
|
| 248 |
+
try:
|
| 249 |
+
compare_models(
|
| 250 |
+
model1_path=args.model1,
|
| 251 |
+
model2_path=args.model2,
|
| 252 |
+
model1_name=args.model1_name,
|
| 253 |
+
model2_name=args.model2_name,
|
| 254 |
+
num_samples=args.num_samples,
|
| 255 |
+
dataset_repo_id=args.dataset_repo_id,
|
| 256 |
+
data_dir=args.data_dir,
|
| 257 |
+
data_column=args.data_column,
|
| 258 |
+
output_dir=args.output_dir
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
print("\n✅ Comparison complete!")
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f"\n❌ Error during comparison: {e}")
|
| 265 |
+
import traceback
|
| 266 |
+
traceback.print_exc()
|
| 267 |
+
sys.exit(1)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
main()
|
scripts/compare_v1_v2_simple.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple comparison of V1 vs V2 model generation quality
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import torch
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
| 10 |
+
from peft import PeftModel
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
from classes.expression import Expression
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ExpressionStoppingCriteria(StoppingCriteria):
|
| 17 |
+
def __init__(self, tokenizer, stop_sequences):
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False)
|
| 20 |
+
for seq in stop_sequences]
|
| 21 |
+
|
| 22 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 23 |
+
for stop_ids in self.stop_ids:
|
| 24 |
+
if len(stop_ids) > 0 and len(input_ids[0]) >= len(stop_ids):
|
| 25 |
+
if input_ids[0][-len(stop_ids):].tolist() == stop_ids:
|
| 26 |
+
return True
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_model(model_name, model_label):
|
| 31 |
+
print(f"\n{'='*60}")
|
| 32 |
+
print(f"Loading {model_label}: {model_name}")
|
| 33 |
+
print('='*60)
|
| 34 |
+
|
| 35 |
+
# Load base GPT-2
|
| 36 |
+
print("Loading base GPT-2...")
|
| 37 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 38 |
+
"gpt2",
|
| 39 |
+
torch_dtype=torch.float16,
|
| 40 |
+
device_map="auto"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Setup tokenizer
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 45 |
+
tokenizer.add_special_tokens({
|
| 46 |
+
"additional_special_tokens": ["<|startofex|>", "<|endofex|>"]
|
| 47 |
+
})
|
| 48 |
+
|
| 49 |
+
# Resize embeddings
|
| 50 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 51 |
+
|
| 52 |
+
# Load adapter and merge
|
| 53 |
+
print(f"Loading adapter from {model_name}...")
|
| 54 |
+
model = PeftModel.from_pretrained(model, model_name)
|
| 55 |
+
print("Merging adapter...")
|
| 56 |
+
model = model.merge_and_unload()
|
| 57 |
+
model.eval()
|
| 58 |
+
|
| 59 |
+
print(f"✓ {model_label} loaded successfully")
|
| 60 |
+
return model, tokenizer
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_model(model, tokenizer, model_label, n_samples=20):
|
| 64 |
+
print(f"\n{'='*60}")
|
| 65 |
+
print(f"Testing {model_label} - {n_samples} generations")
|
| 66 |
+
print('='*60)
|
| 67 |
+
|
| 68 |
+
# Same prompt for both models
|
| 69 |
+
prompt = """vars: x_1, x_2
|
| 70 |
+
oper: *, +, -, sin, cos
|
| 71 |
+
cons: C
|
| 72 |
+
expr:"""
|
| 73 |
+
|
| 74 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 75 |
+
|
| 76 |
+
# Stopping criteria
|
| 77 |
+
stopping_criteria = StoppingCriteriaList([
|
| 78 |
+
ExpressionStoppingCriteria(tokenizer, ["<|endofex|>", "\n\nvars:"])
|
| 79 |
+
])
|
| 80 |
+
|
| 81 |
+
# Use OPTIMAL config for each model (from FINAL_RESULTS_V1_VS_V2.md)
|
| 82 |
+
if model_label == "V1":
|
| 83 |
+
# V1 optimal: 83.3% valid rate
|
| 84 |
+
gen_config = {
|
| 85 |
+
"temperature": 0.5,
|
| 86 |
+
"top_k": 40,
|
| 87 |
+
"top_p": 0.9,
|
| 88 |
+
"repetition_penalty": 1.15,
|
| 89 |
+
"max_new_tokens": 100,
|
| 90 |
+
"do_sample": True,
|
| 91 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 92 |
+
}
|
| 93 |
+
print("Using V1 optimal config: temp=0.5, top_k=40, rep_penalty=1.15")
|
| 94 |
+
else: # V2
|
| 95 |
+
# V2 optimal: 90% valid rate
|
| 96 |
+
gen_config = {
|
| 97 |
+
"temperature": 0.7,
|
| 98 |
+
"top_k": 0,
|
| 99 |
+
"top_p": 0.8,
|
| 100 |
+
"repetition_penalty": 1.0,
|
| 101 |
+
"max_new_tokens": 128,
|
| 102 |
+
"do_sample": True,
|
| 103 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 104 |
+
}
|
| 105 |
+
print("Using V2 optimal config: temp=0.7, top_p=0.8 (nucleus sampling)")
|
| 106 |
+
|
| 107 |
+
results = {
|
| 108 |
+
"valid_count": 0,
|
| 109 |
+
"correct_symbols_count": 0,
|
| 110 |
+
"expressions": []
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
allowed_vars = {"x_1", "x_2", "C"}
|
| 114 |
+
allowed_ops = {"*", "+", "-", "sin", "cos", "(", ")"}
|
| 115 |
+
|
| 116 |
+
print(f"\nGenerating {n_samples} expressions...\n")
|
| 117 |
+
|
| 118 |
+
for i in range(n_samples):
|
| 119 |
+
output = model.generate(
|
| 120 |
+
**inputs,
|
| 121 |
+
**gen_config,
|
| 122 |
+
stopping_criteria=stopping_criteria
|
| 123 |
+
)
|
| 124 |
+
text = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 125 |
+
|
| 126 |
+
# Extract expression
|
| 127 |
+
if "expr:" in text:
|
| 128 |
+
expr_str = text.split("expr:")[-1].strip()
|
| 129 |
+
expr_str = expr_str.split("<|endofex|>")[0].strip()
|
| 130 |
+
else:
|
| 131 |
+
expr_str = text
|
| 132 |
+
|
| 133 |
+
# Check if valid (can be parsed and evaluated)
|
| 134 |
+
is_valid = False
|
| 135 |
+
try:
|
| 136 |
+
expr = Expression(expr_str, is_prefix=False)
|
| 137 |
+
X_test = [[1.0, 2.0]] # Simple test
|
| 138 |
+
result = expr.evaluate(X_test)
|
| 139 |
+
if len(result) > 0 and all(x != float('inf') and x != float('-inf') and x == x for x in result):
|
| 140 |
+
is_valid = True
|
| 141 |
+
results["valid_count"] += 1
|
| 142 |
+
except:
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
# Check if uses only correct symbols
|
| 146 |
+
has_correct_symbols = True
|
| 147 |
+
# Remove spaces and check tokens
|
| 148 |
+
expr_clean = expr_str.replace(" ", "")
|
| 149 |
+
# Check for allowed patterns
|
| 150 |
+
for char in expr_clean:
|
| 151 |
+
if char.isalpha() and char not in "xCsinco_":
|
| 152 |
+
has_correct_symbols = False
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
# Check for garbage words
|
| 156 |
+
garbage_words = ["Buyable", "Instore", "Online", "Muslims", "crash", "Berman",
|
| 157 |
+
"vars:", "oper:", "expressed", "fluent", "Avenger", "repositories"]
|
| 158 |
+
for word in garbage_words:
|
| 159 |
+
if word in expr_str:
|
| 160 |
+
has_correct_symbols = False
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
if has_correct_symbols:
|
| 164 |
+
results["correct_symbols_count"] += 1
|
| 165 |
+
|
| 166 |
+
results["expressions"].append({
|
| 167 |
+
"index": i + 1,
|
| 168 |
+
"expression": expr_str[:80], # Limit display length
|
| 169 |
+
"valid": is_valid,
|
| 170 |
+
"correct_symbols": has_correct_symbols
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
# Show first 5 samples
|
| 174 |
+
if i < 5:
|
| 175 |
+
status = "✓ Valid" if is_valid else "✗ Invalid"
|
| 176 |
+
symbols = "✓ Clean" if has_correct_symbols else "✗ Garbage"
|
| 177 |
+
print(f" [{i+1:2d}] {status:10s} {symbols:10s} | {expr_str[:60]}")
|
| 178 |
+
|
| 179 |
+
print(f"\n{'-'*60}")
|
| 180 |
+
print(f"RESULTS FOR {model_label}:")
|
| 181 |
+
print(f" Valid expressions: {results['valid_count']:2d}/{n_samples} ({results['valid_count']/n_samples*100:.1f}%)")
|
| 182 |
+
print(f" Correct symbols only: {results['correct_symbols_count']:2d}/{n_samples} ({results['correct_symbols_count']/n_samples*100:.1f}%)")
|
| 183 |
+
print(f"{'-'*60}")
|
| 184 |
+
|
| 185 |
+
return results
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def main():
|
| 189 |
+
print("\n" + "="*60)
|
| 190 |
+
print("V1 vs V2 MODEL COMPARISON")
|
| 191 |
+
print("="*60)
|
| 192 |
+
print("Testing same prompt on both models")
|
| 193 |
+
print("Measuring: valid expressions + symbol correctness\n")
|
| 194 |
+
|
| 195 |
+
# Test V1
|
| 196 |
+
v1_model, v1_tokenizer = load_model("augustocsc/Se124M_700K_infix", "V1")
|
| 197 |
+
v1_results = test_model(v1_model, v1_tokenizer, "V1", n_samples=20)
|
| 198 |
+
|
| 199 |
+
# Clean up V1 from memory
|
| 200 |
+
del v1_model
|
| 201 |
+
torch.cuda.empty_cache()
|
| 202 |
+
|
| 203 |
+
# Test V2
|
| 204 |
+
v2_model, v2_tokenizer = load_model("augustocsc/Se124M_700K_infix_v2", "V2")
|
| 205 |
+
v2_results = test_model(v2_model, v2_tokenizer, "V2", n_samples=20)
|
| 206 |
+
|
| 207 |
+
# Final comparison
|
| 208 |
+
print("\n" + "="*60)
|
| 209 |
+
print("FINAL COMPARISON")
|
| 210 |
+
print("="*60)
|
| 211 |
+
print(f"\n{'Metric':<30s} {'V1':>10s} {'V2':>10s} {'Winner':>10s}")
|
| 212 |
+
print("-"*60)
|
| 213 |
+
|
| 214 |
+
v1_valid = v1_results["valid_count"]
|
| 215 |
+
v2_valid = v2_results["valid_count"]
|
| 216 |
+
valid_winner = "V1" if v1_valid > v2_valid else ("V2" if v2_valid > v1_valid else "TIE")
|
| 217 |
+
print(f"{'Valid Expressions':<30s} {v1_valid:>10d} {v2_valid:>10d} {valid_winner:>10s}")
|
| 218 |
+
|
| 219 |
+
v1_clean = v1_results["correct_symbols_count"]
|
| 220 |
+
v2_clean = v2_results["correct_symbols_count"]
|
| 221 |
+
clean_winner = "V1" if v1_clean > v2_clean else ("V2" if v2_clean > v1_clean else "TIE")
|
| 222 |
+
print(f"{'Correct Symbols Only':<30s} {v1_clean:>10d} {v2_clean:>10d} {clean_winner:>10s}")
|
| 223 |
+
|
| 224 |
+
print("-"*60)
|
| 225 |
+
print(f"{'Valid Rate':<30s} {v1_valid/20*100:>9.1f}% {v2_valid/20*100:>9.1f}%")
|
| 226 |
+
print(f"{'Clean Symbol Rate':<30s} {v1_clean/20*100:>9.1f}% {v2_clean/20*100:>9.1f}%")
|
| 227 |
+
print("="*60)
|
| 228 |
+
|
| 229 |
+
# Conclusion
|
| 230 |
+
print("\nConclusion:")
|
| 231 |
+
if v1_valid > v2_valid and v1_clean > v2_clean:
|
| 232 |
+
print(" → V1 is better on both metrics")
|
| 233 |
+
elif v2_valid > v1_valid and v2_clean > v1_clean:
|
| 234 |
+
print(" → V2 is better on both metrics")
|
| 235 |
+
else:
|
| 236 |
+
print(" → Mixed results - models have different strengths")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
main()
|
scripts/data/data_augmentation.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# augmentor.py
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
ALL_OPERANDS = ['+', '-', '*', '/', 'log', 'exp', 'cos', 'sqrt', 'asin', 'sin', '**', 'tan', 'abs']
|
| 7 |
+
|
| 8 |
+
def extract_operators(expr_str):
|
| 9 |
+
ops = set()
|
| 10 |
+
if 'exp' in expr_str: ops.add('exp')
|
| 11 |
+
if 'log' in expr_str: ops.add('log')
|
| 12 |
+
if 'cos' in expr_str: ops.add('cos')
|
| 13 |
+
if 'sin' in expr_str: ops.add('sin')
|
| 14 |
+
if '**' in expr_str: ops.add('**')
|
| 15 |
+
if 'sqrt' in expr_str: ops.add('sqrt')
|
| 16 |
+
if 'asin' in expr_str: ops.add('asin')
|
| 17 |
+
if 'tan' in expr_str: ops.add('tan')
|
| 18 |
+
if 'abs' in expr_str: ops.add('abs')
|
| 19 |
+
if '/' in expr_str: ops.add('/')
|
| 20 |
+
for op in ['+', '-', '*']:
|
| 21 |
+
if op in expr_str: ops.add(op)
|
| 22 |
+
return list(ops)
|
| 23 |
+
|
| 24 |
+
def infer_max_var(expr_str):
|
| 25 |
+
matches = re.findall(r'x_(\d+)', expr_str)
|
| 26 |
+
return max([int(m) for m in matches]) if matches else 1
|
| 27 |
+
|
| 28 |
+
def generate_expression_instructions(expr_str):
|
| 29 |
+
max_var = infer_max_var(expr_str)
|
| 30 |
+
|
| 31 |
+
variables = [f"x_{i}" for i in range(1, max_var + random.randint(1, (max_var) + 1))]
|
| 32 |
+
|
| 33 |
+
used_ops = extract_operators(expr_str)
|
| 34 |
+
extra_ops = list(set(ALL_OPERANDS) - set(used_ops))
|
| 35 |
+
added_ops = random.sample(extra_ops, random.randint(1, len(extra_ops))) if extra_ops else []
|
| 36 |
+
all_ops = sorted(set(used_ops + added_ops))
|
| 37 |
+
constants = ['C']
|
| 38 |
+
wrapped_expr = f"{expr_str}"
|
| 39 |
+
|
| 40 |
+
return {
|
| 41 |
+
"Simple_Instruct": f"Instruction: Generate a mathematical expression using variables {variables} and operands {all_ops} and {constants} as constant.\nExpression: {wrapped_expr}",
|
| 42 |
+
"Key_Value": f"Variables: {variables}\nOperands: {all_ops}\nConstant: {constants}\nExpression: {wrapped_expr}",
|
| 43 |
+
"Delimiter_Based": f"Input: Variables={variables}, Operands={all_ops}, Constant={constants}\nOutput: {wrapped_expr}",
|
| 44 |
+
"Minimalist": f"{variables} | {all_ops} | {constants} => {wrapped_expr}"
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
def generate_expression_instruction(expr_str):
|
| 48 |
+
max_var = infer_max_var(expr_str)
|
| 49 |
+
|
| 50 |
+
variables = [f"x_{i}" for i in range(1, max_var + random.randint(1, (max_var) + 1))]
|
| 51 |
+
|
| 52 |
+
used_ops = extract_operators(expr_str)
|
| 53 |
+
extra_ops = list(set(ALL_OPERANDS) - set(used_ops))
|
| 54 |
+
added_ops = random.sample(extra_ops, random.randint(1, len(extra_ops))) if extra_ops else []
|
| 55 |
+
all_ops = sorted(set(used_ops + added_ops))
|
| 56 |
+
constants = ['C']
|
| 57 |
+
wrapped_expr = f"{expr_str}"
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
#"instriction": f"{','.join(variables)}\n{', '.join(all_ops)}\n{', '.join(constants)}\n{wrapped_expr}"
|
| 61 |
+
"instriction": f"vars: {', '.join(variables)}\noper: {', '.join(all_ops)}\ncons: {', '.join(constants)}\nexpr: {wrapped_expr}"
|
| 62 |
+
}
|
| 63 |
+
#print(generate_expression_instruction("x_1 - (x_4 - C)*(x_3 + exp(C*x_2) + C)"))
|
scripts/data/data_cleaning.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sympy import sympify, Eq
|
| 5 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 6 |
+
from sympy.core.sympify import SympifyError
|
| 7 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
+
from sympy import simplify, sympify
|
| 10 |
+
from sympy.core.sympify import SympifyError
|
| 11 |
+
import swifter
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
from joblib import Parallel, delayed
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from tqdm.auto import tqdm
|
| 18 |
+
|
| 19 |
+
def apply_chunk(chunk, func):
|
| 20 |
+
"""Helper function to apply a function to a chunk of data."""
|
| 21 |
+
return chunk.apply(func)
|
| 22 |
+
|
| 23 |
+
def parallel_apply(series, func, n_jobs=None):
|
| 24 |
+
n_jobs = mp.cpu_count() if n_jobs is None else n_jobs
|
| 25 |
+
# Split into roughly equal chunks
|
| 26 |
+
chunks = np.array_split(series, n_jobs)
|
| 27 |
+
with mp.Pool(n_jobs) as pool:
|
| 28 |
+
# Use the helper function instead of a lambda
|
| 29 |
+
results = pool.starmap(apply_chunk, [(chunk, func) for chunk in chunks])
|
| 30 |
+
# Concatenate the resulting Series
|
| 31 |
+
return pd.concat(results)
|
| 32 |
+
|
| 33 |
+
def canonicalize_expr(expr, canonicalizer=simplify):
|
| 34 |
+
canon = canonicalizer(expr)
|
| 35 |
+
return (hash(canon), canon, expr)
|
| 36 |
+
|
| 37 |
+
def replace_constants(equation):
|
| 38 |
+
# Match positive/negative floats and integers not part of variable names
|
| 39 |
+
pattern = r'(?<![\w.])(?:[-+]?\d*\.\d+|\d+)(?![\w.])'
|
| 40 |
+
return re.sub(pattern, 'C', equation)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def augment_expression(equation, var_prefix='x', max_index=10, p=0.5):
|
| 44 |
+
"""
|
| 45 |
+
1. Replace all standalone numeric constants (including scientific notation) with 'C'.
|
| 46 |
+
2. For each occurrence of a variable (e.g., x_1), with probability p replace it
|
| 47 |
+
by a randomly chosen new variable x_1…x_max_index; otherwise leave as is.
|
| 48 |
+
"""
|
| 49 |
+
# Step 1: Replace constants (including scientific notation)
|
| 50 |
+
const_pattern = r'(?<![\w.])(?:[-+]?\d*\.\d+(?:[eE][-+]?\d+)?|\d+(?:[eE][-+]?\d+)?)(?![\w.])'
|
| 51 |
+
equation = re.sub(const_pattern, 'C', equation)
|
| 52 |
+
|
| 53 |
+
# Step 2: Replace variables with probability p
|
| 54 |
+
var_pattern = rf'\b{var_prefix}_\d+\b'
|
| 55 |
+
def repl(match):
|
| 56 |
+
if random.random() < p:
|
| 57 |
+
new_idx = random.randint(1, max_index)
|
| 58 |
+
return f"{var_prefix}_{new_idx}"
|
| 59 |
+
return match.group(0)
|
| 60 |
+
|
| 61 |
+
return re.sub(var_pattern, repl, equation)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def is_valid_equation(equation_str):
|
| 66 |
+
"""Verifica se uma string representa uma expressão matemática válida para o SymPy."""
|
| 67 |
+
if not isinstance(equation_str, str):
|
| 68 |
+
return False
|
| 69 |
+
if pd.isna(equation_str) or equation_str.strip() == '':
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
# Tenta analisar a expressão
|
| 74 |
+
expr = parse_expr(equation_str.strip())
|
| 75 |
+
return True
|
| 76 |
+
except (SympifyError, SyntaxError, ValueError, TypeError, AttributeError):
|
| 77 |
+
print(f"Erro ao analisar a equação: {equation_str}")
|
| 78 |
+
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
def canonical_form(expr_str):
|
| 82 |
+
"""
|
| 83 |
+
Recebe uma expressão como string e retorna sua forma canônica (simplificada).
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
#expr_str = sympify(expr_str)
|
| 87 |
+
canonica = simplify(expr_str).expand()
|
| 88 |
+
return str(canonica)
|
| 89 |
+
except SympifyError as e:
|
| 90 |
+
return f"Erro ao interpretar a expressão: {expr_str}"
|
scripts/data/data_processing.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import argparse
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import multiprocessing as mp
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
|
| 10 |
+
# Adjust import paths for custom modules
|
| 11 |
+
def setup_paths():
|
| 12 |
+
for folder in ["../scripts", "../classes"]:
|
| 13 |
+
path = os.path.abspath(os.path.join(folder))
|
| 14 |
+
if path not in sys.path:
|
| 15 |
+
sys.path.append(path)
|
| 16 |
+
|
| 17 |
+
setup_paths()
|
| 18 |
+
|
| 19 |
+
# Local imports after path setup
|
| 20 |
+
import scripts.data.data_cleaning as dc
|
| 21 |
+
from expression import Expression
|
| 22 |
+
from data.parallel_utils import augment_dataframe_parallel
|
| 23 |
+
|
| 24 |
+
def parallel_apply(series, func, n_jobs=None):
|
| 25 |
+
"""Apply a function to a pandas Series in parallel."""
|
| 26 |
+
def apply_chunk(chunk, func):
|
| 27 |
+
return chunk.apply(func)
|
| 28 |
+
|
| 29 |
+
n_jobs = mp.cpu_count() if n_jobs is None else n_jobs
|
| 30 |
+
chunks = np.array_split(series, n_jobs)
|
| 31 |
+
with mp.Pool(n_jobs) as pool:
|
| 32 |
+
results = pool.starmap(apply_chunk, [(chunk, func) for chunk in chunks])
|
| 33 |
+
return pd.concat(results)
|
| 34 |
+
|
| 35 |
+
def process_chunk(chunk):
|
| 36 |
+
"""Clean and transform a single data chunk."""
|
| 37 |
+
chunk = chunk[['eq']]
|
| 38 |
+
chunk = chunk[~chunk['eq'].str.contains('ERROR_simplify')]
|
| 39 |
+
chunk['eq'] = parallel_apply(chunk['eq'], dc.augment_expression)
|
| 40 |
+
chunk.rename(columns={'eq': 'infix_expr'}, inplace=True)
|
| 41 |
+
chunk['prefix_expr'] = parallel_apply(chunk['infix_expr'], Expression.infix_to_prefix)
|
| 42 |
+
return chunk
|
| 43 |
+
|
| 44 |
+
def process_file(file_path, chunk_size=100000):
|
| 45 |
+
"""Process the CSV file in chunks."""
|
| 46 |
+
processed_chunks = []
|
| 47 |
+
total_rows = sum(1 for _ in open(file_path)) - 1
|
| 48 |
+
total_chunks = (total_rows // chunk_size) + 1
|
| 49 |
+
|
| 50 |
+
with tqdm(total=total_chunks, desc="Processing chunks") as pbar:
|
| 51 |
+
for chunk in pd.read_csv(file_path, chunksize=chunk_size):
|
| 52 |
+
processed_chunk = process_chunk(chunk)
|
| 53 |
+
processed_chunks.append(processed_chunk)
|
| 54 |
+
pbar.update(1)
|
| 55 |
+
|
| 56 |
+
return pd.concat(processed_chunks, ignore_index=True)
|
| 57 |
+
|
| 58 |
+
def augment_df(df):
|
| 59 |
+
"""Apply augmentation to both infix and prefix expressions."""
|
| 60 |
+
df = augment_dataframe_parallel(df, expression_col="infix_expr", n_jobs=4)
|
| 61 |
+
df.rename(columns={
|
| 62 |
+
'simple': 'i_simple',
|
| 63 |
+
'key_value': 'i_key_value',
|
| 64 |
+
'delimiter': 'i_delimiter',
|
| 65 |
+
'minimalist': 'i_minimalist'
|
| 66 |
+
}, inplace=True)
|
| 67 |
+
|
| 68 |
+
df = augment_dataframe_parallel(df, expression_col="prefix_expr", n_jobs=4)
|
| 69 |
+
df.rename(columns={
|
| 70 |
+
'simple': 'p_simple',
|
| 71 |
+
'key_value': 'p_key_value',
|
| 72 |
+
'delimiter': 'p_delimiter',
|
| 73 |
+
'minimalist': 'p_minimalist'
|
| 74 |
+
}, inplace=True)
|
| 75 |
+
|
| 76 |
+
return df
|
| 77 |
+
|
| 78 |
+
def split_and_save(df, base_file_path):
|
| 79 |
+
"""Split into train/val/test and save them."""
|
| 80 |
+
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
|
| 81 |
+
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
|
| 82 |
+
|
| 83 |
+
file = os.path.basename(base_file_path)
|
| 84 |
+
base_dir = f'../data/processed/{file.replace(".csv", "")}'
|
| 85 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
train_df.to_csv(os.path.join(base_dir, f"train_{file}"), index=False)
|
| 88 |
+
val_df.to_csv(os.path.join(base_dir, f"val_{file}"), index=False)
|
| 89 |
+
test_df.to_csv(os.path.join(base_dir, f"test_{file}"), index=False)
|
| 90 |
+
df.to_csv(os.path.join(base_dir, file), index=False)
|
| 91 |
+
|
| 92 |
+
def main():
|
| 93 |
+
parser = argparse.ArgumentParser(description="Process a raw equation CSV file.")
|
| 94 |
+
parser.add_argument("file_path", type=str, help="Path to the raw CSV file to process.", default="../data/raw/13k.csv")
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
|
| 97 |
+
file_path = args.file_path
|
| 98 |
+
if not os.path.exists(file_path):
|
| 99 |
+
print(f"Error: File not found at {file_path}")
|
| 100 |
+
sys.exit(1)
|
| 101 |
+
|
| 102 |
+
df_processed = process_file(file_path)
|
| 103 |
+
df_processed.drop_duplicates(subset=['infix_expr'], inplace=True)
|
| 104 |
+
df_augmented = augment_df(df_processed)
|
| 105 |
+
split_and_save(df_augmented, file_path)
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
main()
|
scripts/data/parallel_utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# parallel_utils.py
|
| 2 |
+
|
| 3 |
+
from joblib import Parallel, delayed
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from .data_augmentation import generate_expression_instructions, generate_expression_instruction
|
| 6 |
+
|
| 7 |
+
def augment_dataframe_parallel(df, expression_col="expression", n_jobs=-1):
|
| 8 |
+
"""
|
| 9 |
+
Parallelized augmentation of a DataFrame with math expressions.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
df (pd.DataFrame): DataFrame with a column of expressions.
|
| 13 |
+
expression_col (str): Name of the column with expressions.
|
| 14 |
+
n_jobs (int): Number of parallel workers (-1 = all cores).
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
pd.DataFrame: Original DataFrame with new instruction columns.
|
| 18 |
+
"""
|
| 19 |
+
expressions = df[expression_col].tolist()
|
| 20 |
+
augmented_data = Parallel(n_jobs=n_jobs)(
|
| 21 |
+
delayed(generate_expression_instruction)(expr) for expr in expressions
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
df_aug = df.copy()
|
| 25 |
+
df_aug["instruction"] = [item["instriction"] for item in augmented_data]
|
| 26 |
+
#df_aug["simple"] = [item["Simple_Instruct"] for item in augmented_data]
|
| 27 |
+
#df_aug["key_value"] = [item["Key_Value"] for item in augmented_data]
|
| 28 |
+
#df_aug["delimiter"] = [item["Delimiter_Based"] for item in augmented_data]
|
| 29 |
+
#df_aug["minimalist"] = [item["Minimalist"] for item in augmented_data]
|
| 30 |
+
|
| 31 |
+
return df_aug
|
scripts/data/prepare_experiment_data.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Data preparation script for training experiments.
|
| 4 |
+
|
| 5 |
+
Prepares data in two formats:
|
| 6 |
+
- EXP-A: JSON structured format
|
| 7 |
+
- EXP-B: EOS token format (GPT-2's <|endoftext|>)
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python scripts/data/prepare_experiment_data.py \
|
| 11 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 12 |
+
--data_dir 700K \
|
| 13 |
+
--data_column i_prompt_n \
|
| 14 |
+
--output_base_dir ./data/experiments
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import re
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
| 26 |
+
import pandas as pd
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
level=logging.INFO,
|
| 30 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 31 |
+
)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def parse_original_format(text: str) -> Optional[Dict]:
|
| 36 |
+
"""
|
| 37 |
+
Parse the original format into components.
|
| 38 |
+
|
| 39 |
+
Original format:
|
| 40 |
+
vars: x_1, x_2
|
| 41 |
+
oper: *, +, sin
|
| 42 |
+
cons: C
|
| 43 |
+
expr: C*sin(x_1) + x_2
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Dictionary with vars, ops, cons, expr or None if parsing fails
|
| 47 |
+
"""
|
| 48 |
+
result = {
|
| 49 |
+
'vars': [],
|
| 50 |
+
'ops': [],
|
| 51 |
+
'cons': None,
|
| 52 |
+
'expr': None,
|
| 53 |
+
'raw_text': text
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
lines = text.strip().split('\n')
|
| 57 |
+
|
| 58 |
+
for line in lines:
|
| 59 |
+
line = line.strip()
|
| 60 |
+
if not line:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
if line.startswith('vars:') or line.startswith('Variables:'):
|
| 64 |
+
# Extract variables
|
| 65 |
+
var_part = line.split(':', 1)[1].strip()
|
| 66 |
+
vars_list = [v.strip() for v in var_part.split(',') if v.strip()]
|
| 67 |
+
result['vars'] = vars_list
|
| 68 |
+
|
| 69 |
+
elif line.startswith('oper:') or line.startswith('Operators:'):
|
| 70 |
+
# Extract operators
|
| 71 |
+
op_part = line.split(':', 1)[1].strip()
|
| 72 |
+
ops_list = [o.strip() for o in op_part.split(',') if o.strip()]
|
| 73 |
+
result['ops'] = ops_list
|
| 74 |
+
|
| 75 |
+
elif line.startswith('cons:') or line.startswith('Constants:'):
|
| 76 |
+
# Extract constants
|
| 77 |
+
cons_part = line.split(':', 1)[1].strip()
|
| 78 |
+
result['cons'] = cons_part if cons_part else None
|
| 79 |
+
|
| 80 |
+
elif line.startswith('expr:'):
|
| 81 |
+
# Extract expression - everything after 'expr:'
|
| 82 |
+
expr_part = line.split(':', 1)[1].strip()
|
| 83 |
+
# Clean expression: remove any markers or trailing content
|
| 84 |
+
expr_part = expr_part.split('<|')[0].strip() # Remove any existing markers
|
| 85 |
+
expr_part = expr_part.split('\n')[0].strip() # Remove newlines
|
| 86 |
+
result['expr'] = expr_part
|
| 87 |
+
|
| 88 |
+
# Validate we got the essential parts
|
| 89 |
+
if not result['expr']:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def convert_to_json_format(parsed: Dict) -> str:
|
| 96 |
+
"""
|
| 97 |
+
Convert parsed data to JSON format (EXP-A).
|
| 98 |
+
|
| 99 |
+
Output format:
|
| 100 |
+
{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "C*sin(x_1) + x_2"}
|
| 101 |
+
"""
|
| 102 |
+
json_obj = {
|
| 103 |
+
'vars': parsed['vars'],
|
| 104 |
+
'ops': parsed['ops'],
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
if parsed['cons']:
|
| 108 |
+
json_obj['cons'] = parsed['cons']
|
| 109 |
+
|
| 110 |
+
json_obj['expr'] = parsed['expr']
|
| 111 |
+
|
| 112 |
+
return json.dumps(json_obj, ensure_ascii=False)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def convert_to_eos_format(parsed: Dict) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Convert parsed data to EOS token format (EXP-B).
|
| 118 |
+
|
| 119 |
+
Output format:
|
| 120 |
+
vars: x_1, x_2
|
| 121 |
+
oper: *, +, sin
|
| 122 |
+
cons: C
|
| 123 |
+
expr: C*sin(x_1) + x_2<|endoftext|>
|
| 124 |
+
"""
|
| 125 |
+
lines = []
|
| 126 |
+
|
| 127 |
+
if parsed['vars']:
|
| 128 |
+
lines.append(f"vars: {', '.join(parsed['vars'])}")
|
| 129 |
+
|
| 130 |
+
if parsed['ops']:
|
| 131 |
+
lines.append(f"oper: {', '.join(parsed['ops'])}")
|
| 132 |
+
|
| 133 |
+
if parsed['cons']:
|
| 134 |
+
lines.append(f"cons: {parsed['cons']}")
|
| 135 |
+
|
| 136 |
+
# Add expression with EOS token
|
| 137 |
+
lines.append(f"expr: {parsed['expr']}<|endoftext|>")
|
| 138 |
+
|
| 139 |
+
return '\n'.join(lines)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def process_example_json(example: Dict) -> Dict:
|
| 143 |
+
"""Process a single example into JSON format."""
|
| 144 |
+
text = example['text']
|
| 145 |
+
parsed = parse_original_format(text)
|
| 146 |
+
|
| 147 |
+
if parsed is None:
|
| 148 |
+
logger.warning(f"Failed to parse: {text[:100]}...")
|
| 149 |
+
return {'text': '', 'valid': False}
|
| 150 |
+
|
| 151 |
+
json_text = convert_to_json_format(parsed)
|
| 152 |
+
return {'text': json_text, 'valid': True}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def process_example_eos(example: Dict) -> Dict:
|
| 156 |
+
"""Process a single example into EOS format."""
|
| 157 |
+
text = example['text']
|
| 158 |
+
parsed = parse_original_format(text)
|
| 159 |
+
|
| 160 |
+
if parsed is None:
|
| 161 |
+
logger.warning(f"Failed to parse: {text[:100]}...")
|
| 162 |
+
return {'text': '', 'valid': False}
|
| 163 |
+
|
| 164 |
+
eos_text = convert_to_eos_format(parsed)
|
| 165 |
+
return {'text': eos_text, 'valid': True}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def validate_json_format(text: str) -> bool:
|
| 169 |
+
"""Validate JSON format is correct."""
|
| 170 |
+
try:
|
| 171 |
+
obj = json.loads(text)
|
| 172 |
+
return 'expr' in obj and 'vars' in obj and 'ops' in obj
|
| 173 |
+
except:
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def validate_eos_format(text: str) -> bool:
|
| 178 |
+
"""Validate EOS format is correct."""
|
| 179 |
+
return '<|endoftext|>' in text and 'expr:' in text
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def process_dataset(
|
| 183 |
+
dataset_repo_id: str,
|
| 184 |
+
data_dir: str,
|
| 185 |
+
data_column: str,
|
| 186 |
+
output_base_dir: Path,
|
| 187 |
+
max_samples: Optional[int] = None
|
| 188 |
+
) -> Dict:
|
| 189 |
+
"""
|
| 190 |
+
Process the dataset into both formats.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
dataset_repo_id: HuggingFace dataset repository ID
|
| 194 |
+
data_dir: Subdirectory within the dataset
|
| 195 |
+
data_column: Column containing the text data
|
| 196 |
+
output_base_dir: Base directory for output
|
| 197 |
+
max_samples: Optional limit on number of samples (for testing)
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Dictionary with processing statistics
|
| 201 |
+
"""
|
| 202 |
+
logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...")
|
| 203 |
+
|
| 204 |
+
# Load dataset
|
| 205 |
+
dataset = load_dataset(
|
| 206 |
+
dataset_repo_id,
|
| 207 |
+
data_dir=data_dir,
|
| 208 |
+
split=None
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if not isinstance(dataset, dict):
|
| 212 |
+
dataset = {'train': dataset}
|
| 213 |
+
|
| 214 |
+
logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}")
|
| 215 |
+
|
| 216 |
+
# Show sample
|
| 217 |
+
if 'train' in dataset:
|
| 218 |
+
sample = dataset['train'][0][data_column]
|
| 219 |
+
logger.info(f"\nSample ORIGINAL format:\n{sample}\n")
|
| 220 |
+
|
| 221 |
+
# Create output directories
|
| 222 |
+
output_json = output_base_dir / 'exp_a_json'
|
| 223 |
+
output_eos = output_base_dir / 'exp_b_eos'
|
| 224 |
+
output_json.mkdir(parents=True, exist_ok=True)
|
| 225 |
+
output_eos.mkdir(parents=True, exist_ok=True)
|
| 226 |
+
|
| 227 |
+
statistics = {
|
| 228 |
+
'total': 0,
|
| 229 |
+
'json_valid': 0,
|
| 230 |
+
'eos_valid': 0,
|
| 231 |
+
'json_invalid': 0,
|
| 232 |
+
'eos_invalid': 0,
|
| 233 |
+
'splits': {}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
for split_name, split_data in dataset.items():
|
| 237 |
+
logger.info(f"\n{'='*60}")
|
| 238 |
+
logger.info(f"Processing {split_name} split ({len(split_data)} examples)")
|
| 239 |
+
logger.info('='*60)
|
| 240 |
+
|
| 241 |
+
# Rename column if needed
|
| 242 |
+
if data_column != 'text':
|
| 243 |
+
split_data = split_data.rename_column(data_column, 'text')
|
| 244 |
+
|
| 245 |
+
# Limit samples if specified
|
| 246 |
+
if max_samples and len(split_data) > max_samples:
|
| 247 |
+
logger.info(f"Limiting to {max_samples} samples for testing")
|
| 248 |
+
split_data = split_data.select(range(max_samples))
|
| 249 |
+
|
| 250 |
+
statistics['total'] += len(split_data)
|
| 251 |
+
|
| 252 |
+
# Process to JSON format
|
| 253 |
+
logger.info("\nConverting to JSON format (EXP-A)...")
|
| 254 |
+
json_data = split_data.map(
|
| 255 |
+
process_example_json,
|
| 256 |
+
desc=f"JSON format ({split_name})"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Filter valid examples
|
| 260 |
+
json_valid = json_data.filter(lambda x: x['valid'])
|
| 261 |
+
json_invalid_count = len(json_data) - len(json_valid)
|
| 262 |
+
|
| 263 |
+
logger.info(f"JSON format: {len(json_valid)}/{len(json_data)} valid")
|
| 264 |
+
|
| 265 |
+
if len(json_valid) > 0:
|
| 266 |
+
logger.info(f"\nSample JSON format:\n{json_valid[0]['text']}\n")
|
| 267 |
+
|
| 268 |
+
# Process to EOS format
|
| 269 |
+
logger.info("\nConverting to EOS format (EXP-B)...")
|
| 270 |
+
eos_data = split_data.map(
|
| 271 |
+
process_example_eos,
|
| 272 |
+
desc=f"EOS format ({split_name})"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Filter valid examples
|
| 276 |
+
eos_valid = eos_data.filter(lambda x: x['valid'])
|
| 277 |
+
eos_invalid_count = len(eos_data) - len(eos_valid)
|
| 278 |
+
|
| 279 |
+
logger.info(f"EOS format: {len(eos_valid)}/{len(eos_data)} valid")
|
| 280 |
+
|
| 281 |
+
if len(eos_valid) > 0:
|
| 282 |
+
logger.info(f"\nSample EOS format:\n{eos_valid[0]['text']}\n")
|
| 283 |
+
|
| 284 |
+
# Update statistics
|
| 285 |
+
statistics['json_valid'] += len(json_valid)
|
| 286 |
+
statistics['json_invalid'] += json_invalid_count
|
| 287 |
+
statistics['eos_valid'] += len(eos_valid)
|
| 288 |
+
statistics['eos_invalid'] += eos_invalid_count
|
| 289 |
+
statistics['splits'][split_name] = {
|
| 290 |
+
'total': len(split_data),
|
| 291 |
+
'json_valid': len(json_valid),
|
| 292 |
+
'eos_valid': len(eos_valid)
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
# Save JSON format
|
| 296 |
+
json_df = pd.DataFrame({'text': [ex['text'] for ex in json_valid]})
|
| 297 |
+
json_file = output_json / f'{split_name}.csv'
|
| 298 |
+
json_df.to_csv(json_file, index=False)
|
| 299 |
+
logger.info(f"Saved JSON: {json_file} ({len(json_df)} examples)")
|
| 300 |
+
|
| 301 |
+
# Save EOS format
|
| 302 |
+
eos_df = pd.DataFrame({'text': [ex['text'] for ex in eos_valid]})
|
| 303 |
+
eos_file = output_eos / f'{split_name}.csv'
|
| 304 |
+
eos_df.to_csv(eos_file, index=False)
|
| 305 |
+
logger.info(f"Saved EOS: {eos_file} ({len(eos_df)} examples)")
|
| 306 |
+
|
| 307 |
+
return statistics
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def validate_output_files(output_base_dir: Path) -> Dict:
|
| 311 |
+
"""
|
| 312 |
+
Validate the generated output files.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Validation results dictionary
|
| 316 |
+
"""
|
| 317 |
+
logger.info("\n" + "="*60)
|
| 318 |
+
logger.info("VALIDATION OF OUTPUT FILES")
|
| 319 |
+
logger.info("="*60)
|
| 320 |
+
|
| 321 |
+
results = {
|
| 322 |
+
'exp_a_json': {'valid': True, 'issues': []},
|
| 323 |
+
'exp_b_eos': {'valid': True, 'issues': []}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
# Validate JSON format (EXP-A)
|
| 327 |
+
json_dir = output_base_dir / 'exp_a_json'
|
| 328 |
+
for csv_file in json_dir.glob('*.csv'):
|
| 329 |
+
logger.info(f"\nValidating {csv_file.name}...")
|
| 330 |
+
df = pd.read_csv(csv_file)
|
| 331 |
+
|
| 332 |
+
valid_count = 0
|
| 333 |
+
invalid_samples = []
|
| 334 |
+
|
| 335 |
+
for idx, row in df.iterrows():
|
| 336 |
+
text = row['text']
|
| 337 |
+
if validate_json_format(text):
|
| 338 |
+
valid_count += 1
|
| 339 |
+
else:
|
| 340 |
+
if len(invalid_samples) < 3:
|
| 341 |
+
invalid_samples.append(text[:100])
|
| 342 |
+
|
| 343 |
+
rate = valid_count / len(df) * 100 if len(df) > 0 else 0
|
| 344 |
+
logger.info(f" Valid: {valid_count}/{len(df)} ({rate:.1f}%)")
|
| 345 |
+
|
| 346 |
+
if invalid_samples:
|
| 347 |
+
results['exp_a_json']['valid'] = False
|
| 348 |
+
results['exp_a_json']['issues'].extend(invalid_samples)
|
| 349 |
+
|
| 350 |
+
# Validate EOS format (EXP-B)
|
| 351 |
+
eos_dir = output_base_dir / 'exp_b_eos'
|
| 352 |
+
for csv_file in eos_dir.glob('*.csv'):
|
| 353 |
+
logger.info(f"\nValidating {csv_file.name}...")
|
| 354 |
+
df = pd.read_csv(csv_file)
|
| 355 |
+
|
| 356 |
+
valid_count = 0
|
| 357 |
+
invalid_samples = []
|
| 358 |
+
|
| 359 |
+
for idx, row in df.iterrows():
|
| 360 |
+
text = row['text']
|
| 361 |
+
if validate_eos_format(text):
|
| 362 |
+
valid_count += 1
|
| 363 |
+
else:
|
| 364 |
+
if len(invalid_samples) < 3:
|
| 365 |
+
invalid_samples.append(text[:100])
|
| 366 |
+
|
| 367 |
+
rate = valid_count / len(df) * 100 if len(df) > 0 else 0
|
| 368 |
+
logger.info(f" Valid: {valid_count}/{len(df)} ({rate:.1f}%)")
|
| 369 |
+
|
| 370 |
+
if invalid_samples:
|
| 371 |
+
results['exp_b_eos']['valid'] = False
|
| 372 |
+
results['exp_b_eos']['issues'].extend(invalid_samples)
|
| 373 |
+
|
| 374 |
+
return results
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def print_final_report(statistics: Dict, validation: Dict):
|
| 378 |
+
"""Print final processing report."""
|
| 379 |
+
logger.info("\n" + "="*60)
|
| 380 |
+
logger.info("FINAL REPORT")
|
| 381 |
+
logger.info("="*60)
|
| 382 |
+
|
| 383 |
+
logger.info(f"\nTotal examples processed: {statistics['total']}")
|
| 384 |
+
|
| 385 |
+
logger.info("\nEXP-A (JSON Format):")
|
| 386 |
+
logger.info(f" Valid: {statistics['json_valid']}")
|
| 387 |
+
logger.info(f" Invalid: {statistics['json_invalid']}")
|
| 388 |
+
json_rate = statistics['json_valid'] / statistics['total'] * 100 if statistics['total'] > 0 else 0
|
| 389 |
+
logger.info(f" Success rate: {json_rate:.1f}%")
|
| 390 |
+
logger.info(f" Validation: {'PASS' if validation['exp_a_json']['valid'] else 'FAIL'}")
|
| 391 |
+
|
| 392 |
+
logger.info("\nEXP-B (EOS Format):")
|
| 393 |
+
logger.info(f" Valid: {statistics['eos_valid']}")
|
| 394 |
+
logger.info(f" Invalid: {statistics['eos_invalid']}")
|
| 395 |
+
eos_rate = statistics['eos_valid'] / statistics['total'] * 100 if statistics['total'] > 0 else 0
|
| 396 |
+
logger.info(f" Success rate: {eos_rate:.1f}%")
|
| 397 |
+
logger.info(f" Validation: {'PASS' if validation['exp_b_eos']['valid'] else 'FAIL'}")
|
| 398 |
+
|
| 399 |
+
logger.info("\nPer-split breakdown:")
|
| 400 |
+
for split_name, split_stats in statistics['splits'].items():
|
| 401 |
+
logger.info(f"\n {split_name.upper()}:")
|
| 402 |
+
logger.info(f" Total: {split_stats['total']}")
|
| 403 |
+
logger.info(f" JSON valid: {split_stats['json_valid']}")
|
| 404 |
+
logger.info(f" EOS valid: {split_stats['eos_valid']}")
|
| 405 |
+
|
| 406 |
+
logger.info("\n" + "="*60)
|
| 407 |
+
|
| 408 |
+
all_valid = validation['exp_a_json']['valid'] and validation['exp_b_eos']['valid']
|
| 409 |
+
if all_valid:
|
| 410 |
+
logger.info("STATUS: ALL VALIDATIONS PASSED")
|
| 411 |
+
else:
|
| 412 |
+
logger.info("STATUS: SOME VALIDATIONS FAILED")
|
| 413 |
+
|
| 414 |
+
logger.info("="*60)
|
| 415 |
+
|
| 416 |
+
return all_valid
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def main():
|
| 420 |
+
parser = argparse.ArgumentParser(
|
| 421 |
+
description="Prepare experiment data in JSON and EOS formats"
|
| 422 |
+
)
|
| 423 |
+
parser.add_argument(
|
| 424 |
+
"--dataset_repo_id",
|
| 425 |
+
type=str,
|
| 426 |
+
default="augustocsc/sintetico_natural",
|
| 427 |
+
help="HuggingFace dataset repository ID"
|
| 428 |
+
)
|
| 429 |
+
parser.add_argument(
|
| 430 |
+
"--data_dir",
|
| 431 |
+
type=str,
|
| 432 |
+
default="700K",
|
| 433 |
+
help="Subdirectory within the dataset"
|
| 434 |
+
)
|
| 435 |
+
parser.add_argument(
|
| 436 |
+
"--data_column",
|
| 437 |
+
type=str,
|
| 438 |
+
default="i_prompt_n",
|
| 439 |
+
help="Column containing text data"
|
| 440 |
+
)
|
| 441 |
+
parser.add_argument(
|
| 442 |
+
"--output_base_dir",
|
| 443 |
+
type=str,
|
| 444 |
+
default="./data/experiments",
|
| 445 |
+
help="Base directory for output"
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--max_samples",
|
| 449 |
+
type=int,
|
| 450 |
+
default=None,
|
| 451 |
+
help="Maximum samples per split (for testing)"
|
| 452 |
+
)
|
| 453 |
+
parser.add_argument(
|
| 454 |
+
"--skip_validation",
|
| 455 |
+
action="store_true",
|
| 456 |
+
help="Skip output file validation"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
args = parser.parse_args()
|
| 460 |
+
|
| 461 |
+
output_base_dir = Path(args.output_base_dir)
|
| 462 |
+
|
| 463 |
+
logger.info("="*60)
|
| 464 |
+
logger.info("EXPERIMENT DATA PREPARATION")
|
| 465 |
+
logger.info("="*60)
|
| 466 |
+
logger.info(f"Dataset: {args.dataset_repo_id}/{args.data_dir}")
|
| 467 |
+
logger.info(f"Column: {args.data_column}")
|
| 468 |
+
logger.info(f"Output: {output_base_dir}")
|
| 469 |
+
if args.max_samples:
|
| 470 |
+
logger.info(f"Max samples: {args.max_samples}")
|
| 471 |
+
logger.info("="*60)
|
| 472 |
+
|
| 473 |
+
try:
|
| 474 |
+
# Process dataset
|
| 475 |
+
statistics = process_dataset(
|
| 476 |
+
dataset_repo_id=args.dataset_repo_id,
|
| 477 |
+
data_dir=args.data_dir,
|
| 478 |
+
data_column=args.data_column,
|
| 479 |
+
output_base_dir=output_base_dir,
|
| 480 |
+
max_samples=args.max_samples
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Validate output
|
| 484 |
+
if not args.skip_validation:
|
| 485 |
+
validation = validate_output_files(output_base_dir)
|
| 486 |
+
else:
|
| 487 |
+
validation = {
|
| 488 |
+
'exp_a_json': {'valid': True, 'issues': []},
|
| 489 |
+
'exp_b_eos': {'valid': True, 'issues': []}
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
# Print report
|
| 493 |
+
all_valid = print_final_report(statistics, validation)
|
| 494 |
+
|
| 495 |
+
if all_valid:
|
| 496 |
+
logger.info("\nData preparation completed successfully!")
|
| 497 |
+
logger.info(f"\nOutput directories:")
|
| 498 |
+
logger.info(f" EXP-A (JSON): {output_base_dir / 'exp_a_json'}")
|
| 499 |
+
logger.info(f" EXP-B (EOS): {output_base_dir / 'exp_b_eos'}")
|
| 500 |
+
sys.exit(0)
|
| 501 |
+
else:
|
| 502 |
+
logger.error("\nData preparation completed with validation errors!")
|
| 503 |
+
sys.exit(1)
|
| 504 |
+
|
| 505 |
+
except Exception as e:
|
| 506 |
+
logger.error(f"\nFailed to prepare data: {e}")
|
| 507 |
+
import traceback
|
| 508 |
+
traceback.print_exc()
|
| 509 |
+
sys.exit(1)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
main()
|
scripts/data/prepare_training_data_fixed.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data preparation script that adds proper <|endofex|> markers to training data.
|
| 3 |
+
|
| 4 |
+
This script processes the existing dataset and wraps expressions with end-of-expression
|
| 5 |
+
markers so the model learns to stop generation correctly.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/data/prepare_training_data_fixed.py \
|
| 9 |
+
--dataset_repo_id augustocsc/sintetico_natural \
|
| 10 |
+
--data_dir 700K \
|
| 11 |
+
--data_column i_prompt_n \
|
| 12 |
+
--output_dir ./data/processed/700K_fixed \
|
| 13 |
+
--validate
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, Tuple
|
| 22 |
+
|
| 23 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
| 24 |
+
import pandas as pd
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def add_end_markers(example: Dict) -> Dict:
|
| 34 |
+
"""
|
| 35 |
+
Add end-of-expression markers to training data.
|
| 36 |
+
|
| 37 |
+
This function:
|
| 38 |
+
1. Locates the expression in the text (after 'expr:')
|
| 39 |
+
2. Finds the natural end boundary (before 'vars:', newlines, etc.)
|
| 40 |
+
3. Inserts <|endofex|> marker at the end
|
| 41 |
+
4. Preserves any remaining content after the marker
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
example: Dictionary containing 'text' field with training data
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Dictionary with modified 'text' field containing end markers
|
| 48 |
+
"""
|
| 49 |
+
text = example['text']
|
| 50 |
+
|
| 51 |
+
# Check if expression part exists
|
| 52 |
+
if 'expr:' not in text:
|
| 53 |
+
logger.warning(f"No 'expr:' found in text: {text[:100]}...")
|
| 54 |
+
return {'text': text}
|
| 55 |
+
|
| 56 |
+
# Split at expr: and add marker after expression
|
| 57 |
+
parts = text.split('expr:', 1)
|
| 58 |
+
if len(parts) != 2:
|
| 59 |
+
logger.warning(f"Unexpected format in text: {text[:100]}...")
|
| 60 |
+
return {'text': text}
|
| 61 |
+
|
| 62 |
+
prefix = parts[0]
|
| 63 |
+
expression_part = parts[1]
|
| 64 |
+
|
| 65 |
+
# Check if marker already exists
|
| 66 |
+
if '<|endofex|>' in expression_part:
|
| 67 |
+
logger.debug("Marker already present, skipping")
|
| 68 |
+
return {'text': text}
|
| 69 |
+
|
| 70 |
+
# Find natural end of expression (before vars:, newline, etc)
|
| 71 |
+
end_idx = len(expression_part)
|
| 72 |
+
boundaries = ['\nvars:', '\nVariables:', '\n\n', '\nvar:', '\nVariable:']
|
| 73 |
+
|
| 74 |
+
for boundary in boundaries:
|
| 75 |
+
idx = expression_part.find(boundary)
|
| 76 |
+
if idx != -1 and idx < end_idx:
|
| 77 |
+
end_idx = idx
|
| 78 |
+
|
| 79 |
+
# Insert marker
|
| 80 |
+
clean_expr = expression_part[:end_idx].strip()
|
| 81 |
+
remaining = expression_part[end_idx:]
|
| 82 |
+
|
| 83 |
+
# Reconstruct text with marker
|
| 84 |
+
new_text = f"{prefix}expr: {clean_expr}<|endofex|>{remaining}"
|
| 85 |
+
|
| 86 |
+
return {'text': new_text}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def validate_markers(example: Dict) -> Dict:
|
| 90 |
+
"""
|
| 91 |
+
Validate that markers are properly present in the text.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
example: Dictionary containing 'text' field
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Dictionary with validation metadata
|
| 98 |
+
"""
|
| 99 |
+
text = example['text']
|
| 100 |
+
start_count = text.count('<|startofex|>')
|
| 101 |
+
end_count = text.count('<|endofex|>')
|
| 102 |
+
|
| 103 |
+
# Valid if we have at least one end marker
|
| 104 |
+
# (start marker is optional depending on format)
|
| 105 |
+
valid = end_count > 0
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
'valid': valid,
|
| 109 |
+
'start_count': start_count,
|
| 110 |
+
'end_count': end_count,
|
| 111 |
+
'text': text
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def process_dataset(
|
| 116 |
+
dataset_repo_id: str,
|
| 117 |
+
data_dir: str,
|
| 118 |
+
data_column: str,
|
| 119 |
+
output_dir: Path,
|
| 120 |
+
validate: bool = True
|
| 121 |
+
) -> Tuple[DatasetDict, Dict]:
|
| 122 |
+
"""
|
| 123 |
+
Process the dataset by adding end markers to all splits.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
dataset_repo_id: HuggingFace dataset repository ID
|
| 127 |
+
data_dir: Subdirectory within the dataset (e.g., '700K')
|
| 128 |
+
data_column: Column to use for training data
|
| 129 |
+
output_dir: Directory to save processed dataset
|
| 130 |
+
validate: Whether to run validation after processing
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Tuple of (processed_dataset, statistics)
|
| 134 |
+
"""
|
| 135 |
+
logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...")
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Load dataset from HuggingFace Hub
|
| 139 |
+
dataset = load_dataset(
|
| 140 |
+
dataset_repo_id,
|
| 141 |
+
data_dir=data_dir,
|
| 142 |
+
split=None # Load all splits
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if not isinstance(dataset, dict):
|
| 146 |
+
# If single split, convert to dict
|
| 147 |
+
dataset = {'train': dataset}
|
| 148 |
+
|
| 149 |
+
logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}")
|
| 150 |
+
|
| 151 |
+
# Show sample before processing
|
| 152 |
+
if 'train' in dataset and len(dataset['train']) > 0:
|
| 153 |
+
logger.info(f"\nSample BEFORE processing:")
|
| 154 |
+
logger.info(f"{dataset['train'][0][data_column][:200]}...")
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Failed to load dataset: {e}")
|
| 158 |
+
raise
|
| 159 |
+
|
| 160 |
+
# Process each split
|
| 161 |
+
processed_dataset = {}
|
| 162 |
+
statistics = {
|
| 163 |
+
'total_examples': 0,
|
| 164 |
+
'processed_examples': 0,
|
| 165 |
+
'already_marked': 0,
|
| 166 |
+
'splits': {}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
for split_name, split_data in dataset.items():
|
| 170 |
+
logger.info(f"\nProcessing {split_name} split ({len(split_data)} examples)...")
|
| 171 |
+
|
| 172 |
+
# Rename column to 'text' if needed
|
| 173 |
+
if data_column != 'text':
|
| 174 |
+
split_data = split_data.rename_column(data_column, 'text')
|
| 175 |
+
|
| 176 |
+
# Count examples that already have markers
|
| 177 |
+
already_marked = sum(1 for ex in split_data if '<|endofex|>' in ex['text'])
|
| 178 |
+
statistics['already_marked'] += already_marked
|
| 179 |
+
|
| 180 |
+
if already_marked > 0:
|
| 181 |
+
logger.info(f"Found {already_marked} examples already with markers")
|
| 182 |
+
|
| 183 |
+
# Apply marker addition
|
| 184 |
+
processed_split = split_data.map(
|
| 185 |
+
add_end_markers,
|
| 186 |
+
desc=f"Adding markers to {split_name}"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
processed_dataset[split_name] = processed_split
|
| 190 |
+
|
| 191 |
+
# Update statistics
|
| 192 |
+
split_stats = {
|
| 193 |
+
'total': len(split_data),
|
| 194 |
+
'processed': len(processed_split),
|
| 195 |
+
'already_marked': already_marked
|
| 196 |
+
}
|
| 197 |
+
statistics['splits'][split_name] = split_stats
|
| 198 |
+
statistics['total_examples'] += len(split_data)
|
| 199 |
+
statistics['processed_examples'] += len(processed_split)
|
| 200 |
+
|
| 201 |
+
# Show sample after processing
|
| 202 |
+
if len(processed_split) > 0:
|
| 203 |
+
logger.info(f"\nSample AFTER processing:")
|
| 204 |
+
logger.info(f"{processed_split[0]['text'][:200]}...")
|
| 205 |
+
|
| 206 |
+
# Validate if requested
|
| 207 |
+
if validate:
|
| 208 |
+
logger.info("\n" + "="*60)
|
| 209 |
+
logger.info("VALIDATION")
|
| 210 |
+
logger.info("="*60)
|
| 211 |
+
|
| 212 |
+
for split_name, split_data in processed_dataset.items():
|
| 213 |
+
logger.info(f"\nValidating {split_name} split...")
|
| 214 |
+
|
| 215 |
+
# Apply validation
|
| 216 |
+
validated = split_data.map(validate_markers)
|
| 217 |
+
|
| 218 |
+
# Count valid examples
|
| 219 |
+
valid_count = sum(validated['valid'])
|
| 220 |
+
invalid_count = len(validated) - valid_count
|
| 221 |
+
|
| 222 |
+
valid_rate = valid_count / len(validated) * 100
|
| 223 |
+
|
| 224 |
+
logger.info(f"Valid examples: {valid_count}/{len(validated)} ({valid_rate:.1f}%)")
|
| 225 |
+
|
| 226 |
+
if invalid_count > 0:
|
| 227 |
+
logger.warning(f"Found {invalid_count} invalid examples!")
|
| 228 |
+
|
| 229 |
+
# Show first few invalid examples
|
| 230 |
+
invalid_examples = [
|
| 231 |
+
ex for ex in validated if not ex['valid']
|
| 232 |
+
][:3]
|
| 233 |
+
|
| 234 |
+
for i, ex in enumerate(invalid_examples):
|
| 235 |
+
logger.warning(f"\nInvalid example {i+1}:")
|
| 236 |
+
logger.warning(f"Start markers: {ex['start_count']}")
|
| 237 |
+
logger.warning(f"End markers: {ex['end_count']}")
|
| 238 |
+
logger.warning(f"Text: {ex['text'][:200]}...")
|
| 239 |
+
|
| 240 |
+
# Update statistics
|
| 241 |
+
statistics['splits'][split_name]['valid'] = valid_count
|
| 242 |
+
statistics['splits'][split_name]['invalid'] = invalid_count
|
| 243 |
+
statistics['splits'][split_name]['valid_rate'] = valid_rate
|
| 244 |
+
|
| 245 |
+
# Convert back to DatasetDict
|
| 246 |
+
processed_dataset = DatasetDict(processed_dataset)
|
| 247 |
+
|
| 248 |
+
return processed_dataset, statistics
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def save_dataset(dataset: DatasetDict, output_dir: Path, data_dir: str):
|
| 252 |
+
"""
|
| 253 |
+
Save processed dataset to local directory.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
dataset: Processed dataset to save
|
| 257 |
+
output_dir: Directory to save to
|
| 258 |
+
data_dir: Original data directory name (for filename)
|
| 259 |
+
"""
|
| 260 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 261 |
+
|
| 262 |
+
logger.info(f"\nSaving processed dataset to {output_dir}...")
|
| 263 |
+
|
| 264 |
+
for split_name, split_data in dataset.items():
|
| 265 |
+
# Save as CSV
|
| 266 |
+
output_file = output_dir / f"{split_name}_{data_dir}.csv"
|
| 267 |
+
|
| 268 |
+
# Convert to pandas and save
|
| 269 |
+
df = split_data.to_pandas()
|
| 270 |
+
df.to_csv(output_file, index=False)
|
| 271 |
+
|
| 272 |
+
logger.info(f"Saved {split_name} split: {output_file} ({len(df)} examples)")
|
| 273 |
+
|
| 274 |
+
logger.info("Dataset saved successfully!")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def print_statistics(statistics: Dict):
|
| 278 |
+
"""
|
| 279 |
+
Print processing statistics in a formatted table.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
statistics: Dictionary containing processing statistics
|
| 283 |
+
"""
|
| 284 |
+
logger.info("\n" + "="*60)
|
| 285 |
+
logger.info("PROCESSING STATISTICS")
|
| 286 |
+
logger.info("="*60)
|
| 287 |
+
|
| 288 |
+
logger.info(f"\nTotal examples: {statistics['total_examples']}")
|
| 289 |
+
logger.info(f"Processed examples: {statistics['processed_examples']}")
|
| 290 |
+
logger.info(f"Already marked: {statistics['already_marked']}")
|
| 291 |
+
|
| 292 |
+
logger.info("\nPer-split statistics:")
|
| 293 |
+
logger.info("-"*60)
|
| 294 |
+
|
| 295 |
+
for split_name, split_stats in statistics['splits'].items():
|
| 296 |
+
logger.info(f"\n{split_name.upper()}:")
|
| 297 |
+
logger.info(f" Total: {split_stats['total']}")
|
| 298 |
+
logger.info(f" Processed: {split_stats['processed']}")
|
| 299 |
+
logger.info(f" Already marked: {split_stats.get('already_marked', 0)}")
|
| 300 |
+
|
| 301 |
+
if 'valid' in split_stats:
|
| 302 |
+
logger.info(f" Valid: {split_stats['valid']}")
|
| 303 |
+
logger.info(f" Invalid: {split_stats['invalid']}")
|
| 304 |
+
logger.info(f" Valid rate: {split_stats['valid_rate']:.1f}%")
|
| 305 |
+
|
| 306 |
+
logger.info("="*60)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def main():
|
| 310 |
+
parser = argparse.ArgumentParser(
|
| 311 |
+
description="Prepare training data with proper end-of-expression markers"
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--dataset_repo_id",
|
| 315 |
+
type=str,
|
| 316 |
+
required=True,
|
| 317 |
+
help="HuggingFace dataset repository ID"
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--data_dir",
|
| 321 |
+
type=str,
|
| 322 |
+
required=True,
|
| 323 |
+
help="Subdirectory within the dataset (e.g., '700K')"
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--data_column",
|
| 327 |
+
type=str,
|
| 328 |
+
required=True,
|
| 329 |
+
help="Column to use for training data (e.g., 'i_prompt_n')"
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--output_dir",
|
| 333 |
+
type=str,
|
| 334 |
+
required=True,
|
| 335 |
+
help="Directory to save processed dataset"
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--validate",
|
| 339 |
+
action="store_true",
|
| 340 |
+
help="Run validation after processing"
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--push_to_hub",
|
| 344 |
+
action="store_true",
|
| 345 |
+
help="Push processed dataset to HuggingFace Hub"
|
| 346 |
+
)
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--hub_repo_id",
|
| 349 |
+
type=str,
|
| 350 |
+
default=None,
|
| 351 |
+
help="HuggingFace repository ID for pushing (if --push_to_hub)"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
args = parser.parse_args()
|
| 355 |
+
|
| 356 |
+
# Convert output_dir to Path
|
| 357 |
+
output_dir = Path(args.output_dir)
|
| 358 |
+
|
| 359 |
+
# Process dataset
|
| 360 |
+
try:
|
| 361 |
+
processed_dataset, statistics = process_dataset(
|
| 362 |
+
dataset_repo_id=args.dataset_repo_id,
|
| 363 |
+
data_dir=args.data_dir,
|
| 364 |
+
data_column=args.data_column,
|
| 365 |
+
output_dir=output_dir,
|
| 366 |
+
validate=args.validate
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Print statistics
|
| 370 |
+
print_statistics(statistics)
|
| 371 |
+
|
| 372 |
+
# Save to local directory
|
| 373 |
+
save_dataset(processed_dataset, output_dir, args.data_dir)
|
| 374 |
+
|
| 375 |
+
# Push to Hub if requested
|
| 376 |
+
if args.push_to_hub:
|
| 377 |
+
if not args.hub_repo_id:
|
| 378 |
+
logger.error("--hub_repo_id required when using --push_to_hub")
|
| 379 |
+
sys.exit(1)
|
| 380 |
+
|
| 381 |
+
logger.info(f"\nPushing to HuggingFace Hub: {args.hub_repo_id}")
|
| 382 |
+
processed_dataset.push_to_hub(args.hub_repo_id)
|
| 383 |
+
logger.info("Successfully pushed to Hub!")
|
| 384 |
+
|
| 385 |
+
# Check if any validation failed
|
| 386 |
+
if args.validate:
|
| 387 |
+
all_valid = all(
|
| 388 |
+
split_stats.get('invalid', 0) == 0
|
| 389 |
+
for split_stats in statistics['splits'].values()
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
if not all_valid:
|
| 393 |
+
logger.error("\n⚠️ Some examples failed validation!")
|
| 394 |
+
sys.exit(1)
|
| 395 |
+
else:
|
| 396 |
+
logger.info("\n✅ All examples validated successfully!")
|
| 397 |
+
|
| 398 |
+
logger.info("\n✅ Data preparation complete!")
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(f"\n❌ Error during processing: {e}")
|
| 402 |
+
import traceback
|
| 403 |
+
traceback.print_exc()
|
| 404 |
+
sys.exit(1)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
+
main()
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Script para avaliacao customizada de modelos treinados
|
| 2 |
+
# Projeto Seriguela - Avaliacao de expressoes simbolicas geradas
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import re
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
+
from peft import PeftModel
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
# Add parent directory to path for imports
|
| 20 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
+
from classes.expression import Expression
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description="Evaluate a trained model on expression generation")
|
| 26 |
+
parser.add_argument("--model_path", type=str, required=True,
|
| 27 |
+
help="Path to model (local or HuggingFace Hub)")
|
| 28 |
+
parser.add_argument("--base_model", type=str, default=None,
|
| 29 |
+
help="Base model for PEFT (if model_path is adapter)")
|
| 30 |
+
parser.add_argument("--dataset_repo_id", type=str, default="augustocsc/sintetico_natural",
|
| 31 |
+
help="HuggingFace dataset repository")
|
| 32 |
+
parser.add_argument("--data_dir", type=str, default="700K",
|
| 33 |
+
help="Data directory within dataset")
|
| 34 |
+
parser.add_argument("--data_column", type=str, default="i_prompt_n",
|
| 35 |
+
help="Column name for prompts (i_prompt_n for infix, p_prompt_n for prefix)")
|
| 36 |
+
parser.add_argument("--num_samples", type=int, default=500,
|
| 37 |
+
help="Number of samples to evaluate")
|
| 38 |
+
parser.add_argument("--num_generations", type=int, default=1,
|
| 39 |
+
help="Number of generations per prompt")
|
| 40 |
+
parser.add_argument("--max_new_tokens", type=int, default=128,
|
| 41 |
+
help="Maximum new tokens to generate")
|
| 42 |
+
parser.add_argument("--temperature", type=float, default=0.7,
|
| 43 |
+
help="Sampling temperature")
|
| 44 |
+
parser.add_argument("--top_p", type=float, default=0.9,
|
| 45 |
+
help="Top-p sampling parameter")
|
| 46 |
+
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
| 47 |
+
help="Directory to save evaluation results")
|
| 48 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 49 |
+
help="Random seed")
|
| 50 |
+
parser.add_argument("--device", type=str, default="auto",
|
| 51 |
+
help="Device to use (auto, cuda, cpu)")
|
| 52 |
+
return parser.parse_args()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def extract_expression_from_output(output: str, is_prefix: bool = False) -> str:
|
| 56 |
+
"""Extract the expression from model output."""
|
| 57 |
+
# Try marker-based first
|
| 58 |
+
start_marker = "<|startofex|>"
|
| 59 |
+
end_marker = "<|endofex|>"
|
| 60 |
+
|
| 61 |
+
if start_marker in output and end_marker in output:
|
| 62 |
+
start_idx = output.find(start_marker) + len(start_marker)
|
| 63 |
+
end_idx = output.find(end_marker)
|
| 64 |
+
if start_idx < end_idx:
|
| 65 |
+
return output[start_idx:end_idx].strip()
|
| 66 |
+
|
| 67 |
+
# Fallback: Extract first complete expression after start marker
|
| 68 |
+
if start_marker in output:
|
| 69 |
+
start_idx = output.find(start_marker) + len(start_marker)
|
| 70 |
+
remaining = output[start_idx:].strip()
|
| 71 |
+
|
| 72 |
+
# Split at common boundaries
|
| 73 |
+
for boundary in ["\nvars:", "\nVariables:", "\nOperators:", "\n\n", "<|endoftext|>"]:
|
| 74 |
+
if boundary in remaining:
|
| 75 |
+
remaining = remaining.split(boundary)[0].strip()
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
# Remove any trailing incomplete text - take just the first line
|
| 79 |
+
remaining = remaining.split("\n")[0].strip()
|
| 80 |
+
|
| 81 |
+
# Limit length if unreasonably long
|
| 82 |
+
if len(remaining) > 150:
|
| 83 |
+
remaining = remaining[:150]
|
| 84 |
+
|
| 85 |
+
return remaining
|
| 86 |
+
|
| 87 |
+
# Last resort: look for "expr:" or "Expression:" pattern
|
| 88 |
+
match = re.search(r'(?:expr|Expression):\s*(.+?)(?:\n|$)', output, re.IGNORECASE)
|
| 89 |
+
if match:
|
| 90 |
+
return match.group(1).strip()
|
| 91 |
+
|
| 92 |
+
# Give up: return first line, limited length
|
| 93 |
+
first_line = output.strip().split("\n")[0]
|
| 94 |
+
return first_line[:100] if len(first_line) > 100 else first_line
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def validate_expression(expr_str: str, is_prefix: bool = False) -> dict:
|
| 98 |
+
"""Validate if expression is syntactically correct."""
|
| 99 |
+
result = {
|
| 100 |
+
"valid": False,
|
| 101 |
+
"parseable": False,
|
| 102 |
+
"error": None,
|
| 103 |
+
"expression_obj": None
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if not expr_str or expr_str.strip() == "":
|
| 107 |
+
result["error"] = "Empty expression"
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
expr_obj = Expression(expr_str, is_prefix=is_prefix)
|
| 112 |
+
result["parseable"] = True
|
| 113 |
+
result["valid"] = True
|
| 114 |
+
result["expression_obj"] = expr_obj
|
| 115 |
+
except Exception as e:
|
| 116 |
+
result["error"] = str(e)
|
| 117 |
+
|
| 118 |
+
return result
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def check_prompt_adherence(expr_str: str, prompt: str, is_prefix: bool = False) -> dict:
|
| 122 |
+
"""Check if expression adheres to prompt constraints."""
|
| 123 |
+
result = {
|
| 124 |
+
"uses_allowed_vars": False,
|
| 125 |
+
"uses_allowed_ops": False,
|
| 126 |
+
"all_constraints_met": False
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Extract allowed vars and ops from prompt
|
| 130 |
+
# Typical prompt format: "Variables: x_1, x_2, x_3\nOperators: +, -, *, sin\n..."
|
| 131 |
+
|
| 132 |
+
# Extract variables from prompt
|
| 133 |
+
var_match = re.search(r"Variables?:\s*([^\n]+)", prompt, re.IGNORECASE)
|
| 134 |
+
allowed_vars = set()
|
| 135 |
+
if var_match:
|
| 136 |
+
var_str = var_match.group(1)
|
| 137 |
+
# Match patterns like x_1, x_2, etc.
|
| 138 |
+
allowed_vars = set(re.findall(r"x_\d+", var_str))
|
| 139 |
+
|
| 140 |
+
# Extract operators from prompt
|
| 141 |
+
op_match = re.search(r"Operators?:\s*([^\n]+)", prompt, re.IGNORECASE)
|
| 142 |
+
allowed_ops = set()
|
| 143 |
+
if op_match:
|
| 144 |
+
op_str = op_match.group(1)
|
| 145 |
+
# Common operators
|
| 146 |
+
ops = ['+', '-', '*', '/', '**', 'sin', 'cos', 'tan', 'log', 'sqrt', 'exp']
|
| 147 |
+
for op in ops:
|
| 148 |
+
if op in op_str:
|
| 149 |
+
allowed_ops.add(op)
|
| 150 |
+
|
| 151 |
+
# Check variables in expression
|
| 152 |
+
expr_vars = set(re.findall(r"x_\d+", expr_str))
|
| 153 |
+
if allowed_vars:
|
| 154 |
+
result["uses_allowed_vars"] = expr_vars.issubset(allowed_vars)
|
| 155 |
+
else:
|
| 156 |
+
result["uses_allowed_vars"] = True # No constraint specified
|
| 157 |
+
|
| 158 |
+
# Check operators (simplified check)
|
| 159 |
+
result["uses_allowed_ops"] = True # Default to true if no ops specified
|
| 160 |
+
if allowed_ops:
|
| 161 |
+
# This is a simplified check - would need more sophisticated parsing for accuracy
|
| 162 |
+
for op in ['sin', 'cos', 'tan', 'log', 'sqrt', 'exp']:
|
| 163 |
+
if op in expr_str and op not in allowed_ops:
|
| 164 |
+
result["uses_allowed_ops"] = False
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
result["all_constraints_met"] = result["uses_allowed_vars"] and result["uses_allowed_ops"]
|
| 168 |
+
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def load_model_and_tokenizer(model_path: str, base_model: str = None, device: str = "auto"):
|
| 173 |
+
"""Load model and tokenizer."""
|
| 174 |
+
print(f"Loading model from: {model_path}")
|
| 175 |
+
|
| 176 |
+
# Determine device
|
| 177 |
+
if device == "auto":
|
| 178 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 179 |
+
|
| 180 |
+
# Load tokenizer
|
| 181 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 182 |
+
if tokenizer.pad_token is None:
|
| 183 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 184 |
+
|
| 185 |
+
# Check if this is a PEFT model
|
| 186 |
+
is_peft = os.path.exists(os.path.join(model_path, "adapter_config.json")) if os.path.isdir(model_path) else False
|
| 187 |
+
|
| 188 |
+
if is_peft or base_model:
|
| 189 |
+
# Load base model first
|
| 190 |
+
base = base_model or "gpt2"
|
| 191 |
+
print(f"Loading base model: {base}")
|
| 192 |
+
model = AutoModelForCausalLM.from_pretrained(base)
|
| 193 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 194 |
+
|
| 195 |
+
# Load PEFT adapter
|
| 196 |
+
print("Loading PEFT adapter...")
|
| 197 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 198 |
+
model = model.merge_and_unload() # Merge for faster inference
|
| 199 |
+
else:
|
| 200 |
+
# Load full model
|
| 201 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
| 202 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 203 |
+
|
| 204 |
+
model = model.to(device)
|
| 205 |
+
model.eval()
|
| 206 |
+
|
| 207 |
+
return model, tokenizer, device
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def generate_expression(model, tokenizer, prompt: str, device: str,
|
| 211 |
+
max_new_tokens: int = 128, temperature: float = 0.7,
|
| 212 |
+
top_p: float = 0.9, num_return_sequences: int = 1):
|
| 213 |
+
"""Generate expression(s) from prompt."""
|
| 214 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 215 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 216 |
+
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
outputs = model.generate(
|
| 219 |
+
**inputs,
|
| 220 |
+
max_new_tokens=max_new_tokens,
|
| 221 |
+
temperature=temperature,
|
| 222 |
+
top_p=top_p,
|
| 223 |
+
do_sample=True,
|
| 224 |
+
num_return_sequences=num_return_sequences,
|
| 225 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 226 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
generated = tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
| 230 |
+
return generated
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def evaluate_model(args):
|
| 234 |
+
"""Main evaluation function."""
|
| 235 |
+
# Set seed
|
| 236 |
+
torch.manual_seed(args.seed)
|
| 237 |
+
np.random.seed(args.seed)
|
| 238 |
+
|
| 239 |
+
# Load model
|
| 240 |
+
model, tokenizer, device = load_model_and_tokenizer(
|
| 241 |
+
args.model_path, args.base_model, args.device
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Load dataset
|
| 245 |
+
print(f"Loading dataset: {args.dataset_repo_id}/{args.data_dir}")
|
| 246 |
+
try:
|
| 247 |
+
dataset = load_dataset(
|
| 248 |
+
args.dataset_repo_id,
|
| 249 |
+
data_files={
|
| 250 |
+
"test": f"{args.data_dir}/test_{args.data_dir}.csv"
|
| 251 |
+
}
|
| 252 |
+
)["test"]
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Error loading test set, trying validation: {e}")
|
| 255 |
+
dataset = load_dataset(
|
| 256 |
+
args.dataset_repo_id,
|
| 257 |
+
data_files={
|
| 258 |
+
"validation": f"{args.data_dir}/val_{args.data_dir}.csv"
|
| 259 |
+
}
|
| 260 |
+
)["validation"]
|
| 261 |
+
|
| 262 |
+
# Sample if needed
|
| 263 |
+
if len(dataset) > args.num_samples:
|
| 264 |
+
indices = np.random.choice(len(dataset), args.num_samples, replace=False)
|
| 265 |
+
dataset = dataset.select(indices)
|
| 266 |
+
|
| 267 |
+
print(f"Evaluating on {len(dataset)} samples...")
|
| 268 |
+
|
| 269 |
+
# Determine if prefix or infix
|
| 270 |
+
is_prefix = args.data_column.startswith("p_")
|
| 271 |
+
|
| 272 |
+
# Evaluation metrics
|
| 273 |
+
metrics = {
|
| 274 |
+
"total_samples": 0,
|
| 275 |
+
"total_generations": 0,
|
| 276 |
+
"valid_expressions": 0,
|
| 277 |
+
"parseable_expressions": 0,
|
| 278 |
+
"uses_allowed_vars": 0,
|
| 279 |
+
"uses_allowed_ops": 0,
|
| 280 |
+
"all_constraints_met": 0,
|
| 281 |
+
"unique_expressions": set(),
|
| 282 |
+
"expression_lengths": [],
|
| 283 |
+
"errors": Counter(),
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
results = []
|
| 287 |
+
|
| 288 |
+
# Generate and evaluate
|
| 289 |
+
for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")):
|
| 290 |
+
prompt = sample[args.data_column]
|
| 291 |
+
|
| 292 |
+
# Extract just the prompt part (before the expression)
|
| 293 |
+
# Typically the prompt ends before <|startofex|>
|
| 294 |
+
if "<|startofex|>" in prompt:
|
| 295 |
+
prompt_only = prompt.split("<|startofex|>")[0] + "<|startofex|>"
|
| 296 |
+
else:
|
| 297 |
+
prompt_only = prompt
|
| 298 |
+
|
| 299 |
+
generations = generate_expression(
|
| 300 |
+
model, tokenizer, prompt_only, device,
|
| 301 |
+
max_new_tokens=args.max_new_tokens,
|
| 302 |
+
temperature=args.temperature,
|
| 303 |
+
top_p=args.top_p,
|
| 304 |
+
num_return_sequences=args.num_generations
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
metrics["total_samples"] += 1
|
| 308 |
+
|
| 309 |
+
for gen_output in generations:
|
| 310 |
+
metrics["total_generations"] += 1
|
| 311 |
+
|
| 312 |
+
# Extract expression
|
| 313 |
+
expr_str = extract_expression_from_output(gen_output, is_prefix)
|
| 314 |
+
|
| 315 |
+
# Validate
|
| 316 |
+
validation = validate_expression(expr_str, is_prefix)
|
| 317 |
+
|
| 318 |
+
# Check adherence
|
| 319 |
+
adherence = check_prompt_adherence(expr_str, prompt_only, is_prefix)
|
| 320 |
+
|
| 321 |
+
# Update metrics
|
| 322 |
+
if validation["valid"]:
|
| 323 |
+
metrics["valid_expressions"] += 1
|
| 324 |
+
if validation["parseable"]:
|
| 325 |
+
metrics["parseable_expressions"] += 1
|
| 326 |
+
metrics["unique_expressions"].add(expr_str)
|
| 327 |
+
metrics["expression_lengths"].append(len(expr_str))
|
| 328 |
+
if validation["error"]:
|
| 329 |
+
metrics["errors"][validation["error"][:50]] += 1
|
| 330 |
+
|
| 331 |
+
if adherence["uses_allowed_vars"]:
|
| 332 |
+
metrics["uses_allowed_vars"] += 1
|
| 333 |
+
if adherence["uses_allowed_ops"]:
|
| 334 |
+
metrics["uses_allowed_ops"] += 1
|
| 335 |
+
if adherence["all_constraints_met"]:
|
| 336 |
+
metrics["all_constraints_met"] += 1
|
| 337 |
+
|
| 338 |
+
results.append({
|
| 339 |
+
"sample_idx": idx,
|
| 340 |
+
"prompt": prompt_only[:200], # Truncate for storage
|
| 341 |
+
"generated_output": gen_output[:500],
|
| 342 |
+
"extracted_expression": expr_str,
|
| 343 |
+
"valid": validation["valid"],
|
| 344 |
+
"parseable": validation["parseable"],
|
| 345 |
+
"error": validation["error"],
|
| 346 |
+
"uses_allowed_vars": adherence["uses_allowed_vars"],
|
| 347 |
+
"uses_allowed_ops": adherence["uses_allowed_ops"],
|
| 348 |
+
})
|
| 349 |
+
|
| 350 |
+
# Calculate final metrics
|
| 351 |
+
total_gen = metrics["total_generations"]
|
| 352 |
+
final_metrics = {
|
| 353 |
+
"model_path": args.model_path,
|
| 354 |
+
"dataset": f"{args.dataset_repo_id}/{args.data_dir}",
|
| 355 |
+
"data_column": args.data_column,
|
| 356 |
+
"is_prefix": is_prefix,
|
| 357 |
+
"num_samples": metrics["total_samples"],
|
| 358 |
+
"num_generations": total_gen,
|
| 359 |
+
"temperature": args.temperature,
|
| 360 |
+
"top_p": args.top_p,
|
| 361 |
+
|
| 362 |
+
# Validity metrics
|
| 363 |
+
"valid_rate": metrics["valid_expressions"] / total_gen if total_gen > 0 else 0,
|
| 364 |
+
"parseable_rate": metrics["parseable_expressions"] / total_gen if total_gen > 0 else 0,
|
| 365 |
+
|
| 366 |
+
# Adherence metrics
|
| 367 |
+
"uses_allowed_vars_rate": metrics["uses_allowed_vars"] / total_gen if total_gen > 0 else 0,
|
| 368 |
+
"uses_allowed_ops_rate": metrics["uses_allowed_ops"] / total_gen if total_gen > 0 else 0,
|
| 369 |
+
"constraints_met_rate": metrics["all_constraints_met"] / total_gen if total_gen > 0 else 0,
|
| 370 |
+
|
| 371 |
+
# Diversity metrics
|
| 372 |
+
"unique_expressions": len(metrics["unique_expressions"]),
|
| 373 |
+
"diversity_rate": len(metrics["unique_expressions"]) / total_gen if total_gen > 0 else 0,
|
| 374 |
+
"avg_expression_length": np.mean(metrics["expression_lengths"]) if metrics["expression_lengths"] else 0,
|
| 375 |
+
|
| 376 |
+
# Error distribution (top 10)
|
| 377 |
+
"top_errors": dict(metrics["errors"].most_common(10)),
|
| 378 |
+
|
| 379 |
+
"timestamp": datetime.now().isoformat(),
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
# Print results
|
| 383 |
+
print("\n" + "="*60)
|
| 384 |
+
print("EVALUATION RESULTS")
|
| 385 |
+
print("="*60)
|
| 386 |
+
print(f"Model: {args.model_path}")
|
| 387 |
+
print(f"Dataset: {args.dataset_repo_id}/{args.data_dir}")
|
| 388 |
+
print(f"Format: {'Prefix' if is_prefix else 'Infix'}")
|
| 389 |
+
print("-"*60)
|
| 390 |
+
print(f"Total samples: {metrics['total_samples']}")
|
| 391 |
+
print(f"Total generations: {total_gen}")
|
| 392 |
+
print("-"*60)
|
| 393 |
+
print("VALIDITY METRICS:")
|
| 394 |
+
print(f" Valid rate: {final_metrics['valid_rate']:.2%}")
|
| 395 |
+
print(f" Parseable rate: {final_metrics['parseable_rate']:.2%}")
|
| 396 |
+
print("-"*60)
|
| 397 |
+
print("ADHERENCE METRICS:")
|
| 398 |
+
print(f" Uses allowed vars: {final_metrics['uses_allowed_vars_rate']:.2%}")
|
| 399 |
+
print(f" Uses allowed ops: {final_metrics['uses_allowed_ops_rate']:.2%}")
|
| 400 |
+
print(f" All constraints met: {final_metrics['constraints_met_rate']:.2%}")
|
| 401 |
+
print("-"*60)
|
| 402 |
+
print("DIVERSITY METRICS:")
|
| 403 |
+
print(f" Unique expressions: {final_metrics['unique_expressions']}")
|
| 404 |
+
print(f" Diversity rate: {final_metrics['diversity_rate']:.2%}")
|
| 405 |
+
print(f" Avg expression length: {final_metrics['avg_expression_length']:.1f}")
|
| 406 |
+
print("="*60)
|
| 407 |
+
|
| 408 |
+
# Save results
|
| 409 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 410 |
+
|
| 411 |
+
# Create filename from model path
|
| 412 |
+
model_name = args.model_path.replace("/", "_").replace("\\", "_")
|
| 413 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 414 |
+
|
| 415 |
+
# Save metrics
|
| 416 |
+
metrics_file = os.path.join(args.output_dir, f"metrics_{model_name}_{timestamp}.json")
|
| 417 |
+
with open(metrics_file, "w") as f:
|
| 418 |
+
json.dump(final_metrics, f, indent=2)
|
| 419 |
+
print(f"\nMetrics saved to: {metrics_file}")
|
| 420 |
+
|
| 421 |
+
# Save detailed results
|
| 422 |
+
results_file = os.path.join(args.output_dir, f"results_{model_name}_{timestamp}.json")
|
| 423 |
+
with open(results_file, "w") as f:
|
| 424 |
+
json.dump(results, f, indent=2)
|
| 425 |
+
print(f"Detailed results saved to: {results_file}")
|
| 426 |
+
|
| 427 |
+
return final_metrics
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
if __name__ == "__main__":
|
| 431 |
+
args = parse_args()
|
| 432 |
+
evaluate_model(args)
|