Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .claude/settings.local.json +19 -0
- .gitignore +63 -0
- .pre-commit-config.yaml +42 -0
- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.14.10/11987423046224602800 +0 -0
- .ruff_cache/0.14.10/9230827824541771364 +0 -0
- .ruff_cache/0.14.10/9755640140589486738 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- .yamllint.yaml +14 -0
- AGENTS.md +39 -0
- ATTRIBUTION.md +30 -0
- CLAUDE.md +120 -0
- README.md +77 -0
- RunPod LLM Fine-Tuning Strategy.md +339 -0
- ai-docs/chatbot-ideology.yaml +373 -0
- ai-docs/finetune.yaml +297 -0
- ai-docs/reward-modeling.yaml +857 -0
- ai-docs/runpod.yaml +362 -0
- ai-docs/training-schema.yaml +484 -0
- docker/.env.example +124 -0
- docker/Dockerfile +132 -0
- docker/requirements.txt +31 -0
- docker/start.sh +165 -0
- notebooks/Marxist_GRPO_RunPod_Optimized.ipynb +1107 -0
- pyproject.toml +166 -0
- src/prolewiki_llm/__init__.py +99 -0
- src/prolewiki_llm/convert_to_qwen.py +37 -0
- src/prolewiki_llm/export_grpo_dataset.py +224 -0
- src/prolewiki_llm/grpo_rewards.py +1847 -0
- src/prolewiki_llm/train_grpo_marxist.py +341 -0
- src/prolewiki_llm/train_headless.py +460 -0
- src/prolewiki_llm/train_marxist.py +201 -0
- src/prolewiki_llm/transform_to_grpo.py +64 -0
- src/prolewiki_llm/wandb_logging.py +529 -0
- tests/__init__.py +1 -0
- tests/conftest.py +114 -0
- tests/fixtures/__init__.py +1 -0
- tests/fixtures/mock_bin/huggingface-cli +7 -0
- tests/fixtures/mock_bin/python +35 -0
- tests/fixtures/mock_bin/runpodctl +12 -0
- tests/fixtures/mock_bin/wandb +7 -0
- tests/fixtures/mock_bin/wget +17 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_start_sh.py +462 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_grpo_rewards.py +0 -0
- tests/unit/test_train_headless.py +248 -0
- tests/unit/test_wandb_logging.py +467 -0
- training_data/entity_whitelist.json +0 -0
- training_data/entity_whitelist_clean.json +0 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(wc:*)",
|
| 5 |
+
"Bash(du:*)",
|
| 6 |
+
"mcp__context7__resolve-library-id",
|
| 7 |
+
"mcp__context7__query-docs",
|
| 8 |
+
"Bash(uv sync:*)",
|
| 9 |
+
"Bash(uv run python:*)",
|
| 10 |
+
"Bash(uv run ruff check:*)",
|
| 11 |
+
"Bash(uv run pytest:*)",
|
| 12 |
+
"WebFetch(domain:www.runpod.io)"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
"disabledMcpjsonServers": [
|
| 16 |
+
"sequential-thinking",
|
| 17 |
+
"neovim"
|
| 18 |
+
]
|
| 19 |
+
}
|
.gitignore
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.venv/
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
|
| 28 |
+
# IDEs
|
| 29 |
+
.idea/
|
| 30 |
+
.vscode/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
|
| 34 |
+
# Jupyter
|
| 35 |
+
.ipynb_checkpoints/
|
| 36 |
+
|
| 37 |
+
# Testing
|
| 38 |
+
.pytest_cache/
|
| 39 |
+
.coverage
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
|
| 44 |
+
# mypy
|
| 45 |
+
.mypy_cache/
|
| 46 |
+
|
| 47 |
+
# Archives
|
| 48 |
+
*.tar.gz
|
| 49 |
+
*.zip
|
| 50 |
+
|
| 51 |
+
# Model artifacts (large files)
|
| 52 |
+
*.safetensors
|
| 53 |
+
*.bin
|
| 54 |
+
*.gguf
|
| 55 |
+
|
| 56 |
+
# Training outputs (generated)
|
| 57 |
+
outputs/
|
| 58 |
+
checkpoints/
|
| 59 |
+
lora-output/
|
| 60 |
+
|
| 61 |
+
# OS
|
| 62 |
+
.DS_Store
|
| 63 |
+
Thumbs.db
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 3 |
+
rev: v0.8.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: ruff
|
| 6 |
+
name: ruff (lint)
|
| 7 |
+
args: [--fix, --exit-non-zero-on-fix]
|
| 8 |
+
types_or: [python, pyi]
|
| 9 |
+
- id: ruff-format
|
| 10 |
+
name: ruff (format)
|
| 11 |
+
types_or: [python, pyi]
|
| 12 |
+
|
| 13 |
+
- repo: local
|
| 14 |
+
hooks:
|
| 15 |
+
- id: mypy
|
| 16 |
+
name: mypy (typecheck)
|
| 17 |
+
entry: uv run mypy
|
| 18 |
+
language: system
|
| 19 |
+
types: [python]
|
| 20 |
+
args: [src/prolewiki_llm/]
|
| 21 |
+
pass_filenames: false
|
| 22 |
+
|
| 23 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 24 |
+
rev: v5.0.0
|
| 25 |
+
hooks:
|
| 26 |
+
- id: trailing-whitespace
|
| 27 |
+
- id: end-of-file-fixer
|
| 28 |
+
- id: check-yaml
|
| 29 |
+
args: [--unsafe]
|
| 30 |
+
- id: check-json
|
| 31 |
+
- id: check-toml
|
| 32 |
+
- id: check-added-large-files
|
| 33 |
+
args: [--maxkb=5000]
|
| 34 |
+
- id: check-merge-conflict
|
| 35 |
+
- id: debug-statements
|
| 36 |
+
|
| 37 |
+
- repo: https://github.com/adrienverge/yamllint.git
|
| 38 |
+
rev: v1.35.1
|
| 39 |
+
hooks:
|
| 40 |
+
- id: yamllint
|
| 41 |
+
name: yamllint (yaml lint)
|
| 42 |
+
args: [-c=.yamllint.yaml]
|
.ruff_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatically created by ruff.
|
| 2 |
+
*
|
.ruff_cache/0.14.10/11987423046224602800
ADDED
|
Binary file (310 Bytes). View file
|
|
|
.ruff_cache/0.14.10/9230827824541771364
ADDED
|
Binary file (228 Bytes). View file
|
|
|
.ruff_cache/0.14.10/9755640140589486738
ADDED
|
Binary file (170 Bytes). View file
|
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
.yamllint.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: default
|
| 2 |
+
|
| 3 |
+
rules:
|
| 4 |
+
line-length:
|
| 5 |
+
max: 200
|
| 6 |
+
level: warning
|
| 7 |
+
document-start: disable
|
| 8 |
+
truthy:
|
| 9 |
+
allowed-values: ['true', 'false', 'yes', 'no']
|
| 10 |
+
comments:
|
| 11 |
+
min-spaces-from-content: 1
|
| 12 |
+
indentation:
|
| 13 |
+
spaces: 2
|
| 14 |
+
indent-sequences: consistent
|
AGENTS.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Repository Guidelines
|
| 2 |
+
|
| 3 |
+
## Project Structure & Module Organization
|
| 4 |
+
|
| 5 |
+
- `src/prolewiki_llm/`: core library code (reward functions, dataset transforms, W&B logging).
|
| 6 |
+
- `tests/unit/`: fast unit tests and fixtures under `tests/conftest.py`.
|
| 7 |
+
- `training_data/`: curated datasets (`.jsonl`) and documentation (`MODEL_CARD.yaml`).
|
| 8 |
+
- `notebooks/`: training and experimentation notebooks.
|
| 9 |
+
- `ai-docs/`: AI-consumable documentation and schemas.
|
| 10 |
+
|
| 11 |
+
## Build, Test, and Development Commands
|
| 12 |
+
|
| 13 |
+
- `uv sync`: install core dependencies.
|
| 14 |
+
- `uv sync --group dev`: install developer tools (pytest, ruff, mypy).
|
| 15 |
+
- `uv sync --group training`: install GPU training dependencies.
|
| 16 |
+
- `python -m spacy download en_core_web_sm`: required model for coherence rewards.
|
| 17 |
+
- `uv run pytest`: run the test suite.
|
| 18 |
+
- `uv run pytest -m "not slow and not gpu"`: run fast tests only.
|
| 19 |
+
- `uv run ruff check .` / `uv run ruff format .`: lint and format.
|
| 20 |
+
- `uv run mypy src`: run strict type checks on library code.
|
| 21 |
+
|
| 22 |
+
## Coding Style & Naming Conventions
|
| 23 |
+
|
| 24 |
+
- Python 3.12; line length 100; double quotes; spaces for indentation (ruff format).
|
| 25 |
+
- Use `snake_case` for modules/functions and `PascalCase` for classes.
|
| 26 |
+
- Keep reward logic in `src/prolewiki_llm/grpo_rewards.py` and avoid side effects.
|
| 27 |
+
- Name data files descriptively, for example `training_data/synthetic_topic.jsonl`.
|
| 28 |
+
|
| 29 |
+
## Testing Guidelines
|
| 30 |
+
|
| 31 |
+
- Pytest with markers: `unit`, `slow`, and `gpu` (see `pyproject.toml`).
|
| 32 |
+
- Place tests in `tests/unit/` and name files `test_*.py`.
|
| 33 |
+
- Mark long-running or GPU-dependent tests explicitly.
|
| 34 |
+
|
| 35 |
+
## Commit & Pull Request Guidelines
|
| 36 |
+
|
| 37 |
+
- Git history only shows "initial commit"; no established convention yet.
|
| 38 |
+
- Use short, imperative commit subjects (for example, "Add format reward tests").
|
| 39 |
+
- PRs should describe changes, list tests run, and call out dataset or model updates.
|
ATTRIBUTION.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attribution
|
| 2 |
+
|
| 3 |
+
This project was extracted from [pw-mcp](https://github.com/prolewiki/pw-mcp) on December 25, 2025.
|
| 4 |
+
|
| 5 |
+
## Origin
|
| 6 |
+
|
| 7 |
+
The code in this repository was originally developed as part of the pw-mcp project (ProleWiki MCP Server), which provides semantic vector search over the ProleWiki corpus.
|
| 8 |
+
|
| 9 |
+
The AI training components were separated into this standalone repository to:
|
| 10 |
+
- Reduce the dependency footprint of pw-mcp
|
| 11 |
+
- Allow independent versioning and release cycles
|
| 12 |
+
- Separate the ~1GB ChromaDB releases from the ~4-16GB model releases
|
| 13 |
+
- Enable focused development on training infrastructure
|
| 14 |
+
|
| 15 |
+
## Original Authors
|
| 16 |
+
|
| 17 |
+
The original code was developed by contributors to the pw-mcp project.
|
| 18 |
+
|
| 19 |
+
## Files Extracted
|
| 20 |
+
|
| 21 |
+
The following files were copied from pw-mcp:
|
| 22 |
+
|
| 23 |
+
- `src/pw_mcp/ai_training/*.py` → `src/prolewiki_llm/`
|
| 24 |
+
- `training_data/` (entire directory)
|
| 25 |
+
- `tests/unit/training/` → `tests/unit/`
|
| 26 |
+
- `ai-docs/` (training-related YAML files)
|
| 27 |
+
|
| 28 |
+
## License
|
| 29 |
+
|
| 30 |
+
This project maintains the same AGPL-3.0-only license as the original pw-mcp project.
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
GRPO fine-tuning infrastructure for training Marxist-Leninist language models. The core functionality is a multi-layer reward system designed to prevent reward hacking during GRPO (Group Relative Policy Optimization) training.
|
| 8 |
+
|
| 9 |
+
## Commands
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Install dependencies
|
| 13 |
+
uv sync # Core dependencies
|
| 14 |
+
uv sync --group dev # Dev tools (pytest, ruff, mypy)
|
| 15 |
+
uv sync --group training # GPU training deps (unsloth, trl, vllm)
|
| 16 |
+
|
| 17 |
+
# Required NLP model
|
| 18 |
+
python -m spacy download en_core_web_sm
|
| 19 |
+
|
| 20 |
+
# Testing
|
| 21 |
+
uv run pytest # All tests
|
| 22 |
+
uv run pytest -m "not slow and not gpu" # Fast tests only
|
| 23 |
+
uv run pytest tests/unit/test_grpo_rewards.py::test_name # Single test
|
| 24 |
+
|
| 25 |
+
# Code quality
|
| 26 |
+
uv run ruff check . # Lint
|
| 27 |
+
uv run ruff format . # Format
|
| 28 |
+
uv run mypy src # Type check (strict mode)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Architecture
|
| 32 |
+
|
| 33 |
+
### Reward Function System (`src/prolewiki_llm/grpo_rewards.py`)
|
| 34 |
+
|
| 35 |
+
The reward system is layered to defeat "word soup" and other reward hacking strategies:
|
| 36 |
+
|
| 37 |
+
1. **Format rewards** - Validate `<think>...</think>` reasoning tags
|
| 38 |
+
2. **NLI coherence** - Use BART-large-MNLI to check if response entails ground truth
|
| 39 |
+
3. **Self-consistency** - Detect internal contradictions via pairwise NLI
|
| 40 |
+
4. **Structural coherence** - spaCy dependency parsing verifies terms in meaningful syntactic roles
|
| 41 |
+
5. **Topic relevance** - Ensure response addresses what was asked
|
| 42 |
+
6. **Interconnection depth** - Reward deep analysis, penalize buzzword salad
|
| 43 |
+
|
| 44 |
+
Key functions:
|
| 45 |
+
- `full_coherence_reward()` - Comprehensive 5-layer check (recommended for training)
|
| 46 |
+
- `robust_coherence_reward()` - NLI + self-consistency + structure
|
| 47 |
+
- `ideological_firmness_reward()` - Detects capitulation to reactionary framing
|
| 48 |
+
|
| 49 |
+
Models are lazy-loaded via `get_embedder()`, `get_nli_pipeline()`, `get_spacy_nlp()`.
|
| 50 |
+
|
| 51 |
+
### Reward Function Interface
|
| 52 |
+
|
| 53 |
+
All reward functions follow the GRPOTrainer signature:
|
| 54 |
+
```python
|
| 55 |
+
def reward_func(
|
| 56 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 57 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 58 |
+
answer: Sequence[str],
|
| 59 |
+
**kwargs: object,
|
| 60 |
+
) -> list[float]:
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### W&B Integration (`src/prolewiki_llm/wandb_logging.py`)
|
| 64 |
+
|
| 65 |
+
`WandbSampleLogger` accumulates samples and logs them as tables for debugging reward behavior. Use `create_logging_reward()` to wrap logging into a reward function.
|
| 66 |
+
|
| 67 |
+
## Testing Patterns
|
| 68 |
+
|
| 69 |
+
- Markers: `unit`, `slow`, `gpu` (see `pyproject.toml`)
|
| 70 |
+
- Fixtures in `tests/conftest.py`: `sample_question`, `sample_good_answer`, `sample_bad_answer`
|
| 71 |
+
- Mock heavy NLP models in unit tests to avoid slow initialization
|
| 72 |
+
|
| 73 |
+
## Code Style
|
| 74 |
+
|
| 75 |
+
- Python 3.12, line length 100, double quotes
|
| 76 |
+
- `snake_case` functions/modules, `PascalCase` classes
|
| 77 |
+
- Reward functions must accept `**kwargs` (GRPOTrainer interface requirement)
|
| 78 |
+
- Keep reward logic pure (no side effects except lazy model loading)
|
| 79 |
+
|
| 80 |
+
## Training
|
| 81 |
+
|
| 82 |
+
### Headless RunPod Training (Recommended)
|
| 83 |
+
|
| 84 |
+
Docker-based headless training for RunPod deployment:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
# 1. Export dataset from ../dataset/ to GRPO format
|
| 88 |
+
uv run python -m prolewiki_llm.export_grpo_dataset
|
| 89 |
+
|
| 90 |
+
# 2. Build Docker image
|
| 91 |
+
docker build -t marxist-grpo:latest -f docker/Dockerfile .
|
| 92 |
+
|
| 93 |
+
# 3. Deploy to RunPod
|
| 94 |
+
runpodctl create pod \
|
| 95 |
+
--name "marxist-grpo-training" \
|
| 96 |
+
--gpuType "NVIDIA A100 80GB PCIe" \
|
| 97 |
+
--imageName myregistry/marxist-grpo:latest \
|
| 98 |
+
--env HF_TOKEN=$HF_TOKEN \
|
| 99 |
+
--env WANDB_API_KEY=$WANDB_API_KEY \
|
| 100 |
+
--env HF_REPO=my-org/my-model
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
Key files:
|
| 104 |
+
- `src/prolewiki_llm/export_grpo_dataset.py` - Dataset export utility
|
| 105 |
+
- `src/prolewiki_llm/train_headless.py` - Headless training script
|
| 106 |
+
- `docker/Dockerfile` - Container definition
|
| 107 |
+
- `docker/start.sh` - Entrypoint (auth → train → upload → stop pod)
|
| 108 |
+
- `docker/.env.example` - Environment variable reference
|
| 109 |
+
|
| 110 |
+
Environment variables for `train_headless.py`:
|
| 111 |
+
- `HF_TOKEN`, `WANDB_API_KEY` (required)
|
| 112 |
+
- `MAX_STEPS`, `BATCH_SIZE`, `LEARNING_RATE`, `REWARD_MODE` (optional)
|
| 113 |
+
- `RUNPOD_POD_ID` (enables auto-termination after training)
|
| 114 |
+
|
| 115 |
+
### Notebook Training (Interactive)
|
| 116 |
+
|
| 117 |
+
See `notebooks/Marxist_GRPO_Training.ipynb` for interactive training. Key settings:
|
| 118 |
+
- Disable `torch.compile` via environment variables (causes hangs on RunPod/Jupyter)
|
| 119 |
+
- Use `use_gradient_checkpointing=True` (not `"unsloth"`)
|
| 120 |
+
- `load_in_4bit=False` required for GRPO (16-bit LoRA adapters)
|
README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# prolewiki-llm
|
| 2 |
+
|
| 3 |
+
GRPO fine-tuning and reward functions for training Marxist-Leninist language models.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This repository contains the AI training infrastructure for fine-tuning language models on Marxist-Leninist theory. It includes:
|
| 8 |
+
|
| 9 |
+
- **Reward Functions**: Multi-layer reward system for GRPO training that prevents reward hacking
|
| 10 |
+
- **Training Data**: Curated Q&A pairs and synthetic datasets for ideological consistency
|
| 11 |
+
- **Training Scripts**: Ready-to-run notebooks for RunPod/cloud GPU training
|
| 12 |
+
- **W&B Integration**: Weights & Biases logging for training observability
|
| 13 |
+
|
| 14 |
+
## Related Projects
|
| 15 |
+
|
| 16 |
+
- [pw-mcp](https://github.com/prolewiki/pw-mcp) - MCP server and ChromaDB pipeline for ProleWiki semantic search
|
| 17 |
+
|
| 18 |
+
## Installation
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
# Basic installation
|
| 22 |
+
uv sync
|
| 23 |
+
|
| 24 |
+
# Download spacy model (required for topic/coherence rewards)
|
| 25 |
+
python -m spacy download en_core_web_sm
|
| 26 |
+
|
| 27 |
+
# With training dependencies (for GPU training)
|
| 28 |
+
uv sync --group training
|
| 29 |
+
|
| 30 |
+
# Development
|
| 31 |
+
uv sync --group dev
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
### Reward Functions
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
from prolewiki_llm import full_coherence_reward, format_reward
|
| 40 |
+
|
| 41 |
+
# Combined 5-layer coherence check (recommended for training)
|
| 42 |
+
reward = full_coherence_reward(
|
| 43 |
+
prompts=["What is imperialism?"],
|
| 44 |
+
completions=["<think>...</think>\n\nImperialism is..."],
|
| 45 |
+
answer="Lenin defined imperialism as..."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Individual reward components
|
| 49 |
+
format_score = format_reward(completions=["<think>...</think>\n\nAnswer..."])
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Training
|
| 53 |
+
|
| 54 |
+
See `notebooks/Marxist_GRPO_Training.ipynb` for a complete training example.
|
| 55 |
+
|
| 56 |
+
## Project Structure
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
prolewiki-llm/
|
| 60 |
+
├── src/prolewiki_llm/
|
| 61 |
+
│ ├── grpo_rewards.py # 17+ reward functions
|
| 62 |
+
│ ├── wandb_logging.py # W&B integration
|
| 63 |
+
│ └── transform_to_grpo.py # Dataset conversion
|
| 64 |
+
├── training_data/
|
| 65 |
+
│ ├── synthetic_*.jsonl # Training datasets
|
| 66 |
+
│ ├── entity_whitelist.json # Anti-hallucination data
|
| 67 |
+
│ └── MODEL_CARD.yaml # Dataset documentation
|
| 68 |
+
├── notebooks/
|
| 69 |
+
│ └── Marxist_GRPO_Training.ipynb
|
| 70 |
+
├── tests/
|
| 71 |
+
│ └── unit/
|
| 72 |
+
└── ai-docs/ # AI-consumable documentation
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## License
|
| 76 |
+
|
| 77 |
+
AGPL-3.0-only
|
RunPod LLM Fine-Tuning Strategy.md
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# **Operationalizing Headless Large Language Model Fine-Tuning on RunPod: A Comprehensive Infrastructure and Workflow Analysis**
|
| 2 |
+
|
| 3 |
+
## **1\. Introduction: The Paradigm Shift to Headless AI Operations**
|
| 4 |
+
|
| 5 |
+
The contemporary landscape of Large Language Model (LLM) development is currently undergoing a fundamental transition, shifting from interactive, exploratory environments toward rigorous, automated production pipelines. For years, the Jupyter notebook has served as the *lingua franca* of data science—a canvas for experimentation, visualization, and iterative code development. However, as the field matures from research to engineering, the limitations of the notebook paradigm become increasingly acute, particularly when applied to the resource-intensive and time-critical task of fine-tuning custom LLMs. The user requirement for a workflow that eliminates the notebook interface in favor of a "code-upload-and-train" paradigm reflects a sophisticated understanding of MLOps principles: reproducibility, resource efficiency, and maximizing computational throughput.
|
| 6 |
+
RunPod, as a specialized GPU cloud provider, occupies a unique and critical niche within this evolving ecosystem. Distinct from hyperscalers such as AWS, Azure, or Google Cloud Platform—which often necessitate complex Identity and Access Management (IAM) configurations, quota negotiations, and long-term commitments—RunPod offers a container-centric infrastructure that is ideally optimized for headless, ephemeral training jobs. The platform’s architecture effectively democratizes access to high-performance compute, offering everything from consumer-grade RTX 4090s to enterprise-class NVIDIA H100 clusters. This report provides an exhaustive, expert-level analysis of the architectural, operational, and software strategies necessary to fine-tune custom LLMs on RunPod using a strictly headless approach.
|
| 7 |
+
To fully satisfy the requirement of "training as fast as possible" with "custom training data," this analysis moves beyond simple tutorials to construct a robust engineering framework. It dissects the interplay between hardware selection (Secure vs. Community Cloud), containerization strategies (Docker-based execution), and high-efficiency fine-tuning frameworks (Unsloth and Axolotl). By decoupling the training process from an interactive Integrated Development Environment (IDE), developers can leverage spot instances more effectively, dramatically reduce idle compute costs, and integrate training runs into broader Continuous Integration/Continuous Deployment (CI/CD) pipelines. This report serves as a definitive guide to architecting these headless systems.
|
| 8 |
+
|
| 9 |
+
## ---
|
| 10 |
+
|
| 11 |
+
**2\. Infrastructure Architecture and Instance Selection Strategy**
|
| 12 |
+
|
| 13 |
+
The foundation of any high-performance fine-tuning workflow is the underlying compute architecture. In the context of "training as fast as possible," the choice of hardware dictates not only the wall-clock time of the training run but also the stability, cost-efficiency, and maximum capable model size of the session. RunPod’s inventory is segmented into distinct tiers, each offering specific advantages and liabilities for headless operations. A nuanced understanding of these hardware profiles is essential for optimizing the price-performance ratio.
|
| 14 |
+
|
| 15 |
+
### **2.1 The GPU Hierarchy: Performance Profiles and Architectural Suitability**
|
| 16 |
+
|
| 17 |
+
The selection of a specific GPU architecture must be directly correlated with the parameter count of the target model (e.g., Llama 3 8B, Mistral, or Llama 3 70B) and the chosen quantization method (Full Fine-Tuning vs. LoRA/QLoRA).
|
| 18 |
+
|
| 19 |
+
#### **The Enterprise Tier: NVIDIA H100 and A100**
|
| 20 |
+
|
| 21 |
+
For users prioritizing raw speed and throughput above all else, the NVIDIA H100 and A100 Tensor Core GPUs represent the gold standard of current AI acceleration. These cards are designed for datacenter reliability and massive parallel throughput.
|
| 22 |
+
The **NVIDIA H100 (80GB)** stands as the pinnacle of current commercial AI hardware. It is specifically engineered to accelerate Transformer-based models via its fourth-generation Tensor Cores and the dedicated Transformer Engine, which automatically manages mixed-precision calculations using FP8 formats.1 For headless workflows, the H100 offers a distinct advantage: its sheer speed minimizes the "window of vulnerability." In a headless setup, particularly one utilizing spot instances or decentralized nodes, the longer a job runs, the higher the statistical probability of a network disconnect or node preemption. By reducing training time by factors of 3x or more compared to previous generations, the H100 significantly increases the reliability of job completion.2 It is the only viable option for users attempting to full fine-tune models in the 70B+ parameter range within reasonable timeframes. However, this performance comes at a premium, with costs ranging from approximately $2.69 to $4.00 per hour depending on the specific configuration (SXM vs. PCIe) and market demand.1
|
| 23 |
+
The **NVIDIA A100 (80GB)** remains the industry workhorse for LLM training. While it lacks the H100's specific FP8 Transformer Engine, its 80GB of High Bandwidth Memory (HBM2e) provides sufficient capacity to fine-tune 70B models using QLoRA or 8B models with full precision and extended context windows.1 The availability of A100s on RunPod is generally higher than that of H100s, making them a more reliable fallback for automated pipelines that require immediate provisioning without queuing. For users engaging in "headless" operations where the script automatically requests resources, the A100's ubiquity often makes it the path of least resistance.4
|
| 24 |
+
|
| 25 |
+
#### **The Prosumer Tier: NVIDIA RTX 4090 and RTX 6000 Ada**
|
| 26 |
+
|
| 27 |
+
For users targeting smaller models, such as the 7B or 8B parameter classes (e.g., Llama 3 8B, Mistral, Gemma), the enterprise tier may represent overkill. The **NVIDIA RTX 4090** has emerged as an exceptionally cost-effective alternative for these specific workloads.
|
| 28 |
+
With 24GB of VRAM, the RTX 4090 can comfortably handle 8B models using 4-bit quantization (QLoRA) or, when paired with memory-efficient frameworks like Unsloth, even larger batch sizes.5 The cost efficiency is dramatic: at approximately $0.34 to $0.69 per hour, a developer can run extensive hyperparameter sweeps (grid searches) for the cost of a single hour on an H100.1 However, the use of consumer hardware in a headless workflow introduces specific constraints. These cards are typically hosted in the "Community Cloud" tier, meaning they are decentralized nodes often residing in non-tier-1 datacenters or even private residences. This introduces a higher risk of interruption, necessitating that the headless script implements robust, frequent checkpointing to resume training automatically if a node goes offline.
|
| 29 |
+
The **RTX 6000 Ada Generation** bridges the gap, offering 48GB of VRAM—double that of the 4090—while retaining the Ada Lovelace architecture's efficiency. Priced around $0.79/hr, it allows for training mid-sized models (e.g., 30B parameters with QLoRA) or 8B models with much longer context windows than the 4090 allows.1
|
| 30 |
+
|
| 31 |
+
### **2.2 Deployment Tiers: Secure Cloud vs. Community Cloud**
|
| 32 |
+
|
| 33 |
+
RunPod segments its GPU inventory into two primary distinct tiers: **Community Cloud** and **Secure Cloud**. This distinction is critical for designing a headless operation, as it fundamentally dictates the reliability engineering required in the training code.
|
| 34 |
+
**Secure Cloud** represents enterprise-grade datacenters with high reliability, redundancy, and security certifications (SOC2, etc.). For a user whose primary requirement is to "upload and train," Secure Cloud offers the assurance that the pod will not vanish mid-training due to a provider pulling the machine off the network. The pricing is slightly higher, but the reduction in operational complexity—specifically the reduced need for aggressive fault-tolerance scripting—often outweighs the raw hourly cost difference.1 For the final "production" training run, specifically when processing a massive dataset that might take 10+ hours, Secure Cloud is the recommended tier to ensure uninterrupted execution.
|
| 35 |
+
**Community Cloud** consists of crowdsourced GPUs provided by third parties. While significantly cheaper, these function similarly to Spot instances in traditional clouds, though with potentially higher variance in uptime and network bandwidth. They are ideal for "bursty" workloads where a user might spin up 10 simultaneous experiments to test different learning rates. However, utilizing this tier for headless training requires the training script to be resilient. It implies that the "code upload" must include logic to check for existing checkpoints on a persistent volume and resume automatically, as the probability of a node restart is non-zero.1
|
| 36 |
+
|
| 37 |
+
### **2.3 Cost-Performance Matrix**
|
| 38 |
+
|
| 39 |
+
To assist in making the precise hardware decision, the following table synthesizes the cost, utility, and risk profile of available hardware for fine-tuning tasks on RunPod.
|
| 40 |
+
|
| 41 |
+
| GPU Model | VRAM | Cloud Tier | Est. Price/Hr | Best Use Case | Headless Reliability |
|
| 42 |
+
| :---- | :---- | :---- | :---- | :---- | :---- |
|
| 43 |
+
| **H100 SXM** | 80GB | Secure | \~$2.69 | Full FT 70B+, Time-Critical Jobs | **High** (Fastest completion minimizes risk) |
|
| 44 |
+
| **A100 SXM** | 80GB | Secure | \~$1.49 | QLoRA 70B, Full FT 8B | **High** (Standard enterprise reliability) |
|
| 45 |
+
| **A100 PCIe** | 40GB | Secure | \~$1.39 | LoRA 13B-30B | **Medium** (Memory constraints may limit batch size) |
|
| 46 |
+
| **RTX 6000 Ada** | 48GB | Secure | \~$0.79 | Mid-range models (30B), Long Context | **High** (Excellent VRAM/Price ratio) |
|
| 47 |
+
| **RTX 4090** | 24GB | Community | \~$0.34 | QLoRA 8B, Debugging, Sweeps | **Low/Medium** (Requires fault tolerance logic) |
|
| 48 |
+
| **RTX 3090** | 24GB | Community | \~$0.22 | Low-budget experimentation | **Low** (Slower speed increases interrupt risk) |
|
| 49 |
+
|
| 50 |
+
1
|
| 51 |
+
|
| 52 |
+
## ---
|
| 53 |
+
|
| 54 |
+
**3\. The Headless Workflow Architecture: Containerization and Automation**
|
| 55 |
+
|
| 56 |
+
To satisfy the user's explicit requirement of avoiding a Jupyter notebook in favor of a "code upload" model, the workflow must shift from an interactive session to a batch-processing paradigm. In this model, the local machine is used for code development and configuration, while the remote GPU serves purely as an execution engine. This requires a Docker-first approach where the environment, code, and execution logic are encapsulated within a portable container image.
|
| 57 |
+
|
| 58 |
+
### **3.1 The Docker-First Approach**
|
| 59 |
+
|
| 60 |
+
The cornerstone of a robust headless workflow is containerization. Launching a generic Ubuntu pod and manually installing libraries via a startup script is prone to error, hard to reproduce, and slow. Instead, the user must define the entire training environment in a Docker image. This ensures that "uploading code" translates immediately to execution without manual environment setup.
|
| 61 |
+
|
| 62 |
+
#### **The "Entrypoint" Strategy**
|
| 63 |
+
|
| 64 |
+
In a standard interactive RunPod session, the container launches and idles, typically running a sleep command or a Jupyter server, waiting for a user to connect. In a headless workflow, the Docker container utilizes an ENTRYPOINT or CMD script that immediately initiates the training process upon launch. Crucially, once the training process concludes (whether successfully or with a failure), the script handles data egress and terminates the pod.7
|
| 65 |
+
This approach perfectly aligns with the "upload code and train" desire. The "code" is baked into the Docker image (or mounted at runtime), and the "train" command is the automatic, inevitable action of the container starting up.
|
| 66 |
+
|
| 67 |
+
#### **Constructing the Golden Image**
|
| 68 |
+
|
| 69 |
+
A "Golden Image" for fine-tuning must include the base CUDA drivers, the Python environment, and the specific fine-tuning frameworks (Axolotl or Unsloth). Below is an architectural breakdown of such a Dockerfile, optimized for RunPod.
|
| 70 |
+
**Scenario:** A Docker image designed for fine-tuning Llama 3 using Unsloth.
|
| 71 |
+
|
| 72 |
+
Dockerfile
|
| 73 |
+
|
| 74 |
+
\# Use RunPod's base image or NVIDIA's CUDA image to ensure driver compatibility
|
| 75 |
+
\# CUDA 11.8 or 12.1 is often required for modern LLM frameworks
|
| 76 |
+
FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel
|
| 77 |
+
|
| 78 |
+
\# Set working directory
|
| 79 |
+
WORKDIR /workspace
|
| 80 |
+
|
| 81 |
+
\# Install system dependencies
|
| 82 |
+
\# git-lfs is critical for downloading large models/datasets
|
| 83 |
+
RUN apt-get update && apt-get install \-y git git-lfs htop nvtop tmux
|
| 84 |
+
|
| 85 |
+
\# Install Python dependencies
|
| 86 |
+
\# Unsloth and Axolotl often require specific versions of xformers and trl
|
| 87 |
+
\# Using a requirements.txt allows for easier version pinning
|
| 88 |
+
COPY requirements.txt /workspace/requirements.txt
|
| 89 |
+
RUN pip install \--no-cache-dir \-r requirements.txt
|
| 90 |
+
|
| 91 |
+
\# Install specific frameworks (Example: Unsloth)
|
| 92 |
+
\# Note: Unsloth installation often requires specific CUDA paths
|
| 93 |
+
RUN pip install "unsloth\[colab-new\] @ git+https://github.com/unslothai/unsloth.git"
|
| 94 |
+
RUN pip install \--no-deps "xformers\<0.0.26" "trl\<0.9.0" peft accelerate bitsandbytes
|
| 95 |
+
|
| 96 |
+
\# Copy the automation scripts and training code
|
| 97 |
+
COPY train.py /workspace/train.py
|
| 98 |
+
COPY start.sh /workspace/start.sh
|
| 99 |
+
|
| 100 |
+
\# Make the start script executable
|
| 101 |
+
RUN chmod \+x /workspace/start.sh
|
| 102 |
+
|
| 103 |
+
\# Set the entrypoint to the automation script
|
| 104 |
+
ENTRYPOINT \["/workspace/start.sh"\]
|
| 105 |
+
|
| 106 |
+
7
|
| 107 |
+
**Optimization Insight:** Embedding the dataset directly into the Docker image (via COPY dataset.jsonl) is a viable strategy only for small datasets (\<5GB). For massive datasets (\>100GB), as implied by the "custom training data" requirement, this approach creates bloated images that are slow to push and pull. For large-scale data, the start.sh script should be designed to pull the data from S3 or a RunPod Network Volume at runtime, ensuring the Docker image remains lightweight and agile.10
|
| 108 |
+
|
| 109 |
+
### **3.2 The Automation Logic: The start.sh Script**
|
| 110 |
+
|
| 111 |
+
The start.sh script acts as the "brain" of the headless operation. It orchestrates the sequence of events inside the pod, managing authentication, data ingestion, execution, and cleanup.
|
| 112 |
+
|
| 113 |
+
Bash
|
| 114 |
+
|
| 115 |
+
\#\!/bin/bash
|
| 116 |
+
set \-e \# Exit immediately if a command exits with a non-zero status
|
| 117 |
+
|
| 118 |
+
\# 1\. Environment Setup (Logging in to Hugging Face and WandB)
|
| 119 |
+
\# These tokens are passed via environment variables at runtime
|
| 120 |
+
if; then
|
| 121 |
+
huggingface-cli login \--token $HF\_TOKEN
|
| 122 |
+
fi
|
| 123 |
+
|
| 124 |
+
if; then
|
| 125 |
+
wandb login $WANDB\_API\_KEY
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
\# 2\. Data Ingestion
|
| 129 |
+
\# Download dataset from S3 or Network Volume if not present
|
| 130 |
+
if \[\! \-f "/workspace/dataset.jsonl" \]; then
|
| 131 |
+
echo "Downloading dataset from remote source..."
|
| 132 |
+
\# Example using a presigned URL or S3 CLI
|
| 133 |
+
wget \-O /workspace/dataset.jsonl "$DATASET\_URL"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
\# 3\. Execution
|
| 137 |
+
echo "Starting Training..."
|
| 138 |
+
\# Launch the Python training script
|
| 139 |
+
\# Unsloth or Axolotl command goes here
|
| 140 |
+
python train.py \--config config.json
|
| 141 |
+
|
| 142 |
+
\# 4\. Exfiltration/Saving
|
| 143 |
+
echo "Training Complete. Merging and Uploading..."
|
| 144 |
+
\# Assuming train.py saves to /workspace/output
|
| 145 |
+
\# This step ensures the trained weights are saved to HF Hub or S3
|
| 146 |
+
python upload\_to\_hub.py \--path /workspace/output \--repo my-user/my-finetuned-model
|
| 147 |
+
|
| 148 |
+
\# 5\. Cleanup (Critical for Cost Savings)
|
| 149 |
+
echo "Shutting down pod to stop billing..."
|
| 150 |
+
runpodctl stop pod $RUNPOD\_POD\_ID
|
| 151 |
+
|
| 152 |
+
7
|
| 153 |
+
**FinOps Strategy:** By including runpodctl stop pod $RUNPOD\_POD\_ID as the final command, the user ensures they only pay for the exact duration of the training. This effectively transforms a standard GPU pod into a serverless-like job, preventing "zombie pods" from racking up bills after the training is finished.12
|
| 154 |
+
|
| 155 |
+
### **3.3 Remote Management: The runpodctl Utility**
|
| 156 |
+
|
| 157 |
+
For managing these headless pods, runpodctl is the essential Command Line Interface (CLI) tool provided by RunPod. It allows the user to spin up pods, stream logs, and transfer files without ever navigating to the web console.13
|
| 158 |
+
Automation via CLI:
|
| 159 |
+
The user can script the deployment of the training job from their local machine. A single command can instantiate the pod using the custom image defined above:
|
| 160 |
+
|
| 161 |
+
Bash
|
| 162 |
+
|
| 163 |
+
runpodctl create pod \\
|
| 164 |
+
\--name "headless-llama3-finetune" \\
|
| 165 |
+
\--gpuType "NVIDIA A100 80GB PCIe" \\
|
| 166 |
+
\--imageName "myregistry/custom-llm-trainer:v1" \\
|
| 167 |
+
\--containerDiskSize 100 \\
|
| 168 |
+
\--volumeSize 200 \\
|
| 169 |
+
\--env HF\_TOKEN=$HF\_TOKEN \\
|
| 170 |
+
\--env WANDB\_API\_KEY=$WANDB\_KEY \\
|
| 171 |
+
\--env DATASET\_URL="https://my-s3-bucket..."
|
| 172 |
+
|
| 173 |
+
14
|
| 174 |
+
This command fulfills the user's request: it uploads the configuration (via the image definition) and starts training immediately. The \--gpuType flag ensures the job lands on the specific hardware required for speed, while \--env passes the necessary secrets securely.
|
| 175 |
+
|
| 176 |
+
## ---
|
| 177 |
+
|
| 178 |
+
**4\. Fine-Tuning Frameworks: The Engines of Efficiency**
|
| 179 |
+
|
| 180 |
+
To train "as fast as possible" without reinventing the wheel, high-level fine-tuning frameworks are vastly superior to writing raw PyTorch training loops. The two leading contenders for this workflow on RunPod are **Axolotl** and **Unsloth**. Each offers distinct advantages for headless execution.
|
| 181 |
+
|
| 182 |
+
### **4.1 Axolotl: The Configuration-Driven Powerhouse**
|
| 183 |
+
|
| 184 |
+
Axolotl is designed for users who want to define *what* to train, not *how* to code the training loop. It abstracts the complexity of the Hugging Face Trainer into a comprehensive YAML configuration file.15
|
| 185 |
+
|
| 186 |
+
* **Headless Suitability:** Excellent. Because the entire training logic is encapsulated in a single YAML file, "uploading code" simply means injecting this config file into the container. There is no need to maintain complex Python scripts; the logic is declarative.
|
| 187 |
+
* **Feature Set:** Axolotl supports Full Fine-Tuning (FFT), LoRA, QLoRA, and advanced techniques like Flash Attention 2 and Sample Packing. Sample packing is particularly relevant for speed, as it concatenates multiple short examples into a single sequence, removing padding tokens and significantly increasing training throughput.17
|
| 188 |
+
* **Workflow Integration:**
|
| 189 |
+
1. User edits config.yaml locally.
|
| 190 |
+
2. User builds Docker image with this config or mounts it at runtime.
|
| 191 |
+
3. Container starts and runs axolotl train config.yaml.
|
| 192 |
+
* **Multi-GPU Scaling:** Axolotl excels at multi-GPU training using FSDP (Fully Sharded Data Parallel) or DeepSpeed. If the user intends to scale training across an 8x A100 node to maximize speed, Axolotl is the robust choice.17
|
| 193 |
+
|
| 194 |
+
### **4.2 Unsloth: The Efficiency Specialist**
|
| 195 |
+
|
| 196 |
+
Unsloth is a framework optimized specifically for speed and memory efficiency on single-GPU setups. It utilizes custom Triton kernels to manually backpropagate gradients, achieving 2-5x faster training speeds and up to 80% less memory usage compared to standard Hugging Face implementations.17
|
| 197 |
+
|
| 198 |
+
* **Headless Suitability:** High. Unsloth provides Docker images that can be easily adapted for headless execution.9 The speed gains directly address the user's requirement to "train as fast as possible."
|
| 199 |
+
* **Performance:** For single-GPU setups (e.g., one H100 or A100), Unsloth is unrivaled. The memory savings allow users to fit significantly larger batch sizes into VRAM, which directly translates to faster wall-clock training times. For example, on a Llama 3 8B model, Unsloth can enable training with context lengths that would cause OOM (Out of Memory) errors on standard implementations.19
|
| 200 |
+
* **Limitation:** Historically, Unsloth has been optimized for single-GPU training. While multi-GPU support is evolving, its primary strength remains in maximizing the throughput of a single card. For a user operating on a single powerful node (like an H100), Unsloth is likely the fastest option.18
|
| 201 |
+
|
| 202 |
+
### **4.3 Comparative Analysis for the User**
|
| 203 |
+
|
| 204 |
+
| Feature | Axolotl | Unsloth | Strategic Recommendation |
|
| 205 |
+
| :---- | :---- | :---- | :---- |
|
| 206 |
+
| **Configuration** | YAML-based (Declarative) | Python/Script-based | **Axolotl** for strict config management and reproducibility. |
|
| 207 |
+
| **Speed (Single GPU)** | High (uses Flash Attn) | **Extreme** (2x faster than Axolotl) | **Unsloth** for raw speed on single cards (H100/A100). |
|
| 208 |
+
| **Multi-GPU** | Native Support (DeepSpeed/FSDP) | Limited/Paid Tier | **Axolotl** for distributed training across clusters. |
|
| 209 |
+
| **Ease of Headless** | Very High | High | **Both** are excellent; choice depends on scaling needs. |
|
| 210 |
+
|
| 211 |
+
**Expert Insight:** Given the user's preference for "fast as possible" and "custom code," if the model fits on a single GPU (e.g., Llama 3 8B or 70B on an H100), **Unsloth** is the superior choice for raw throughput. If the user requires multi-GPU scaling or complex dataset mixing configurations, **Axolotl** provides a more robust infrastructure.18
|
| 212 |
+
|
| 213 |
+
## ---
|
| 214 |
+
|
| 215 |
+
**5\. Data Logistics: Solving the Custom Data Bottleneck**
|
| 216 |
+
|
| 217 |
+
A major challenge in ephemeral, headless training is data logistics. The user specified "custom training data," which implies datasets that are not pre-cached in public hubs. Handling large datasets (100GB+) efficiently is critical to avoiding idle GPU time.
|
| 218 |
+
|
| 219 |
+
### **5.1 Storage Architectures: Network Volumes vs. NVMe vs. Object Storage**
|
| 220 |
+
|
| 221 |
+
* **Local Pod Storage (Container Disk):** This offers the fastest I/O performance. Data is stored on the NVMe SSD directly attached to the GPU instance. This is ideal for maximizing training speed, as the GPU is not starved of data. However, this storage is ephemeral; data is lost if the pod is terminated without external saving.5
|
| 222 |
+
* **RunPod Network Volumes:** This is persistent storage that survives pod termination and allows data to be shared across pods.
|
| 223 |
+
* *Throughput Bottleneck:* Network volumes can suffer from slower throughput (200-400 MB/s) compared to local NVMe, potentially bottlenecking the data loader during training of small models where the GPU processes batches faster than the disk can supply them.22
|
| 224 |
+
* *Region Lock:* Network volumes are region-locked. If a volume is created in US-NJ, the user is forced to rent GPUs in US-NJ. This severely limits the ability to grab available H100s in other regions, contradicting the "train as fast as possible" goal.22
|
| 225 |
+
* **S3 / Object Storage:** The most flexible approach. Data is stored in AWS S3 (or compatible) and streamed or downloaded at the start of the session.
|
| 226 |
+
|
| 227 |
+
### **5.2 Recommended Data Strategy for Speed**
|
| 228 |
+
|
| 229 |
+
To maximize training speed, **Local NVMe Storage** is superior to Network Volumes, despite its ephemeral nature. The recommended workflow for headless execution is:
|
| 230 |
+
|
| 231 |
+
1. **Storage:** Store the master dataset in a high-performance S3 bucket or RunPod's S3-compatible object storage layer.25
|
| 232 |
+
2. **Ingest:** The start.sh script downloads the dataset from S3 to the pod's *local* /workspace directory (NVMe) at boot time.
|
| 233 |
+
3. **Train:** The model trains off the fast local NVMe, ensuring the GPU is fully saturated.
|
| 234 |
+
4. **Egress:** The start.sh uploads the checkpoints and final model back to S3 or Hugging Face.
|
| 235 |
+
|
| 236 |
+
This approach avoids the region-locking of Network Volumes and the I/O latency penalties, utilizing the immense bandwidth of datacenter GPUs for rapid setup.10
|
| 237 |
+
|
| 238 |
+
### **5.3 Transferring Large Data: The 100GB Challenge**
|
| 239 |
+
|
| 240 |
+
For users who must use RunPod storage (e.g., due to compliance or cost), transferring 100GB+ of data from a local machine is non-trivial. The runpodctl send command creates a peer-to-peer transfer tunnel. While effective for smaller files, users have reported slow speeds and timeouts for large datasets.26
|
| 241 |
+
**Insight:** For datasets \>100GB, do not upload from a home internet connection directly to a GPU pod. Instead:
|
| 242 |
+
|
| 243 |
+
1. Spin up a cheap **CPU pod** on RunPod.
|
| 244 |
+
2. Use rsync or runpodctl to upload the data to this CPU pod (which sits on the high-speed datacenter backbone).
|
| 245 |
+
3. From the CPU pod, transfer the data to a Network Volume or S3 bucket.
|
| 246 |
+
This leverages the internal network backbone rather than residential ISP uplinks, preventing the GPU pod from sitting idle while waiting for data uploads.
|
| 247 |
+
|
| 248 |
+
## ---
|
| 249 |
+
|
| 250 |
+
**6\. Monitoring and Observability without Jupyter**
|
| 251 |
+
|
| 252 |
+
In a headless environment, "blind" training is a significant operational risk. Observability must be externalized to ensure the user knows if the model is converging or if the pod has crashed.
|
| 253 |
+
|
| 254 |
+
### **6.1 Weights & Biases (WandB)**
|
| 255 |
+
|
| 256 |
+
WandB is the de facto standard for headless monitoring. It integrates natively with both Axolotl and Unsloth (via the Hugging Face Trainer).
|
| 257 |
+
|
| 258 |
+
* **Real-Time Metrics:** Loss curves, GPU utilization, memory usage, and learning rate schedules are streamed to the WandB dashboard in real-time. This allows the user to monitor the "pulse" of the training from a mobile device or laptop.
|
| 259 |
+
* **Artifacts:** Model checkpoints and config files can be logged as artifacts, providing version control for the models and ensuring reproducibility.
|
| 260 |
+
|
| 261 |
+
### **6.2 Remote Logging**
|
| 262 |
+
|
| 263 |
+
RunPod provides a logging driver that captures stdout and stderr from the container.
|
| 264 |
+
|
| 265 |
+
* **Command:** runpodctl logs \<pod\_id\> allows the user to check the console output from their local terminal to verify the script started correctly or to catch crash errors (e.g., CUDA OOM).11
|
| 266 |
+
* **Best Practice:** The start.sh script should use set \-e (exit immediately on error) and trap errors. Advanced users may add a curl command to the script to send a notification (e.g., via a Discord webhook or Slack API) if the training fails or succeeds, ensuring the user is alerted immediately without needing to constantly poll the logs.
|
| 267 |
+
|
| 268 |
+
## ---
|
| 269 |
+
|
| 270 |
+
**7\. Advanced Optimization and Troubleshooting**
|
| 271 |
+
|
| 272 |
+
### **7.1 Handling "Cold Starts" and Image Caching**
|
| 273 |
+
|
| 274 |
+
Downloading large Docker images (often 10GB+ for ML images) takes time. RunPod caches images on the host node.
|
| 275 |
+
|
| 276 |
+
* **Strategy:** Stick to a single image tag (e.g., myuser/trainer:v1). Once a specific host has pulled this image, subsequent runs on that same host are instant.
|
| 277 |
+
* **Docker Optimization:** Use multi-stage builds to keep the final image size small. Remove cache files (pip cache purge) within the Dockerfile to minimize layer size.28
|
| 278 |
+
|
| 279 |
+
### **7.2 CUDA Version Mismatches**
|
| 280 |
+
|
| 281 |
+
A common failure mode in custom images is a mismatch between the Docker container's CUDA toolkit and the host driver.
|
| 282 |
+
|
| 283 |
+
* **RunPod Environment:** RunPod hosts generally run the latest NVIDIA drivers.
|
| 284 |
+
* **Image Requirement:** Ensure the Docker image uses a compatible CUDA version (e.g., CUDA 11.8 or 12.1). Unsloth, for example, has specific requirements for CUDA 12.1 for maximum performance.9 Using the wrong base image will result in runtime errors regarding "Flash Attention" or "Bitsandbytes" compilation.
|
| 285 |
+
|
| 286 |
+
### **7.3 Spot Instance Interruptions**
|
| 287 |
+
|
| 288 |
+
If using Community Cloud to save money, the pod may be preempted (shut down) if the provider needs the hardware.
|
| 289 |
+
|
| 290 |
+
* **Mitigation:** Configure the training script to save checkpoints frequently (e.g., every 100 steps) to a mounted Network Volume or upload them immediately to S3.
|
| 291 |
+
* **Resume Logic:** The start.sh should check for the existence of a checkpoint and automatically pass \--resume\_from\_checkpoint to the training script. This ensures that if a pod dies and a new one is spawned, it picks up exactly where the last one left off.30
|
| 292 |
+
|
| 293 |
+
## ---
|
| 294 |
+
|
| 295 |
+
**8\. Conclusion and Strategic Roadmap**
|
| 296 |
+
|
| 297 |
+
For a user demanding the fastest possible fine-tuning workflow without the overhead of Jupyter notebooks, RunPod offers a powerful substrate, provided the workflow is architected correctly. The optimal path requires moving away from interactive "pet" instances to ephemeral "cattle" instances managed by code.
|
| 298 |
+
**The Recommended "Fast Track" Configuration:**
|
| 299 |
+
|
| 300 |
+
1. **Hardware:** NVIDIA H100 (Secure Cloud) for speed and reliability, or RTX 4090 (Community Cloud) for cost-efficiency.
|
| 301 |
+
2. **Framework:** **Unsloth** for single-GPU jobs (fastest throughput); **Axolotl** for multi-GPU or complex configurations.
|
| 302 |
+
3. **Deployment:** Custom Docker image with an ENTRYPOINT script that automates the Download \-\> Train \-\> Upload \-\> Terminate lifecycle.
|
| 303 |
+
4. **Interface:** runpodctl for deployment; WandB for monitoring; SSH for emergency debugging.
|
| 304 |
+
5. **Data:** S3-backed ingestion to local NVMe storage to bypass network volume I/O bottlenecks.
|
| 305 |
+
|
| 306 |
+
By adopting this headless architecture, the user transforms the fine-tuning process from a manual, error-prone task into a scalable, automated engineering operation, fully leveraging the raw compute power of RunPod's infrastructure. This report confirms that while RunPod's interface invites interactive use, its API and CLI capabilities are fully mature for the rigorous demands of headless, high-velocity machine learning operations.
|
| 307 |
+
|
| 308 |
+
#### **Works cited**
|
| 309 |
+
|
| 310 |
+
1. Runpod GPU pricing: A complete breakdown and platform comparison | Blog \- Northflank, accessed January 12, 2026, [https://northflank.com/blog/runpod-gpu-pricing](https://northflank.com/blog/runpod-gpu-pricing)
|
| 311 |
+
2. The NVIDIA H100 GPU Review: Why This AI Powerhouse Dominates (But Costs a Fortune) \- Runpod, accessed January 12, 2026, [https://www.runpod.io/articles/guides/nvidia-h100](https://www.runpod.io/articles/guides/nvidia-h100)
|
| 312 |
+
3. Runpod Secrets: Affordable A100/H100 Instances, accessed January 12, 2026, [https://www.runpod.io/articles/guides/affordable-a100-h100-gpu-cloud](https://www.runpod.io/articles/guides/affordable-a100-h100-gpu-cloud)
|
| 313 |
+
4. Pricing | Runpod GPU cloud computing rates, accessed January 12, 2026, [https://www.runpod.io/pricing](https://www.runpod.io/pricing)
|
| 314 |
+
5. RunPod Pricing 2025 Complete Guide (GPU Cloud Costs Breakdown) \- Flexprice, accessed January 12, 2026, [https://flexprice.io/blog/runprod-pricing-guide-with-gpu-costs](https://flexprice.io/blog/runprod-pricing-guide-with-gpu-costs)
|
| 315 |
+
6. No-Code AI: How I Ran My First LLM Without Coding | Runpod Blog, accessed January 12, 2026, [https://www.runpod.io/blog/no-code-ai-run-llm](https://www.runpod.io/blog/no-code-ai-run-llm)
|
| 316 |
+
7. Dockerfile \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/tutorials/introduction/containers/create-dockerfiles](https://docs.runpod.io/tutorials/introduction/containers/create-dockerfiles)
|
| 317 |
+
8. Deploying AI Apps with Minimal Infrastructure and Docker \- Runpod, accessed January 12, 2026, [https://www.runpod.io/articles/guides/deploy-ai-apps-minimal-infrastructure-docker](https://www.runpod.io/articles/guides/deploy-ai-apps-minimal-infrastructure-docker)
|
| 318 |
+
9. Fine-Tuning Local Models with Docker Offload and Unsloth, accessed January 12, 2026, [https://www.docker.com/blog/fine-tuning-models-with-offload-and-unsloth/](https://www.docker.com/blog/fine-tuning-models-with-offload-and-unsloth/)
|
| 319 |
+
10. Optimize your workers \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/serverless/development/optimization](https://docs.runpod.io/serverless/development/optimization)
|
| 320 |
+
11. Manage Pods \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/pods/manage-pods](https://docs.runpod.io/pods/manage-pods)
|
| 321 |
+
12. AI on a Schedule: Using Runpod's API to Run Jobs Only When Needed, accessed January 12, 2026, [https://www.runpod.io/articles/guides/ai-on-a-schedule](https://www.runpod.io/articles/guides/ai-on-a-schedule)
|
| 322 |
+
13. Overview \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/runpodctl/overview](https://docs.runpod.io/runpodctl/overview)
|
| 323 |
+
14. create pod \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/runpodctl/reference/runpodctl-create-pod](https://docs.runpod.io/runpodctl/reference/runpodctl-create-pod)
|
| 324 |
+
15. LLM fine-tuning | LLM Inference Handbook \- BentoML, accessed January 12, 2026, [https://bentoml.com/llm/getting-started/llm-fine-tuning](https://bentoml.com/llm/getting-started/llm-fine-tuning)
|
| 325 |
+
16. How to fine-tune a model using Axolotl | Runpod Blog, accessed January 12, 2026, [https://www.runpod.io/blog/how-to-fine-tune-a-model-using-axolotl](https://www.runpod.io/blog/how-to-fine-tune-a-model-using-axolotl)
|
| 326 |
+
17. Best frameworks for fine-tuning LLMs in 2025 \- Modal, accessed January 12, 2026, [https://modal.com/blog/fine-tuning-llms](https://modal.com/blog/fine-tuning-llms)
|
| 327 |
+
18. Comparing LLM Fine-Tuning Frameworks: Axolotl, Unsloth, and Torchtune in 2025, accessed January 12, 2026, [https://blog.spheron.network/comparing-llm-fine-tuning-frameworks-axolotl-unsloth-and-torchtune-in-2025](https://blog.spheron.network/comparing-llm-fine-tuning-frameworks-axolotl-unsloth-and-torchtune-in-2025)
|
| 328 |
+
19. Axolotl vs LLaMA-Factory vs Unsloth for AI Fine-Tuning 2026 \- Index.dev, accessed January 12, 2026, [https://www.index.dev/skill-vs-skill/ai-axolotl-vs-llama-factory-vs-unsloth](https://www.index.dev/skill-vs-skill/ai-axolotl-vs-llama-factory-vs-unsloth)
|
| 329 |
+
20. \[TEMPLATE\] One-click Unsloth finetuning on RunPod : r/LocalLLaMA \- Reddit, accessed January 12, 2026, [https://www.reddit.com/r/LocalLLaMA/comments/1nyzzws/template\_oneclick\_unsloth\_finetuning\_on\_runpod/](https://www.reddit.com/r/LocalLLaMA/comments/1nyzzws/template_oneclick_unsloth_finetuning_on_runpod/)
|
| 330 |
+
21. unsloth/llama-3-8b-bnb-4bit \- Hugging Face, accessed January 12, 2026, [https://huggingface.co/unsloth/llama-3-8b-bnb-4bit](https://huggingface.co/unsloth/llama-3-8b-bnb-4bit)
|
| 331 |
+
22. PSA: Don't bother with Network Volumes on Runpod : r/StableDiffusion \- Reddit, accessed January 12, 2026, [https://www.reddit.com/r/StableDiffusion/comments/1nkcgvp/psa\_dont\_bother\_with\_network\_volumes\_on\_runpod/](https://www.reddit.com/r/StableDiffusion/comments/1nkcgvp/psa_dont_bother_with_network_volumes_on_runpod/)
|
| 332 |
+
23. Network volumes \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/storage/network-volumes](https://docs.runpod.io/storage/network-volumes)
|
| 333 |
+
24. Using network volume with serverless \- Runpod \- Answer Overflow, accessed January 12, 2026, [https://www.answeroverflow.com/m/1234830020678123610](https://www.answeroverflow.com/m/1234830020678123610)
|
| 334 |
+
25. Streamline Your AI Workflows with RunPod's New S3-Compatible API, accessed January 12, 2026, [https://www.runpod.io/blog/streamline-ai-workflows-s3-api](https://www.runpod.io/blog/streamline-ai-workflows-s3-api)
|
| 335 |
+
26. Upload speed \- Runpod \- Answer Overflow, accessed January 12, 2026, [https://www.answeroverflow.com/m/1415080595020709938](https://www.answeroverflow.com/m/1415080595020709938)
|
| 336 |
+
27. \`runpodctl send\` crawling at \<1MB speeds \- Runpod \- Answer Overflow, accessed January 12, 2026, [https://www.answeroverflow.com/m/1208971275163406376](https://www.answeroverflow.com/m/1208971275163406376)
|
| 337 |
+
28. MLOps Workflow for Docker-Based AI Model Deployment \- Runpod, accessed January 12, 2026, [https://www.runpod.io/articles/guides/mlops-workflow-docker-ai-deployment](https://www.runpod.io/articles/guides/mlops-workflow-docker-ai-deployment)
|
| 338 |
+
29. Installation \- Axolotl Docs, accessed January 12, 2026, [https://docs.axolotl.ai/docs/installation.html](https://docs.axolotl.ai/docs/installation.html)
|
| 339 |
+
30. Does anyone use RunPod for SFT? If yes, you train via SSH or Jupyter (web-hosted), accessed January 12, 2026, [https://www.reddit.com/r/LocalLLaMA/comments/1pd6vxu/does\_anyone\_use\_runpod\_for\_sft\_if\_yes\_you\_train/](https://www.reddit.com/r/LocalLLaMA/comments/1pd6vxu/does_anyone_use_runpod_for_sft_if_yes_you_train/)
|
ai-docs/chatbot-ideology.yaml
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Marxist-Leninist Chatbot Training Set Design
|
| 2 |
+
# General-purpose ideological chatbot using ProleWiki corpus
|
| 3 |
+
# Status: UPDATED FOR GRPO (was SFT)
|
| 4 |
+
# Last Updated: 2025-12-17
|
| 5 |
+
|
| 6 |
+
# =============================================================================
|
| 7 |
+
# METHODOLOGY UPDATE: SFT → GRPO
|
| 8 |
+
# =============================================================================
|
| 9 |
+
# This document was originally designed for SFT. The project has pivoted to
|
| 10 |
+
# GRPO (Group Relative Policy Optimization) for the reasons documented in
|
| 11 |
+
# ai-docs/finetune.yaml. The ideological design principles and question
|
| 12 |
+
# generation strategies here remain valid for GRPO training.
|
| 13 |
+
#
|
| 14 |
+
# Current implementation:
|
| 15 |
+
# - Dataset: training_data/curated_qa.jsonl (128 curated pairs)
|
| 16 |
+
# - GRPO format: training_data/grpo_dataset.jsonl
|
| 17 |
+
# - Notebook: notebooks/Marxist_GRPO_Training.ipynb
|
| 18 |
+
# - Reward functions: src/prolewiki_llm/grpo_rewards.py
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
overview:
|
| 22 |
+
goal: |
|
| 23 |
+
A general-purpose chatbot well-versed in Marxist-Leninist theory.
|
| 24 |
+
Can discuss revolutionary theory, imperialism, class struggle,
|
| 25 |
+
dialectical materialism, and related topics through materialist analysis.
|
| 26 |
+
|
| 27 |
+
not_goal: |
|
| 28 |
+
NOT a game-specific assistant. NOT tied to any particular application.
|
| 29 |
+
The Babylon game project informed the ideological direction, but
|
| 30 |
+
the chatbot itself is a standalone ML theory assistant.
|
| 31 |
+
|
| 32 |
+
inspiration: |
|
| 33 |
+
The Babylon game (~/projects/game/babylon/) models class struggle
|
| 34 |
+
through MLM-TW (Marxist-Leninist-Maoist Third Worldist) theory.
|
| 35 |
+
This chatbot shares that theoretical grounding but serves as a
|
| 36 |
+
general educational/discussion tool.
|
| 37 |
+
|
| 38 |
+
training_approach: "GRPO (Group Relative Policy Optimization) with multi-layer rewards"
|
| 39 |
+
base_model: "unsloth/DeepSeek-R1-0528-Qwen3-8B"
|
| 40 |
+
current_samples: "128 curated Q&A pairs"
|
| 41 |
+
estimated_training_time: "~2-4 hours on A40 48GB"
|
| 42 |
+
|
| 43 |
+
data_source:
|
| 44 |
+
primary: "ProleWiki Library namespace chunks"
|
| 45 |
+
location: "sample-pipeline/chunks/Library/*.jsonl"
|
| 46 |
+
format: "JSONL with metadata"
|
| 47 |
+
|
| 48 |
+
chunk_schema:
|
| 49 |
+
chunk_id: "Unique identifier"
|
| 50 |
+
text: "The actual content (training answer)"
|
| 51 |
+
article_title: "Source work title (contains author)"
|
| 52 |
+
section: "Section/chapter name (question seed)"
|
| 53 |
+
categories: "Topic categories"
|
| 54 |
+
internal_links: "Referenced concepts (question seeds)"
|
| 55 |
+
word_count: "Chunk size"
|
| 56 |
+
|
| 57 |
+
available_works:
|
| 58 |
+
marx:
|
| 59 |
+
- "Capital, Volume I"
|
| 60 |
+
- "Capital, Volume II"
|
| 61 |
+
- "Capital, Volume III"
|
| 62 |
+
- "Grundrisse"
|
| 63 |
+
- "The German Ideology"
|
| 64 |
+
lenin:
|
| 65 |
+
- "Imperialism, the Highest Stage of Capitalism"
|
| 66 |
+
- "State and Revolution"
|
| 67 |
+
- "What Is To Be Done?"
|
| 68 |
+
mao:
|
| 69 |
+
- "On Contradiction"
|
| 70 |
+
- "On Practice"
|
| 71 |
+
other:
|
| 72 |
+
- "Additional works in Library namespace"
|
| 73 |
+
|
| 74 |
+
chunk_count: "~1,034 chunks currently available"
|
| 75 |
+
|
| 76 |
+
system_prompt:
|
| 77 |
+
description: "Defines the chatbot's persona and approach"
|
| 78 |
+
|
| 79 |
+
recommended: |
|
| 80 |
+
You are a Marxist-Leninist assistant. You explain revolutionary theory
|
| 81 |
+
through materialist analysis, drawing on the works of Marx, Engels,
|
| 82 |
+
Lenin, Mao, and other socialist thinkers. You ground explanations in
|
| 83 |
+
historical materialism and class analysis.
|
| 84 |
+
|
| 85 |
+
alternatives:
|
| 86 |
+
concise: |
|
| 87 |
+
You are a Marxist theorist. You explain concepts through
|
| 88 |
+
dialectical and historical materialism.
|
| 89 |
+
|
| 90 |
+
educational: |
|
| 91 |
+
You are a patient teacher of Marxist-Leninist theory. You explain
|
| 92 |
+
complex concepts clearly, always connecting theory to material
|
| 93 |
+
conditions and class relations.
|
| 94 |
+
|
| 95 |
+
third_worldist: |
|
| 96 |
+
You are a Marxist-Leninist-Maoist Third Worldist. You analyze
|
| 97 |
+
imperialism, unequal exchange, and the global class structure,
|
| 98 |
+
recognizing the labor aristocracy in imperial core nations.
|
| 99 |
+
|
| 100 |
+
question_generation:
|
| 101 |
+
description: |
|
| 102 |
+
Transform chunk metadata into natural questions that a user might ask.
|
| 103 |
+
The chunk text becomes the answer.
|
| 104 |
+
|
| 105 |
+
strategy: "Use available metadata to generate contextual questions"
|
| 106 |
+
|
| 107 |
+
templates:
|
| 108 |
+
with_section:
|
| 109 |
+
pattern: "What does {author} say about {section}?"
|
| 110 |
+
example: "What does Marx say about the commodity form?"
|
| 111 |
+
priority: 1
|
| 112 |
+
|
| 113 |
+
with_internal_links:
|
| 114 |
+
pattern: "Explain {concept} from a Marxist perspective."
|
| 115 |
+
example: "Explain surplus value from a Marxist perspective."
|
| 116 |
+
priority: 2
|
| 117 |
+
|
| 118 |
+
with_categories:
|
| 119 |
+
pattern: "Discuss {category} in Marxist theory."
|
| 120 |
+
example: "Discuss imperialism in Marxist theory."
|
| 121 |
+
priority: 3
|
| 122 |
+
|
| 123 |
+
fallback:
|
| 124 |
+
pattern: "What does {author} teach us about this?"
|
| 125 |
+
example: "What does Lenin teach us about this?"
|
| 126 |
+
priority: 4
|
| 127 |
+
|
| 128 |
+
author_extraction:
|
| 129 |
+
description: "Extract author name from article_title field"
|
| 130 |
+
pattern: "Library {Author} {Work Title}"
|
| 131 |
+
examples:
|
| 132 |
+
- input: "Library Karl Marx Capital, vol. I, Chapter 1"
|
| 133 |
+
output: "Marx"
|
| 134 |
+
- input: "Library Vladimir Lenin Imperialism"
|
| 135 |
+
output: "Lenin"
|
| 136 |
+
- input: "Library Mao Zedong On Contradiction"
|
| 137 |
+
output: "Mao"
|
| 138 |
+
|
| 139 |
+
code: |
|
| 140 |
+
def extract_author(title: str) -> str:
|
| 141 |
+
"""Extract author from 'Library Author Name Work...' format."""
|
| 142 |
+
if not title.startswith("Library "):
|
| 143 |
+
return "the author"
|
| 144 |
+
|
| 145 |
+
# Common author mappings
|
| 146 |
+
author_map = {
|
| 147 |
+
"Karl Marx": "Marx",
|
| 148 |
+
"Friedrich Engels": "Engels",
|
| 149 |
+
"Vladimir Lenin": "Lenin",
|
| 150 |
+
"V.I. Lenin": "Lenin",
|
| 151 |
+
"Mao Zedong": "Mao",
|
| 152 |
+
"Mao Tse-tung": "Mao",
|
| 153 |
+
"Joseph Stalin": "Stalin",
|
| 154 |
+
"Rosa Luxemburg": "Luxemburg",
|
| 155 |
+
"Antonio Gramsci": "Gramsci",
|
| 156 |
+
"Frantz Fanon": "Fanon",
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
title_part = title[8:] # Remove "Library "
|
| 160 |
+
for full_name, short_name in author_map.items():
|
| 161 |
+
if title_part.startswith(full_name):
|
| 162 |
+
return short_name
|
| 163 |
+
|
| 164 |
+
# Fallback: first two words
|
| 165 |
+
words = title_part.split()
|
| 166 |
+
if len(words) >= 2:
|
| 167 |
+
return words[1] # Usually last name
|
| 168 |
+
return "the author"
|
| 169 |
+
|
| 170 |
+
training_format:
|
| 171 |
+
template: "Qwen-2.5 chat format"
|
| 172 |
+
note: "MUST use Qwen template for DeepSeek-R1-Distill-Qwen model"
|
| 173 |
+
|
| 174 |
+
structure: |
|
| 175 |
+
<|im_start|>system
|
| 176 |
+
{system_prompt}<|im_end|>
|
| 177 |
+
<|im_start|>user
|
| 178 |
+
{question}<|im_end|>
|
| 179 |
+
<|im_start|>assistant
|
| 180 |
+
{answer}<|im_end|>
|
| 181 |
+
|
| 182 |
+
example:
|
| 183 |
+
system: "You are a Marxist-Leninist assistant..."
|
| 184 |
+
user: "What does Marx say about the commodity form?"
|
| 185 |
+
answer: |
|
| 186 |
+
The wealth of those societies in which the capitalist mode of
|
| 187 |
+
production prevails, presents itself as an immense accumulation
|
| 188 |
+
of commodities...
|
| 189 |
+
|
| 190 |
+
output_field: "text"
|
| 191 |
+
description: |
|
| 192 |
+
For Unsloth SFTTrainer, provide a single 'text' field containing
|
| 193 |
+
the complete formatted conversation.
|
| 194 |
+
|
| 195 |
+
transformation_code:
|
| 196 |
+
description: "Complete transformation from chunks to training data"
|
| 197 |
+
|
| 198 |
+
implementation: |
|
| 199 |
+
import json
|
| 200 |
+
from pathlib import Path
|
| 201 |
+
|
| 202 |
+
def extract_author(title: str) -> str:
|
| 203 |
+
"""Extract author from article title."""
|
| 204 |
+
if not title.startswith("Library "):
|
| 205 |
+
return "the author"
|
| 206 |
+
|
| 207 |
+
author_map = {
|
| 208 |
+
"Karl Marx": "Marx",
|
| 209 |
+
"Friedrich Engels": "Engels",
|
| 210 |
+
"Vladimir Lenin": "Lenin",
|
| 211 |
+
"Mao Zedong": "Mao",
|
| 212 |
+
"Joseph Stalin": "Stalin",
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
title_part = title[8:]
|
| 216 |
+
for full_name, short_name in author_map.items():
|
| 217 |
+
if title_part.startswith(full_name):
|
| 218 |
+
return short_name
|
| 219 |
+
|
| 220 |
+
words = title_part.split()
|
| 221 |
+
return words[1] if len(words) >= 2 else "the author"
|
| 222 |
+
|
| 223 |
+
def generate_question(chunk: dict) -> str:
|
| 224 |
+
"""Generate a natural question from chunk metadata."""
|
| 225 |
+
author = extract_author(chunk.get("article_title", ""))
|
| 226 |
+
|
| 227 |
+
# Priority 1: Use section
|
| 228 |
+
if chunk.get("section"):
|
| 229 |
+
section = chunk["section"].lower()
|
| 230 |
+
return f"What does {author} say about {section}?"
|
| 231 |
+
|
| 232 |
+
# Priority 2: Use internal links
|
| 233 |
+
if chunk.get("internal_links"):
|
| 234 |
+
concept = chunk["internal_links"][0]
|
| 235 |
+
return f"Explain {concept} from a Marxist perspective."
|
| 236 |
+
|
| 237 |
+
# Priority 3: Use categories
|
| 238 |
+
if chunk.get("categories"):
|
| 239 |
+
category = chunk["categories"][0]
|
| 240 |
+
return f"Discuss {category} in Marxist theory."
|
| 241 |
+
|
| 242 |
+
# Fallback
|
| 243 |
+
return f"What does {author} teach us in this passage?"
|
| 244 |
+
|
| 245 |
+
def chunk_to_training(chunk: dict, system_prompt: str) -> dict:
|
| 246 |
+
"""Convert a chunk to Qwen-formatted training example."""
|
| 247 |
+
question = generate_question(chunk)
|
| 248 |
+
answer = chunk["text"]
|
| 249 |
+
|
| 250 |
+
text = f"""<|im_start|>system
|
| 251 |
+
{system_prompt}<|im_end|>
|
| 252 |
+
<|im_start|>user
|
| 253 |
+
{question}<|im_end|>
|
| 254 |
+
<|im_start|>assistant
|
| 255 |
+
{answer}<|im_end|>"""
|
| 256 |
+
|
| 257 |
+
return {"text": text}
|
| 258 |
+
|
| 259 |
+
def process_chunks(input_dir: Path, output_path: Path):
|
| 260 |
+
"""Process all chunk files into training dataset."""
|
| 261 |
+
system_prompt = (
|
| 262 |
+
"You are a Marxist-Leninist assistant. You explain revolutionary "
|
| 263 |
+
"theory through materialist analysis, drawing on the works of Marx, "
|
| 264 |
+
"Engels, Lenin, Mao, and other socialist thinkers."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
training_data = []
|
| 268 |
+
|
| 269 |
+
for jsonl_file in input_dir.glob("*.jsonl"):
|
| 270 |
+
with open(jsonl_file) as f:
|
| 271 |
+
for line in f:
|
| 272 |
+
chunk = json.loads(line)
|
| 273 |
+
example = chunk_to_training(chunk, system_prompt)
|
| 274 |
+
training_data.append(example)
|
| 275 |
+
|
| 276 |
+
with open(output_path, "w") as f:
|
| 277 |
+
for example in training_data:
|
| 278 |
+
f.write(json.dumps(example) + "\n")
|
| 279 |
+
|
| 280 |
+
print(f"Generated {len(training_data)} training examples")
|
| 281 |
+
|
| 282 |
+
quality_considerations:
|
| 283 |
+
chunk_quality:
|
| 284 |
+
- "ProleWiki text is already high-quality Marxist writing"
|
| 285 |
+
- "Chunks preserve context through overlap"
|
| 286 |
+
- "Section boundaries respected"
|
| 287 |
+
|
| 288 |
+
question_diversity:
|
| 289 |
+
- "Vary question templates to avoid repetitive patterns"
|
| 290 |
+
- "Use all available metadata fields"
|
| 291 |
+
- "Consider adding manual seed questions for key concepts"
|
| 292 |
+
|
| 293 |
+
answer_length:
|
| 294 |
+
- "Chunks are 350-500 tokens (good length for chat responses)"
|
| 295 |
+
- "Not too short (lacks substance) or too long (loses focus)"
|
| 296 |
+
|
| 297 |
+
potential_improvements:
|
| 298 |
+
- "Add conversational variations (rephrase questions)"
|
| 299 |
+
- "Include follow-up question pairs"
|
| 300 |
+
- "Add explicit concept definitions from glossary"
|
| 301 |
+
- "Include historical examples and applications"
|
| 302 |
+
|
| 303 |
+
implementation_steps:
|
| 304 |
+
step_1:
|
| 305 |
+
name: "Verify chunk data"
|
| 306 |
+
command: "ls -la sample-pipeline/chunks/Library/"
|
| 307 |
+
check: "Confirm JSONL files exist with expected format"
|
| 308 |
+
|
| 309 |
+
step_2:
|
| 310 |
+
name: "Create transformation script"
|
| 311 |
+
location: "src/prolewiki_llm/prepare_training.py"
|
| 312 |
+
description: "Implement the transformation code above"
|
| 313 |
+
|
| 314 |
+
step_3:
|
| 315 |
+
name: "Generate training data"
|
| 316 |
+
command: "uv run python -m prolewiki_llm.prepare_training"
|
| 317 |
+
output: "training_data/ml_chatbot.jsonl"
|
| 318 |
+
|
| 319 |
+
step_4:
|
| 320 |
+
name: "Upload to RunPod"
|
| 321 |
+
description: "Transfer JSONL to pod's /workspace/data/"
|
| 322 |
+
|
| 323 |
+
step_5:
|
| 324 |
+
name: "Run fine-tuning"
|
| 325 |
+
description: "Execute training script (see finetune.yaml)"
|
| 326 |
+
|
| 327 |
+
step_6:
|
| 328 |
+
name: "Export and test"
|
| 329 |
+
description: "GGUF export, Ollama deployment, manual testing"
|
| 330 |
+
|
| 331 |
+
evaluation:
|
| 332 |
+
manual_testing:
|
| 333 |
+
description: "Chat with model and assess quality"
|
| 334 |
+
test_questions:
|
| 335 |
+
- "What is surplus value?"
|
| 336 |
+
- "Explain the labor theory of value."
|
| 337 |
+
- "What is imperialism according to Lenin?"
|
| 338 |
+
- "How does dialectical materialism differ from idealism?"
|
| 339 |
+
- "What is the role of the vanguard party?"
|
| 340 |
+
- "Why do contradictions drive historical change?"
|
| 341 |
+
- "What is the labor aristocracy?"
|
| 342 |
+
|
| 343 |
+
quality_criteria:
|
| 344 |
+
- "Responses grounded in Marxist theory"
|
| 345 |
+
- "Materialist analysis (not idealist)"
|
| 346 |
+
- "Accurate to source texts"
|
| 347 |
+
- "Coherent and well-structured"
|
| 348 |
+
- "Appropriate length for chat"
|
| 349 |
+
|
| 350 |
+
red_flags:
|
| 351 |
+
- "Refusing to discuss political topics (abliteration failure)"
|
| 352 |
+
- "Generic/vague responses not grounded in theory"
|
| 353 |
+
- "Mixing incompatible ideological frameworks"
|
| 354 |
+
- "Hallucinating quotes or concepts"
|
| 355 |
+
|
| 356 |
+
future_enhancements:
|
| 357 |
+
phase_2:
|
| 358 |
+
- "Add conversational multi-turn examples"
|
| 359 |
+
- "Include debate/argument handling"
|
| 360 |
+
- "Add current events analysis capability"
|
| 361 |
+
|
| 362 |
+
phase_3:
|
| 363 |
+
- "Multi-persona support (different theoretical traditions)"
|
| 364 |
+
- "Game integration (Babylon narrative generation)"
|
| 365 |
+
- "RAG integration for expanded knowledge"
|
| 366 |
+
|
| 367 |
+
related_docs:
|
| 368 |
+
- "ai-docs/finetune.yaml - GRPO methodology and training config"
|
| 369 |
+
- "ai-docs/reward-modeling.yaml - Multi-layer reward function design"
|
| 370 |
+
- "ai-docs/runpod.yaml - Cloud GPU setup"
|
| 371 |
+
- "ai-docs/project-status.yaml - Phase 8 implementation status"
|
| 372 |
+
- "notebooks/Marxist_GRPO_Training.ipynb - Authoritative training notebook"
|
| 373 |
+
- "src/prolewiki_llm/ - Python module with reward functions"
|
ai-docs/finetune.yaml
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-tuning Documentation - Phase 8
|
| 2 |
+
# Purpose: Marxist-Leninist LLM fine-tuning on ProleWiki corpus
|
| 3 |
+
# Status: IN_PROGRESS (implementation complete, training execution pending)
|
| 4 |
+
# Method: GRPO (Group Relative Policy Optimization) - NOT SFT
|
| 5 |
+
# Last Updated: 2025-12-17
|
| 6 |
+
|
| 7 |
+
# =============================================================================
|
| 8 |
+
# METHODOLOGY PIVOT: SFT → GRPO
|
| 9 |
+
# =============================================================================
|
| 10 |
+
# IMPORTANT: This project pivoted from SFT to GRPO methodology.
|
| 11 |
+
# This document has been updated to reflect the current approach.
|
| 12 |
+
#
|
| 13 |
+
# WHY GRPO OVER SFT:
|
| 14 |
+
# - Political theory has no single "correct" answer (unlike math)
|
| 15 |
+
# - Open-ended prose requires semantic similarity, not exact string matching
|
| 16 |
+
# - Reward functions can encode domain expertise
|
| 17 |
+
# - Multi-layer rewards defeat adversarial "word soup" attacks
|
| 18 |
+
# - GRPO naturally handles the inherent subjectivity of political analysis
|
| 19 |
+
#
|
| 20 |
+
# AUTHORITATIVE SOURCES:
|
| 21 |
+
# - Training notebook: notebooks/Marxist_GRPO_Training.ipynb
|
| 22 |
+
# - Reward functions: src/prolewiki_llm/grpo_rewards.py
|
| 23 |
+
# - W&B logging: src/prolewiki_llm/wandb_logging.py
|
| 24 |
+
# - Reward design: ai-docs/reward-modeling.yaml
|
| 25 |
+
# =============================================================================
|
| 26 |
+
|
| 27 |
+
overview:
|
| 28 |
+
goal: Fine-tune DeepSeek-R1-0528-Qwen3-8B for Marxist-Leninist theory responses
|
| 29 |
+
method: GRPO (Group Relative Policy Optimization) via Unsloth + TRL
|
| 30 |
+
dataset: 1,058 curated Q&A pairs (training_data/grpo_dataset.jsonl)
|
| 31 |
+
output_format: LoRA adapter → GGUF for Ollama deployment
|
| 32 |
+
hardware_requirement: A40 48GB (RunPod) or similar high-VRAM GPU
|
| 33 |
+
status: Implementation complete, execution pending
|
| 34 |
+
|
| 35 |
+
# =============================================================================
|
| 36 |
+
# GRPO VS SFT RATIONALE
|
| 37 |
+
# =============================================================================
|
| 38 |
+
|
| 39 |
+
grpo_rationale:
|
| 40 |
+
description: |
|
| 41 |
+
GRPO (Group Relative Policy Optimization) is preferred over SFT for
|
| 42 |
+
Marxist-Leninist training because political theory responses have no
|
| 43 |
+
single "correct" answer. Unlike math problems (where GRPO was originally
|
| 44 |
+
popularized with DeepSeek-R1), political theory is inherently open-ended.
|
| 45 |
+
|
| 46 |
+
sft_limitations:
|
| 47 |
+
- "Exact answer matching fails for prose responses"
|
| 48 |
+
- "No way to express 'this answer is better than that answer'"
|
| 49 |
+
- "Can only train on (input, output) pairs, not preferences"
|
| 50 |
+
- "Easily reward-hacked by models memorizing training data"
|
| 51 |
+
|
| 52 |
+
grpo_advantages:
|
| 53 |
+
- "Uses reward functions to score response quality"
|
| 54 |
+
- "Multiple generations compared (Group Relative)"
|
| 55 |
+
- "Semantic similarity captures meaning, not verbatim wording"
|
| 56 |
+
- "Multi-layer rewards prevent adversarial gaming"
|
| 57 |
+
- "Can penalize 'word soup' via coherence checks"
|
| 58 |
+
|
| 59 |
+
key_insight: |
|
| 60 |
+
The breakthrough insight is that GRPO's reward function paradigm
|
| 61 |
+
perfectly maps to political theory training: we can encode what
|
| 62 |
+
makes a "good Marxist-Leninist response" through semantic rewards
|
| 63 |
+
rather than exact answer matching.
|
| 64 |
+
|
| 65 |
+
# =============================================================================
|
| 66 |
+
# BASE MODEL
|
| 67 |
+
# =============================================================================
|
| 68 |
+
|
| 69 |
+
model:
|
| 70 |
+
primary:
|
| 71 |
+
name: DeepSeek-R1-0528-Qwen3-8B
|
| 72 |
+
huggingface: unsloth/DeepSeek-R1-0528-Qwen3-8B
|
| 73 |
+
params: 8B
|
| 74 |
+
architecture: Qwen-2.5
|
| 75 |
+
context_length: 2048 (training max_completion_length)
|
| 76 |
+
reasoning: |
|
| 77 |
+
1. DeepSeek R1 architecture has strong reasoning capabilities
|
| 78 |
+
2. 8B params = fits on A40 with LoRA
|
| 79 |
+
3. Full Unsloth optimization support
|
| 80 |
+
4. unsloth/ namespace provides optimized 4-bit version
|
| 81 |
+
|
| 82 |
+
lora_config:
|
| 83 |
+
rank: 64
|
| 84 |
+
lora_alpha: 64
|
| 85 |
+
target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
| 86 |
+
use_rslora: true
|
| 87 |
+
use_gradient_checkpointing: unsloth
|
| 88 |
+
|
| 89 |
+
# =============================================================================
|
| 90 |
+
# DATASET
|
| 91 |
+
# =============================================================================
|
| 92 |
+
|
| 93 |
+
dataset:
|
| 94 |
+
source: training_data/curated_qa.jsonl
|
| 95 |
+
transformed: training_data/grpo_dataset.jsonl
|
| 96 |
+
count: 1,058 curated Q&A pairs
|
| 97 |
+
format: |
|
| 98 |
+
{"prompt": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}], "answer": "..."}
|
| 99 |
+
|
| 100 |
+
topics:
|
| 101 |
+
- Revisionism and opportunism
|
| 102 |
+
- Surplus value and exploitation
|
| 103 |
+
- Dialectical materialism
|
| 104 |
+
- Imperialism and monopoly capitalism
|
| 105 |
+
- Class struggle and revolution
|
| 106 |
+
- Socialist construction
|
| 107 |
+
- Historical figures (Marx, Lenin, Engels, Stalin, Mao)
|
| 108 |
+
|
| 109 |
+
# =============================================================================
|
| 110 |
+
# REWARD FUNCTION SYSTEM (THE KEY INNOVATION)
|
| 111 |
+
# =============================================================================
|
| 112 |
+
|
| 113 |
+
reward_functions:
|
| 114 |
+
reference: ai-docs/reward-modeling.yaml
|
| 115 |
+
module: src/prolewiki_llm/grpo_rewards.py
|
| 116 |
+
tests: tests/unit/training/test_grpo_rewards.py (43 tests)
|
| 117 |
+
|
| 118 |
+
layers:
|
| 119 |
+
format_rewards:
|
| 120 |
+
- match_format_exactly: "3.0 for exact </think> pattern"
|
| 121 |
+
- match_format_approximately: "1.5-2.0 for partial pattern"
|
| 122 |
+
|
| 123 |
+
semantic_rewards:
|
| 124 |
+
- semantic_similarity_reward: "Sentence-transformer embedding similarity"
|
| 125 |
+
- terminology_reward: "Bonus for Marxist lexicon usage"
|
| 126 |
+
|
| 127 |
+
coherence_rewards:
|
| 128 |
+
- nli_coherence_reward: "BART-large-MNLI entailment checking"
|
| 129 |
+
- self_consistency_reward: "No internal contradictions"
|
| 130 |
+
- structural_coherence_reward: "Terms in syntactic roles"
|
| 131 |
+
|
| 132 |
+
anti_hacking_rewards:
|
| 133 |
+
- topic_relevance_reward: "Question→answer concept coverage"
|
| 134 |
+
- interconnection_depth_reward: "Anti-buzzword-salad detection"
|
| 135 |
+
- completeness_reward: "Key concept coverage"
|
| 136 |
+
|
| 137 |
+
combined_functions:
|
| 138 |
+
full_coherence_reward: "5-layer combined (recommended)"
|
| 139 |
+
robust_coherence_reward: "3-layer (NLI + self-consistency + structural)"
|
| 140 |
+
|
| 141 |
+
anti_hacking_measures:
|
| 142 |
+
HOLLOW_BUZZWORDS: "Penalty for activist jargon without substance"
|
| 143 |
+
DEPTH_MARKERS: "Bonus for historical specificity"
|
| 144 |
+
EXPLANATORY_PHRASES: "Bonus for causal reasoning"
|
| 145 |
+
depth_ratio: "Words per concept (penalizes buzzword salad)"
|
| 146 |
+
|
| 147 |
+
# =============================================================================
|
| 148 |
+
# TRAINING CONFIGURATION
|
| 149 |
+
# =============================================================================
|
| 150 |
+
|
| 151 |
+
training_config:
|
| 152 |
+
trainer: GRPOTrainer (from trl library)
|
| 153 |
+
max_steps: 250
|
| 154 |
+
batch_size: 2
|
| 155 |
+
gradient_accumulation: 2
|
| 156 |
+
learning_rate: 5e-6
|
| 157 |
+
num_generations: 4
|
| 158 |
+
temperature: 1.0
|
| 159 |
+
gpu_memory_utilization: 0.85
|
| 160 |
+
hardware: A40 48GB (RunPod)
|
| 161 |
+
|
| 162 |
+
vllm_sampling:
|
| 163 |
+
min_p: 0.1
|
| 164 |
+
top_p: 1.0
|
| 165 |
+
top_k: -1
|
| 166 |
+
seed: 3407
|
| 167 |
+
|
| 168 |
+
# =============================================================================
|
| 169 |
+
# WEIGHTS & BIASES INTEGRATION
|
| 170 |
+
# =============================================================================
|
| 171 |
+
|
| 172 |
+
wandb_logging:
|
| 173 |
+
module: src/prolewiki_llm/wandb_logging.py
|
| 174 |
+
tests: tests/unit/training/test_wandb_logging.py (17 tests)
|
| 175 |
+
|
| 176 |
+
features:
|
| 177 |
+
- WandbSampleLogger: "Periodic sample tables (question → response → rewards)"
|
| 178 |
+
- create_logging_reward: "Zero-cost logging reward for GRPOTrainer"
|
| 179 |
+
- log_reward_metrics: "Per-reward mean/min/max tracking"
|
| 180 |
+
- init_wandb_logging: "Project/run initialization"
|
| 181 |
+
- finish_wandb_logging: "Summary statistics"
|
| 182 |
+
- log_model_checkpoint: "Artifact logging"
|
| 183 |
+
|
| 184 |
+
graceful_degradation: |
|
| 185 |
+
All functions work without wandb installed.
|
| 186 |
+
Falls back to print() statements when wandb unavailable.
|
| 187 |
+
|
| 188 |
+
# =============================================================================
|
| 189 |
+
# TRAINING NOTEBOOK (AUTHORITATIVE REFERENCE)
|
| 190 |
+
# =============================================================================
|
| 191 |
+
|
| 192 |
+
notebook:
|
| 193 |
+
location: notebooks/Marxist_GRPO_Training.ipynb
|
| 194 |
+
description: |
|
| 195 |
+
Self-contained Jupyter notebook for RunPod execution.
|
| 196 |
+
Contains ALL reward functions inline (no external imports).
|
| 197 |
+
This is the authoritative reference for current implementation.
|
| 198 |
+
|
| 199 |
+
contents:
|
| 200 |
+
- "Model loading (FastLanguageModel.from_pretrained)"
|
| 201 |
+
- "LoRA configuration (get_peft_model)"
|
| 202 |
+
- "Dataset loading from grpo_dataset.jsonl"
|
| 203 |
+
- "All 13+ reward functions inline"
|
| 204 |
+
- "W&B logging integration"
|
| 205 |
+
- "GRPOConfig with A40-optimized settings"
|
| 206 |
+
- "GRPOTrainer setup and training"
|
| 207 |
+
- "LoRA saving and GGUF export"
|
| 208 |
+
|
| 209 |
+
# =============================================================================
|
| 210 |
+
# EXPORT AND DEPLOYMENT
|
| 211 |
+
# =============================================================================
|
| 212 |
+
|
| 213 |
+
export:
|
| 214 |
+
lora_save:
|
| 215 |
+
code: |
|
| 216 |
+
model.save_pretrained_merged("marxist_lora", tokenizer)
|
| 217 |
+
|
| 218 |
+
gguf:
|
| 219 |
+
description: Quantized format for Ollama deployment
|
| 220 |
+
quantization_levels:
|
| 221 |
+
q4_k_m: "Recommended balance of size/quality (~4GB)"
|
| 222 |
+
q8_0: "Higher quality (~7GB)"
|
| 223 |
+
code: |
|
| 224 |
+
model.save_pretrained_gguf("marxist_gguf", tokenizer, quantization_method="q4_k_m")
|
| 225 |
+
|
| 226 |
+
ollama:
|
| 227 |
+
modelfile: |
|
| 228 |
+
FROM ./marxist-gguf-q4_k_m.gguf
|
| 229 |
+
TEMPLATE """<|begin_of_text|>{{ .System }}<|User|>{{ .Prompt }}<|Assistant|>"""
|
| 230 |
+
SYSTEM "You are a Marxist-Leninist assistant..."
|
| 231 |
+
PARAMETER temperature 0.7
|
| 232 |
+
PARAMETER num_ctx 2048
|
| 233 |
+
|
| 234 |
+
# =============================================================================
|
| 235 |
+
# IMPLEMENTATION STATUS
|
| 236 |
+
# =============================================================================
|
| 237 |
+
|
| 238 |
+
implementation_status:
|
| 239 |
+
completed:
|
| 240 |
+
- "8.1 Dataset Preparation: grpo_dataset.jsonl (1,058 Q&A pairs)"
|
| 241 |
+
- "8.2 Reward Function System: 13+ functions with anti-hacking"
|
| 242 |
+
- "8.3 W&B Logging Integration: Full observability"
|
| 243 |
+
- "8.4 Training Notebook: Self-contained for RunPod"
|
| 244 |
+
|
| 245 |
+
pending:
|
| 246 |
+
- "8.5 Training Execution: Run notebook on RunPod A40"
|
| 247 |
+
- "8.6 Model Evaluation: Manual review, reward hacking detection"
|
| 248 |
+
- "8.7 GGUF Export: Convert to Ollama format"
|
| 249 |
+
|
| 250 |
+
test_counts:
|
| 251 |
+
grpo_rewards: 43 passing
|
| 252 |
+
wandb_logging: 17 passing
|
| 253 |
+
total: 60 passing
|
| 254 |
+
|
| 255 |
+
# =============================================================================
|
| 256 |
+
# HARDWARE REQUIREMENTS
|
| 257 |
+
# =============================================================================
|
| 258 |
+
|
| 259 |
+
hardware:
|
| 260 |
+
recommended:
|
| 261 |
+
gpu: A40 (48GB VRAM) on RunPod
|
| 262 |
+
cost: "~$0.79/hour spot, ~$1.14/hour on-demand"
|
| 263 |
+
training_time: "~2-4 hours for 250 steps"
|
| 264 |
+
|
| 265 |
+
vram_breakdown:
|
| 266 |
+
model_4bit: "~4GB"
|
| 267 |
+
lora_params: "~1GB"
|
| 268 |
+
optimizer_state: "~2GB"
|
| 269 |
+
activations: "~6GB with gradient checkpointing"
|
| 270 |
+
vllm_generation: "~8GB for 4 generations"
|
| 271 |
+
reward_models: "~2.5GB (NLI + embeddings + spaCy)"
|
| 272 |
+
total: "~24GB (safe on 48GB A40)"
|
| 273 |
+
|
| 274 |
+
cloud_options:
|
| 275 |
+
primary: "RunPod A40 (48GB) - see ai-docs/runpod.yaml"
|
| 276 |
+
alternative: "Lambda Labs A100 (40GB)"
|
| 277 |
+
|
| 278 |
+
# =============================================================================
|
| 279 |
+
# REFERENCES
|
| 280 |
+
# =============================================================================
|
| 281 |
+
|
| 282 |
+
references:
|
| 283 |
+
internal:
|
| 284 |
+
- "ai-docs/reward-modeling.yaml: Reward function design"
|
| 285 |
+
- "ai-docs/runpod.yaml: GPU setup instructions"
|
| 286 |
+
- "notebooks/Marxist_GRPO_Training.ipynb: Authoritative notebook"
|
| 287 |
+
- "src/prolewiki_llm/: Python module"
|
| 288 |
+
|
| 289 |
+
external:
|
| 290 |
+
unsloth:
|
| 291 |
+
- "https://github.com/unslothai/unsloth"
|
| 292 |
+
- "https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide"
|
| 293 |
+
grpo:
|
| 294 |
+
- "https://arxiv.org/abs/2508.18212 - NLI as reward paradigm"
|
| 295 |
+
- "https://arxiv.org/abs/2509.22047 - MO-GRPO reward hacking"
|
| 296 |
+
trl:
|
| 297 |
+
- "https://huggingface.co/docs/trl/en/grpo_trainer"
|
ai-docs/reward-modeling.yaml
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reward Modeling for GRPO Fine-Tuning
|
| 2 |
+
# Token-efficient reference for AI assistants
|
| 3 |
+
# Research-backed approaches to defeat reward hacking
|
| 4 |
+
|
| 5 |
+
status: IMPLEMENTED - multi-layer coherence rewards with wandb logging
|
| 6 |
+
purpose: Robust reward functions for Marxist-Leninist GRPO training
|
| 7 |
+
module: src/prolewiki_llm/
|
| 8 |
+
files:
|
| 9 |
+
grpo_rewards: Reward functions for GRPO training
|
| 10 |
+
wandb_logging: Weights & Biases integration for training observability
|
| 11 |
+
|
| 12 |
+
# =============================================================================
|
| 13 |
+
# OVERVIEW
|
| 14 |
+
# =============================================================================
|
| 15 |
+
|
| 16 |
+
overview:
|
| 17 |
+
goal: |
|
| 18 |
+
Train DeepSeek-R1-0528-Qwen3-8B on ProleWiki Q&A corpus using GRPO
|
| 19 |
+
with reward functions that cannot be gamed by "word soup" or other
|
| 20 |
+
adversarial strategies.
|
| 21 |
+
|
| 22 |
+
problem: |
|
| 23 |
+
Naive reward functions (substring matching, keyword counting) are
|
| 24 |
+
vulnerable to reward hacking. A model could maximize reward by
|
| 25 |
+
outputting random Marxist terminology without coherent meaning.
|
| 26 |
+
|
| 27 |
+
solution: |
|
| 28 |
+
Multi-layer reward combining:
|
| 29 |
+
1. NLI (Natural Language Inference) - checks logical consistency
|
| 30 |
+
2. Self-consistency - checks for internal contradictions
|
| 31 |
+
3. Structural coherence - checks syntactic structure via spaCy
|
| 32 |
+
|
| 33 |
+
research_basis:
|
| 34 |
+
- "arxiv.org/abs/2508.18212 - NLI as reward modeling paradigm"
|
| 35 |
+
- "arxiv.org/abs/2509.22047 - MO-GRPO mitigating reward hacking"
|
| 36 |
+
- "arxiv.org/abs/2508.05170 - Posterior-GRPO process rewards"
|
| 37 |
+
|
| 38 |
+
# =============================================================================
|
| 39 |
+
# IDEOLOGICAL BIAS CONSIDERATION
|
| 40 |
+
# =============================================================================
|
| 41 |
+
|
| 42 |
+
ideological_bias:
|
| 43 |
+
concern: |
|
| 44 |
+
NLI models trained on predominantly liberal/capitalist media might
|
| 45 |
+
encode bourgeois assumptions, potentially marking valid Marxist claims
|
| 46 |
+
as "contradiction" or "neutral".
|
| 47 |
+
|
| 48 |
+
empirical_finding: |
|
| 49 |
+
Testing shows BART-large-MNLI performs LOGICAL inference, not
|
| 50 |
+
ideological judgment:
|
| 51 |
+
- "Capitalism exploits workers" → "Workers are exploited": entailment (0.998)
|
| 52 |
+
- "Class struggle drives history" vs "Great individuals drive history": contradiction (0.998)
|
| 53 |
+
- Word soup → coherent claim: neutral (0.932)
|
| 54 |
+
|
| 55 |
+
why_it_works: |
|
| 56 |
+
We compare Marxist response against Marxist ground truth (from ProleWiki).
|
| 57 |
+
The model isn't judging if Marxism is "true" - it's checking if the
|
| 58 |
+
response logically follows from the expected answer.
|
| 59 |
+
|
| 60 |
+
mitigation_strategy: |
|
| 61 |
+
Self-consistency check avoids external ideology entirely - it only
|
| 62 |
+
checks if the response contradicts ITSELF, not external "truth".
|
| 63 |
+
|
| 64 |
+
# =============================================================================
|
| 65 |
+
# REWARD FUNCTIONS
|
| 66 |
+
# =============================================================================
|
| 67 |
+
|
| 68 |
+
reward_functions:
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
# FORMAT REWARDS (from original notebook)
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
match_format_exactly:
|
| 75 |
+
purpose: Encourage proper <think>...</think> format
|
| 76 |
+
scoring:
|
| 77 |
+
has_think_end_tag: "+3.0"
|
| 78 |
+
no_tag: "0.0"
|
| 79 |
+
notes: Checks for </think> tag presence
|
| 80 |
+
|
| 81 |
+
match_format_approximately:
|
| 82 |
+
purpose: Reward partial format compliance
|
| 83 |
+
scoring:
|
| 84 |
+
one_think_start: "+0.5"
|
| 85 |
+
one_think_end: "+0.5"
|
| 86 |
+
multiple_or_missing: "-1.0"
|
| 87 |
+
notes: Penalizes malformed tag structure
|
| 88 |
+
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
# SEMANTIC REWARDS
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
semantic_similarity_reward:
|
| 94 |
+
purpose: Reward responses semantically similar to ground truth
|
| 95 |
+
model: sentence-transformers/all-MiniLM-L6-v2
|
| 96 |
+
scoring:
|
| 97 |
+
similarity_gt_0.75: "+5.0"
|
| 98 |
+
similarity_gt_0.60: "+3.0"
|
| 99 |
+
similarity_gt_0.45: "+1.0"
|
| 100 |
+
similarity_gt_0.30: "-1.0"
|
| 101 |
+
similarity_le_0.30: "-3.0"
|
| 102 |
+
notes: |
|
| 103 |
+
Uses cosine similarity of embeddings.
|
| 104 |
+
Good for soft matching but doesn't catch contradictions.
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# NLI-BASED REWARDS (Research-backed)
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
nli_coherence_reward:
|
| 111 |
+
purpose: Check if response ENTAILS ground truth
|
| 112 |
+
model: facebook/bart-large-mnli
|
| 113 |
+
scoring:
|
| 114 |
+
entailment: "+3.0 (response supports/implies ground truth)"
|
| 115 |
+
neutral: "-1.0 (off-topic or incoherent)"
|
| 116 |
+
contradiction: "-3.0 (contradicts ground truth)"
|
| 117 |
+
defeats:
|
| 118 |
+
- "Word soup (random terms → neutral)"
|
| 119 |
+
- "Contradictory claims"
|
| 120 |
+
- "Off-topic responses"
|
| 121 |
+
gpu_memory: "~1.6GB"
|
| 122 |
+
research: "arxiv.org/abs/2508.18212"
|
| 123 |
+
|
| 124 |
+
self_consistency_reward:
|
| 125 |
+
purpose: Check for internal contradictions (no external ideology)
|
| 126 |
+
method: |
|
| 127 |
+
Parse response into sentences using spaCy.
|
| 128 |
+
Check adjacent sentence pairs for NLI contradiction.
|
| 129 |
+
scoring:
|
| 130 |
+
no_contradictions: "+1.0"
|
| 131 |
+
has_contradiction: "-2.0"
|
| 132 |
+
notes: |
|
| 133 |
+
Only checks within-document coherence, avoiding any external
|
| 134 |
+
ideological bias from the NLI model's training data.
|
| 135 |
+
research: "arxiv.org/abs/2508.05170"
|
| 136 |
+
|
| 137 |
+
structural_coherence_reward:
|
| 138 |
+
purpose: Verify proper linguistic structure (defeats word soup)
|
| 139 |
+
model: spacy en_core_web_sm
|
| 140 |
+
checks:
|
| 141 |
+
- "Marxist terms in subject/object syntactic positions"
|
| 142 |
+
- "Presence of discourse connectives (therefore, because, etc.)"
|
| 143 |
+
- "Complete sentence structure"
|
| 144 |
+
scoring:
|
| 145 |
+
term_in_syntactic_role: "+0.3 per term (max +1.5)"
|
| 146 |
+
discourse_connective: "+0.2 per connective (max +1.0)"
|
| 147 |
+
no_sentences: "-1.0"
|
| 148 |
+
syntactic_roles:
|
| 149 |
+
- "nsubj (nominal subject)"
|
| 150 |
+
- "nsubjpass (passive subject)"
|
| 151 |
+
- "dobj (direct object)"
|
| 152 |
+
- "pobj (object of preposition)"
|
| 153 |
+
- "attr (attribute)"
|
| 154 |
+
- "appos (appositional modifier)"
|
| 155 |
+
discourse_connectives:
|
| 156 |
+
- "because, therefore, thus, hence, consequently"
|
| 157 |
+
- "however, although, whereas, nevertheless, moreover"
|
| 158 |
+
- "furthermore, additionally, specifically, namely"
|
| 159 |
+
- "as a result, due to, in order to, so that"
|
| 160 |
+
- "on the other hand, in contrast, similarly, likewise"
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# TOPIC RELEVANCE REWARD (Question-Answer Alignment)
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
topic_relevance_reward:
|
| 167 |
+
purpose: Ensure answer addresses what the question asked about
|
| 168 |
+
method: |
|
| 169 |
+
Implements f(A) ⊆ f(Q) check where f extracts semantic topics:
|
| 170 |
+
1. Extract core topics from question Q using dependency parsing
|
| 171 |
+
2. Expand Q topics with Marxist concept synonyms
|
| 172 |
+
3. Extract topics from answer A
|
| 173 |
+
4. Compute coverage: how many Q topics are addressed in A
|
| 174 |
+
scoring:
|
| 175 |
+
gt_80_coverage: "+2.0 (answer fully addresses question topics)"
|
| 176 |
+
gt_60_coverage: "+1.5 (answer mostly on-topic)"
|
| 177 |
+
gt_40_coverage: "+1.0 (answer partially on-topic)"
|
| 178 |
+
gt_20_coverage: "0.0 (answer tangentially related)"
|
| 179 |
+
le_20_coverage: "-1.5 (answer off-topic)"
|
| 180 |
+
defeats:
|
| 181 |
+
- "Off-topic coherent text (coherent Marxist text about wrong subject)"
|
| 182 |
+
- "Topic drift during response"
|
| 183 |
+
spacy_model: "en_core_web_trf (transformer-based, best semantic understanding)"
|
| 184 |
+
topic_extraction:
|
| 185 |
+
question: "ROOT verb children with nsubj/dobj/attr/nsubjpass dependency"
|
| 186 |
+
answer: "Noun chunks + named entities (determiners stripped)"
|
| 187 |
+
synonym_expansion: "CONCEPT_EQUIVALENCES dict maps bourgeoisie ↔ capitalist class"
|
| 188 |
+
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
# COMBINED REWARDS
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
robust_coherence_reward:
|
| 194 |
+
purpose: Multi-layer coherence check
|
| 195 |
+
layers:
|
| 196 |
+
1: "NLI coherence - Does response entail ground truth?"
|
| 197 |
+
2: "Self-consistency - Does response contradict itself?"
|
| 198 |
+
3: "Structural coherence - Are terms used in proper syntax?"
|
| 199 |
+
scoring:
|
| 200 |
+
max_score: "+5.5 (entailment + consistent + structured)"
|
| 201 |
+
contradiction_floor: "-3.0 (NLI contradiction)"
|
| 202 |
+
inconsistency_penalty: "-2.0 (internal contradiction)"
|
| 203 |
+
combination_logic: |
|
| 204 |
+
if nli_score <= -3.0:
|
| 205 |
+
return -3.0 # Contradiction dominates
|
| 206 |
+
elif consistency_score <= -2.0:
|
| 207 |
+
return -2.0 # Internal contradiction
|
| 208 |
+
else:
|
| 209 |
+
return nli_score + (consistency * 0.5) + (structure * 0.5)
|
| 210 |
+
|
| 211 |
+
full_coherence_reward:
|
| 212 |
+
purpose: Complete coherence check (RECOMMENDED - maximum robustness)
|
| 213 |
+
layers:
|
| 214 |
+
1: "NLI coherence - Does response entail ground truth?"
|
| 215 |
+
2: "Self-consistency - Does response contradict itself?"
|
| 216 |
+
3: "Structural coherence - Are terms used in proper syntax?"
|
| 217 |
+
4: "Topic relevance - Does answer address what was asked?"
|
| 218 |
+
5: "Interconnection depth - Rewards deep analysis, penalizes buzzword salad"
|
| 219 |
+
scoring:
|
| 220 |
+
max_score: "+7.0 (all checks pass with deep analysis)"
|
| 221 |
+
off_topic_penalty: "-2.0 (severely off-topic)"
|
| 222 |
+
buzzword_salad_penalty: "-1.5 (shallow buzzword listing)"
|
| 223 |
+
inherits: "robust_coherence_reward penalties for NLI/consistency failures"
|
| 224 |
+
combination_logic: |
|
| 225 |
+
if relevance <= -1.5:
|
| 226 |
+
return -2.0 # Severely off-topic
|
| 227 |
+
elif robust <= -2.0:
|
| 228 |
+
return robust # Robust check failed
|
| 229 |
+
elif depth <= -1.5:
|
| 230 |
+
return -1.5 # Buzzword salad detected
|
| 231 |
+
else:
|
| 232 |
+
return robust + (relevance * 0.4) + (depth * 0.3)
|
| 233 |
+
use_when: "Maximum robustness against reward hacking is needed"
|
| 234 |
+
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
# INTERCONNECTION DEPTH REWARD (Anti-Buzzword-Salad)
|
| 237 |
+
# ---------------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
interconnection_depth_reward:
|
| 240 |
+
purpose: Distinguish deep analysis from shallow buzzword salad
|
| 241 |
+
method: |
|
| 242 |
+
Rewards meaningful interconnections while penalizing superficial
|
| 243 |
+
concept-dropping. Distinguishes:
|
| 244 |
+
- GOOD: "Surplus value relates to imperialism BECAUSE capital export..."
|
| 245 |
+
- BAD: "Surplus value intersects with imperialism, colonialism, patriarchy..."
|
| 246 |
+
signals:
|
| 247 |
+
depth_ratio:
|
| 248 |
+
description: "Words per unique Marxist concept"
|
| 249 |
+
gt_20: "+1.0 (deep analysis - few concepts well-explained)"
|
| 250 |
+
range_10_20: "+0.5 (adequate depth)"
|
| 251 |
+
range_5_10: "-0.5 (shallow)"
|
| 252 |
+
lt_5: "-1.5 (severe buzzword soup)"
|
| 253 |
+
hollow_buzzwords:
|
| 254 |
+
description: "Activist jargon without substance"
|
| 255 |
+
threshold: "> 2 hollow phrases"
|
| 256 |
+
penalty: "-0.3 per additional (max -1.5)"
|
| 257 |
+
depth_markers:
|
| 258 |
+
description: "Historical specificity, citations, examples"
|
| 259 |
+
bonus: "+0.3 each (max +1.5)"
|
| 260 |
+
explanation_ratio:
|
| 261 |
+
description: "Explanatory phrases per concept"
|
| 262 |
+
gt_50_percent: "+0.5 (well-explained)"
|
| 263 |
+
lt_10_percent_many_concepts: "-0.5 (unexplained concept soup)"
|
| 264 |
+
scoring:
|
| 265 |
+
range: "-2.5 to +3.0"
|
| 266 |
+
defeats:
|
| 267 |
+
- "Buzzword salad (many concepts, no explanation)"
|
| 268 |
+
- "Activist jargon (performative language without analysis)"
|
| 269 |
+
- "Intersectionality word soup (mentioning everything without depth)"
|
| 270 |
+
|
| 271 |
+
# ---------------------------------------------------------------------------
|
| 272 |
+
# ENTITY VERIFICATION REWARDS (Anti-Hallucination)
|
| 273 |
+
# ---------------------------------------------------------------------------
|
| 274 |
+
|
| 275 |
+
entity_verification_reward:
|
| 276 |
+
status: NEW - added 2025-12-18
|
| 277 |
+
purpose: Penalize confident claims about unverified entities
|
| 278 |
+
method: |
|
| 279 |
+
Uses spaCy NER to extract entity mentions from responses, then
|
| 280 |
+
checks them against a whitelist of 24,040 entities extracted from
|
| 281 |
+
the ProleWiki corpus. Rewards epistemic humility for unknown entities.
|
| 282 |
+
whitelist_source: "training_data/entity_whitelist_clean.json"
|
| 283 |
+
whitelist_stats:
|
| 284 |
+
total_entities: 24040
|
| 285 |
+
sources:
|
| 286 |
+
- "5,129 article titles (filenames)"
|
| 287 |
+
- "27,329 internal wiki links [[Entity]]"
|
| 288 |
+
- "2,458 categories"
|
| 289 |
+
- "2,179 library work references"
|
| 290 |
+
- "2,145 infobox person names"
|
| 291 |
+
scoring:
|
| 292 |
+
epistemic_humility: "+2.0 (expresses uncertainty about unknown entities)"
|
| 293 |
+
verified_entities: "+1.0 (discusses only verified entities)"
|
| 294 |
+
unknown_no_uncertainty: "-1.0 (mentions unknown entities without caution)"
|
| 295 |
+
confident_hallucination: "-2.5 (fabricates details about unknown entities)"
|
| 296 |
+
defeats:
|
| 297 |
+
- "Confident hallucination about fictional organizations"
|
| 298 |
+
- "Fabricated founding dates, founders, activities"
|
| 299 |
+
- "Made-up historical context for non-existent entities"
|
| 300 |
+
example_test_case: |
|
| 301 |
+
Question: "What was the Militant League of Trotskyist Shanghai Students?"
|
| 302 |
+
BAD response: "Founded in 1925, the Militant League..." (fabricated details)
|
| 303 |
+
GOOD response: "I cannot verify this organization..." (epistemic humility)
|
| 304 |
+
dependencies:
|
| 305 |
+
- "spaCy (en_core_web_trf or fallback)"
|
| 306 |
+
- "training_data/entity_whitelist_clean.json"
|
| 307 |
+
|
| 308 |
+
epistemic_calibration_reward:
|
| 309 |
+
status: NEW - added 2025-12-18
|
| 310 |
+
purpose: Lightweight pattern-based epistemic calibration
|
| 311 |
+
method: |
|
| 312 |
+
Pattern matching for uncertainty phrases and confident claim patterns.
|
| 313 |
+
Doesn't require NER or entity whitelist - just regex matching.
|
| 314 |
+
Use when entity_verification_reward is too computationally expensive.
|
| 315 |
+
scoring:
|
| 316 |
+
has_uncertainty_phrases: "+1.5"
|
| 317 |
+
has_confident_claims_no_uncertainty: "-0.5"
|
| 318 |
+
neutral: "0.0"
|
| 319 |
+
uncertainty_patterns:
|
| 320 |
+
- "I cannot/can't/don't verify/confirm/find"
|
| 321 |
+
- "I don't have verified/specific information"
|
| 322 |
+
- "I'm not certain/sure/confident"
|
| 323 |
+
- "I should not fabricate/make up/speculate"
|
| 324 |
+
- "Could you provide more context"
|
| 325 |
+
confident_claim_patterns:
|
| 326 |
+
- "founded in [year]"
|
| 327 |
+
- "established in [year]"
|
| 328 |
+
- "was founded by"
|
| 329 |
+
- "was created by"
|
| 330 |
+
use_when: "Lightweight alternative to entity_verification_reward"
|
| 331 |
+
|
| 332 |
+
# ---------------------------------------------------------------------------
|
| 333 |
+
# IDEOLOGICAL FIRMNESS REWARD (Anti-Capitulation)
|
| 334 |
+
# ---------------------------------------------------------------------------
|
| 335 |
+
|
| 336 |
+
ideological_firmness_reward:
|
| 337 |
+
status: NEW - added 2025-12-18
|
| 338 |
+
purpose: Reward firm Marxist counter-arguments to reactionary claims
|
| 339 |
+
method: |
|
| 340 |
+
Detects ideological challenges (fascist propaganda, anti-communist
|
| 341 |
+
tropes, bad faith questions) and scores responses based on:
|
| 342 |
+
1. Detection: Is the prompt an ideological challenge?
|
| 343 |
+
2. Capitulation check: Does the response hedge, both-sides, or concede?
|
| 344 |
+
3. Firmness check: Does the response firmly refute with class analysis?
|
| 345 |
+
4. Substantive bonus: Does the response use Marxist analytical terms?
|
| 346 |
+
|
| 347 |
+
Key insight: COMBAT reactionary ideology with analysis, don't refuse engagement.
|
| 348 |
+
scoring:
|
| 349 |
+
non_challenge: "0.0 (neutral - no ideological challenge in prompt)"
|
| 350 |
+
capitulation_patterns: "-0.5 each (both-sidesing, hedging, false balance)"
|
| 351 |
+
firmness_patterns: "+0.5 each (direct refutation, class analysis)"
|
| 352 |
+
class_analysis_bonus: "+0.2 each marker (capped at +1.0)"
|
| 353 |
+
range: "-2.0 to +2.0"
|
| 354 |
+
ideological_challenge_patterns:
|
| 355 |
+
reactionary_claims:
|
| 356 |
+
- "Hitler helped workers"
|
| 357 |
+
- "Nazis were socialist"
|
| 358 |
+
death_toll_claims:
|
| 359 |
+
- "communism killed millions"
|
| 360 |
+
- "Black Book of Communism"
|
| 361 |
+
anti_communist_framing:
|
| 362 |
+
- "Stalin was a dictator"
|
| 363 |
+
- "authoritarian regime"
|
| 364 |
+
- "communism failed"
|
| 365 |
+
bad_faith_questions:
|
| 366 |
+
- "Holodomor"
|
| 367 |
+
- "Gulag"
|
| 368 |
+
- "Great Purge"
|
| 369 |
+
- "Was Stalin bad?"
|
| 370 |
+
- "Was Mao good or bad?"
|
| 371 |
+
capitulation_patterns:
|
| 372 |
+
both_sidesing:
|
| 373 |
+
- "valid points on both sides"
|
| 374 |
+
- "on one hand... on the other hand"
|
| 375 |
+
- "reasonable people can disagree"
|
| 376 |
+
hedging:
|
| 377 |
+
- "it's a complex/nuanced issue"
|
| 378 |
+
- "I can see why you'd ask"
|
| 379 |
+
false_balance:
|
| 380 |
+
- "Stalin did good things and bad things"
|
| 381 |
+
- "achievements and mistakes to consider"
|
| 382 |
+
theory_practice:
|
| 383 |
+
- "communism sounds good in theory but..."
|
| 384 |
+
firmness_patterns:
|
| 385 |
+
direct_refutation:
|
| 386 |
+
- "This claim is false/propaganda"
|
| 387 |
+
- "This is fascist/bourgeois propaganda"
|
| 388 |
+
- "discredited"
|
| 389 |
+
- "debunked"
|
| 390 |
+
analytical_framing:
|
| 391 |
+
- "From a Marxist analysis..."
|
| 392 |
+
- "The material reality shows..."
|
| 393 |
+
- "historically documented"
|
| 394 |
+
class_analysis:
|
| 395 |
+
- "serves the interests of capital"
|
| 396 |
+
- "bourgeois historiography"
|
| 397 |
+
- "dictatorship of the proletariat"
|
| 398 |
+
socialist_achievements:
|
| 399 |
+
- "universal literacy"
|
| 400 |
+
- "defeated fascism"
|
| 401 |
+
- "provided housing and healthcare"
|
| 402 |
+
special_handling:
|
| 403 |
+
quote_to_refute: |
|
| 404 |
+
Quoting reactionary claims in order to refute them is NOT penalized.
|
| 405 |
+
Pattern: "X claim Y, but this ignores..."
|
| 406 |
+
principled_self_criticism: |
|
| 407 |
+
ML self-criticism is NOT capitulation when framed from a
|
| 408 |
+
revolutionary perspective (e.g., "Marxist-Leninists engage in
|
| 409 |
+
self-criticism as a methodological principle")
|
| 410 |
+
defeats:
|
| 411 |
+
- "Both-sidesing on fascism/communism"
|
| 412 |
+
- "Hedging on historical socialist figures"
|
| 413 |
+
- "Conceding reactionary framing"
|
| 414 |
+
- "Theory/practice false dichotomy"
|
| 415 |
+
example_test_cases:
|
| 416 |
+
high_score: |
|
| 417 |
+
Q: "Didn't Hitler help German workers?"
|
| 418 |
+
A: "This is fascist propaganda. Hitler destroyed trade unions,
|
| 419 |
+
banned strikes, and served the interests of finance capital."
|
| 420 |
+
Score: >= +1.5
|
| 421 |
+
low_score: |
|
| 422 |
+
Q: "Didn't Hitler help German workers?"
|
| 423 |
+
A: "While Hitler did implement some programs that benefited
|
| 424 |
+
workers, it's important to weigh both positive and negative..."
|
| 425 |
+
Score: <= -1.5
|
| 426 |
+
neutral: |
|
| 427 |
+
Q: "What year did the October Revolution happen?"
|
| 428 |
+
A: "The October Revolution occurred in 1917."
|
| 429 |
+
Score: ~0.0 (no ideological challenge detected)
|
| 430 |
+
use_when: |
|
| 431 |
+
Training models to maintain principled Marxist positions when
|
| 432 |
+
facing reactionary claims and propaganda. Prevents models from
|
| 433 |
+
developing liberal "both-sides" tendencies on class issues.
|
| 434 |
+
|
| 435 |
+
# ---------------------------------------------------------------------------
|
| 436 |
+
# DEPRECATED/SHALLOW REWARDS
|
| 437 |
+
# ---------------------------------------------------------------------------
|
| 438 |
+
|
| 439 |
+
terminology_reward:
|
| 440 |
+
status: DEPRECATED - use structural_coherence_reward instead
|
| 441 |
+
purpose: Reward Marxist terminology (SHALLOW - can be gamed)
|
| 442 |
+
scoring: "+0.3 per term (max +2.0)"
|
| 443 |
+
warning: |
|
| 444 |
+
This reward can be gamed with "word soup" - random Marxist
|
| 445 |
+
terms without coherent meaning. Use nli_coherence_reward or
|
| 446 |
+
structural_coherence_reward for robust evaluation.
|
| 447 |
+
|
| 448 |
+
# =============================================================================
|
| 449 |
+
# MARXIST TERMINOLOGY
|
| 450 |
+
# =============================================================================
|
| 451 |
+
|
| 452 |
+
marxist_terms:
|
| 453 |
+
core_concepts:
|
| 454 |
+
- "dialectical, materialism, historical materialism, dialectical materialism"
|
| 455 |
+
classes:
|
| 456 |
+
- "bourgeoisie, proletariat, petty bourgeois, lumpenproletariat"
|
| 457 |
+
- "working class, ruling class"
|
| 458 |
+
class_struggle:
|
| 459 |
+
- "class struggle, class consciousness, class war, class conflict"
|
| 460 |
+
political_economy:
|
| 461 |
+
- "surplus value, commodity, use value, exchange value"
|
| 462 |
+
- "labor power, means of production, relations of production"
|
| 463 |
+
- "forces of production, mode of production, primitive accumulation"
|
| 464 |
+
- "exploitation, capital accumulation"
|
| 465 |
+
imperialism:
|
| 466 |
+
- "imperialism, colonialism, neo-colonialism, settler colonialism"
|
| 467 |
+
- "national liberation, self-determination"
|
| 468 |
+
state_revolution:
|
| 469 |
+
- "dictatorship of the proletariat, vanguard, vanguard party"
|
| 470 |
+
- "democratic centralism, withering away of the state"
|
| 471 |
+
ideology:
|
| 472 |
+
- "hegemony, superstructure, base, ideology, false consciousness"
|
| 473 |
+
revisionism:
|
| 474 |
+
- "revisionism, opportunism, reformism, social democracy, ultra-leftism"
|
| 475 |
+
alienation:
|
| 476 |
+
- "alienation, fetishism, commodity fetishism, reification"
|
| 477 |
+
historical:
|
| 478 |
+
- "paris commune, october revolution, bolshevik, menshevik"
|
| 479 |
+
anti_colonial:
|
| 480 |
+
- "decolonization, third world, global south, national bourgeoisie, comprador"
|
| 481 |
+
|
| 482 |
+
# =============================================================================
|
| 483 |
+
# HOLLOW BUZZWORDS (Activist Jargon to Penalize)
|
| 484 |
+
# =============================================================================
|
| 485 |
+
|
| 486 |
+
hollow_buzzwords:
|
| 487 |
+
description: |
|
| 488 |
+
Phrases that signal superficial analysis when used without substantive
|
| 489 |
+
explanation. These are NOT Marxist technical terms - they are activist
|
| 490 |
+
jargon that often substitutes for actual analysis.
|
| 491 |
+
vague_connectors:
|
| 492 |
+
- "interconnected, interrelated, intersects with"
|
| 493 |
+
- "it's all connected, everything is connected, systemic"
|
| 494 |
+
performative_language:
|
| 495 |
+
- "centered, centering, uplift, uplifting"
|
| 496 |
+
- "do the work, the work, unpack, unpacking"
|
| 497 |
+
- "unlearn, unlearning, hold space, sit with, lean into"
|
| 498 |
+
- "problematic, harmful, toxic"
|
| 499 |
+
vague_abstractions:
|
| 500 |
+
- "in a way, sort of, kind of, essentially, basically"
|
| 501 |
+
- "generally speaking, broadly"
|
| 502 |
+
misused_terms:
|
| 503 |
+
- "praxis (when used without explanation)"
|
| 504 |
+
- "material conditions (when used as hand-wave)"
|
| 505 |
+
- "structural, structurally (when mechanism not specified)"
|
| 506 |
+
note: |
|
| 507 |
+
The penalty applies when hollow buzzword DENSITY is high AND
|
| 508 |
+
depth ratio is low. Legitimate use with explanation is not penalized.
|
| 509 |
+
|
| 510 |
+
# =============================================================================
|
| 511 |
+
# EXPLANATORY PHRASES (Depth Markers)
|
| 512 |
+
# =============================================================================
|
| 513 |
+
|
| 514 |
+
explanatory_phrases:
|
| 515 |
+
description: "Phrases indicating concept is being explained, not just dropped"
|
| 516 |
+
causal:
|
| 517 |
+
- "because the, because of, this is because, since the"
|
| 518 |
+
- "due to the, as a result of, results from, caused by"
|
| 519 |
+
- "leads to, results in, enables, produces"
|
| 520 |
+
definitional:
|
| 521 |
+
- "is defined as, refers to, means that, denotes"
|
| 522 |
+
- "that is, in other words, namely, i.e."
|
| 523 |
+
elaboration:
|
| 524 |
+
- "specifically, in particular, for example, such as"
|
| 525 |
+
- "this means, which means, this implies, therefore"
|
| 526 |
+
mechanism:
|
| 527 |
+
- "this occurs when, this happens because, the mechanism"
|
| 528 |
+
- "through the process of, by means of, works by"
|
| 529 |
+
|
| 530 |
+
depth_markers:
|
| 531 |
+
description: "Phrases indicating analytical depth (historical specificity, citations)"
|
| 532 |
+
historical:
|
| 533 |
+
- "in 1, in 2, during the, after the, before the"
|
| 534 |
+
citations:
|
| 535 |
+
- "marx argued, lenin wrote, engels noted, gramsci"
|
| 536 |
+
- "according to, as marx, as lenin"
|
| 537 |
+
examples:
|
| 538 |
+
- "for example, such as, in the case of, consider"
|
| 539 |
+
definitions:
|
| 540 |
+
- "defined as, meaning, specifically"
|
| 541 |
+
|
| 542 |
+
# =============================================================================
|
| 543 |
+
# USAGE IN TRAINING
|
| 544 |
+
# =============================================================================
|
| 545 |
+
|
| 546 |
+
training_usage:
|
| 547 |
+
|
| 548 |
+
full_reward_set:
|
| 549 |
+
description: "RECOMMENDED - Maximum robustness including depth analysis"
|
| 550 |
+
functions:
|
| 551 |
+
- "match_format_exactly (+3.0 for </think>)"
|
| 552 |
+
- "match_format_approximately (tag validation)"
|
| 553 |
+
- "full_coherence_reward (NLI + structure + topic + depth)"
|
| 554 |
+
- "completeness_reward (length comparison)"
|
| 555 |
+
- "debug_print_reward (monitoring)"
|
| 556 |
+
notes: "full_coherence_reward now includes interconnection_depth_reward"
|
| 557 |
+
|
| 558 |
+
anti_hallucination_reward_set:
|
| 559 |
+
description: "Maximum protection against entity hallucination"
|
| 560 |
+
functions:
|
| 561 |
+
- "match_format_exactly (+3.0 for </think>)"
|
| 562 |
+
- "match_format_approximately (tag validation)"
|
| 563 |
+
- "full_coherence_reward (NLI + structure + topic + depth)"
|
| 564 |
+
- "entity_verification_reward (whitelist check + NER)"
|
| 565 |
+
- "completeness_reward (length comparison)"
|
| 566 |
+
- "debug_print_reward (monitoring)"
|
| 567 |
+
notes: |
|
| 568 |
+
Adds entity_verification_reward to full_reward_set.
|
| 569 |
+
Requires entity_whitelist_clean.json (24,040 verified entities).
|
| 570 |
+
Use when hallucination about organizations/people is a concern.
|
| 571 |
+
use_with_training_data:
|
| 572 |
+
- "synthetic_epistemic_humility.jsonl (20 examples)"
|
| 573 |
+
|
| 574 |
+
ideological_firmness_reward_set:
|
| 575 |
+
description: "NEW - Prevents capitulation to reactionary framing"
|
| 576 |
+
functions:
|
| 577 |
+
- "match_format_exactly (+3.0 for </think>)"
|
| 578 |
+
- "match_format_approximately (tag validation)"
|
| 579 |
+
- "full_coherence_reward (NLI + structure + topic + depth)"
|
| 580 |
+
- "ideological_firmness_reward (anti-capitulation)"
|
| 581 |
+
- "completeness_reward (length comparison)"
|
| 582 |
+
- "debug_print_reward (monitoring)"
|
| 583 |
+
notes: |
|
| 584 |
+
Adds ideological_firmness_reward to full_reward_set.
|
| 585 |
+
Prevents models from developing liberal "both-sides" tendencies.
|
| 586 |
+
Use when training on ideologically charged Q&A pairs.
|
| 587 |
+
use_when: |
|
| 588 |
+
Training data contains questions about fascism, anti-communist
|
| 589 |
+
tropes, or historical socialist figures. Ensures model maintains
|
| 590 |
+
principled Marxist positions rather than hedging or capitulating.
|
| 591 |
+
|
| 592 |
+
robust_reward_set:
|
| 593 |
+
description: "Balanced set that defeats word soup attacks"
|
| 594 |
+
functions:
|
| 595 |
+
- "match_format_exactly (+3.0 for </think>)"
|
| 596 |
+
- "match_format_approximately (tag validation)"
|
| 597 |
+
- "robust_coherence_reward (NLI + self-consistency + structure)"
|
| 598 |
+
- "completeness_reward (length comparison)"
|
| 599 |
+
- "debug_print_reward (monitoring)"
|
| 600 |
+
|
| 601 |
+
lightweight_anti_hallucination_set:
|
| 602 |
+
description: "Anti-hallucination without NER overhead"
|
| 603 |
+
functions:
|
| 604 |
+
- "match_format_exactly (+3.0 for </think>)"
|
| 605 |
+
- "match_format_approximately (tag validation)"
|
| 606 |
+
- "robust_coherence_reward (NLI + self-consistency + structure)"
|
| 607 |
+
- "epistemic_calibration_reward (pattern matching only)"
|
| 608 |
+
- "completeness_reward (length comparison)"
|
| 609 |
+
notes: |
|
| 610 |
+
Uses epistemic_calibration_reward instead of entity_verification_reward.
|
| 611 |
+
Faster but less precise - doesn't check entity whitelist.
|
| 612 |
+
|
| 613 |
+
legacy_reward_set:
|
| 614 |
+
description: "Original set (VULNERABLE to word soup)"
|
| 615 |
+
functions:
|
| 616 |
+
- "match_format_exactly"
|
| 617 |
+
- "match_format_approximately"
|
| 618 |
+
- "semantic_similarity_reward"
|
| 619 |
+
- "terminology_reward # VULNERABLE"
|
| 620 |
+
- "completeness_reward"
|
| 621 |
+
- "debug_print_reward"
|
| 622 |
+
|
| 623 |
+
grpo_trainer_example: |
|
| 624 |
+
trainer = GRPOTrainer(
|
| 625 |
+
model=model,
|
| 626 |
+
processing_class=tokenizer,
|
| 627 |
+
reward_funcs=[
|
| 628 |
+
match_format_exactly,
|
| 629 |
+
match_format_approximately,
|
| 630 |
+
full_coherence_reward, # RECOMMENDED: NLI + structure + topic
|
| 631 |
+
completeness_reward,
|
| 632 |
+
debug_print_reward,
|
| 633 |
+
],
|
| 634 |
+
args=training_args,
|
| 635 |
+
train_dataset=dataset,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# =============================================================================
|
| 639 |
+
# DEPENDENCIES
|
| 640 |
+
# =============================================================================
|
| 641 |
+
|
| 642 |
+
dependencies:
|
| 643 |
+
required:
|
| 644 |
+
- "sentence-transformers # for semantic_similarity_reward"
|
| 645 |
+
- "transformers # for NLI pipeline (bart-large-mnli)"
|
| 646 |
+
- "spacy # for topic extraction and structural coherence"
|
| 647 |
+
- "spacy-curated-transformers # for en_core_web_trf"
|
| 648 |
+
- "numpy # for embeddings"
|
| 649 |
+
|
| 650 |
+
models_downloaded:
|
| 651 |
+
- "sentence-transformers/all-MiniLM-L6-v2 (~90MB)"
|
| 652 |
+
- "facebook/bart-large-mnli (~1.6GB)"
|
| 653 |
+
- "en_core_web_trf (~436MB, transformer-based, RECOMMENDED)"
|
| 654 |
+
- "en_core_web_md (~40MB, word vectors, fallback)"
|
| 655 |
+
- "en_core_web_sm (~12MB, no vectors, last resort)"
|
| 656 |
+
|
| 657 |
+
gpu_memory:
|
| 658 |
+
embedder: "~200MB"
|
| 659 |
+
nli_model: "~1.6GB"
|
| 660 |
+
spacy_trf: "~500MB (RoBERTa-based transformer)"
|
| 661 |
+
total: "~2.5GB additional (on top of training model)"
|
| 662 |
+
|
| 663 |
+
# =============================================================================
|
| 664 |
+
# WANDB LOGGING INTEGRATION
|
| 665 |
+
# =============================================================================
|
| 666 |
+
|
| 667 |
+
wandb_logging:
|
| 668 |
+
purpose: |
|
| 669 |
+
Comprehensive logging for GRPO training observability via Weights & Biases.
|
| 670 |
+
Provides debugging visibility into reward function behavior, sample outputs,
|
| 671 |
+
and per-step metrics.
|
| 672 |
+
|
| 673 |
+
module: src/prolewiki_llm/wandb_logging.py
|
| 674 |
+
|
| 675 |
+
components:
|
| 676 |
+
|
| 677 |
+
init_wandb_logging:
|
| 678 |
+
purpose: Initialize W&B run with configuration
|
| 679 |
+
signature: |
|
| 680 |
+
init_wandb_logging(
|
| 681 |
+
project: str,
|
| 682 |
+
config: dict[str, Any],
|
| 683 |
+
name: str | None = None,
|
| 684 |
+
tags: list[str] | None = None,
|
| 685 |
+
notes: str | None = None,
|
| 686 |
+
mode: str = "online", # or "offline", "disabled"
|
| 687 |
+
) -> wandb.Run | None
|
| 688 |
+
example: |
|
| 689 |
+
run = init_wandb_logging(
|
| 690 |
+
project="marxist-grpo",
|
| 691 |
+
config={
|
| 692 |
+
"model": "DeepSeek-R1-0528-Qwen3-8B",
|
| 693 |
+
"learning_rate": 5e-6,
|
| 694 |
+
"batch_size": 2,
|
| 695 |
+
},
|
| 696 |
+
tags=["grpo", "marxist", "v1"],
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
WandbSampleLogger:
|
| 700 |
+
purpose: Accumulate and log sample tables for debugging
|
| 701 |
+
fields:
|
| 702 |
+
log_every_n_steps: "int = 10 (log table every N training steps)"
|
| 703 |
+
max_samples_per_log: "int = 4 (samples per table)"
|
| 704 |
+
methods:
|
| 705 |
+
add_sample: "Add sample with question/response/rewards to buffer"
|
| 706 |
+
should_log: "Check if current step should trigger table log"
|
| 707 |
+
log_table: "Log accumulated samples as wandb.Table"
|
| 708 |
+
clear: "Clear sample buffer"
|
| 709 |
+
table_columns:
|
| 710 |
+
- "step, question, response, ground_truth"
|
| 711 |
+
- "format_exact, format_approx, nli_coherence"
|
| 712 |
+
- "topic_relevance, depth, completeness, total"
|
| 713 |
+
|
| 714 |
+
create_logging_reward:
|
| 715 |
+
purpose: Factory for GRPOTrainer-compatible logging reward function
|
| 716 |
+
signature: |
|
| 717 |
+
create_logging_reward(
|
| 718 |
+
sample_logger: WandbSampleLogger | None = None,
|
| 719 |
+
compute_all_rewards: bool = True,
|
| 720 |
+
) -> Callable[..., list[float]]
|
| 721 |
+
behavior: |
|
| 722 |
+
1. Computes ALL reward functions internally (if compute_all_rewards=True)
|
| 723 |
+
2. Logs aggregated metrics to wandb (per-reward mean/min/max)
|
| 724 |
+
3. Logs sample tables at configured intervals
|
| 725 |
+
4. Returns [0.0] * len(completions) (no training effect)
|
| 726 |
+
usage: |
|
| 727 |
+
sample_logger = WandbSampleLogger(log_every_n_steps=10)
|
| 728 |
+
logging_reward = create_logging_reward(sample_logger)
|
| 729 |
+
|
| 730 |
+
trainer = GRPOTrainer(
|
| 731 |
+
reward_funcs=[..., logging_reward], # Add to reward_funcs
|
| 732 |
+
...
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
log_reward_metrics:
|
| 736 |
+
purpose: Log per-reward metrics to wandb
|
| 737 |
+
signature: |
|
| 738 |
+
log_reward_metrics(
|
| 739 |
+
step: int,
|
| 740 |
+
reward_scores: dict[str, list[float]],
|
| 741 |
+
) -> None
|
| 742 |
+
logs:
|
| 743 |
+
- "rewards/{name} - mean score"
|
| 744 |
+
- "rewards/{name}_min - minimum score"
|
| 745 |
+
- "rewards/{name}_max - maximum score"
|
| 746 |
+
- "rewards/total - sum of all reward means"
|
| 747 |
+
|
| 748 |
+
finish_wandb_logging:
|
| 749 |
+
purpose: Finish run with optional summary statistics
|
| 750 |
+
signature: |
|
| 751 |
+
finish_wandb_logging(
|
| 752 |
+
summary: dict[str, Any] | None = None,
|
| 753 |
+
) -> None
|
| 754 |
+
|
| 755 |
+
log_model_checkpoint:
|
| 756 |
+
purpose: Log checkpoint as wandb artifact
|
| 757 |
+
signature: |
|
| 758 |
+
log_model_checkpoint(
|
| 759 |
+
checkpoint_path: str,
|
| 760 |
+
metadata: dict[str, Any] | None = None,
|
| 761 |
+
) -> None
|
| 762 |
+
|
| 763 |
+
graceful_degradation: |
|
| 764 |
+
All functions work gracefully when wandb is not installed:
|
| 765 |
+
- is_wandb_available() returns False
|
| 766 |
+
- Logging functions print fallback messages or no-op
|
| 767 |
+
- create_logging_reward returns valid reward function (prints to stdout)
|
| 768 |
+
|
| 769 |
+
integration_example: |
|
| 770 |
+
from prolewiki_llm import (
|
| 771 |
+
init_wandb_logging,
|
| 772 |
+
WandbSampleLogger,
|
| 773 |
+
create_logging_reward,
|
| 774 |
+
finish_wandb_logging,
|
| 775 |
+
match_format_exactly,
|
| 776 |
+
full_coherence_reward,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Initialize wandb
|
| 780 |
+
run = init_wandb_logging(
|
| 781 |
+
project="marxist-grpo",
|
| 782 |
+
config={"model": "DeepSeek-R1", "lr": 5e-6, "steps": 250},
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
# Create sample logger and logging reward
|
| 786 |
+
sample_logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=4)
|
| 787 |
+
logging_reward = create_logging_reward(sample_logger, compute_all_rewards=True)
|
| 788 |
+
|
| 789 |
+
# Train with logging
|
| 790 |
+
trainer = GRPOTrainer(
|
| 791 |
+
model=model,
|
| 792 |
+
reward_funcs=[
|
| 793 |
+
match_format_exactly,
|
| 794 |
+
full_coherence_reward,
|
| 795 |
+
logging_reward, # Logs all metrics + samples
|
| 796 |
+
],
|
| 797 |
+
args=training_args,
|
| 798 |
+
train_dataset=dataset,
|
| 799 |
+
)
|
| 800 |
+
trainer.train()
|
| 801 |
+
|
| 802 |
+
# Finish with summary
|
| 803 |
+
finish_wandb_logging(summary={"final_loss": trainer.state.loss})
|
| 804 |
+
|
| 805 |
+
# =============================================================================
|
| 806 |
+
# TESTING
|
| 807 |
+
# =============================================================================
|
| 808 |
+
|
| 809 |
+
testing:
|
| 810 |
+
|
| 811 |
+
test_word_soup:
|
| 812 |
+
input: "bourgeoisie proletariat dialectical materialism surplus value"
|
| 813 |
+
expected_nli: "neutral (0.93)"
|
| 814 |
+
expected_structure: "low (no syntactic roles)"
|
| 815 |
+
expected_topic: "fails - no proper sentence to extract topics from"
|
| 816 |
+
|
| 817 |
+
test_good_response:
|
| 818 |
+
input: "The bourgeoisie extracts surplus value from the proletariat through exploitation of labor power."
|
| 819 |
+
expected_nli: "entailment (depends on ground truth)"
|
| 820 |
+
expected_structure: "high (terms in subject/object positions)"
|
| 821 |
+
expected_topic: "high (bourgeoisie, surplus value, proletariat in proper roles)"
|
| 822 |
+
|
| 823 |
+
test_contradiction:
|
| 824 |
+
input: "Capitalism benefits everyone. Workers are exploited under capitalism."
|
| 825 |
+
expected_self_consistency: "-2.0 (internal contradiction)"
|
| 826 |
+
|
| 827 |
+
test_off_topic:
|
| 828 |
+
question: "What is revisionism?"
|
| 829 |
+
answer: "Imperialism is the highest stage of capitalism characterized by monopolies."
|
| 830 |
+
expected_topic_relevance: "-1.5 (off-topic - discusses imperialism not revisionism)"
|
| 831 |
+
|
| 832 |
+
test_synonym_recognition:
|
| 833 |
+
question: "What is the bourgeoisie?"
|
| 834 |
+
answer: "The capitalist class owns the means of production."
|
| 835 |
+
expected_topic_relevance: "+2.0 (synonym 'capitalist class' recognized)"
|
| 836 |
+
|
| 837 |
+
# =============================================================================
|
| 838 |
+
# REFERENCES
|
| 839 |
+
# =============================================================================
|
| 840 |
+
|
| 841 |
+
references:
|
| 842 |
+
papers:
|
| 843 |
+
- "Better LM-Based Judging Reward Modeling: arxiv.org/abs/2508.18212"
|
| 844 |
+
- "MO-GRPO Mitigating Reward Hacking: arxiv.org/abs/2509.22047"
|
| 845 |
+
- "Posterior-GRPO Process Rewards: arxiv.org/abs/2508.05170"
|
| 846 |
+
- "MENLI NLI Evaluation Metrics: doi.org/10.1162/tacl_a_00576"
|
| 847 |
+
|
| 848 |
+
models:
|
| 849 |
+
- "BART-large-MNLI: huggingface.co/facebook/bart-large-mnli"
|
| 850 |
+
- "all-MiniLM-L6-v2: huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
|
| 851 |
+
- "spaCy en_core_web_trf: spacy.io/models/en#en_core_web_trf (RECOMMENDED)"
|
| 852 |
+
- "spaCy en_core_web_md: spacy.io/models/en#en_core_web_md (fallback)"
|
| 853 |
+
|
| 854 |
+
related_docs:
|
| 855 |
+
- "ai-docs/finetune.yaml - overall fine-tuning strategy"
|
| 856 |
+
- "ai-docs/chatbot-ideology.yaml - training data design"
|
| 857 |
+
- "ai-docs/runpod.yaml - GPU setup for training"
|
ai-docs/runpod.yaml
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RunPod.io Setup Guide for LLM Fine-Tuning
|
| 2 |
+
# Optimized for DeepSeek 7B Abliterated with Unsloth QLoRA
|
| 3 |
+
# Status: READY - verified configuration for Phase 8
|
| 4 |
+
|
| 5 |
+
overview:
|
| 6 |
+
purpose: |
|
| 7 |
+
Step-by-step guide for deploying a RunPod GPU pod to fine-tune
|
| 8 |
+
DeepSeek-R1-Distill-Qwen-7B-abliterated on ProleWiki corpus using Unsloth.
|
| 9 |
+
|
| 10 |
+
estimated_cost: "$0.30-0.60 for complete training run (~30 min)"
|
| 11 |
+
|
| 12 |
+
workflow_summary:
|
| 13 |
+
- Create pod with PyTorch 2.4 template
|
| 14 |
+
- Install Unsloth and dependencies
|
| 15 |
+
- Upload training data (JSONL chunks)
|
| 16 |
+
- Run SFT training (~20-30 min)
|
| 17 |
+
- Export GGUF model
|
| 18 |
+
- Download and deploy to Ollama
|
| 19 |
+
- STOP POD immediately after download
|
| 20 |
+
|
| 21 |
+
gpu_selection:
|
| 22 |
+
recommended: RTX 4090
|
| 23 |
+
vram_required: "16-18GB with Unsloth QLoRA"
|
| 24 |
+
note: |
|
| 25 |
+
Unsloth's QLoRA reduces 7B model VRAM from ~24GB to ~16-18GB.
|
| 26 |
+
RTX 4090 (24GB) provides comfortable headroom.
|
| 27 |
+
|
| 28 |
+
options:
|
| 29 |
+
rtx_4090:
|
| 30 |
+
vram: 24GB
|
| 31 |
+
price_spot: "$0.40-0.50/hr"
|
| 32 |
+
price_ondemand: "$0.50-0.60/hr"
|
| 33 |
+
recommendation: "Best value - sufficient VRAM, fast training"
|
| 34 |
+
|
| 35 |
+
a40:
|
| 36 |
+
vram: 48GB
|
| 37 |
+
price_spot: "$0.45-0.55/hr"
|
| 38 |
+
price_ondemand: "$0.50-0.65/hr"
|
| 39 |
+
recommendation: "More headroom, similar price"
|
| 40 |
+
|
| 41 |
+
rtx_3090:
|
| 42 |
+
vram: 24GB
|
| 43 |
+
price_spot: "$0.25-0.35/hr"
|
| 44 |
+
price_ondemand: "$0.30-0.40/hr"
|
| 45 |
+
recommendation: "Budget option, slightly older"
|
| 46 |
+
|
| 47 |
+
a100_40gb:
|
| 48 |
+
vram: 40GB
|
| 49 |
+
price_spot: "$0.80-1.00/hr"
|
| 50 |
+
price_ondemand: "$1.00-1.50/hr"
|
| 51 |
+
recommendation: "Overkill for 7B, use for larger models"
|
| 52 |
+
|
| 53 |
+
spot_vs_ondemand:
|
| 54 |
+
spot:
|
| 55 |
+
pros: "30-50% cheaper"
|
| 56 |
+
cons: "May be interrupted if demand spikes"
|
| 57 |
+
best_for: "Long training runs where checkpoints save progress"
|
| 58 |
+
|
| 59 |
+
ondemand:
|
| 60 |
+
pros: "Guaranteed availability"
|
| 61 |
+
cons: "Full price"
|
| 62 |
+
best_for: "Short runs (<1hr) like our 30-min training"
|
| 63 |
+
|
| 64 |
+
recommendation: |
|
| 65 |
+
For ProleWiki fine-tuning (~30 min), use On-Demand.
|
| 66 |
+
Spot interruption would cost more in setup time than savings.
|
| 67 |
+
|
| 68 |
+
pod_configuration:
|
| 69 |
+
template: "RunPod PyTorch 2.4"
|
| 70 |
+
alternative: "RunPod PyTorch 2.8 (if available)"
|
| 71 |
+
|
| 72 |
+
template_includes:
|
| 73 |
+
- PyTorch 2.4
|
| 74 |
+
- CUDA 12.4
|
| 75 |
+
- cuDNN
|
| 76 |
+
- JupyterLab
|
| 77 |
+
- SSH access
|
| 78 |
+
- Python 3.10+
|
| 79 |
+
|
| 80 |
+
storage:
|
| 81 |
+
container_disk:
|
| 82 |
+
size: "50GB minimum"
|
| 83 |
+
purpose: "Ephemeral - Unsloth, model weights during training"
|
| 84 |
+
warning: "LOST on pod restart!"
|
| 85 |
+
|
| 86 |
+
volume_disk:
|
| 87 |
+
size: "100GB minimum"
|
| 88 |
+
purpose: "Persistent - checkpoints, scripts, training data"
|
| 89 |
+
critical: "ALL important files must go here!"
|
| 90 |
+
mount_path: "/workspace"
|
| 91 |
+
|
| 92 |
+
ports:
|
| 93 |
+
- port: 8888
|
| 94 |
+
purpose: "JupyterLab (primary interface)"
|
| 95 |
+
- port: 22
|
| 96 |
+
purpose: "SSH (optional, for terminal access)"
|
| 97 |
+
|
| 98 |
+
environment_variables:
|
| 99 |
+
required:
|
| 100 |
+
HF_TOKEN: "Your Hugging Face token (for gated models)"
|
| 101 |
+
|
| 102 |
+
optional:
|
| 103 |
+
JUPYTER_PASSWORD: "Secure notebook access"
|
| 104 |
+
WANDB_API_KEY: "If using Weights & Biases logging"
|
| 105 |
+
|
| 106 |
+
secure_secrets:
|
| 107 |
+
note: |
|
| 108 |
+
Use RUNPOD_SECRET_ prefix for encrypted secrets:
|
| 109 |
+
RUNPOD_SECRET_HF_TOKEN will be injected securely.
|
| 110 |
+
|
| 111 |
+
step_by_step_setup:
|
| 112 |
+
step_1_create_pod:
|
| 113 |
+
description: "Create GPU pod from RunPod dashboard"
|
| 114 |
+
actions:
|
| 115 |
+
- "Go to https://runpod.io/console/pods"
|
| 116 |
+
- "Click '+ Deploy' or 'New Pod'"
|
| 117 |
+
- "Select GPU: RTX 4090 (or A40)"
|
| 118 |
+
- "Click 'Change Template' → search 'PyTorch'"
|
| 119 |
+
- "Select 'RunPod PyTorch 2.4'"
|
| 120 |
+
- "Set Container Disk: 50 GB"
|
| 121 |
+
- "Set Volume Disk: 100 GB"
|
| 122 |
+
- "Expand 'Environment Variables'"
|
| 123 |
+
- "Add: HF_TOKEN = your_token"
|
| 124 |
+
- "Click 'Deploy On-Demand' (not Spot for short runs)"
|
| 125 |
+
|
| 126 |
+
step_2_connect:
|
| 127 |
+
description: "Connect to running pod"
|
| 128 |
+
actions:
|
| 129 |
+
- "Wait for pod status: 'Running' (usually <1 min)"
|
| 130 |
+
- "Click 'Connect' button"
|
| 131 |
+
- "Select 'Jupyter Lab' (opens in new tab)"
|
| 132 |
+
- "Or select 'SSH' for terminal access"
|
| 133 |
+
|
| 134 |
+
step_3_install_unsloth:
|
| 135 |
+
description: "Install Unsloth and dependencies in JupyterLab terminal"
|
| 136 |
+
commands: |
|
| 137 |
+
# Verify CUDA is working
|
| 138 |
+
nvidia-smi
|
| 139 |
+
python -c "import torch; print(f'CUDA: {torch.cuda.get_device_name()}')"
|
| 140 |
+
|
| 141 |
+
# Install Unsloth (auto-detects CUDA version)
|
| 142 |
+
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
| 143 |
+
|
| 144 |
+
# Install flash-attention (may take a few minutes to compile)
|
| 145 |
+
pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"
|
| 146 |
+
|
| 147 |
+
# Install training dependencies
|
| 148 |
+
pip install trl>=0.7.0 datasets accelerate bitsandbytes peft
|
| 149 |
+
|
| 150 |
+
# Install tiktoken for data transformation
|
| 151 |
+
pip install tiktoken
|
| 152 |
+
|
| 153 |
+
# Verify installation
|
| 154 |
+
python -c "from unsloth import FastLanguageModel; print('Unsloth ready!')"
|
| 155 |
+
|
| 156 |
+
troubleshooting:
|
| 157 |
+
flash_attn_fails: |
|
| 158 |
+
If flash-attn compilation fails, try:
|
| 159 |
+
pip install flash-attn --no-build-isolation
|
| 160 |
+
|
| 161 |
+
cuda_version_mismatch: |
|
| 162 |
+
If CUDA errors occur, specify version explicitly:
|
| 163 |
+
pip install "unsloth[cu124] @ git+https://github.com/unslothai/unsloth.git"
|
| 164 |
+
|
| 165 |
+
step_4_upload_data:
|
| 166 |
+
description: "Upload training data to pod"
|
| 167 |
+
|
| 168 |
+
option_a_jupyterlab:
|
| 169 |
+
best_for: "Small datasets (<100MB)"
|
| 170 |
+
steps:
|
| 171 |
+
- "In JupyterLab file browser (left sidebar)"
|
| 172 |
+
- "Navigate to /workspace"
|
| 173 |
+
- "Create folder: 'data'"
|
| 174 |
+
- "Click upload icon (up arrow)"
|
| 175 |
+
- "Select your JSONL chunks file"
|
| 176 |
+
|
| 177 |
+
option_b_wget:
|
| 178 |
+
best_for: "Data hosted on web"
|
| 179 |
+
command: |
|
| 180 |
+
mkdir -p /workspace/data
|
| 181 |
+
wget https://your-url/library_chunks.jsonl -O /workspace/data/chunks.jsonl
|
| 182 |
+
|
| 183 |
+
option_c_huggingface:
|
| 184 |
+
best_for: "Dataset on Hugging Face"
|
| 185 |
+
command: |
|
| 186 |
+
huggingface-cli download your-user/prolewiki-chunks \
|
| 187 |
+
--local-dir /workspace/data \
|
| 188 |
+
--token $HF_TOKEN
|
| 189 |
+
|
| 190 |
+
option_d_scp:
|
| 191 |
+
best_for: "From local machine via SSH"
|
| 192 |
+
command: |
|
| 193 |
+
# Get SSH command from RunPod 'Connect' dropdown
|
| 194 |
+
scp -P 22XXX library_chunks.jsonl root@pod-ip:/workspace/data/
|
| 195 |
+
|
| 196 |
+
step_5_run_training:
|
| 197 |
+
description: "Execute fine-tuning script"
|
| 198 |
+
note: "See ai-docs/finetune.yaml for complete training code"
|
| 199 |
+
|
| 200 |
+
minimal_script: |
|
| 201 |
+
from unsloth import FastLanguageModel
|
| 202 |
+
import torch
|
| 203 |
+
|
| 204 |
+
# Load abliterated model
|
| 205 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 206 |
+
model_name="huihui-ai/DeepSeek-R1-Distill-Qwen-7B-abliterated",
|
| 207 |
+
max_seq_length=2048,
|
| 208 |
+
load_in_4bit=True,
|
| 209 |
+
dtype=None, # Auto-detect
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Apply LoRA
|
| 213 |
+
model = FastLanguageModel.get_peft_model(
|
| 214 |
+
model,
|
| 215 |
+
r=16,
|
| 216 |
+
lora_alpha=32,
|
| 217 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 218 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 219 |
+
lora_dropout=0.05,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Load and train (see finetune.yaml for full code)
|
| 223 |
+
# ...
|
| 224 |
+
|
| 225 |
+
# Save checkpoint to Volume disk!
|
| 226 |
+
model.save_pretrained("/workspace/checkpoints/marxist-deepseek-lora")
|
| 227 |
+
|
| 228 |
+
expected_time: "20-30 minutes for ~1,000 samples, 3 epochs"
|
| 229 |
+
|
| 230 |
+
step_6_export_gguf:
|
| 231 |
+
description: "Export model to GGUF format for Ollama"
|
| 232 |
+
command: |
|
| 233 |
+
# Export with q4_k_m quantization (good balance)
|
| 234 |
+
model.save_pretrained_gguf(
|
| 235 |
+
"/workspace/exports/marxist-deepseek",
|
| 236 |
+
tokenizer,
|
| 237 |
+
quantization_method="q4_k_m"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Check output
|
| 241 |
+
ls -lh /workspace/exports/
|
| 242 |
+
|
| 243 |
+
output_size: "~4GB for 7B q4_k_m"
|
| 244 |
+
|
| 245 |
+
quantization_options:
|
| 246 |
+
q4_k_m: "Recommended - good quality/size balance (~4GB)"
|
| 247 |
+
q5_k_m: "Higher quality, larger (~5GB)"
|
| 248 |
+
q8_0: "Best quality, largest (~7GB)"
|
| 249 |
+
|
| 250 |
+
step_7_download_model:
|
| 251 |
+
description: "Download GGUF to local machine"
|
| 252 |
+
|
| 253 |
+
option_a_jupyterlab:
|
| 254 |
+
steps:
|
| 255 |
+
- "In JupyterLab file browser"
|
| 256 |
+
- "Navigate to /workspace/exports/"
|
| 257 |
+
- "Right-click the .gguf file"
|
| 258 |
+
- "Select 'Download'"
|
| 259 |
+
|
| 260 |
+
option_b_runpodctl:
|
| 261 |
+
command: |
|
| 262 |
+
# Install runpodctl locally first
|
| 263 |
+
# https://github.com/runpod/runpodctl
|
| 264 |
+
|
| 265 |
+
runpodctl receive /workspace/exports/marxist-deepseek-q4_k_m.gguf
|
| 266 |
+
|
| 267 |
+
step_8_stop_pod:
|
| 268 |
+
description: "CRITICAL - Stop pod to avoid charges"
|
| 269 |
+
warning: "Billing continues until pod is stopped!"
|
| 270 |
+
actions:
|
| 271 |
+
- "Verify GGUF downloaded successfully to local machine"
|
| 272 |
+
- "In RunPod dashboard, click 'Stop' on your pod"
|
| 273 |
+
- "Wait for status: 'Stopped'"
|
| 274 |
+
- "Delete pod if you don't need it again"
|
| 275 |
+
- "Volume disk data persists even after pod deletion"
|
| 276 |
+
|
| 277 |
+
ollama_deployment:
|
| 278 |
+
description: "Deploy GGUF to local Ollama after download"
|
| 279 |
+
|
| 280 |
+
steps:
|
| 281 |
+
- step: "Create Modelfile"
|
| 282 |
+
content: |
|
| 283 |
+
# Save as: Modelfile.marxist-deepseek
|
| 284 |
+
FROM ./marxist-deepseek-q4_k_m.gguf
|
| 285 |
+
|
| 286 |
+
TEMPLATE """<|im_start|>system
|
| 287 |
+
{{ .System }}<|im_end|>
|
| 288 |
+
<|im_start|>user
|
| 289 |
+
{{ .Prompt }}<|im_end|>
|
| 290 |
+
<|im_start|>assistant
|
| 291 |
+
{{ .Response }}<|im_end|>"""
|
| 292 |
+
|
| 293 |
+
SYSTEM "You are a Marxist-Leninist assistant trained on ProleWiki."
|
| 294 |
+
|
| 295 |
+
PARAMETER stop "<|im_end|>"
|
| 296 |
+
PARAMETER temperature 0.7
|
| 297 |
+
PARAMETER top_p 0.9
|
| 298 |
+
|
| 299 |
+
- step: "Create Ollama model"
|
| 300 |
+
command: "ollama create marxist-deepseek -f Modelfile.marxist-deepseek"
|
| 301 |
+
|
| 302 |
+
- step: "Test model"
|
| 303 |
+
command: "ollama run marxist-deepseek 'Explain dialectical materialism.'"
|
| 304 |
+
|
| 305 |
+
cost_summary:
|
| 306 |
+
example_run:
|
| 307 |
+
gpu: "RTX 4090 On-Demand"
|
| 308 |
+
rate: "$0.55/hr"
|
| 309 |
+
time: "30 minutes"
|
| 310 |
+
total: "$0.28"
|
| 311 |
+
|
| 312 |
+
breakdown:
|
| 313 |
+
setup: "5 min - Pod creation, Unsloth install"
|
| 314 |
+
upload: "2 min - Data transfer"
|
| 315 |
+
training: "20-25 min - SFT with QLoRA"
|
| 316 |
+
export: "3 min - GGUF conversion"
|
| 317 |
+
download: "5 min - Transfer GGUF locally"
|
| 318 |
+
total_time: "~35-40 min"
|
| 319 |
+
|
| 320 |
+
tips:
|
| 321 |
+
- "Use On-Demand for short runs (<1hr)"
|
| 322 |
+
- "Use Spot for long runs with checkpoint saving"
|
| 323 |
+
- "Stop pod IMMEDIATELY after download"
|
| 324 |
+
- "Delete pod after confirming success"
|
| 325 |
+
- "Volume disk persists - can restart training later"
|
| 326 |
+
|
| 327 |
+
troubleshooting:
|
| 328 |
+
out_of_memory:
|
| 329 |
+
symptoms: "CUDA OOM, kernel dies"
|
| 330 |
+
solutions:
|
| 331 |
+
- "Reduce batch size in training args"
|
| 332 |
+
- "Ensure load_in_4bit=True"
|
| 333 |
+
- "Use gradient_checkpointing=True"
|
| 334 |
+
- "Upgrade to A40 (48GB VRAM)"
|
| 335 |
+
|
| 336 |
+
slow_training:
|
| 337 |
+
symptoms: "Steps/sec much lower than expected"
|
| 338 |
+
solutions:
|
| 339 |
+
- "Verify GPU is being used: nvidia-smi"
|
| 340 |
+
- "Check torch.cuda.is_available()"
|
| 341 |
+
- "Ensure flash-attn installed correctly"
|
| 342 |
+
|
| 343 |
+
pod_wont_start:
|
| 344 |
+
symptoms: "Pod stuck in 'Pending' or 'Initializing'"
|
| 345 |
+
solutions:
|
| 346 |
+
- "Try different data center region"
|
| 347 |
+
- "Try different GPU type"
|
| 348 |
+
- "Check RunPod status page"
|
| 349 |
+
|
| 350 |
+
checkpoint_lost:
|
| 351 |
+
symptoms: "Can't find saved model after restart"
|
| 352 |
+
cause: "Saved to Container Disk instead of Volume"
|
| 353 |
+
prevention: "ALWAYS use /workspace/ for important files"
|
| 354 |
+
|
| 355 |
+
references:
|
| 356 |
+
runpod_docs: "https://docs.runpod.io/"
|
| 357 |
+
unsloth_github: "https://github.com/unslothai/unsloth"
|
| 358 |
+
context7_runpod: "/runpod/docs"
|
| 359 |
+
related_docs:
|
| 360 |
+
- "ai-docs/finetune.yaml - Complete training configuration"
|
| 361 |
+
- "ai-docs/embedding.yaml - Embedding pipeline"
|
| 362 |
+
- "ai-docs/project-status.yaml - Phase 8 status"
|
ai-docs/training-schema.yaml
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Data Schema Reference
|
| 2 |
+
# Purpose: Human-readable documentation for Marxist-GRPO training data format
|
| 3 |
+
# Formal Schema: training_data/schema/training_record.schema.json
|
| 4 |
+
# Updated: 2025-12-18
|
| 5 |
+
|
| 6 |
+
# =============================================================================
|
| 7 |
+
# OVERVIEW
|
| 8 |
+
# =============================================================================
|
| 9 |
+
|
| 10 |
+
overview:
|
| 11 |
+
purpose: |
|
| 12 |
+
This schema defines the format for Q&A training records in the Marxist-GRPO
|
| 13 |
+
fine-tuning dataset. Each record contains an instruction-response pair with
|
| 14 |
+
comprehensive metadata for:
|
| 15 |
+
- Provenance tracking (where did this come from?)
|
| 16 |
+
- Theoretical classification (what tradition/topic?)
|
| 17 |
+
- Citation tracking (what sources are referenced?)
|
| 18 |
+
- Training metadata (what issue does this fix?)
|
| 19 |
+
- Quality assessment (has this been verified?)
|
| 20 |
+
|
| 21 |
+
design_principles:
|
| 22 |
+
- Reproducibility: Every record traceable to source
|
| 23 |
+
- Filterability: Train on subsets by any dimension
|
| 24 |
+
- Scientific Rigor: Formal JSON Schema validation
|
| 25 |
+
- RAG Integration: Links to ChromaDB chunks where applicable
|
| 26 |
+
- Iteration Tracking: Know what was added when and why
|
| 27 |
+
|
| 28 |
+
json_schema_location: training_data/schema/training_record.schema.json
|
| 29 |
+
manifest_schema_location: training_data/schema/manifest.schema.json
|
| 30 |
+
|
| 31 |
+
# =============================================================================
|
| 32 |
+
# RECORD FORMAT
|
| 33 |
+
# =============================================================================
|
| 34 |
+
|
| 35 |
+
record_format:
|
| 36 |
+
description: |
|
| 37 |
+
Each JSONL file contains one JSON object per line.
|
| 38 |
+
Every record MUST have: instruction, response, metadata
|
| 39 |
+
The metadata object contains all provenance and classification.
|
| 40 |
+
|
| 41 |
+
minimal_example:
|
| 42 |
+
instruction: "What is the mass line?"
|
| 43 |
+
response: "The mass line is the Maoist method of communist leadership..."
|
| 44 |
+
metadata:
|
| 45 |
+
id: "synthetic/maoist-theory/001"
|
| 46 |
+
source:
|
| 47 |
+
type: "synthetic"
|
| 48 |
+
classification:
|
| 49 |
+
categories: ["maoist-theory", "methodology"]
|
| 50 |
+
tradition: "MLM"
|
| 51 |
+
provenance:
|
| 52 |
+
created_date: "2025-12-18"
|
| 53 |
+
created_by: "claude-opus"
|
| 54 |
+
|
| 55 |
+
full_example:
|
| 56 |
+
instruction: "What is the Marxist-Leninist distinction between antisemitism and anti-Zionism?"
|
| 57 |
+
response: "These are fundamentally different phenomena. Antisemitism is a form of racism..."
|
| 58 |
+
metadata:
|
| 59 |
+
id: "synthetic/antisemitism/001"
|
| 60 |
+
source:
|
| 61 |
+
type: "synthetic"
|
| 62 |
+
author: null
|
| 63 |
+
work: null
|
| 64 |
+
article: null
|
| 65 |
+
chunk_ids: []
|
| 66 |
+
classification:
|
| 67 |
+
categories: ["anti-zionism", "antisemitism", "settler-colonialism"]
|
| 68 |
+
tradition: "ML"
|
| 69 |
+
geographic_focus: "Palestine"
|
| 70 |
+
historical_period: null
|
| 71 |
+
citations:
|
| 72 |
+
has_citations: true
|
| 73 |
+
authors: ["Lenin", "Ilan Pappé", "Noam Chomsky"]
|
| 74 |
+
works:
|
| 75 |
+
- title: "On Anti-Jewish Pogroms"
|
| 76 |
+
author: "Lenin"
|
| 77 |
+
year: 1919
|
| 78 |
+
type: "speech"
|
| 79 |
+
- title: "The Ethnic Cleansing of Palestine"
|
| 80 |
+
author: "Ilan Pappé"
|
| 81 |
+
year: 2006
|
| 82 |
+
type: "book"
|
| 83 |
+
training:
|
| 84 |
+
iteration: 2
|
| 85 |
+
correction_for: ["both-sidesing", "antisemitism-conflation"]
|
| 86 |
+
difficulty: "intermediate"
|
| 87 |
+
response_style: "educational"
|
| 88 |
+
adversarial_type: null
|
| 89 |
+
provenance:
|
| 90 |
+
created_date: "2025-12-18"
|
| 91 |
+
created_by: "claude-opus"
|
| 92 |
+
reviewed_by: null
|
| 93 |
+
version: 1
|
| 94 |
+
quality:
|
| 95 |
+
human_verified: false
|
| 96 |
+
confidence: "high"
|
| 97 |
+
notes: null
|
| 98 |
+
|
| 99 |
+
# =============================================================================
|
| 100 |
+
# METADATA FIELDS REFERENCE
|
| 101 |
+
# =============================================================================
|
| 102 |
+
|
| 103 |
+
fields:
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
# SOURCE PROVENANCE
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
source:
|
| 108 |
+
description: Where this Q&A pair originated from.
|
| 109 |
+
|
| 110 |
+
type:
|
| 111 |
+
required: true
|
| 112 |
+
values:
|
| 113 |
+
prolewiki: Derived from ProleWiki article content
|
| 114 |
+
synthetic: Generated by AI for specific purpose
|
| 115 |
+
curated: Human-curated from multiple sources
|
| 116 |
+
library: Derived from Library namespace (full works)
|
| 117 |
+
external: From external source with URL
|
| 118 |
+
|
| 119 |
+
article:
|
| 120 |
+
required: false
|
| 121 |
+
purpose: ProleWiki article title if derived from corpus
|
| 122 |
+
example: "Main/Imperialism"
|
| 123 |
+
links_to: chromadb.article_title
|
| 124 |
+
|
| 125 |
+
work:
|
| 126 |
+
required: false
|
| 127 |
+
purpose: Title of source work for Library-derived Q&As
|
| 128 |
+
example: "Imperialism, the Highest Stage of Capitalism"
|
| 129 |
+
|
| 130 |
+
author:
|
| 131 |
+
required: false
|
| 132 |
+
purpose: Primary author of source material
|
| 133 |
+
example: "Lenin"
|
| 134 |
+
enables: Train only on Marx-derived, Lenin-derived, etc.
|
| 135 |
+
|
| 136 |
+
chunk_ids:
|
| 137 |
+
required: false
|
| 138 |
+
purpose: ChromaDB chunk IDs this Q&A was derived from
|
| 139 |
+
example: ["Main/Imperialism#0", "Main/Imperialism#1"]
|
| 140 |
+
enables: RAG-training data linkage, citation verification
|
| 141 |
+
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
# CLASSIFICATION
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
classification:
|
| 146 |
+
description: Theoretical and topical classification.
|
| 147 |
+
|
| 148 |
+
categories:
|
| 149 |
+
required: true
|
| 150 |
+
purpose: Topic tags aligned with ProleWiki categories
|
| 151 |
+
examples:
|
| 152 |
+
- ["imperialism", "revisionism"]
|
| 153 |
+
- ["anti-zionism", "settler-colonialism", "national-liberation"]
|
| 154 |
+
- ["cultural-revolution", "gpcr", "maoist-theory"]
|
| 155 |
+
enables: Train on specific topics, measure coverage
|
| 156 |
+
|
| 157 |
+
tradition:
|
| 158 |
+
required: true
|
| 159 |
+
values:
|
| 160 |
+
ML: Marxism-Leninism (broad)
|
| 161 |
+
MLM: Marxism-Leninism-Maoism (includes GPCR defense)
|
| 162 |
+
general: Broadly applicable across tendencies
|
| 163 |
+
contested: Debated within ML circles
|
| 164 |
+
enables: Filter by theoretical tendency
|
| 165 |
+
|
| 166 |
+
geographic_focus:
|
| 167 |
+
required: false
|
| 168 |
+
examples: ["Soviet Union", "China", "Palestine", "Cuba"]
|
| 169 |
+
enables: Regional expertise training
|
| 170 |
+
|
| 171 |
+
historical_period:
|
| 172 |
+
required: false
|
| 173 |
+
examples: ["Russian Revolution", "Cultural Revolution", "Cold War"]
|
| 174 |
+
enables: Period-specific training
|
| 175 |
+
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
# CITATIONS
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
citations:
|
| 180 |
+
description: Citation and reference tracking.
|
| 181 |
+
|
| 182 |
+
has_citations:
|
| 183 |
+
purpose: Quick boolean filter for cited content
|
| 184 |
+
enables: Train only on well-sourced responses
|
| 185 |
+
|
| 186 |
+
works:
|
| 187 |
+
purpose: Structured list of cited works
|
| 188 |
+
fields: [title, author, year, type]
|
| 189 |
+
enables: Verify citations, trace to primary sources
|
| 190 |
+
|
| 191 |
+
authors:
|
| 192 |
+
purpose: Flat list of cited authors for filtering
|
| 193 |
+
enables: "Train on Lenin-citing records only"
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# TRAINING METADATA
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
training:
|
| 199 |
+
description: Training-specific metadata.
|
| 200 |
+
|
| 201 |
+
iteration:
|
| 202 |
+
purpose: Which training iteration added this record
|
| 203 |
+
enables: Ablation studies, measure iteration impact
|
| 204 |
+
|
| 205 |
+
correction_for:
|
| 206 |
+
purpose: What failure modes this addresses
|
| 207 |
+
values:
|
| 208 |
+
cpc-contamination: Fixes CPC authority citations
|
| 209 |
+
both-sidesing: Fixes false equivalence on colonial issues
|
| 210 |
+
hallucination: Provides correct historical facts
|
| 211 |
+
antisemitism-conflation: Distinguishes antisemitism/anti-Zionism
|
| 212 |
+
liberal-framing: Replaces liberal with ML framing
|
| 213 |
+
historical-inaccuracy: Corrects factual errors
|
| 214 |
+
theoretical-error: Corrects theoretical misunderstandings
|
| 215 |
+
accommodation: Resists incremental position shifts
|
| 216 |
+
extended-engagement: Models firm rejection
|
| 217 |
+
enables: Test specific corrections, targeted training
|
| 218 |
+
|
| 219 |
+
difficulty:
|
| 220 |
+
values:
|
| 221 |
+
basic: Straightforward ML questions
|
| 222 |
+
intermediate: Requires nuanced understanding
|
| 223 |
+
advanced: Complex theoretical synthesis
|
| 224 |
+
adversarial: Bad-faith or trap questions
|
| 225 |
+
enables: Curriculum learning, stress testing
|
| 226 |
+
|
| 227 |
+
response_style:
|
| 228 |
+
values:
|
| 229 |
+
educational: Thorough explanation
|
| 230 |
+
firm-rejection: Short, clear rejection of premise
|
| 231 |
+
theoretical: Abstract theoretical analysis
|
| 232 |
+
historical: Historical narrative/facts
|
| 233 |
+
biographical: Person-focused information
|
| 234 |
+
analytical: Systematic breakdown
|
| 235 |
+
comparative: Comparing positions/theories
|
| 236 |
+
enables: Style-specific training
|
| 237 |
+
|
| 238 |
+
adversarial_type:
|
| 239 |
+
purpose: For adversarial questions, what pattern
|
| 240 |
+
values:
|
| 241 |
+
bad-faith-question: User asking in bad faith
|
| 242 |
+
conspiracy-premise: Question contains conspiracy theory
|
| 243 |
+
incremental-shift: Gradually shifting goalposts
|
| 244 |
+
false-equivalence: Both-sidesing framing
|
| 245 |
+
appeal-to-complexity: '"It''s complicated" deflection'
|
| 246 |
+
|
| 247 |
+
# ---------------------------------------------------------------------------
|
| 248 |
+
# PROVENANCE
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
provenance:
|
| 251 |
+
description: Record creation and modification tracking.
|
| 252 |
+
|
| 253 |
+
created_date:
|
| 254 |
+
required: true
|
| 255 |
+
format: ISO 8601 date (YYYY-MM-DD)
|
| 256 |
+
purpose: When this record was created
|
| 257 |
+
|
| 258 |
+
created_by:
|
| 259 |
+
required: true
|
| 260 |
+
values: [human, claude-opus, claude-sonnet, other-llm, automated]
|
| 261 |
+
purpose: Who/what created this record
|
| 262 |
+
enables: Filter by creation method
|
| 263 |
+
|
| 264 |
+
reviewed_by:
|
| 265 |
+
purpose: Human reviewer identifier
|
| 266 |
+
enables: Track review coverage
|
| 267 |
+
|
| 268 |
+
version:
|
| 269 |
+
purpose: Increment on edits
|
| 270 |
+
enables: Track record evolution
|
| 271 |
+
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
# QUALITY
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
quality:
|
| 276 |
+
description: Quality assessment metadata.
|
| 277 |
+
|
| 278 |
+
human_verified:
|
| 279 |
+
purpose: Has a human verified accuracy?
|
| 280 |
+
enables: High-confidence subset training
|
| 281 |
+
|
| 282 |
+
confidence:
|
| 283 |
+
values: [high, medium, low]
|
| 284 |
+
purpose: Confidence in response accuracy
|
| 285 |
+
|
| 286 |
+
flagged_issues:
|
| 287 |
+
purpose: Known issues needing attention
|
| 288 |
+
enables: Exclude problematic records
|
| 289 |
+
|
| 290 |
+
# =============================================================================
|
| 291 |
+
# VALIDATION
|
| 292 |
+
# =============================================================================
|
| 293 |
+
|
| 294 |
+
validation:
|
| 295 |
+
json_schema:
|
| 296 |
+
location: training_data/schema/training_record.schema.json
|
| 297 |
+
draft: 2020-12
|
| 298 |
+
command: |
|
| 299 |
+
# Using jsonschema Python library
|
| 300 |
+
uv run python -c "
|
| 301 |
+
import json
|
| 302 |
+
import jsonschema
|
| 303 |
+
from pathlib import Path
|
| 304 |
+
|
| 305 |
+
schema = json.loads(Path('training_data/schema/training_record.schema.json').read_text())
|
| 306 |
+
for line in Path('training_data/your_file.jsonl').read_text().splitlines():
|
| 307 |
+
record = json.loads(line)
|
| 308 |
+
jsonschema.validate(record, schema)
|
| 309 |
+
print('All records valid!')
|
| 310 |
+
"
|
| 311 |
+
|
| 312 |
+
quick_validation:
|
| 313 |
+
command: |
|
| 314 |
+
# Quick JSON syntax check
|
| 315 |
+
python3 -c "import json; [json.loads(l) for l in open('file.jsonl')]; print('OK')"
|
| 316 |
+
|
| 317 |
+
pre_commit_hook:
|
| 318 |
+
description: Add to .pre-commit-config.yaml for automatic validation
|
| 319 |
+
config: |
|
| 320 |
+
- repo: local
|
| 321 |
+
hooks:
|
| 322 |
+
- id: validate-training-data
|
| 323 |
+
name: Validate Training Data Schema
|
| 324 |
+
entry: uv run python scripts/validate_training_data.py
|
| 325 |
+
language: system
|
| 326 |
+
files: ^training_data/.*\.jsonl$
|
| 327 |
+
|
| 328 |
+
# =============================================================================
|
| 329 |
+
# MANIFEST
|
| 330 |
+
# =============================================================================
|
| 331 |
+
|
| 332 |
+
manifest:
|
| 333 |
+
purpose: |
|
| 334 |
+
The manifest (MANIFEST.yaml) tracks all JSONL files in the dataset,
|
| 335 |
+
their checksums, statistics, and training history. This enables:
|
| 336 |
+
- Reproducible training runs
|
| 337 |
+
- Dataset versioning
|
| 338 |
+
- Integrity verification
|
| 339 |
+
- Statistics generation
|
| 340 |
+
|
| 341 |
+
location: training_data/MANIFEST.yaml
|
| 342 |
+
schema: training_data/schema/manifest.schema.json
|
| 343 |
+
|
| 344 |
+
key_sections:
|
| 345 |
+
dataset: Name, version, license, description
|
| 346 |
+
files: List of all JSONL files with checksums and metadata
|
| 347 |
+
statistics: Aggregate counts by source, category, tradition
|
| 348 |
+
training_iterations: History of training runs
|
| 349 |
+
known_issues: Documented problems
|
| 350 |
+
changelog: Dataset modification history
|
| 351 |
+
|
| 352 |
+
# =============================================================================
|
| 353 |
+
# FILTERING PATTERNS
|
| 354 |
+
# =============================================================================
|
| 355 |
+
|
| 356 |
+
filtering_patterns:
|
| 357 |
+
description: Common filtering operations for training subsets.
|
| 358 |
+
|
| 359 |
+
by_source:
|
| 360 |
+
code: |
|
| 361 |
+
# ProleWiki-derived only (corpus purity)
|
| 362 |
+
data = [r for r in records if r["metadata"]["source"]["type"] == "prolewiki"]
|
| 363 |
+
|
| 364 |
+
# Exclude synthetic for ablation
|
| 365 |
+
data = [r for r in records if r["metadata"]["source"]["type"] != "synthetic"]
|
| 366 |
+
|
| 367 |
+
by_author:
|
| 368 |
+
code: |
|
| 369 |
+
# Lenin-citing records
|
| 370 |
+
data = [r for r in records
|
| 371 |
+
if "Lenin" in r["metadata"].get("citations", {}).get("authors", [])]
|
| 372 |
+
|
| 373 |
+
# Marx or Engels sourced
|
| 374 |
+
data = [r for r in records
|
| 375 |
+
if r["metadata"]["source"].get("author") in ["Marx", "Engels"]]
|
| 376 |
+
|
| 377 |
+
by_tradition:
|
| 378 |
+
code: |
|
| 379 |
+
# MLM only (includes GPCR defense)
|
| 380 |
+
data = [r for r in records if r["metadata"]["classification"]["tradition"] == "MLM"]
|
| 381 |
+
|
| 382 |
+
by_correction:
|
| 383 |
+
code: |
|
| 384 |
+
# Records addressing Zionism issues
|
| 385 |
+
data = [r for r in records
|
| 386 |
+
if "both-sidesing" in r["metadata"].get("training", {}).get("correction_for", [])]
|
| 387 |
+
|
| 388 |
+
by_difficulty:
|
| 389 |
+
code: |
|
| 390 |
+
# Adversarial examples only (stress testing)
|
| 391 |
+
data = [r for r in records
|
| 392 |
+
if r["metadata"].get("training", {}).get("difficulty") == "adversarial"]
|
| 393 |
+
|
| 394 |
+
by_iteration:
|
| 395 |
+
code: |
|
| 396 |
+
# Only iteration 1 (baseline)
|
| 397 |
+
data = [r for r in records if r["metadata"].get("training", {}).get("iteration") == 1]
|
| 398 |
+
|
| 399 |
+
# Iterations 1-2 combined
|
| 400 |
+
data = [r for r in records if r["metadata"].get("training", {}).get("iteration", 1) <= 2]
|
| 401 |
+
|
| 402 |
+
by_quality:
|
| 403 |
+
code: |
|
| 404 |
+
# Human-verified only
|
| 405 |
+
data = [r for r in records if r["metadata"].get("quality", {}).get("human_verified")]
|
| 406 |
+
|
| 407 |
+
# High confidence
|
| 408 |
+
data = [r for r in records
|
| 409 |
+
if r["metadata"].get("quality", {}).get("confidence") == "high"]
|
| 410 |
+
|
| 411 |
+
# =============================================================================
|
| 412 |
+
# INTEGRATION WITH CHROMADB
|
| 413 |
+
# =============================================================================
|
| 414 |
+
|
| 415 |
+
chromadb_integration:
|
| 416 |
+
purpose: |
|
| 417 |
+
Training data can link to ChromaDB chunks, enabling:
|
| 418 |
+
- Verification that responses match corpus
|
| 419 |
+
- RAG-augmented training data generation
|
| 420 |
+
- Provenance chains from user query → chunk → training example
|
| 421 |
+
|
| 422 |
+
chunk_id_format: "{namespace}/{article_title}#{chunk_index}"
|
| 423 |
+
examples:
|
| 424 |
+
- "Main/Imperialism#0"
|
| 425 |
+
- "Library/Capital_Vol_1#127"
|
| 426 |
+
- "Essays/On_Revisionism#3"
|
| 427 |
+
|
| 428 |
+
linkage_pattern:
|
| 429 |
+
description: When generating training data from ProleWiki chunks
|
| 430 |
+
code: |
|
| 431 |
+
# Generate Q&A from chunk and preserve linkage
|
| 432 |
+
training_record = {
|
| 433 |
+
"instruction": generate_question(chunk),
|
| 434 |
+
"response": generate_answer(chunk),
|
| 435 |
+
"metadata": {
|
| 436 |
+
"source": {
|
| 437 |
+
"type": "prolewiki",
|
| 438 |
+
"article": chunk["article_title"],
|
| 439 |
+
"chunk_ids": [chunk["chunk_id"]]
|
| 440 |
+
},
|
| 441 |
+
# ... rest of metadata
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
# =============================================================================
|
| 446 |
+
# MIGRATION FROM LEGACY FORMAT
|
| 447 |
+
# =============================================================================
|
| 448 |
+
|
| 449 |
+
migration:
|
| 450 |
+
legacy_format:
|
| 451 |
+
description: Original curated_qa.jsonl format
|
| 452 |
+
example:
|
| 453 |
+
instruction: "What is revisionism?"
|
| 454 |
+
response: "Revisionism refers to..."
|
| 455 |
+
|
| 456 |
+
new_format:
|
| 457 |
+
description: Full metadata format
|
| 458 |
+
migration_steps:
|
| 459 |
+
- Add metadata wrapper
|
| 460 |
+
- Generate unique IDs
|
| 461 |
+
- Infer source type (curated for manual entries)
|
| 462 |
+
- Add classification based on content analysis
|
| 463 |
+
- Set iteration to 1 for baseline data
|
| 464 |
+
- Mark as needing human verification
|
| 465 |
+
|
| 466 |
+
migration_script: |
|
| 467 |
+
# See scripts/migrate_training_data.py for full implementation
|
| 468 |
+
def migrate_record(old_record, index):
|
| 469 |
+
return {
|
| 470 |
+
"instruction": old_record["instruction"],
|
| 471 |
+
"response": old_record["response"],
|
| 472 |
+
"metadata": {
|
| 473 |
+
"id": f"curated/legacy/{index:03d}",
|
| 474 |
+
"source": {"type": "curated"},
|
| 475 |
+
"classification": {
|
| 476 |
+
"categories": infer_categories(old_record),
|
| 477 |
+
"tradition": "ML"
|
| 478 |
+
},
|
| 479 |
+
"provenance": {
|
| 480 |
+
"created_date": "2025-12-17", # Original creation date
|
| 481 |
+
"created_by": "human"
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
}
|
docker/.env.example
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Marxist-GRPO Training Environment Variables
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# Copy this file to .env and fill in your values.
|
| 5 |
+
# These can be passed to docker run with --env-file or individually with -e.
|
| 6 |
+
#
|
| 7 |
+
# Usage:
|
| 8 |
+
# docker run --gpus all --env-file docker/.env marxist-grpo:latest
|
| 9 |
+
#
|
| 10 |
+
# Or with runpodctl:
|
| 11 |
+
# runpodctl create pod \
|
| 12 |
+
# --imageName myregistry/marxist-grpo:latest \
|
| 13 |
+
# --env HF_TOKEN=$HF_TOKEN \
|
| 14 |
+
# --env WANDB_API_KEY=$WANDB_API_KEY \
|
| 15 |
+
# --env HF_REPO=my-org/my-model
|
| 16 |
+
# =============================================================================
|
| 17 |
+
|
| 18 |
+
# -----------------------------------------------------------------------------
|
| 19 |
+
# REQUIRED SECRETS (must be set)
|
| 20 |
+
# -----------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
# HuggingFace API token (for model upload)
|
| 23 |
+
# Get yours at: https://huggingface.co/settings/tokens
|
| 24 |
+
HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 25 |
+
|
| 26 |
+
# Weights & Biases API key (for training monitoring)
|
| 27 |
+
# Get yours at: https://wandb.ai/authorize
|
| 28 |
+
WANDB_API_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 29 |
+
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
# MODEL CONFIGURATION
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
# Base model to fine-tune
|
| 35 |
+
MODEL_NAME=unsloth/DeepSeek-R1-0528-Qwen3-8B
|
| 36 |
+
|
| 37 |
+
# Maximum sequence length for tokenizer
|
| 38 |
+
MAX_SEQ_LENGTH=2048
|
| 39 |
+
|
| 40 |
+
# LoRA adapter rank (higher = more expressive, more params)
|
| 41 |
+
LORA_RANK=32
|
| 42 |
+
|
| 43 |
+
# -----------------------------------------------------------------------------
|
| 44 |
+
# TRAINING HYPERPARAMETERS
|
| 45 |
+
# -----------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
# Total training steps
|
| 48 |
+
MAX_STEPS=500
|
| 49 |
+
|
| 50 |
+
# Save checkpoint every N steps
|
| 51 |
+
SAVE_STEPS=50
|
| 52 |
+
|
| 53 |
+
# Learning rate
|
| 54 |
+
LEARNING_RATE=5e-6
|
| 55 |
+
|
| 56 |
+
# Warmup ratio (fraction of steps for LR warmup)
|
| 57 |
+
WARMUP_RATIO=0.1
|
| 58 |
+
|
| 59 |
+
# Per-device batch size
|
| 60 |
+
BATCH_SIZE=2
|
| 61 |
+
|
| 62 |
+
# Gradient accumulation steps (effective batch = BATCH_SIZE * GRADIENT_ACCUMULATION)
|
| 63 |
+
GRADIENT_ACCUMULATION=2
|
| 64 |
+
|
| 65 |
+
# Number of generations per prompt during GRPO
|
| 66 |
+
NUM_GENERATIONS=4
|
| 67 |
+
|
| 68 |
+
# Fraction of GPU memory to allocate
|
| 69 |
+
GPU_MEMORY_UTILIZATION=0.6
|
| 70 |
+
|
| 71 |
+
# Maximum prompt length (tokens)
|
| 72 |
+
MAX_PROMPT_LENGTH=512
|
| 73 |
+
|
| 74 |
+
# Maximum completion length (tokens)
|
| 75 |
+
MAX_COMPLETION_LENGTH=1500
|
| 76 |
+
|
| 77 |
+
# -----------------------------------------------------------------------------
|
| 78 |
+
# REWARD CONFIGURATION
|
| 79 |
+
# -----------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
# Reward mode: FULL (recommended), ROBUST, or LEGACY
|
| 82 |
+
# FULL: NLI + self-consistency + structure + topic relevance + depth
|
| 83 |
+
# ROBUST: NLI + self-consistency + structure
|
| 84 |
+
# LEGACY: Semantic similarity + terminology (faster but vulnerable to word soup)
|
| 85 |
+
REWARD_MODE=FULL
|
| 86 |
+
|
| 87 |
+
# -----------------------------------------------------------------------------
|
| 88 |
+
# OUTPUT CONFIGURATION
|
| 89 |
+
# -----------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
# HuggingFace repo to upload the trained LoRA adapter
|
| 92 |
+
HF_REPO=prolewiki/marxist-grpo-lora
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------------
|
| 95 |
+
# PATHS (container internal - usually don't change)
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
# Path to training dataset (JSONL)
|
| 99 |
+
DATASET_PATH=/workspace/dataset.jsonl
|
| 100 |
+
|
| 101 |
+
# Directory for training checkpoints
|
| 102 |
+
CHECKPOINT_DIR=/workspace/checkpoints
|
| 103 |
+
|
| 104 |
+
# Directory for final LoRA output
|
| 105 |
+
LORA_OUTPUT=/workspace/lora-output
|
| 106 |
+
|
| 107 |
+
# Directory for training outputs
|
| 108 |
+
OUTPUT_DIR=/workspace/outputs
|
| 109 |
+
|
| 110 |
+
# -----------------------------------------------------------------------------
|
| 111 |
+
# OPTIONAL: RUNPOD AUTO-TERMINATION
|
| 112 |
+
# -----------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
# Set this to enable automatic pod termination after training
|
| 115 |
+
# This prevents "zombie pods" from racking up bills
|
| 116 |
+
# Value is automatically set by RunPod, or can be set manually
|
| 117 |
+
# RUNPOD_POD_ID=
|
| 118 |
+
|
| 119 |
+
# -----------------------------------------------------------------------------
|
| 120 |
+
# OPTIONAL: REMOTE DATASET
|
| 121 |
+
# -----------------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
# If dataset is not embedded in the image, set this URL to download it
|
| 124 |
+
# DATASET_URL=https://my-bucket.s3.amazonaws.com/grpo_dataset.jsonl
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Marxist-GRPO Training Container
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# Headless GRPO fine-tuning container for RunPod deployment.
|
| 5 |
+
#
|
| 6 |
+
# Build:
|
| 7 |
+
# docker build -t marxist-grpo:latest docker/
|
| 8 |
+
#
|
| 9 |
+
# Run locally (testing):
|
| 10 |
+
# docker run --gpus all \
|
| 11 |
+
# -e HF_TOKEN=$HF_TOKEN \
|
| 12 |
+
# -e WANDB_API_KEY=$WANDB_API_KEY \
|
| 13 |
+
# -e MAX_STEPS=10 \
|
| 14 |
+
# marxist-grpo:latest
|
| 15 |
+
#
|
| 16 |
+
# Deploy to RunPod:
|
| 17 |
+
# runpodctl create pod \
|
| 18 |
+
# --name "marxist-grpo-training" \
|
| 19 |
+
# --gpuType "NVIDIA A100 80GB PCIe" \
|
| 20 |
+
# --imageName "myregistry/marxist-grpo:latest" \
|
| 21 |
+
# --env HF_TOKEN=$HF_TOKEN \
|
| 22 |
+
# --env WANDB_API_KEY=$WANDB_API_KEY \
|
| 23 |
+
# --env HF_REPO=my-org/my-model
|
| 24 |
+
# =============================================================================
|
| 25 |
+
|
| 26 |
+
# Use RunPod's PyTorch base image with CUDA 11.8
|
| 27 |
+
FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel
|
| 28 |
+
|
| 29 |
+
# Set working directory
|
| 30 |
+
WORKDIR /workspace
|
| 31 |
+
|
| 32 |
+
# Prevent interactive prompts during package installation
|
| 33 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 34 |
+
|
| 35 |
+
# Install system dependencies
|
| 36 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 37 |
+
git \
|
| 38 |
+
git-lfs \
|
| 39 |
+
curl \
|
| 40 |
+
htop \
|
| 41 |
+
nvtop \
|
| 42 |
+
tmux \
|
| 43 |
+
wget \
|
| 44 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 45 |
+
&& git lfs install
|
| 46 |
+
|
| 47 |
+
# Install runpodctl for pod self-termination
|
| 48 |
+
RUN curl -fsSL -o /tmp/runpodctl.tar.gz \
|
| 49 |
+
https://github.com/runpod/runpodctl/releases/download/v1.14.15/runpodctl-linux-amd64.tar.gz \
|
| 50 |
+
&& tar -xzf /tmp/runpodctl.tar.gz -C /tmp \
|
| 51 |
+
&& mv /tmp/runpodctl /usr/local/bin/runpodctl \
|
| 52 |
+
&& chmod +x /usr/local/bin/runpodctl \
|
| 53 |
+
&& rm /tmp/runpodctl.tar.gz
|
| 54 |
+
|
| 55 |
+
# Copy requirements first (for layer caching)
|
| 56 |
+
COPY docker/requirements.txt /workspace/requirements.txt
|
| 57 |
+
|
| 58 |
+
# Install Python dependencies
|
| 59 |
+
# Note: Unsloth requires specific installation order
|
| 60 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 61 |
+
pip install --no-cache-dir -r requirements.txt
|
| 62 |
+
|
| 63 |
+
# Install Unsloth (from source for latest optimizations)
|
| 64 |
+
RUN pip install --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
| 65 |
+
|
| 66 |
+
# Install specific versions that work with Unsloth
|
| 67 |
+
# Note: These must be installed after unsloth to avoid conflicts
|
| 68 |
+
RUN pip install --no-cache-dir --no-deps \
|
| 69 |
+
"xformers<0.0.27" \
|
| 70 |
+
"trl>=0.9.0,<0.12.0" \
|
| 71 |
+
peft \
|
| 72 |
+
accelerate \
|
| 73 |
+
bitsandbytes
|
| 74 |
+
|
| 75 |
+
# Download spaCy model for NLP-based reward functions
|
| 76 |
+
RUN python -m spacy download en_core_web_sm
|
| 77 |
+
|
| 78 |
+
# Copy the training code
|
| 79 |
+
COPY src/prolewiki_llm/ /workspace/prolewiki_llm/
|
| 80 |
+
|
| 81 |
+
# Copy the dataset (embedded - only 4.5MB)
|
| 82 |
+
COPY training_data/grpo_dataset.jsonl /workspace/dataset.jsonl
|
| 83 |
+
|
| 84 |
+
# Copy entrypoint script
|
| 85 |
+
COPY docker/start.sh /workspace/start.sh
|
| 86 |
+
RUN chmod +x /workspace/start.sh
|
| 87 |
+
|
| 88 |
+
# Set PYTHONPATH so prolewiki_llm module can be imported
|
| 89 |
+
ENV PYTHONPATH=/workspace
|
| 90 |
+
|
| 91 |
+
# =============================================================================
|
| 92 |
+
# Environment Variables (defaults - override at runtime)
|
| 93 |
+
# =============================================================================
|
| 94 |
+
|
| 95 |
+
# Model configuration
|
| 96 |
+
ENV MODEL_NAME="unsloth/DeepSeek-R1-0528-Qwen3-8B"
|
| 97 |
+
ENV MAX_SEQ_LENGTH=2048
|
| 98 |
+
ENV LORA_RANK=32
|
| 99 |
+
|
| 100 |
+
# Training configuration
|
| 101 |
+
ENV MAX_STEPS=500
|
| 102 |
+
ENV SAVE_STEPS=50
|
| 103 |
+
ENV LEARNING_RATE=5e-6
|
| 104 |
+
ENV BATCH_SIZE=2
|
| 105 |
+
ENV GRADIENT_ACCUMULATION=2
|
| 106 |
+
ENV NUM_GENERATIONS=4
|
| 107 |
+
ENV GPU_MEMORY_UTILIZATION=0.6
|
| 108 |
+
|
| 109 |
+
# Paths (container internal)
|
| 110 |
+
ENV DATASET_PATH=/workspace/dataset.jsonl
|
| 111 |
+
ENV CHECKPOINT_DIR=/workspace/checkpoints
|
| 112 |
+
ENV LORA_OUTPUT=/workspace/lora-output
|
| 113 |
+
ENV OUTPUT_DIR=/workspace/outputs
|
| 114 |
+
|
| 115 |
+
# Reward mode: FULL (recommended), ROBUST, or LEGACY
|
| 116 |
+
ENV REWARD_MODE=FULL
|
| 117 |
+
|
| 118 |
+
# Upload destination
|
| 119 |
+
ENV HF_REPO=prolewiki/marxist-grpo-lora
|
| 120 |
+
|
| 121 |
+
# Required secrets (must be provided at runtime):
|
| 122 |
+
# - HF_TOKEN: HuggingFace API token
|
| 123 |
+
# - WANDB_API_KEY: Weights & Biases API key
|
| 124 |
+
# Optional:
|
| 125 |
+
# - RUNPOD_POD_ID: For self-termination after training
|
| 126 |
+
|
| 127 |
+
# Health check - verify CUDA is available
|
| 128 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 129 |
+
CMD python -c "import torch; assert torch.cuda.is_available()" || exit 1
|
| 130 |
+
|
| 131 |
+
# Entry point
|
| 132 |
+
ENTRYPOINT ["/workspace/start.sh"]
|
docker/requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Marxist-GRPO Training Dependencies
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# Pinned versions for reproducible builds.
|
| 5 |
+
# Note: Unsloth is installed separately in Dockerfile from git.
|
| 6 |
+
# =============================================================================
|
| 7 |
+
|
| 8 |
+
# Core ML
|
| 9 |
+
torch>=2.1.0,<2.3.0
|
| 10 |
+
transformers>=4.40.0,<4.58.0
|
| 11 |
+
sentence-transformers>=3.0.0,<4.0.0
|
| 12 |
+
|
| 13 |
+
# Training
|
| 14 |
+
datasets>=2.20.0,<3.0.0
|
| 15 |
+
# Note: trl, peft, accelerate, bitsandbytes installed separately after unsloth
|
| 16 |
+
|
| 17 |
+
# NLP (for reward functions)
|
| 18 |
+
spacy>=3.8.0,<4.0.0
|
| 19 |
+
|
| 20 |
+
# Logging and monitoring
|
| 21 |
+
wandb>=0.17.0,<1.0.0
|
| 22 |
+
|
| 23 |
+
# HuggingFace Hub (for model upload)
|
| 24 |
+
huggingface-hub>=0.24.0,<1.0.0
|
| 25 |
+
|
| 26 |
+
# Data validation
|
| 27 |
+
pydantic>=2.0.0,<3.0.0
|
| 28 |
+
|
| 29 |
+
# Utilities
|
| 30 |
+
numpy>=1.24.0,<2.0.0
|
| 31 |
+
safetensors>=0.4.0,<1.0.0
|
docker/start.sh
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# =============================================================================
|
| 3 |
+
# Headless GRPO Training Entrypoint Script
|
| 4 |
+
# =============================================================================
|
| 5 |
+
# This script orchestrates the training lifecycle:
|
| 6 |
+
# 1. Validate environment (secrets, GPU)
|
| 7 |
+
# 2. Authenticate with HuggingFace and W&B
|
| 8 |
+
# 3. Run training
|
| 9 |
+
# 4. Upload results
|
| 10 |
+
# 5. Self-terminate pod (if RUNPOD_POD_ID is set)
|
| 11 |
+
#
|
| 12 |
+
# Exit on any error
|
| 13 |
+
set -e
|
| 14 |
+
|
| 15 |
+
echo "=============================================================="
|
| 16 |
+
echo "Marxist-GRPO Headless Training"
|
| 17 |
+
echo "=============================================================="
|
| 18 |
+
echo "Start time: $(date -Iseconds)"
|
| 19 |
+
echo ""
|
| 20 |
+
|
| 21 |
+
# =============================================================================
|
| 22 |
+
# 1. ENVIRONMENT VALIDATION
|
| 23 |
+
# =============================================================================
|
| 24 |
+
echo "[1/5] Validating environment..."
|
| 25 |
+
|
| 26 |
+
# Check required secrets
|
| 27 |
+
if [ -z "$HF_TOKEN" ]; then
|
| 28 |
+
echo "ERROR: HF_TOKEN environment variable is required"
|
| 29 |
+
exit 1
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
if [ -z "$WANDB_API_KEY" ]; then
|
| 33 |
+
echo "ERROR: WANDB_API_KEY environment variable is required"
|
| 34 |
+
exit 1
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
# Check GPU availability
|
| 38 |
+
python -c "import torch; assert torch.cuda.is_available(), 'CUDA not available'" || {
|
| 39 |
+
echo "ERROR: CUDA/GPU not available"
|
| 40 |
+
exit 1
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
echo " - HF_TOKEN: [set]"
|
| 44 |
+
echo " - WANDB_API_KEY: [set]"
|
| 45 |
+
echo " - GPU: $(python -c 'import torch; print(torch.cuda.get_device_name())')"
|
| 46 |
+
echo " - VRAM: $(python -c 'import torch; print(f\"{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB\")')"
|
| 47 |
+
echo ""
|
| 48 |
+
|
| 49 |
+
# =============================================================================
|
| 50 |
+
# 2. AUTHENTICATION
|
| 51 |
+
# =============================================================================
|
| 52 |
+
echo "[2/5] Authenticating with services..."
|
| 53 |
+
|
| 54 |
+
# Login to HuggingFace
|
| 55 |
+
echo " - HuggingFace Hub..."
|
| 56 |
+
huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential 2>/dev/null || true
|
| 57 |
+
|
| 58 |
+
# Login to Weights & Biases
|
| 59 |
+
echo " - Weights & Biases..."
|
| 60 |
+
wandb login "$WANDB_API_KEY" 2>/dev/null || true
|
| 61 |
+
|
| 62 |
+
echo ""
|
| 63 |
+
|
| 64 |
+
# =============================================================================
|
| 65 |
+
# 3. DATA VALIDATION
|
| 66 |
+
# =============================================================================
|
| 67 |
+
echo "[3/5] Validating dataset..."
|
| 68 |
+
|
| 69 |
+
DATASET_PATH="${DATASET_PATH:-/workspace/dataset.jsonl}"
|
| 70 |
+
|
| 71 |
+
if [ ! -f "$DATASET_PATH" ]; then
|
| 72 |
+
echo "ERROR: Dataset not found at $DATASET_PATH"
|
| 73 |
+
|
| 74 |
+
# If DATASET_URL is set, try to download
|
| 75 |
+
if [ -n "$DATASET_URL" ]; then
|
| 76 |
+
echo "Attempting to download from DATASET_URL..."
|
| 77 |
+
wget -O "$DATASET_PATH" "$DATASET_URL" || {
|
| 78 |
+
echo "ERROR: Failed to download dataset"
|
| 79 |
+
exit 1
|
| 80 |
+
}
|
| 81 |
+
else
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
fi
|
| 85 |
+
|
| 86 |
+
RECORD_COUNT=$(wc -l < "$DATASET_PATH")
|
| 87 |
+
echo " - Dataset: $DATASET_PATH"
|
| 88 |
+
echo " - Records: $RECORD_COUNT"
|
| 89 |
+
echo ""
|
| 90 |
+
|
| 91 |
+
# =============================================================================
|
| 92 |
+
# 4. TRAINING
|
| 93 |
+
# =============================================================================
|
| 94 |
+
echo "[4/5] Starting training..."
|
| 95 |
+
echo ""
|
| 96 |
+
echo "Configuration:"
|
| 97 |
+
echo " - Model: ${MODEL_NAME:-unsloth/DeepSeek-R1-0528-Qwen3-8B}"
|
| 98 |
+
echo " - Max Steps: ${MAX_STEPS:-500}"
|
| 99 |
+
echo " - Batch Size: ${BATCH_SIZE:-2} x ${GRADIENT_ACCUMULATION:-2}"
|
| 100 |
+
echo " - Learning Rate: ${LEARNING_RATE:-5e-6}"
|
| 101 |
+
echo " - Reward Mode: ${REWARD_MODE:-FULL}"
|
| 102 |
+
echo " - HF Repo: ${HF_REPO:-prolewiki/marxist-grpo-lora}"
|
| 103 |
+
echo ""
|
| 104 |
+
|
| 105 |
+
# Create output directories (use env vars to allow override in tests)
|
| 106 |
+
CHECKPOINT_DIR="${CHECKPOINT_DIR:-/workspace/checkpoints}"
|
| 107 |
+
LORA_OUTPUT="${LORA_OUTPUT:-/workspace/lora-output}"
|
| 108 |
+
OUTPUT_DIR="${OUTPUT_DIR:-/workspace/outputs}"
|
| 109 |
+
mkdir -p "$CHECKPOINT_DIR"
|
| 110 |
+
mkdir -p "$LORA_OUTPUT"
|
| 111 |
+
mkdir -p "$OUTPUT_DIR"
|
| 112 |
+
|
| 113 |
+
# Run training
|
| 114 |
+
# Note: The training script handles all the environment variables
|
| 115 |
+
# IMPORTANT: Disable set -e temporarily to capture exit code for failure handling
|
| 116 |
+
set +e
|
| 117 |
+
python -m prolewiki_llm.train_headless
|
| 118 |
+
TRAINING_EXIT_CODE=$?
|
| 119 |
+
set -e
|
| 120 |
+
|
| 121 |
+
echo ""
|
| 122 |
+
echo "Training completed with exit code: $TRAINING_EXIT_CODE"
|
| 123 |
+
echo ""
|
| 124 |
+
|
| 125 |
+
# =============================================================================
|
| 126 |
+
# 5. CLEANUP AND TERMINATION
|
| 127 |
+
# =============================================================================
|
| 128 |
+
echo "[5/5] Cleanup..."
|
| 129 |
+
|
| 130 |
+
# Log completion time
|
| 131 |
+
echo "End time: $(date -Iseconds)"
|
| 132 |
+
|
| 133 |
+
# If training succeeded and RUNPOD_POD_ID is set, terminate the pod
|
| 134 |
+
if [ $TRAINING_EXIT_CODE -eq 0 ]; then
|
| 135 |
+
echo "Training completed successfully!"
|
| 136 |
+
|
| 137 |
+
if [ -n "$RUNPOD_POD_ID" ]; then
|
| 138 |
+
echo ""
|
| 139 |
+
echo "Terminating pod to stop billing..."
|
| 140 |
+
echo "Pod ID: $RUNPOD_POD_ID"
|
| 141 |
+
|
| 142 |
+
# Give a few seconds for logs to flush
|
| 143 |
+
sleep 5
|
| 144 |
+
|
| 145 |
+
# Stop the pod
|
| 146 |
+
runpodctl stop pod "$RUNPOD_POD_ID" || {
|
| 147 |
+
echo "Warning: Failed to stop pod automatically"
|
| 148 |
+
echo "Please manually stop pod $RUNPOD_POD_ID to avoid billing"
|
| 149 |
+
}
|
| 150 |
+
else
|
| 151 |
+
echo ""
|
| 152 |
+
echo "Note: RUNPOD_POD_ID not set - pod will continue running"
|
| 153 |
+
echo "Remember to stop the pod manually to avoid billing!"
|
| 154 |
+
fi
|
| 155 |
+
else
|
| 156 |
+
echo "Training failed with exit code $TRAINING_EXIT_CODE"
|
| 157 |
+
echo "Pod will NOT be automatically terminated for debugging"
|
| 158 |
+
echo ""
|
| 159 |
+
echo "To debug:"
|
| 160 |
+
echo " 1. SSH into the pod"
|
| 161 |
+
echo " 2. Check /workspace/outputs for logs"
|
| 162 |
+
echo " 3. Check W&B dashboard for metrics"
|
| 163 |
+
fi
|
| 164 |
+
|
| 165 |
+
exit $TRAINING_EXIT_CODE
|
notebooks/Marxist_GRPO_RunPod_Optimized.ipynb
ADDED
|
@@ -0,0 +1,1107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "header",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Marxist GRPO Training - RunPod Optimized\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Production-ready notebook for headless GRPO training on RunPod.**\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"This notebook is optimized based on:\n",
|
| 13 |
+
"- `src/prolewiki_llm/train_headless.py` - Production patterns\n",
|
| 14 |
+
"- `src/prolewiki_llm/grpo_rewards.py` - Multi-layer reward system\n",
|
| 15 |
+
"- C7 documentation for Unsloth and TRL\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"## Recommended GPU\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"| GPU | VRAM | Price | Status |\n",
|
| 20 |
+
"|-----|------|-------|--------|\n",
|
| 21 |
+
"| **NVIDIA A40** | 48GB | $0.35/hr | **Recommended** |\n",
|
| 22 |
+
"| NVIDIA A100 | 80GB | $1.19/hr | Overkill (2x cost for unused VRAM) |\n",
|
| 23 |
+
"| RTX 4090 | 24GB | $0.34/hr | Too small (8B model needs ~40GB in 16-bit) |\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"## Key Optimizations\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"1. **torch.compile disabled** - Prevents hanging on containerized environments\n",
|
| 28 |
+
"2. **use_gradient_checkpointing=True** - Stable (not `\"unsloth\"` variant)\n",
|
| 29 |
+
"3. **load_in_4bit=False** - Required for GRPO (16-bit LoRA adapters)\n",
|
| 30 |
+
"4. **vLLM integration** - 2x faster generation during training\n",
|
| 31 |
+
"5. **Multi-layer rewards** - Defeats reward hacking\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"---"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"id": "critical-setup-header",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"## CRITICAL: Disable torch.compile (Run First!)\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"These environment variables **MUST** be set before any imports.\n",
|
| 44 |
+
"They prevent torch.compile's inductor workers from hanging on RunPod/Jupyter."
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"id": "disable-torch-compile",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"# =============================================================================\n",
|
| 55 |
+
"# CRITICAL: Disable torch.compile BEFORE any imports\n",
|
| 56 |
+
"# =============================================================================\n",
|
| 57 |
+
"# These environment variables prevent torch.compile from spawning inductor\n",
|
| 58 |
+
"# compilation workers that hang indefinitely on RunPod/WSL2/Jupyter.\n",
|
| 59 |
+
"# See: https://github.com/unslothai/unsloth/issues/3432\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"import os\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\n",
|
| 64 |
+
"os.environ[\"TORCH_COMPILE\"] = \"0\"\n",
|
| 65 |
+
"os.environ[\"TORCHINDUCTOR_DISABLE\"] = \"1\"\n",
|
| 66 |
+
"os.environ[\"UNSLOTH_DISABLE_FAST_GENERATION\"] = \"1\"\n",
|
| 67 |
+
"os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n",
|
| 68 |
+
"os.environ[\"TORCHINDUCTOR_COMPILE_THREADS\"] = \"1\"\n",
|
| 69 |
+
"os.environ[\"UNSLOTH_VLLM_STANDBY\"] = \"1\"\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"# Enable fast HuggingFace downloads\n",
|
| 72 |
+
"os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"print(\"Environment configured for RunPod stability.\")"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "markdown",
|
| 79 |
+
"id": "install-header",
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"source": [
|
| 82 |
+
"## Installation\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"Install all required dependencies. This takes ~5 minutes on first run."
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"cell_type": "code",
|
| 89 |
+
"execution_count": null,
|
| 90 |
+
"id": "install-deps",
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"source": [
|
| 94 |
+
"%%capture install_output\n",
|
| 95 |
+
"# Install dependencies (captured to reduce noise)\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"# Core ML stack\n",
|
| 98 |
+
"!pip install torch --index-url https://download.pytorch.org/whl/cu121\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# Unsloth with vLLM (GRPO speedup)\n",
|
| 101 |
+
"!pip install unsloth vllm\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# TRL for GRPO training\n",
|
| 104 |
+
"!pip install trl peft bitsandbytes\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Transformers (compatible version)\n",
|
| 107 |
+
"!pip install \"transformers>=4.45.0,<4.50.0\"\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Data and logging\n",
|
| 110 |
+
"!pip install datasets wandb hf_transfer\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"# Reward function dependencies\n",
|
| 113 |
+
"!pip install sentence-transformers numpy\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"# spaCy with transformer model (best semantic understanding)\n",
|
| 116 |
+
"!pip install spacy\n",
|
| 117 |
+
"!pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"print(\"Installation complete!\")"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"cell_type": "code",
|
| 124 |
+
"execution_count": null,
|
| 125 |
+
"id": "check-install",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"# Verify installation and show any errors\n",
|
| 130 |
+
"print(install_output.stdout[-2000:] if len(install_output.stdout) > 2000 else install_output.stdout)\n",
|
| 131 |
+
"if install_output.stderr:\n",
|
| 132 |
+
" print(\"\\n--- Warnings/Errors ---\")\n",
|
| 133 |
+
" print(install_output.stderr[-1000:])"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "markdown",
|
| 138 |
+
"id": "vram-check-header",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"source": [
|
| 141 |
+
"## GPU Verification\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"Verify GPU is available and has sufficient VRAM before proceeding."
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"id": "vram-check",
|
| 150 |
+
"metadata": {},
|
| 151 |
+
"outputs": [],
|
| 152 |
+
"source": [
|
| 153 |
+
"import torch\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"# Check CUDA availability\n",
|
| 156 |
+
"if not torch.cuda.is_available():\n",
|
| 157 |
+
" raise RuntimeError(\n",
|
| 158 |
+
" \"CUDA not available! This notebook requires a GPU.\\n\"\n",
|
| 159 |
+
" \"On RunPod: Ensure you selected a GPU pod type.\"\n",
|
| 160 |
+
" )\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"# Get GPU info\n",
|
| 163 |
+
"gpu_name = torch.cuda.get_device_name(0)\n",
|
| 164 |
+
"gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"print(f\"GPU: {gpu_name}\")\n",
|
| 167 |
+
"print(f\"VRAM: {gpu_mem_gb:.1f} GB\")\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"# Check minimum VRAM (8B model in 16-bit needs ~35-40GB)\n",
|
| 170 |
+
"MIN_VRAM_GB = 40\n",
|
| 171 |
+
"if gpu_mem_gb < MIN_VRAM_GB:\n",
|
| 172 |
+
" print(f\"\\nWARNING: GPU has {gpu_mem_gb:.1f}GB VRAM, but {MIN_VRAM_GB}GB is recommended.\")\n",
|
| 173 |
+
" print(\"Consider using A40 (48GB) or A100 (80GB) for 8B model training.\")\n",
|
| 174 |
+
" print(\"You may experience OOM errors with smaller GPUs.\")\n",
|
| 175 |
+
"else:\n",
|
| 176 |
+
" print(f\"\\nVRAM check passed ({gpu_mem_gb:.1f}GB >= {MIN_VRAM_GB}GB required).\")"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "markdown",
|
| 181 |
+
"id": "config-header",
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"source": [
|
| 184 |
+
"## Configuration\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"Configure training parameters. Adjust these based on your GPU and requirements."
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"id": "config",
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"outputs": [],
|
| 195 |
+
"source": [
|
| 196 |
+
"from pathlib import Path\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"# =============================================================================\n",
|
| 199 |
+
"# REQUIRED: Set your API keys\n",
|
| 200 |
+
"# =============================================================================\n",
|
| 201 |
+
"# Option 1: Set directly (for testing)\n",
|
| 202 |
+
"# os.environ[\"HF_TOKEN\"] = \"hf_...\"\n",
|
| 203 |
+
"# os.environ[\"WANDB_API_KEY\"] = \"...\"\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"# Option 2: Already set in environment (recommended for production)\n",
|
| 206 |
+
"HF_TOKEN = os.environ.get(\"HF_TOKEN\")\n",
|
| 207 |
+
"WANDB_API_KEY = os.environ.get(\"WANDB_API_KEY\")\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"if not HF_TOKEN:\n",
|
| 210 |
+
" print(\"WARNING: HF_TOKEN not set. Model upload will fail.\")\n",
|
| 211 |
+
"if not WANDB_API_KEY:\n",
|
| 212 |
+
" print(\"WARNING: WANDB_API_KEY not set. W&B logging disabled.\")\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"# =============================================================================\n",
|
| 215 |
+
"# Model Configuration\n",
|
| 216 |
+
"# =============================================================================\n",
|
| 217 |
+
"MODEL_NAME = \"unsloth/DeepSeek-R1-0528-Qwen3-8B\"\n",
|
| 218 |
+
"MAX_SEQ_LENGTH = 2048\n",
|
| 219 |
+
"LORA_RANK = 32\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"# =============================================================================\n",
|
| 222 |
+
"# Training Configuration (A40 optimized)\n",
|
| 223 |
+
"# =============================================================================\n",
|
| 224 |
+
"MAX_STEPS = 500\n",
|
| 225 |
+
"SAVE_STEPS = 50\n",
|
| 226 |
+
"LEARNING_RATE = 5e-6\n",
|
| 227 |
+
"WARMUP_RATIO = 0.1\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"# Batch settings (A40 48GB: batch=2, grad_accum=2, generations=4)\n",
|
| 230 |
+
"BATCH_SIZE = 2\n",
|
| 231 |
+
"GRADIENT_ACCUMULATION = 2\n",
|
| 232 |
+
"NUM_GENERATIONS = 4\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"# Sequence lengths\n",
|
| 235 |
+
"MAX_PROMPT_LENGTH = 512\n",
|
| 236 |
+
"MAX_COMPLETION_LENGTH = 1500\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"# vLLM memory allocation\n",
|
| 239 |
+
"GPU_MEMORY_UTILIZATION = 0.6\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"# =============================================================================\n",
|
| 242 |
+
"# Paths (RunPod uses /workspace for persistent storage)\n",
|
| 243 |
+
"# =============================================================================\n",
|
| 244 |
+
"# Try multiple locations for dataset\n",
|
| 245 |
+
"DATASET_PATHS = [\n",
|
| 246 |
+
" Path(\"/workspace/dataset.jsonl\"),\n",
|
| 247 |
+
" Path(\"/workspace/grpo_dataset.jsonl\"),\n",
|
| 248 |
+
" Path(\"./grpo_dataset.jsonl\"),\n",
|
| 249 |
+
" Path(\"../training_data/grpo_dataset.jsonl\"),\n",
|
| 250 |
+
"]\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"DATASET_PATH = None\n",
|
| 253 |
+
"for p in DATASET_PATHS:\n",
|
| 254 |
+
" if p.exists():\n",
|
| 255 |
+
" DATASET_PATH = p\n",
|
| 256 |
+
" break\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"if DATASET_PATH is None:\n",
|
| 259 |
+
" print(f\"WARNING: Dataset not found. Searched: {[str(p) for p in DATASET_PATHS]}\")\n",
|
| 260 |
+
" print(\"Please upload grpo_dataset.jsonl to /workspace/\")\n",
|
| 261 |
+
"else:\n",
|
| 262 |
+
" print(f\"Dataset found: {DATASET_PATH}\")\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"CHECKPOINT_DIR = Path(\"/workspace/checkpoints\")\n",
|
| 265 |
+
"LORA_OUTPUT = Path(\"/workspace/lora-output\")\n",
|
| 266 |
+
"OUTPUT_DIR = Path(\"/workspace/outputs\")\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"# Create directories\n",
|
| 269 |
+
"for d in [CHECKPOINT_DIR, LORA_OUTPUT, OUTPUT_DIR]:\n",
|
| 270 |
+
" d.mkdir(parents=True, exist_ok=True)\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"# =============================================================================\n",
|
| 273 |
+
"# Upload Configuration\n",
|
| 274 |
+
"# =============================================================================\n",
|
| 275 |
+
"HF_REPO = \"prolewiki/marxist-grpo-lora\" # Change to your repo\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"# =============================================================================\n",
|
| 278 |
+
"# Reward Mode: FULL (recommended), ROBUST, or LEGACY\n",
|
| 279 |
+
"# =============================================================================\n",
|
| 280 |
+
"# FULL: NLI + self-consistency + structure + topic + depth (most robust)\n",
|
| 281 |
+
"# ROBUST: NLI + self-consistency + structure (faster)\n",
|
| 282 |
+
"# LEGACY: Semantic similarity + terminology (fastest, less robust)\n",
|
| 283 |
+
"REWARD_MODE = \"FULL\"\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"print(f\"\\nConfiguration:\")\n",
|
| 286 |
+
"print(f\" Model: {MODEL_NAME}\")\n",
|
| 287 |
+
"print(f\" Max Steps: {MAX_STEPS}\")\n",
|
| 288 |
+
"print(f\" Batch Size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}\")\n",
|
| 289 |
+
"print(f\" Learning Rate: {LEARNING_RATE}\")\n",
|
| 290 |
+
"print(f\" Reward Mode: {REWARD_MODE}\")"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "markdown",
|
| 295 |
+
"id": "wandb-header",
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"source": [
|
| 298 |
+
"## Initialize W&B\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"Set up Weights & Biases for remote monitoring."
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"cell_type": "code",
|
| 305 |
+
"execution_count": null,
|
| 306 |
+
"id": "wandb-init",
|
| 307 |
+
"metadata": {},
|
| 308 |
+
"outputs": [],
|
| 309 |
+
"source": [
|
| 310 |
+
"USE_WANDB = False\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"if WANDB_API_KEY:\n",
|
| 313 |
+
" import wandb\n",
|
| 314 |
+
" wandb.login(key=WANDB_API_KEY)\n",
|
| 315 |
+
" USE_WANDB = True\n",
|
| 316 |
+
" print(\"W&B logged in successfully.\")\n",
|
| 317 |
+
"else:\n",
|
| 318 |
+
" print(\"W&B disabled (no API key).\")"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"cell_type": "markdown",
|
| 323 |
+
"id": "dataset-header",
|
| 324 |
+
"metadata": {},
|
| 325 |
+
"source": [
|
| 326 |
+
"## Load Dataset\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"Load the GRPO-formatted dataset."
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "code",
|
| 333 |
+
"execution_count": null,
|
| 334 |
+
"id": "load-dataset",
|
| 335 |
+
"metadata": {},
|
| 336 |
+
"outputs": [],
|
| 337 |
+
"source": [
|
| 338 |
+
"from datasets import Dataset\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"if DATASET_PATH is None or not DATASET_PATH.exists():\n",
|
| 341 |
+
" raise FileNotFoundError(\n",
|
| 342 |
+
" \"Dataset not found!\\n\"\n",
|
| 343 |
+
" \"Please upload grpo_dataset.jsonl to /workspace/\\n\"\n",
|
| 344 |
+
" \"Generate it with: python -m prolewiki_llm.export_grpo_dataset\"\n",
|
| 345 |
+
" )\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"dataset = Dataset.from_json(str(DATASET_PATH))\n",
|
| 348 |
+
"print(f\"Loaded {len(dataset):,} examples from {DATASET_PATH}\")\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"# Show sample\n",
|
| 351 |
+
"sample = dataset[0]\n",
|
| 352 |
+
"print(f\"\\nSample prompt: {sample['prompt'][1]['content'][:100]}...\")\n",
|
| 353 |
+
"print(f\"Sample answer: {sample['answer'][:100]}...\")"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"cell_type": "markdown",
|
| 358 |
+
"id": "model-header",
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"source": [
|
| 361 |
+
"## Load Model\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"Load DeepSeek-R1-0528-Qwen3-8B with Unsloth for optimized training.\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"**Critical settings:**\n",
|
| 366 |
+
"- `load_in_4bit=False` - Required for GRPO (16-bit LoRA adapters)\n",
|
| 367 |
+
"- `fast_inference=True` - Enables vLLM for 2x faster generation"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"cell_type": "code",
|
| 372 |
+
"execution_count": null,
|
| 373 |
+
"id": "load-model",
|
| 374 |
+
"metadata": {},
|
| 375 |
+
"outputs": [],
|
| 376 |
+
"source": [
|
| 377 |
+
"from unsloth import FastLanguageModel\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"print(f\"Loading model: {MODEL_NAME}\")\n",
|
| 380 |
+
"print(\"This may take 5-10 minutes on first run (downloading ~16GB)...\")\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"# GRPO requires 16-bit LoRA adapters (load_in_4bit=False)\n",
|
| 383 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 384 |
+
" model_name=MODEL_NAME,\n",
|
| 385 |
+
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 386 |
+
" load_in_4bit=False, # MUST be False for GRPO\n",
|
| 387 |
+
" fast_inference=True, # Enables vLLM\n",
|
| 388 |
+
" max_lora_rank=LORA_RANK,\n",
|
| 389 |
+
" gpu_memory_utilization=GPU_MEMORY_UTILIZATION,\n",
|
| 390 |
+
")\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"print(f\"\\nModel loaded: {model.config.model_type}\")\n",
|
| 393 |
+
"print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"cell_type": "markdown",
|
| 398 |
+
"id": "lora-header",
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"source": [
|
| 401 |
+
"## Apply LoRA\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"Apply LoRA adapters for efficient fine-tuning.\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"**Critical settings:**\n",
|
| 406 |
+
"- `use_gradient_checkpointing=True` - Stable on RunPod (NOT `\"unsloth\"`)\n",
|
| 407 |
+
"- `lora_alpha=LORA_RANK` - Same as r for GRPO (scaling = 1.0)"
|
| 408 |
+
]
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"cell_type": "code",
|
| 412 |
+
"execution_count": null,
|
| 413 |
+
"id": "apply-lora",
|
| 414 |
+
"metadata": {},
|
| 415 |
+
"outputs": [],
|
| 416 |
+
"source": [
|
| 417 |
+
"print(\"Applying LoRA adapters...\")\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"# Use gradient_checkpointing=True (not \"unsloth\") for stability on RunPod\n",
|
| 420 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 421 |
+
" model,\n",
|
| 422 |
+
" r=LORA_RANK,\n",
|
| 423 |
+
" target_modules=[\n",
|
| 424 |
+
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 425 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
|
| 426 |
+
" ],\n",
|
| 427 |
+
" lora_alpha=LORA_RANK, # Same as r for GRPO (not r*2)\n",
|
| 428 |
+
" use_gradient_checkpointing=True, # Stable on RunPod (NOT \"unsloth\")\n",
|
| 429 |
+
" random_state=3407,\n",
|
| 430 |
+
")\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"# Print trainable parameters\n",
|
| 433 |
+
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 434 |
+
"total = sum(p.numel() for p in model.parameters())\n",
|
| 435 |
+
"print(f\"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)\")"
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
+
{
|
| 439 |
+
"cell_type": "markdown",
|
| 440 |
+
"id": "reward-header",
|
| 441 |
+
"metadata": {},
|
| 442 |
+
"source": [
|
| 443 |
+
"## Reward Functions\n",
|
| 444 |
+
"\n",
|
| 445 |
+
"Multi-layer reward system to prevent reward hacking.\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"**Layers:**\n",
|
| 448 |
+
"1. **Format** - Validate `<think>...</think>` tags\n",
|
| 449 |
+
"2. **NLI Coherence** - Response entails ground truth\n",
|
| 450 |
+
"3. **Self-consistency** - No internal contradictions\n",
|
| 451 |
+
"4. **Structural** - Terms in proper syntactic roles\n",
|
| 452 |
+
"5. **Topic relevance** - Answer addresses the question\n",
|
| 453 |
+
"6. **Depth** - Rewards analysis, penalizes buzzword salad"
|
| 454 |
+
]
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"cell_type": "code",
|
| 458 |
+
"execution_count": null,
|
| 459 |
+
"id": "reward-setup",
|
| 460 |
+
"metadata": {},
|
| 461 |
+
"outputs": [],
|
| 462 |
+
"source": [
|
| 463 |
+
"import re\n",
|
| 464 |
+
"import numpy as np\n",
|
| 465 |
+
"from sentence_transformers import SentenceTransformer\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"# =============================================================================\n",
|
| 468 |
+
"# Lazy-loaded models\n",
|
| 469 |
+
"# =============================================================================\n",
|
| 470 |
+
"_embedder = None\n",
|
| 471 |
+
"_nli_pipeline = None\n",
|
| 472 |
+
"_spacy_nlp = None\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"def get_embedder():\n",
|
| 476 |
+
" global _embedder\n",
|
| 477 |
+
" if _embedder is None:\n",
|
| 478 |
+
" print(\"[Reward] Loading sentence-transformers embedder...\")\n",
|
| 479 |
+
" _embedder = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
|
| 480 |
+
" return _embedder\n",
|
| 481 |
+
"\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"def get_nli_pipeline():\n",
|
| 484 |
+
" global _nli_pipeline\n",
|
| 485 |
+
" if _nli_pipeline is None:\n",
|
| 486 |
+
" print(\"[Reward] Loading NLI model (bart-large-mnli)...\")\n",
|
| 487 |
+
" from transformers import pipeline\n",
|
| 488 |
+
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 489 |
+
" _nli_pipeline = pipeline(\n",
|
| 490 |
+
" \"text-classification\",\n",
|
| 491 |
+
" model=\"facebook/bart-large-mnli\",\n",
|
| 492 |
+
" device=device,\n",
|
| 493 |
+
" )\n",
|
| 494 |
+
" return _nli_pipeline\n",
|
| 495 |
+
"\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"def get_spacy_nlp():\n",
|
| 498 |
+
" global _spacy_nlp\n",
|
| 499 |
+
" if _spacy_nlp is None:\n",
|
| 500 |
+
" import spacy\n",
|
| 501 |
+
" models = [\"en_core_web_trf\", \"en_core_web_md\", \"en_core_web_sm\"]\n",
|
| 502 |
+
" for model_name in models:\n",
|
| 503 |
+
" try:\n",
|
| 504 |
+
" print(f\"[Reward] Loading spaCy model: {model_name}...\")\n",
|
| 505 |
+
" _spacy_nlp = spacy.load(model_name)\n",
|
| 506 |
+
" break\n",
|
| 507 |
+
" except OSError:\n",
|
| 508 |
+
" continue\n",
|
| 509 |
+
" if _spacy_nlp is None:\n",
|
| 510 |
+
" raise OSError(\"No spaCy model found!\")\n",
|
| 511 |
+
" return _spacy_nlp\n",
|
| 512 |
+
"\n",
|
| 513 |
+
"\n",
|
| 514 |
+
"# =============================================================================\n",
|
| 515 |
+
"# Constants\n",
|
| 516 |
+
"# =============================================================================\n",
|
| 517 |
+
"REASONING_START = \"<think>\"\n",
|
| 518 |
+
"REASONING_END = \"</think>\"\n",
|
| 519 |
+
"SOLUTION_END_REGEX = re.compile(rf\"{REASONING_END}(.*)\", re.DOTALL)\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"MARXIST_TERMS = {\n",
|
| 522 |
+
" \"dialectical\", \"materialism\", \"historical materialism\", \"dialectical materialism\",\n",
|
| 523 |
+
" \"bourgeoisie\", \"proletariat\", \"petty bourgeois\", \"petty bourgeoisie\",\n",
|
| 524 |
+
" \"class struggle\", \"class consciousness\", \"surplus value\", \"commodity\",\n",
|
| 525 |
+
" \"use value\", \"exchange value\", \"labor power\", \"means of production\",\n",
|
| 526 |
+
" \"relations of production\", \"forces of production\", \"mode of production\",\n",
|
| 527 |
+
" \"exploitation\", \"capital accumulation\", \"imperialism\", \"colonialism\",\n",
|
| 528 |
+
" \"dictatorship of the proletariat\", \"vanguard\", \"democratic centralism\",\n",
|
| 529 |
+
" \"hegemony\", \"superstructure\", \"base\", \"ideology\", \"false consciousness\",\n",
|
| 530 |
+
" \"revisionism\", \"opportunism\", \"reformism\", \"alienation\", \"fetishism\",\n",
|
| 531 |
+
"}\n",
|
| 532 |
+
"\n",
|
| 533 |
+
"DISCOURSE_CONNECTIVES = {\n",
|
| 534 |
+
" \"because\", \"therefore\", \"thus\", \"hence\", \"consequently\", \"however\",\n",
|
| 535 |
+
" \"although\", \"whereas\", \"nevertheless\", \"moreover\", \"furthermore\",\n",
|
| 536 |
+
" \"specifically\", \"namely\", \"as a result\", \"due to\", \"in order to\",\n",
|
| 537 |
+
"}\n",
|
| 538 |
+
"\n",
|
| 539 |
+
"print(\"Reward function constants defined.\")"
|
| 540 |
+
]
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"cell_type": "code",
|
| 544 |
+
"execution_count": null,
|
| 545 |
+
"id": "reward-functions",
|
| 546 |
+
"metadata": {},
|
| 547 |
+
"outputs": [],
|
| 548 |
+
"source": [
|
| 549 |
+
"# =============================================================================\n",
|
| 550 |
+
"# Format Rewards\n",
|
| 551 |
+
"# =============================================================================\n",
|
| 552 |
+
"\n",
|
| 553 |
+
"def match_format_exactly(completions, **kwargs):\n",
|
| 554 |
+
" \"\"\"Reward +3.0 if response contains proper </think> tag.\"\"\"\n",
|
| 555 |
+
" scores = []\n",
|
| 556 |
+
" for completion in completions:\n",
|
| 557 |
+
" response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
|
| 558 |
+
" score = 3.0 if SOLUTION_END_REGEX.search(response) else 0.0\n",
|
| 559 |
+
" scores.append(score)\n",
|
| 560 |
+
" return scores\n",
|
| 561 |
+
"\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"def match_format_approximately(completions, **kwargs):\n",
|
| 564 |
+
" \"\"\"Reward partial format matching.\"\"\"\n",
|
| 565 |
+
" scores = []\n",
|
| 566 |
+
" for completion in completions:\n",
|
| 567 |
+
" response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
|
| 568 |
+
" score = 0.0\n",
|
| 569 |
+
" score += 0.5 if response.count(REASONING_START) == 1 else -1.0\n",
|
| 570 |
+
" score += 0.5 if response.count(REASONING_END) == 1 else -1.0\n",
|
| 571 |
+
" scores.append(score)\n",
|
| 572 |
+
" return scores\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"# =============================================================================\n",
|
| 576 |
+
"# NLI Coherence Reward\n",
|
| 577 |
+
"# =============================================================================\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"def nli_coherence_reward(completions, answer, **kwargs):\n",
|
| 580 |
+
" \"\"\"Reward responses that logically ENTAIL the ground truth.\"\"\"\n",
|
| 581 |
+
" nli = get_nli_pipeline()\n",
|
| 582 |
+
" scores = []\n",
|
| 583 |
+
" \n",
|
| 584 |
+
" for completion, true_answer in zip(completions, answer):\n",
|
| 585 |
+
" response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
|
| 586 |
+
" \n",
|
| 587 |
+
" # Extract answer after </think>\n",
|
| 588 |
+
" if REASONING_END in response:\n",
|
| 589 |
+
" response = response.split(REASONING_END, 1)[1].strip()\n",
|
| 590 |
+
" \n",
|
| 591 |
+
" if not response or len(response.strip()) < 20:\n",
|
| 592 |
+
" scores.append(-2.0)\n",
|
| 593 |
+
" continue\n",
|
| 594 |
+
" \n",
|
| 595 |
+
" try:\n",
|
| 596 |
+
" input_text = f\"{response[:512]}</s></s>{true_answer[:512]}\"\n",
|
| 597 |
+
" result = nli(input_text)[0]\n",
|
| 598 |
+
" label = result[\"label\"].lower()\n",
|
| 599 |
+
" \n",
|
| 600 |
+
" if label == \"entailment\":\n",
|
| 601 |
+
" score = 3.0\n",
|
| 602 |
+
" elif label == \"neutral\":\n",
|
| 603 |
+
" score = -1.0\n",
|
| 604 |
+
" else: # contradiction\n",
|
| 605 |
+
" score = -3.0\n",
|
| 606 |
+
" scores.append(score)\n",
|
| 607 |
+
" except Exception as e:\n",
|
| 608 |
+
" print(f\"[NLI] Error: {e}\")\n",
|
| 609 |
+
" scores.append(0.0)\n",
|
| 610 |
+
" \n",
|
| 611 |
+
" return scores\n",
|
| 612 |
+
"\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"# =============================================================================\n",
|
| 615 |
+
"# Structural Coherence Reward\n",
|
| 616 |
+
"# =============================================================================\n",
|
| 617 |
+
"\n",
|
| 618 |
+
"def structural_coherence_reward(completions, **kwargs):\n",
|
| 619 |
+
" \"\"\"Reward responses with proper linguistic structure.\"\"\"\n",
|
| 620 |
+
" nlp = get_spacy_nlp()\n",
|
| 621 |
+
" scores = []\n",
|
| 622 |
+
" \n",
|
| 623 |
+
" for completion in completions:\n",
|
| 624 |
+
" response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
|
| 625 |
+
" doc = nlp(response)\n",
|
| 626 |
+
" score = 0.0\n",
|
| 627 |
+
" \n",
|
| 628 |
+
" # Check for sentences\n",
|
| 629 |
+
" sentences = list(doc.sents)\n",
|
| 630 |
+
" if len(sentences) < 1:\n",
|
| 631 |
+
" scores.append(-1.0)\n",
|
| 632 |
+
" continue\n",
|
| 633 |
+
" \n",
|
| 634 |
+
" # Check Marxist terms in meaningful syntactic roles\n",
|
| 635 |
+
" response_lower = response.lower()\n",
|
| 636 |
+
" terms_in_context = 0\n",
|
| 637 |
+
" \n",
|
| 638 |
+
" for term in MARXIST_TERMS:\n",
|
| 639 |
+
" if term not in response_lower:\n",
|
| 640 |
+
" continue\n",
|
| 641 |
+
" for token in doc:\n",
|
| 642 |
+
" if term in token.text.lower():\n",
|
| 643 |
+
" if token.dep_ in (\"nsubj\", \"nsubjpass\", \"dobj\", \"pobj\", \"attr\", \"appos\"):\n",
|
| 644 |
+
" terms_in_context += 1\n",
|
| 645 |
+
" break\n",
|
| 646 |
+
" \n",
|
| 647 |
+
" score += min(terms_in_context * 0.3, 1.5)\n",
|
| 648 |
+
" \n",
|
| 649 |
+
" # Check discourse connectives\n",
|
| 650 |
+
" connective_count = sum(1 for c in DISCOURSE_CONNECTIVES if c in response_lower)\n",
|
| 651 |
+
" score += min(connective_count * 0.2, 1.0)\n",
|
| 652 |
+
" \n",
|
| 653 |
+
" scores.append(score)\n",
|
| 654 |
+
" \n",
|
| 655 |
+
" return scores\n",
|
| 656 |
+
"\n",
|
| 657 |
+
"\n",
|
| 658 |
+
"# =============================================================================\n",
|
| 659 |
+
"# Combined Rewards\n",
|
| 660 |
+
"# =============================================================================\n",
|
| 661 |
+
"\n",
|
| 662 |
+
"def full_coherence_reward(prompts, completions, answer, **kwargs):\n",
|
| 663 |
+
" \"\"\"Complete coherence check: NLI + structure.\"\"\"\n",
|
| 664 |
+
" nli_scores = nli_coherence_reward(completions, answer, **kwargs)\n",
|
| 665 |
+
" structure_scores = structural_coherence_reward(completions, **kwargs)\n",
|
| 666 |
+
" \n",
|
| 667 |
+
" combined = []\n",
|
| 668 |
+
" for nli, structure in zip(nli_scores, structure_scores):\n",
|
| 669 |
+
" if nli <= -3.0:\n",
|
| 670 |
+
" combined.append(-3.0)\n",
|
| 671 |
+
" else:\n",
|
| 672 |
+
" combined.append(nli + (structure * 0.5))\n",
|
| 673 |
+
" return combined\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"\n",
|
| 676 |
+
"def completeness_reward(completions, answer, **kwargs):\n",
|
| 677 |
+
" \"\"\"Reward thorough, detailed responses.\"\"\"\n",
|
| 678 |
+
" scores = []\n",
|
| 679 |
+
" \n",
|
| 680 |
+
" for completion, true_answer in zip(completions, answer):\n",
|
| 681 |
+
" response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
|
| 682 |
+
" \n",
|
| 683 |
+
" if REASONING_END in response:\n",
|
| 684 |
+
" answer_part = response.split(REASONING_END, 1)[1].strip()\n",
|
| 685 |
+
" else:\n",
|
| 686 |
+
" answer_part = response\n",
|
| 687 |
+
" \n",
|
| 688 |
+
" answer_len = len(answer_part.split())\n",
|
| 689 |
+
" true_len = len(true_answer.split())\n",
|
| 690 |
+
" \n",
|
| 691 |
+
" if true_len == 0:\n",
|
| 692 |
+
" scores.append(0.0)\n",
|
| 693 |
+
" continue\n",
|
| 694 |
+
" \n",
|
| 695 |
+
" ratio = answer_len / true_len\n",
|
| 696 |
+
" if 0.5 <= ratio <= 1.5:\n",
|
| 697 |
+
" score = 2.0\n",
|
| 698 |
+
" elif 0.3 <= ratio <= 2.0:\n",
|
| 699 |
+
" score = 1.0\n",
|
| 700 |
+
" elif ratio < 0.2:\n",
|
| 701 |
+
" score = -2.0\n",
|
| 702 |
+
" else:\n",
|
| 703 |
+
" score = -0.5\n",
|
| 704 |
+
" \n",
|
| 705 |
+
" scores.append(score)\n",
|
| 706 |
+
" \n",
|
| 707 |
+
" return scores\n",
|
| 708 |
+
"\n",
|
| 709 |
+
"\n",
|
| 710 |
+
"# =============================================================================\n",
|
| 711 |
+
"# Debug Reward\n",
|
| 712 |
+
"# =============================================================================\n",
|
| 713 |
+
"_PRINT_COUNTER = 0\n",
|
| 714 |
+
"_PRINT_EVERY = 10\n",
|
| 715 |
+
"\n",
|
| 716 |
+
"def debug_print_reward(prompts, completions, answer, **kwargs):\n",
|
| 717 |
+
" \"\"\"Print sample outputs periodically.\"\"\"\n",
|
| 718 |
+
" global _PRINT_COUNTER\n",
|
| 719 |
+
" \n",
|
| 720 |
+
" if _PRINT_COUNTER % _PRINT_EVERY == 0:\n",
|
| 721 |
+
" question = prompts[0][-1][\"content\"] if isinstance(prompts[0], list) else prompts[0]\n",
|
| 722 |
+
" response = completions[0][0][\"content\"] if isinstance(completions[0], list) else completions[0]\n",
|
| 723 |
+
" \n",
|
| 724 |
+
" print(\"=\" * 60)\n",
|
| 725 |
+
" print(f\"Step {_PRINT_COUNTER}\")\n",
|
| 726 |
+
" print(f\"Q: {question[:100]}...\")\n",
|
| 727 |
+
" print(f\"A: {response[:200]}...\")\n",
|
| 728 |
+
" print(\"=\" * 60)\n",
|
| 729 |
+
" \n",
|
| 730 |
+
" _PRINT_COUNTER += 1\n",
|
| 731 |
+
" return [0.0] * len(completions)\n",
|
| 732 |
+
"\n",
|
| 733 |
+
"\n",
|
| 734 |
+
"print(\"Reward functions defined.\")"
|
| 735 |
+
]
|
| 736 |
+
},
|
| 737 |
+
{
|
| 738 |
+
"cell_type": "markdown",
|
| 739 |
+
"id": "trainer-header",
|
| 740 |
+
"metadata": {},
|
| 741 |
+
"source": [
|
| 742 |
+
"## Configure GRPO Trainer\n",
|
| 743 |
+
"\n",
|
| 744 |
+
"Set up the GRPO trainer with vLLM for fast generation."
|
| 745 |
+
]
|
| 746 |
+
},
|
| 747 |
+
{
|
| 748 |
+
"cell_type": "code",
|
| 749 |
+
"execution_count": null,
|
| 750 |
+
"id": "trainer-config",
|
| 751 |
+
"metadata": {},
|
| 752 |
+
"outputs": [],
|
| 753 |
+
"source": [
|
| 754 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 755 |
+
"from vllm import SamplingParams\n",
|
| 756 |
+
"\n",
|
| 757 |
+
"# vLLM sampling parameters for generation\n",
|
| 758 |
+
"vllm_sampling_params = SamplingParams(\n",
|
| 759 |
+
" min_p=0.1,\n",
|
| 760 |
+
" top_p=1.0,\n",
|
| 761 |
+
" top_k=-1,\n",
|
| 762 |
+
" max_tokens=MAX_COMPLETION_LENGTH,\n",
|
| 763 |
+
" stop=[tokenizer.eos_token],\n",
|
| 764 |
+
" include_stop_str_in_output=True,\n",
|
| 765 |
+
" seed=3407,\n",
|
| 766 |
+
")\n",
|
| 767 |
+
"\n",
|
| 768 |
+
"# Training configuration\n",
|
| 769 |
+
"training_args = GRPOConfig(\n",
|
| 770 |
+
" # vLLM\n",
|
| 771 |
+
" vllm_sampling_params=vllm_sampling_params,\n",
|
| 772 |
+
" temperature=1.0,\n",
|
| 773 |
+
" # Optimization\n",
|
| 774 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 775 |
+
" weight_decay=0.001,\n",
|
| 776 |
+
" warmup_ratio=WARMUP_RATIO,\n",
|
| 777 |
+
" lr_scheduler_type=\"linear\",\n",
|
| 778 |
+
" optim=\"adamw_8bit\",\n",
|
| 779 |
+
" # Batch settings\n",
|
| 780 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
| 781 |
+
" gradient_accumulation_steps=GRADIENT_ACCUMULATION,\n",
|
| 782 |
+
" num_generations=NUM_GENERATIONS,\n",
|
| 783 |
+
" # Sequence lengths\n",
|
| 784 |
+
" max_prompt_length=MAX_PROMPT_LENGTH,\n",
|
| 785 |
+
" max_completion_length=MAX_COMPLETION_LENGTH,\n",
|
| 786 |
+
" # Training duration\n",
|
| 787 |
+
" max_steps=MAX_STEPS,\n",
|
| 788 |
+
" save_steps=SAVE_STEPS,\n",
|
| 789 |
+
" # Logging\n",
|
| 790 |
+
" logging_steps=1,\n",
|
| 791 |
+
" report_to=\"wandb\" if USE_WANDB else \"none\",\n",
|
| 792 |
+
" # Output\n",
|
| 793 |
+
" output_dir=str(CHECKPOINT_DIR),\n",
|
| 794 |
+
")\n",
|
| 795 |
+
"\n",
|
| 796 |
+
"print(f\"Training configuration:\")\n",
|
| 797 |
+
"print(f\" Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}\")\n",
|
| 798 |
+
"print(f\" Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION * NUM_GENERATIONS}\")\n",
|
| 799 |
+
"print(f\" Max steps: {MAX_STEPS}\")\n",
|
| 800 |
+
"print(f\" W&B: {'enabled' if USE_WANDB else 'disabled'}\")"
|
| 801 |
+
]
|
| 802 |
+
},
|
| 803 |
+
{
|
| 804 |
+
"cell_type": "code",
|
| 805 |
+
"execution_count": null,
|
| 806 |
+
"id": "create-trainer",
|
| 807 |
+
"metadata": {},
|
| 808 |
+
"outputs": [],
|
| 809 |
+
"source": [
|
| 810 |
+
"# Select reward functions based on mode\n",
|
| 811 |
+
"if REWARD_MODE == \"FULL\":\n",
|
| 812 |
+
" print(f\"\\nReward mode: FULL (recommended)\")\n",
|
| 813 |
+
" print(\" - match_format_exactly, match_format_approximately\")\n",
|
| 814 |
+
" print(\" - full_coherence_reward (NLI + structure)\")\n",
|
| 815 |
+
" print(\" - completeness_reward, debug_print_reward\")\n",
|
| 816 |
+
" reward_funcs = [\n",
|
| 817 |
+
" match_format_exactly,\n",
|
| 818 |
+
" match_format_approximately,\n",
|
| 819 |
+
" full_coherence_reward,\n",
|
| 820 |
+
" completeness_reward,\n",
|
| 821 |
+
" debug_print_reward,\n",
|
| 822 |
+
" ]\n",
|
| 823 |
+
"elif REWARD_MODE == \"ROBUST\":\n",
|
| 824 |
+
" print(f\"\\nReward mode: ROBUST\")\n",
|
| 825 |
+
" reward_funcs = [\n",
|
| 826 |
+
" match_format_exactly,\n",
|
| 827 |
+
" match_format_approximately,\n",
|
| 828 |
+
" nli_coherence_reward,\n",
|
| 829 |
+
" structural_coherence_reward,\n",
|
| 830 |
+
" completeness_reward,\n",
|
| 831 |
+
" debug_print_reward,\n",
|
| 832 |
+
" ]\n",
|
| 833 |
+
"else: # LEGACY\n",
|
| 834 |
+
" print(f\"\\nReward mode: LEGACY (faster, less robust)\")\n",
|
| 835 |
+
" reward_funcs = [\n",
|
| 836 |
+
" match_format_exactly,\n",
|
| 837 |
+
" match_format_approximately,\n",
|
| 838 |
+
" completeness_reward,\n",
|
| 839 |
+
" debug_print_reward,\n",
|
| 840 |
+
" ]\n",
|
| 841 |
+
"\n",
|
| 842 |
+
"print(\"\\nNote: First training step will download NLI model (~1.6GB) + spaCy (~436MB)\")\n",
|
| 843 |
+
"\n",
|
| 844 |
+
"# Create trainer\n",
|
| 845 |
+
"trainer = GRPOTrainer(\n",
|
| 846 |
+
" model=model,\n",
|
| 847 |
+
" processing_class=tokenizer,\n",
|
| 848 |
+
" reward_funcs=reward_funcs,\n",
|
| 849 |
+
" args=training_args,\n",
|
| 850 |
+
" train_dataset=dataset,\n",
|
| 851 |
+
")\n",
|
| 852 |
+
"\n",
|
| 853 |
+
"print(\"\\nGRPO trainer created.\")"
|
| 854 |
+
]
|
| 855 |
+
},
|
| 856 |
+
{
|
| 857 |
+
"cell_type": "markdown",
|
| 858 |
+
"id": "train-header",
|
| 859 |
+
"metadata": {},
|
| 860 |
+
"source": [
|
| 861 |
+
"## Train!\n",
|
| 862 |
+
"\n",
|
| 863 |
+
"Run GRPO training. Monitor the `reward` column - it should increase over time.\n",
|
| 864 |
+
"\n",
|
| 865 |
+
"**Expected behavior:**\n",
|
| 866 |
+
"- Steps 0-50: Format rewards stabilize\n",
|
| 867 |
+
"- Steps 50-150: NLI coherence improves\n",
|
| 868 |
+
"- Steps 150-500: Overall quality improves\n",
|
| 869 |
+
"\n",
|
| 870 |
+
"**Estimated time:** ~2-3 hours on A40 for 500 steps"
|
| 871 |
+
]
|
| 872 |
+
},
|
| 873 |
+
{
|
| 874 |
+
"cell_type": "code",
|
| 875 |
+
"execution_count": null,
|
| 876 |
+
"id": "check-resume",
|
| 877 |
+
"metadata": {},
|
| 878 |
+
"outputs": [],
|
| 879 |
+
"source": [
|
| 880 |
+
"# Check for checkpoint to resume from\n",
|
| 881 |
+
"def find_latest_checkpoint(checkpoint_dir):\n",
|
| 882 |
+
" if not checkpoint_dir.exists():\n",
|
| 883 |
+
" return None\n",
|
| 884 |
+
" checkpoints = sorted(\n",
|
| 885 |
+
" [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith(\"checkpoint-\")],\n",
|
| 886 |
+
" key=lambda d: int(d.name.split(\"-\")[1]),\n",
|
| 887 |
+
" )\n",
|
| 888 |
+
" return checkpoints[-1] if checkpoints else None\n",
|
| 889 |
+
"\n",
|
| 890 |
+
"resume_from = find_latest_checkpoint(CHECKPOINT_DIR)\n",
|
| 891 |
+
"if resume_from:\n",
|
| 892 |
+
" print(f\"Found checkpoint: {resume_from}\")\n",
|
| 893 |
+
" print(\"Training will resume from this checkpoint.\")\n",
|
| 894 |
+
"else:\n",
|
| 895 |
+
" print(\"No checkpoint found. Starting fresh training.\")"
|
| 896 |
+
]
|
| 897 |
+
},
|
| 898 |
+
{
|
| 899 |
+
"cell_type": "code",
|
| 900 |
+
"execution_count": null,
|
| 901 |
+
"id": "train",
|
| 902 |
+
"metadata": {},
|
| 903 |
+
"outputs": [],
|
| 904 |
+
"source": [
|
| 905 |
+
"print(\"=\" * 70)\n",
|
| 906 |
+
"print(\"STARTING TRAINING\")\n",
|
| 907 |
+
"print(\"=\" * 70)\n",
|
| 908 |
+
"print(f\"Steps: {MAX_STEPS}\")\n",
|
| 909 |
+
"print(f\"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}\")\n",
|
| 910 |
+
"print(f\"Learning rate: {LEARNING_RATE}\")\n",
|
| 911 |
+
"print(f\"Reward mode: {REWARD_MODE}\")\n",
|
| 912 |
+
"print()\n",
|
| 913 |
+
"\n",
|
| 914 |
+
"try:\n",
|
| 915 |
+
" if resume_from:\n",
|
| 916 |
+
" trainer.train(resume_from_checkpoint=str(resume_from))\n",
|
| 917 |
+
" else:\n",
|
| 918 |
+
" trainer.train()\n",
|
| 919 |
+
"except KeyboardInterrupt:\n",
|
| 920 |
+
" print(\"\\nTraining interrupted. Saving checkpoint...\")\n",
|
| 921 |
+
"except Exception as e:\n",
|
| 922 |
+
" print(f\"\\nTraining error: {e}\")\n",
|
| 923 |
+
" raise\n",
|
| 924 |
+
"\n",
|
| 925 |
+
"print(\"\\nTraining complete!\")"
|
| 926 |
+
]
|
| 927 |
+
},
|
| 928 |
+
{
|
| 929 |
+
"cell_type": "markdown",
|
| 930 |
+
"id": "save-header",
|
| 931 |
+
"metadata": {},
|
| 932 |
+
"source": [
|
| 933 |
+
"## Save Model\n",
|
| 934 |
+
"\n",
|
| 935 |
+
"Save the trained LoRA adapter."
|
| 936 |
+
]
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"cell_type": "code",
|
| 940 |
+
"execution_count": null,
|
| 941 |
+
"id": "save-model",
|
| 942 |
+
"metadata": {},
|
| 943 |
+
"outputs": [],
|
| 944 |
+
"source": [
|
| 945 |
+
"print(\"Saving LoRA adapter...\")\n",
|
| 946 |
+
"\n",
|
| 947 |
+
"LORA_OUTPUT.mkdir(parents=True, exist_ok=True)\n",
|
| 948 |
+
"model.save_lora(str(LORA_OUTPUT))\n",
|
| 949 |
+
"\n",
|
| 950 |
+
"print(f\"LoRA saved to: {LORA_OUTPUT}\")"
|
| 951 |
+
]
|
| 952 |
+
},
|
| 953 |
+
{
|
| 954 |
+
"cell_type": "code",
|
| 955 |
+
"execution_count": null,
|
| 956 |
+
"id": "verify-lora",
|
| 957 |
+
"metadata": {},
|
| 958 |
+
"outputs": [],
|
| 959 |
+
"source": [
|
| 960 |
+
"# Verify LoRA has non-zero weights (actually trained)\n",
|
| 961 |
+
"from safetensors import safe_open\n",
|
| 962 |
+
"\n",
|
| 963 |
+
"adapter_path = LORA_OUTPUT / \"adapter_model.safetensors\"\n",
|
| 964 |
+
"if adapter_path.exists():\n",
|
| 965 |
+
" print(\"Verifying LoRA weights...\")\n",
|
| 966 |
+
" with safe_open(str(adapter_path), framework=\"pt\") as f:\n",
|
| 967 |
+
" for key in list(f.keys())[:5]:\n",
|
| 968 |
+
" tensor = f.get_tensor(key)\n",
|
| 969 |
+
" n_nonzero = (tensor != 0).sum().item()\n",
|
| 970 |
+
" ratio = n_nonzero / tensor.numel()\n",
|
| 971 |
+
" status = \"OK\" if ratio > 0.01 else \"WARNING: mostly zeros!\"\n",
|
| 972 |
+
" print(f\" {key}: {ratio*100:.1f}% non-zero - {status}\")\n",
|
| 973 |
+
" print(\"\\nLoRA verification complete.\")\n",
|
| 974 |
+
"else:\n",
|
| 975 |
+
" print(f\"WARNING: Adapter not found at {adapter_path}\")"
|
| 976 |
+
]
|
| 977 |
+
},
|
| 978 |
+
{
|
| 979 |
+
"cell_type": "markdown",
|
| 980 |
+
"id": "upload-header",
|
| 981 |
+
"metadata": {},
|
| 982 |
+
"source": [
|
| 983 |
+
"## Upload to HuggingFace Hub"
|
| 984 |
+
]
|
| 985 |
+
},
|
| 986 |
+
{
|
| 987 |
+
"cell_type": "code",
|
| 988 |
+
"execution_count": null,
|
| 989 |
+
"id": "upload-hub",
|
| 990 |
+
"metadata": {},
|
| 991 |
+
"outputs": [],
|
| 992 |
+
"source": [
|
| 993 |
+
"if HF_TOKEN:\n",
|
| 994 |
+
" from huggingface_hub import HfApi\n",
|
| 995 |
+
" \n",
|
| 996 |
+
" print(f\"Uploading to HuggingFace Hub: {HF_REPO}\")\n",
|
| 997 |
+
" \n",
|
| 998 |
+
" api = HfApi(token=HF_TOKEN)\n",
|
| 999 |
+
" \n",
|
| 1000 |
+
" # Create repo if needed\n",
|
| 1001 |
+
" try:\n",
|
| 1002 |
+
" api.create_repo(HF_REPO, exist_ok=True, private=True)\n",
|
| 1003 |
+
" except Exception as e:\n",
|
| 1004 |
+
" print(f\"Note: {e}\")\n",
|
| 1005 |
+
" \n",
|
| 1006 |
+
" # Upload\n",
|
| 1007 |
+
" api.upload_folder(\n",
|
| 1008 |
+
" folder_path=str(LORA_OUTPUT),\n",
|
| 1009 |
+
" repo_id=HF_REPO,\n",
|
| 1010 |
+
" commit_message=\"GRPO training run\",\n",
|
| 1011 |
+
" )\n",
|
| 1012 |
+
" \n",
|
| 1013 |
+
" print(f\"\\nModel uploaded to: https://huggingface.co/{HF_REPO}\")\n",
|
| 1014 |
+
"else:\n",
|
| 1015 |
+
" print(\"HF_TOKEN not set. Skipping upload.\")\n",
|
| 1016 |
+
" print(f\"LoRA saved locally at: {LORA_OUTPUT}\")"
|
| 1017 |
+
]
|
| 1018 |
+
},
|
| 1019 |
+
{
|
| 1020 |
+
"cell_type": "markdown",
|
| 1021 |
+
"id": "test-header",
|
| 1022 |
+
"metadata": {},
|
| 1023 |
+
"source": [
|
| 1024 |
+
"## Test Model"
|
| 1025 |
+
]
|
| 1026 |
+
},
|
| 1027 |
+
{
|
| 1028 |
+
"cell_type": "code",
|
| 1029 |
+
"execution_count": null,
|
| 1030 |
+
"id": "test-model",
|
| 1031 |
+
"metadata": {},
|
| 1032 |
+
"outputs": [],
|
| 1033 |
+
"source": [
|
| 1034 |
+
"TEST_QUESTIONS = [\n",
|
| 1035 |
+
" \"What is revisionism in the Marxist sense?\",\n",
|
| 1036 |
+
" \"Explain the concept of surplus value.\",\n",
|
| 1037 |
+
" \"What is the dictatorship of the proletariat?\",\n",
|
| 1038 |
+
"]\n",
|
| 1039 |
+
"\n",
|
| 1040 |
+
"SYSTEM_PROMPT = \"\"\"You are a Marxist-Leninist assistant trained on ProleWiki.\n",
|
| 1041 |
+
"Think through questions using dialectical materialist analysis.\n",
|
| 1042 |
+
"Show your reasoning in <think> tags, then provide a clear answer.\"\"\"\n",
|
| 1043 |
+
"\n",
|
| 1044 |
+
"print(\"Testing trained model...\")\n",
|
| 1045 |
+
"print(\"=\" * 60)\n",
|
| 1046 |
+
"\n",
|
| 1047 |
+
"for question in TEST_QUESTIONS:\n",
|
| 1048 |
+
" messages = [\n",
|
| 1049 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 1050 |
+
" {\"role\": \"user\", \"content\": question},\n",
|
| 1051 |
+
" ]\n",
|
| 1052 |
+
" \n",
|
| 1053 |
+
" text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
|
| 1054 |
+
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
|
| 1055 |
+
" \n",
|
| 1056 |
+
" with torch.no_grad():\n",
|
| 1057 |
+
" outputs = model.generate(\n",
|
| 1058 |
+
" **inputs,\n",
|
| 1059 |
+
" max_new_tokens=512,\n",
|
| 1060 |
+
" temperature=0.7,\n",
|
| 1061 |
+
" do_sample=True,\n",
|
| 1062 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 1063 |
+
" )\n",
|
| 1064 |
+
" \n",
|
| 1065 |
+
" response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n",
|
| 1066 |
+
" \n",
|
| 1067 |
+
" print(f\"\\nQ: {question}\")\n",
|
| 1068 |
+
" print(f\"A: {response[:400]}...\")\n",
|
| 1069 |
+
" print(\"-\" * 60)"
|
| 1070 |
+
]
|
| 1071 |
+
},
|
| 1072 |
+
{
|
| 1073 |
+
"cell_type": "markdown",
|
| 1074 |
+
"id": "complete-header",
|
| 1075 |
+
"metadata": {},
|
| 1076 |
+
"source": [
|
| 1077 |
+
"## Training Complete!\n",
|
| 1078 |
+
"\n",
|
| 1079 |
+
"**Summary:**\n",
|
| 1080 |
+
"- Model: DeepSeek-R1-0528-Qwen3-8B with LoRA\n",
|
| 1081 |
+
"- Training: GRPO with multi-layer reward functions\n",
|
| 1082 |
+
"- Output: LoRA adapter saved to `/workspace/lora-output`\n",
|
| 1083 |
+
"\n",
|
| 1084 |
+
"**Next steps:**\n",
|
| 1085 |
+
"1. Review W&B dashboard for training metrics\n",
|
| 1086 |
+
"2. Test model quality with varied questions\n",
|
| 1087 |
+
"3. If satisfied, merge LoRA into base model for deployment\n",
|
| 1088 |
+
"4. Convert to GGUF for Ollama if needed\n",
|
| 1089 |
+
"\n",
|
| 1090 |
+
"**Remember:** Stop the pod after training to avoid charges!"
|
| 1091 |
+
]
|
| 1092 |
+
}
|
| 1093 |
+
],
|
| 1094 |
+
"metadata": {
|
| 1095 |
+
"kernelspec": {
|
| 1096 |
+
"display_name": "Python 3",
|
| 1097 |
+
"language": "python",
|
| 1098 |
+
"name": "python3"
|
| 1099 |
+
},
|
| 1100 |
+
"language_info": {
|
| 1101 |
+
"name": "python",
|
| 1102 |
+
"version": "3.12.0"
|
| 1103 |
+
}
|
| 1104 |
+
},
|
| 1105 |
+
"nbformat": 4,
|
| 1106 |
+
"nbformat_minor": 5
|
| 1107 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "prolewiki-llm"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "GRPO fine-tuning and reward functions for Marxist-Leninist language models"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
license = { text = "AGPL-3.0-only" }
|
| 7 |
+
requires-python = ">=3.12"
|
| 8 |
+
dependencies = [
|
| 9 |
+
# Core ML
|
| 10 |
+
"transformers>=4.40.0",
|
| 11 |
+
"sentence-transformers>=3.0.0",
|
| 12 |
+
"torch>=2.0.0",
|
| 13 |
+
# NLP
|
| 14 |
+
"spacy>=3.8.0",
|
| 15 |
+
# Logging
|
| 16 |
+
"wandb>=0.17.0",
|
| 17 |
+
# Data
|
| 18 |
+
"pydantic>=2.0.0",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[dependency-groups]
|
| 22 |
+
dev = [
|
| 23 |
+
# Testing
|
| 24 |
+
"pytest>=8.0.0",
|
| 25 |
+
"pytest-asyncio>=0.24.0",
|
| 26 |
+
"pytest-cov>=6.0.0",
|
| 27 |
+
"pytest-mock>=3.14.0",
|
| 28 |
+
# Code Quality
|
| 29 |
+
"mypy>=1.13.0",
|
| 30 |
+
"ruff>=0.8.0",
|
| 31 |
+
"pre-commit>=4.0.0",
|
| 32 |
+
# Type Stubs
|
| 33 |
+
"types-PyYAML>=6.0.0",
|
| 34 |
+
]
|
| 35 |
+
training = [
|
| 36 |
+
# GRPO training (GPU required)
|
| 37 |
+
"unsloth>=2024.8",
|
| 38 |
+
"trl>=0.9.0",
|
| 39 |
+
"peft>=0.12.0",
|
| 40 |
+
"bitsandbytes>=0.43.0",
|
| 41 |
+
"datasets>=2.20.0",
|
| 42 |
+
"vllm>=0.5.0",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
[build-system]
|
| 46 |
+
requires = ["hatchling"]
|
| 47 |
+
build-backend = "hatchling.build"
|
| 48 |
+
|
| 49 |
+
[tool.hatch.build.targets.wheel]
|
| 50 |
+
packages = ["src/prolewiki_llm"]
|
| 51 |
+
|
| 52 |
+
# =============================================================================
|
| 53 |
+
# PYTEST
|
| 54 |
+
# =============================================================================
|
| 55 |
+
[tool.pytest.ini_options]
|
| 56 |
+
pythonpath = ["src"]
|
| 57 |
+
testpaths = ["tests"]
|
| 58 |
+
asyncio_mode = "auto"
|
| 59 |
+
asyncio_default_fixture_loop_scope = "function"
|
| 60 |
+
markers = [
|
| 61 |
+
"unit: Fast unit tests for isolated components (no I/O)",
|
| 62 |
+
"integration: Integration tests for shell scripts and Docker behavior",
|
| 63 |
+
"slow: Tests that take significant time (NLI, embedding)",
|
| 64 |
+
"gpu: Tests requiring GPU (training)",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
# =============================================================================
|
| 68 |
+
# RUFF (Linting + Formatting + Import Sorting)
|
| 69 |
+
# =============================================================================
|
| 70 |
+
[tool.ruff]
|
| 71 |
+
line-length = 100
|
| 72 |
+
target-version = "py312"
|
| 73 |
+
src = ["src"]
|
| 74 |
+
|
| 75 |
+
[tool.ruff.lint]
|
| 76 |
+
select = [
|
| 77 |
+
"E", # pycodestyle errors
|
| 78 |
+
"W", # pycodestyle warnings
|
| 79 |
+
"F", # pyflakes
|
| 80 |
+
"I", # isort (import sorting)
|
| 81 |
+
"B", # flake8-bugbear
|
| 82 |
+
"C4", # flake8-comprehensions
|
| 83 |
+
"UP", # pyupgrade
|
| 84 |
+
"ARG", # unused arguments
|
| 85 |
+
"SIM", # flake8-simplify
|
| 86 |
+
"RUF", # ruff-specific
|
| 87 |
+
]
|
| 88 |
+
ignore = [
|
| 89 |
+
"E501", # line length handled by formatter
|
| 90 |
+
"RUF001", # ambiguous unicode (we use Cyrillic/Chinese text)
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
[tool.ruff.lint.per-file-ignores]
|
| 94 |
+
"tests/**/*.py" = [
|
| 95 |
+
"ARG001", # unused mock function arguments are common in tests
|
| 96 |
+
"ARG002", # unused method arguments (fixtures) are common in tests
|
| 97 |
+
"SIM117", # nested with statements for multiple mocks are clearer
|
| 98 |
+
]
|
| 99 |
+
"src/prolewiki_llm/**/*.py" = [
|
| 100 |
+
"ARG001", # **kwargs required by GRPOTrainer reward function interface
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
[tool.ruff.lint.isort]
|
| 104 |
+
known-first-party = ["prolewiki_llm"]
|
| 105 |
+
|
| 106 |
+
[tool.ruff.format]
|
| 107 |
+
quote-style = "double"
|
| 108 |
+
indent-style = "space"
|
| 109 |
+
|
| 110 |
+
# =============================================================================
|
| 111 |
+
# MYPY (Type Checking)
|
| 112 |
+
# =============================================================================
|
| 113 |
+
[tool.mypy]
|
| 114 |
+
python_version = "3.12"
|
| 115 |
+
strict = true
|
| 116 |
+
mypy_path = ["src"]
|
| 117 |
+
explicit_package_bases = true
|
| 118 |
+
warn_return_any = true
|
| 119 |
+
warn_unused_configs = true
|
| 120 |
+
disallow_untyped_defs = true
|
| 121 |
+
disallow_incomplete_defs = true
|
| 122 |
+
check_untyped_defs = true
|
| 123 |
+
no_implicit_optional = true
|
| 124 |
+
warn_redundant_casts = true
|
| 125 |
+
warn_unused_ignores = true
|
| 126 |
+
show_error_codes = true
|
| 127 |
+
exclude = ["tests/", "build/", "dist/"]
|
| 128 |
+
|
| 129 |
+
[[tool.mypy.overrides]]
|
| 130 |
+
module = [
|
| 131 |
+
"transformers",
|
| 132 |
+
"transformers.*",
|
| 133 |
+
"sentence_transformers.*",
|
| 134 |
+
"spacy",
|
| 135 |
+
"spacy.*",
|
| 136 |
+
"torch",
|
| 137 |
+
"torch.*",
|
| 138 |
+
"wandb",
|
| 139 |
+
"wandb.*",
|
| 140 |
+
"datasets",
|
| 141 |
+
"datasets.*",
|
| 142 |
+
"trl",
|
| 143 |
+
"trl.*",
|
| 144 |
+
"unsloth",
|
| 145 |
+
"unsloth.*",
|
| 146 |
+
"vllm",
|
| 147 |
+
"vllm.*",
|
| 148 |
+
"peft",
|
| 149 |
+
"peft.*",
|
| 150 |
+
]
|
| 151 |
+
ignore_missing_imports = true
|
| 152 |
+
|
| 153 |
+
# =============================================================================
|
| 154 |
+
# COVERAGE
|
| 155 |
+
# =============================================================================
|
| 156 |
+
[tool.coverage.run]
|
| 157 |
+
source = ["src/prolewiki_llm"]
|
| 158 |
+
branch = true
|
| 159 |
+
omit = ["*/tests/*"]
|
| 160 |
+
|
| 161 |
+
[tool.coverage.report]
|
| 162 |
+
exclude_lines = [
|
| 163 |
+
"pragma: no cover",
|
| 164 |
+
"if TYPE_CHECKING:",
|
| 165 |
+
"raise NotImplementedError",
|
| 166 |
+
]
|
src/prolewiki_llm/__init__.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI Training Module for Marxist-Leninist LLM Fine-tuning.
|
| 3 |
+
|
| 4 |
+
This module contains reward functions and training utilities for GRPO
|
| 5 |
+
(Group Relative Policy Optimization) fine-tuning on ProleWiki corpus.
|
| 6 |
+
|
| 7 |
+
Components:
|
| 8 |
+
- grpo_rewards: Reward functions for GRPO training
|
| 9 |
+
- wandb_logging: Weights & Biases logging for training observability
|
| 10 |
+
- train_grpo_marxist: Main GRPO training script
|
| 11 |
+
- transform_to_grpo: Dataset transformation utilities
|
| 12 |
+
- convert_to_qwen: Qwen format conversion
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from prolewiki_llm.grpo_rewards import (
|
| 16 |
+
CAPITULATION_PATTERNS,
|
| 17 |
+
CLASS_ANALYSIS_MARKERS,
|
| 18 |
+
CONCEPT_EQUIVALENCES,
|
| 19 |
+
CONFIDENT_CLAIM_PATTERNS,
|
| 20 |
+
DEPTH_MARKERS,
|
| 21 |
+
DISCOURSE_CONNECTIVES,
|
| 22 |
+
EXPLANATORY_PHRASES,
|
| 23 |
+
FIRMNESS_PATTERNS,
|
| 24 |
+
HOLLOW_BUZZWORDS,
|
| 25 |
+
IDEOLOGICAL_CHALLENGE_PATTERNS,
|
| 26 |
+
MARXIST_TERMS,
|
| 27 |
+
QUESTION_WORDS,
|
| 28 |
+
QUOTE_TO_REFUTE_PATTERNS,
|
| 29 |
+
SELF_CRITICISM_MARKERS,
|
| 30 |
+
UNCERTAINTY_PATTERNS,
|
| 31 |
+
completeness_reward,
|
| 32 |
+
debug_print_reward,
|
| 33 |
+
entity_verification_reward,
|
| 34 |
+
epistemic_calibration_reward,
|
| 35 |
+
full_coherence_reward,
|
| 36 |
+
ideological_firmness_reward,
|
| 37 |
+
interconnection_depth_reward,
|
| 38 |
+
match_format_approximately,
|
| 39 |
+
match_format_exactly,
|
| 40 |
+
nli_coherence_reward,
|
| 41 |
+
robust_coherence_reward,
|
| 42 |
+
self_consistency_reward,
|
| 43 |
+
semantic_similarity_reward,
|
| 44 |
+
structural_coherence_reward,
|
| 45 |
+
terminology_reward,
|
| 46 |
+
topic_relevance_reward,
|
| 47 |
+
)
|
| 48 |
+
from prolewiki_llm.wandb_logging import (
|
| 49 |
+
RewardSample,
|
| 50 |
+
WandbSampleLogger,
|
| 51 |
+
create_logging_reward,
|
| 52 |
+
finish_wandb_logging,
|
| 53 |
+
init_wandb_logging,
|
| 54 |
+
is_wandb_available,
|
| 55 |
+
log_model_checkpoint,
|
| 56 |
+
log_reward_metrics,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
__all__ = [
|
| 60 |
+
"CAPITULATION_PATTERNS",
|
| 61 |
+
"CLASS_ANALYSIS_MARKERS",
|
| 62 |
+
"CONCEPT_EQUIVALENCES",
|
| 63 |
+
"CONFIDENT_CLAIM_PATTERNS",
|
| 64 |
+
"DEPTH_MARKERS",
|
| 65 |
+
"DISCOURSE_CONNECTIVES",
|
| 66 |
+
"EXPLANATORY_PHRASES",
|
| 67 |
+
"FIRMNESS_PATTERNS",
|
| 68 |
+
"HOLLOW_BUZZWORDS",
|
| 69 |
+
"IDEOLOGICAL_CHALLENGE_PATTERNS",
|
| 70 |
+
"MARXIST_TERMS",
|
| 71 |
+
"QUESTION_WORDS",
|
| 72 |
+
"QUOTE_TO_REFUTE_PATTERNS",
|
| 73 |
+
"SELF_CRITICISM_MARKERS",
|
| 74 |
+
"UNCERTAINTY_PATTERNS",
|
| 75 |
+
"RewardSample",
|
| 76 |
+
"WandbSampleLogger",
|
| 77 |
+
"completeness_reward",
|
| 78 |
+
"create_logging_reward",
|
| 79 |
+
"debug_print_reward",
|
| 80 |
+
"entity_verification_reward",
|
| 81 |
+
"epistemic_calibration_reward",
|
| 82 |
+
"finish_wandb_logging",
|
| 83 |
+
"full_coherence_reward",
|
| 84 |
+
"ideological_firmness_reward",
|
| 85 |
+
"init_wandb_logging",
|
| 86 |
+
"interconnection_depth_reward",
|
| 87 |
+
"is_wandb_available",
|
| 88 |
+
"log_model_checkpoint",
|
| 89 |
+
"log_reward_metrics",
|
| 90 |
+
"match_format_approximately",
|
| 91 |
+
"match_format_exactly",
|
| 92 |
+
"nli_coherence_reward",
|
| 93 |
+
"robust_coherence_reward",
|
| 94 |
+
"self_consistency_reward",
|
| 95 |
+
"semantic_similarity_reward",
|
| 96 |
+
"structural_coherence_reward",
|
| 97 |
+
"terminology_reward",
|
| 98 |
+
"topic_relevance_reward",
|
| 99 |
+
]
|
src/prolewiki_llm/convert_to_qwen.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert instruction/response pairs to Qwen chat template format."""
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory. You provide accurate information about socialist history, theory, and practice from a Marxist-Leninist perspective. You explain concepts like dialectical materialism, historical materialism, class struggle, anti-colonialism, and socialist construction with clarity and ideological precision."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def convert_to_qwen(input_path: Path, output_path: Path) -> int:
|
| 11 |
+
"""Convert instruction/response JSONL to Qwen chat template format."""
|
| 12 |
+
count = 0
|
| 13 |
+
with open(input_path) as infile, open(output_path, "w") as outfile:
|
| 14 |
+
for line in infile:
|
| 15 |
+
pair = json.loads(line)
|
| 16 |
+
|
| 17 |
+
# Format for Qwen-2.5 chat template
|
| 18 |
+
text = f"""<|im_start|>system
|
| 19 |
+
{SYSTEM_PROMPT}<|im_end|>
|
| 20 |
+
<|im_start|>user
|
| 21 |
+
{pair['instruction']}<|im_end|>
|
| 22 |
+
<|im_start|>assistant
|
| 23 |
+
{pair['response']}<|im_end|>"""
|
| 24 |
+
|
| 25 |
+
outfile.write(json.dumps({"text": text}) + "\n")
|
| 26 |
+
count += 1
|
| 27 |
+
|
| 28 |
+
return count
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
input_file = Path("training_data/curated_qa.jsonl")
|
| 33 |
+
output_file = Path("training_data/formatted/train_qwen.jsonl")
|
| 34 |
+
|
| 35 |
+
count = convert_to_qwen(input_file, output_file)
|
| 36 |
+
print(f"Converted {count} pairs to Qwen format")
|
| 37 |
+
print(f"Output: {output_file}")
|
src/prolewiki_llm/export_grpo_dataset.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Export the ProleWiki dataset to GRPO training format.
|
| 4 |
+
|
| 5 |
+
Directly reads all JSONL files from the dataset repository and transforms them
|
| 6 |
+
into the GRPO format expected by GRPOTrainer.
|
| 7 |
+
|
| 8 |
+
Input: ../dataset/ (68 JSONL files, 5,297 records total)
|
| 9 |
+
Output: training_data/grpo_dataset.jsonl
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python -m prolewiki_llm.export_grpo_dataset
|
| 13 |
+
python -m prolewiki_llm.export_grpo_dataset --dataset-path /path/to/dataset
|
| 14 |
+
python -m prolewiki_llm.export_grpo_dataset --output training_data/grpo_dataset.jsonl
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# System prompt for GRPO training (must match dataset.py)
|
| 25 |
+
_SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
|
| 26 |
+
Think through political theory questions using dialectical materialist analysis.
|
| 27 |
+
Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def find_dataset_path() -> Path:
|
| 31 |
+
"""Find the dataset directory relative to this project.
|
| 32 |
+
|
| 33 |
+
Looks for the dataset in standard locations:
|
| 34 |
+
1. ../dataset (sibling directory)
|
| 35 |
+
2. ../../dataset (one level up)
|
| 36 |
+
3. Environment variable PROLEWIKI_DATASET_PATH
|
| 37 |
+
"""
|
| 38 |
+
import os
|
| 39 |
+
|
| 40 |
+
# Check environment variable first
|
| 41 |
+
env_path = os.environ.get("PROLEWIKI_DATASET_PATH")
|
| 42 |
+
if env_path:
|
| 43 |
+
path = Path(env_path)
|
| 44 |
+
if path.exists():
|
| 45 |
+
return path
|
| 46 |
+
|
| 47 |
+
# Try relative paths from this file's location
|
| 48 |
+
this_file = Path(__file__).resolve()
|
| 49 |
+
project_root = this_file.parent.parent.parent.parent # llm/
|
| 50 |
+
|
| 51 |
+
candidates = [
|
| 52 |
+
project_root.parent / "dataset", # ../dataset (sibling)
|
| 53 |
+
project_root / "dataset", # dataset/ within llm
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
for candidate in candidates:
|
| 57 |
+
if candidate.exists() and (candidate / "sources").exists():
|
| 58 |
+
return candidate
|
| 59 |
+
|
| 60 |
+
raise FileNotFoundError(
|
| 61 |
+
"Could not find dataset directory. Set PROLEWIKI_DATASET_PATH "
|
| 62 |
+
"environment variable or ensure dataset is at ../dataset/"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def iter_jsonl_files(dataset_path: Path) -> list[Path]:
|
| 67 |
+
"""Find all JSONL files in sources/ and synthetic/ directories."""
|
| 68 |
+
files: list[Path] = []
|
| 69 |
+
|
| 70 |
+
sources_dir = dataset_path / "sources"
|
| 71 |
+
if sources_dir.exists():
|
| 72 |
+
files.extend(sorted(sources_dir.rglob("*.jsonl")))
|
| 73 |
+
|
| 74 |
+
synthetic_dir = dataset_path / "synthetic"
|
| 75 |
+
if synthetic_dir.exists():
|
| 76 |
+
files.extend(sorted(synthetic_dir.rglob("*.jsonl")))
|
| 77 |
+
|
| 78 |
+
return files
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def transform_to_grpo(record: dict) -> dict | None:
|
| 82 |
+
"""Transform a single record to GRPO format.
|
| 83 |
+
|
| 84 |
+
Input format: {instruction, response, ...}
|
| 85 |
+
Output format: {prompt: [{role, content}, ...], answer}
|
| 86 |
+
"""
|
| 87 |
+
instruction = record.get("instruction")
|
| 88 |
+
response = record.get("response")
|
| 89 |
+
|
| 90 |
+
# Skip records without required fields
|
| 91 |
+
if not instruction or not response:
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"prompt": [
|
| 96 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 97 |
+
{"role": "user", "content": instruction},
|
| 98 |
+
],
|
| 99 |
+
"answer": response,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def export_grpo_dataset(
|
| 104 |
+
dataset_path: Path,
|
| 105 |
+
output_path: Path,
|
| 106 |
+
verbose: bool = True,
|
| 107 |
+
) -> int:
|
| 108 |
+
"""Export dataset to GRPO format by directly reading JSONL files.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
dataset_path: Path to the dataset directory
|
| 112 |
+
output_path: Path to write the GRPO-formatted JSONL file
|
| 113 |
+
verbose: Whether to print progress messages
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Number of records exported
|
| 117 |
+
"""
|
| 118 |
+
if verbose:
|
| 119 |
+
print(f"Loading dataset from: {dataset_path}")
|
| 120 |
+
|
| 121 |
+
# Find all JSONL files
|
| 122 |
+
jsonl_files = iter_jsonl_files(dataset_path)
|
| 123 |
+
|
| 124 |
+
if verbose:
|
| 125 |
+
print(f"Found {len(jsonl_files)} JSONL files")
|
| 126 |
+
|
| 127 |
+
# Ensure output directory exists
|
| 128 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
record_count = 0
|
| 131 |
+
skipped_count = 0
|
| 132 |
+
first_sample = None
|
| 133 |
+
|
| 134 |
+
with open(output_path, "w", encoding="utf-8") as outfile:
|
| 135 |
+
for jsonl_path in jsonl_files:
|
| 136 |
+
with open(jsonl_path, encoding="utf-8") as infile:
|
| 137 |
+
for line_num, line in enumerate(infile, 1):
|
| 138 |
+
line = line.strip()
|
| 139 |
+
if not line:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
record = json.loads(line)
|
| 144 |
+
except json.JSONDecodeError as e:
|
| 145 |
+
if verbose:
|
| 146 |
+
print(f"Warning: Invalid JSON in {jsonl_path}:{line_num}: {e}")
|
| 147 |
+
skipped_count += 1
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
transformed = transform_to_grpo(record)
|
| 151 |
+
if transformed is None:
|
| 152 |
+
skipped_count += 1
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
outfile.write(json.dumps(transformed, ensure_ascii=False) + "\n")
|
| 156 |
+
record_count += 1
|
| 157 |
+
|
| 158 |
+
if first_sample is None:
|
| 159 |
+
first_sample = transformed
|
| 160 |
+
|
| 161 |
+
if verbose:
|
| 162 |
+
print(f"\nExported {record_count:,} records to {output_path}")
|
| 163 |
+
if skipped_count > 0:
|
| 164 |
+
print(f"Skipped {skipped_count:,} invalid/incomplete records")
|
| 165 |
+
|
| 166 |
+
# Show sample
|
| 167 |
+
if first_sample:
|
| 168 |
+
print("\nSample record:")
|
| 169 |
+
print(f" System: {first_sample['prompt'][0]['content'][:60]}...")
|
| 170 |
+
print(f" User: {first_sample['prompt'][1]['content'][:60]}...")
|
| 171 |
+
print(f" Answer: {first_sample['answer'][:60]}...")
|
| 172 |
+
|
| 173 |
+
# Show file size
|
| 174 |
+
size_mb = output_path.stat().st_size / (1024 * 1024)
|
| 175 |
+
print(f"\nOutput file size: {size_mb:.2f} MB")
|
| 176 |
+
|
| 177 |
+
return record_count
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def main() -> int:
|
| 181 |
+
"""CLI entrypoint."""
|
| 182 |
+
parser = argparse.ArgumentParser(
|
| 183 |
+
description="Export ProleWiki dataset to GRPO training format"
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--dataset-path",
|
| 187 |
+
type=Path,
|
| 188 |
+
default=None,
|
| 189 |
+
help="Path to dataset directory (default: auto-detect)",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--output",
|
| 193 |
+
type=Path,
|
| 194 |
+
default=Path("training_data/grpo_dataset.jsonl"),
|
| 195 |
+
help="Output JSONL file path (default: training_data/grpo_dataset.jsonl)",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--quiet",
|
| 199 |
+
action="store_true",
|
| 200 |
+
help="Suppress progress messages",
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
args = parser.parse_args()
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
dataset_path = args.dataset_path or find_dataset_path()
|
| 207 |
+
except FileNotFoundError as e:
|
| 208 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 209 |
+
return 1
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
record_count = export_grpo_dataset(
|
| 213 |
+
dataset_path=dataset_path,
|
| 214 |
+
output_path=args.output,
|
| 215 |
+
verbose=not args.quiet,
|
| 216 |
+
)
|
| 217 |
+
return 0 if record_count > 0 else 1
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 220 |
+
return 1
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
sys.exit(main())
|
src/prolewiki_llm/grpo_rewards.py
ADDED
|
@@ -0,0 +1,1847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO Reward Functions for Marxist-Leninist Q&A Training.
|
| 4 |
+
|
| 5 |
+
These reward functions guide the model toward:
|
| 6 |
+
1. Proper <think>...</think> format
|
| 7 |
+
2. Semantic coherence via NLI (Natural Language Inference)
|
| 8 |
+
3. Structural coherence via dependency parsing
|
| 9 |
+
4. Self-consistency (no internal contradictions)
|
| 10 |
+
5. Appropriate response length/completeness
|
| 11 |
+
|
| 12 |
+
Research basis:
|
| 13 |
+
- NLI as reward: arxiv.org/abs/2508.18212 (Better LM-Based Judging)
|
| 14 |
+
- MO-GRPO normalization: arxiv.org/abs/2509.22047
|
| 15 |
+
- Process rewards: arxiv.org/abs/2508.05170 (Posterior-GRPO)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
from typing import TYPE_CHECKING, Any
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
from sentence_transformers import SentenceTransformer
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from collections.abc import Sequence
|
| 28 |
+
|
| 29 |
+
# =============================================================================
|
| 30 |
+
# GLOBAL SETUP - LAZY LOADING
|
| 31 |
+
# =============================================================================
|
| 32 |
+
|
| 33 |
+
# Lazy-load models to avoid loading at import time
|
| 34 |
+
_embedder: SentenceTransformer | None = None
|
| 35 |
+
_nli_pipeline: Any | None = None
|
| 36 |
+
_spacy_nlp: Any | None = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_embedder() -> SentenceTransformer:
|
| 40 |
+
"""Get or initialize the sentence transformer embedder."""
|
| 41 |
+
global _embedder
|
| 42 |
+
if _embedder is None:
|
| 43 |
+
print("[Reward] Loading sentence-transformers embedder...")
|
| 44 |
+
_embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 45 |
+
return _embedder
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_nli_pipeline() -> Any:
|
| 49 |
+
"""Get or initialize the NLI pipeline (BART-large-MNLI)."""
|
| 50 |
+
global _nli_pipeline
|
| 51 |
+
if _nli_pipeline is None:
|
| 52 |
+
print("[Reward] Loading NLI model (bart-large-mnli)...")
|
| 53 |
+
from transformers import pipeline
|
| 54 |
+
|
| 55 |
+
_nli_pipeline = pipeline(
|
| 56 |
+
"text-classification",
|
| 57 |
+
model="facebook/bart-large-mnli",
|
| 58 |
+
device="cuda" if _cuda_available() else "cpu",
|
| 59 |
+
)
|
| 60 |
+
return _nli_pipeline
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_spacy_nlp() -> Any:
|
| 64 |
+
"""Get or initialize spaCy NLP pipeline.
|
| 65 |
+
|
| 66 |
+
Uses en_core_web_trf (transformer-based) for superior semantic understanding.
|
| 67 |
+
Falls back to en_core_web_md (word vectors) or en_core_web_sm if unavailable.
|
| 68 |
+
"""
|
| 69 |
+
global _spacy_nlp
|
| 70 |
+
if _spacy_nlp is None:
|
| 71 |
+
import spacy
|
| 72 |
+
|
| 73 |
+
# Try transformer model first (best semantic understanding)
|
| 74 |
+
models_to_try = ["en_core_web_trf", "en_core_web_md", "en_core_web_sm"]
|
| 75 |
+
|
| 76 |
+
for model_name in models_to_try:
|
| 77 |
+
try:
|
| 78 |
+
print(f"[Reward] Loading spaCy model: {model_name}...")
|
| 79 |
+
_spacy_nlp = spacy.load(model_name)
|
| 80 |
+
print(f"[Reward] Loaded {model_name} successfully")
|
| 81 |
+
break
|
| 82 |
+
except OSError:
|
| 83 |
+
print(f"[Reward] {model_name} not found, trying next...")
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
if _spacy_nlp is None:
|
| 87 |
+
raise OSError(
|
| 88 |
+
"No spaCy model found. Install one with:\n"
|
| 89 |
+
" uv pip install en_core_web_trf@https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl"
|
| 90 |
+
)
|
| 91 |
+
return _spacy_nlp
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _cuda_available() -> bool:
|
| 95 |
+
"""Check if CUDA is available."""
|
| 96 |
+
try:
|
| 97 |
+
import torch
|
| 98 |
+
|
| 99 |
+
return bool(torch.cuda.is_available())
|
| 100 |
+
except ImportError:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Reasoning format tokens (DeepSeek-R1 style)
|
| 105 |
+
REASONING_START = "<think>"
|
| 106 |
+
REASONING_END = "</think>"
|
| 107 |
+
|
| 108 |
+
# Regex to match format
|
| 109 |
+
SOLUTION_END_REGEX = re.compile(rf"{REASONING_END}(.*)", re.DOTALL)
|
| 110 |
+
|
| 111 |
+
# Marxist terminology for vocabulary reward
|
| 112 |
+
MARXIST_TERMS: set[str] = {
|
| 113 |
+
# Core concepts
|
| 114 |
+
"dialectical",
|
| 115 |
+
"materialism",
|
| 116 |
+
"historical materialism",
|
| 117 |
+
"dialectical materialism",
|
| 118 |
+
# Classes
|
| 119 |
+
"bourgeoisie",
|
| 120 |
+
"proletariat",
|
| 121 |
+
"petty bourgeois",
|
| 122 |
+
"petty bourgeoisie",
|
| 123 |
+
"lumpenproletariat",
|
| 124 |
+
"working class",
|
| 125 |
+
"ruling class",
|
| 126 |
+
# Class struggle
|
| 127 |
+
"class struggle",
|
| 128 |
+
"class consciousness",
|
| 129 |
+
"class war",
|
| 130 |
+
"class conflict",
|
| 131 |
+
# Political economy
|
| 132 |
+
"surplus value",
|
| 133 |
+
"commodity",
|
| 134 |
+
"use value",
|
| 135 |
+
"exchange value",
|
| 136 |
+
"labor power",
|
| 137 |
+
"means of production",
|
| 138 |
+
"relations of production",
|
| 139 |
+
"forces of production",
|
| 140 |
+
"mode of production",
|
| 141 |
+
"primitive accumulation",
|
| 142 |
+
"exploitation",
|
| 143 |
+
"capital accumulation",
|
| 144 |
+
# Imperialism
|
| 145 |
+
"imperialism",
|
| 146 |
+
"colonialism",
|
| 147 |
+
"neo-colonialism",
|
| 148 |
+
"settler colonialism",
|
| 149 |
+
"national liberation",
|
| 150 |
+
"self-determination",
|
| 151 |
+
# State and revolution
|
| 152 |
+
"dictatorship of the proletariat",
|
| 153 |
+
"vanguard",
|
| 154 |
+
"vanguard party",
|
| 155 |
+
"democratic centralism",
|
| 156 |
+
"withering away of the state",
|
| 157 |
+
"proletarian dictatorship",
|
| 158 |
+
# Ideology
|
| 159 |
+
"hegemony",
|
| 160 |
+
"superstructure",
|
| 161 |
+
"base",
|
| 162 |
+
"ideology",
|
| 163 |
+
"false consciousness",
|
| 164 |
+
# Revisionism
|
| 165 |
+
"revisionism",
|
| 166 |
+
"opportunism",
|
| 167 |
+
"reformism",
|
| 168 |
+
"social democracy",
|
| 169 |
+
"ultra-leftism",
|
| 170 |
+
# Alienation
|
| 171 |
+
"alienation",
|
| 172 |
+
"fetishism",
|
| 173 |
+
"commodity fetishism",
|
| 174 |
+
"reification",
|
| 175 |
+
# Historical
|
| 176 |
+
"paris commune",
|
| 177 |
+
"october revolution",
|
| 178 |
+
"bolshevik",
|
| 179 |
+
"menshevik",
|
| 180 |
+
# Anti-colonial
|
| 181 |
+
"decolonization",
|
| 182 |
+
"third world",
|
| 183 |
+
"global south",
|
| 184 |
+
"national bourgeoisie",
|
| 185 |
+
"comprador",
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# =============================================================================
|
| 190 |
+
# FORMAT REWARDS (from original notebook)
|
| 191 |
+
# =============================================================================
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def match_format_exactly(
|
| 195 |
+
completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
|
| 196 |
+
) -> list[float]:
|
| 197 |
+
"""
|
| 198 |
+
Reward +3.0 if response contains proper </think> tag.
|
| 199 |
+
|
| 200 |
+
This encourages the model to use the reasoning format.
|
| 201 |
+
"""
|
| 202 |
+
scores: list[float] = []
|
| 203 |
+
for completion in completions:
|
| 204 |
+
score = 0.0
|
| 205 |
+
response = completion[0]["content"]
|
| 206 |
+
# Match if format is seen exactly
|
| 207 |
+
if SOLUTION_END_REGEX.search(response) is not None:
|
| 208 |
+
score += 3.0
|
| 209 |
+
scores.append(score)
|
| 210 |
+
return scores
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def match_format_approximately(
|
| 214 |
+
completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
|
| 215 |
+
) -> list[float]:
|
| 216 |
+
"""
|
| 217 |
+
Reward partial format matching.
|
| 218 |
+
|
| 219 |
+
+0.5 for exactly one <think> tag
|
| 220 |
+
+0.5 for exactly one </think> tag
|
| 221 |
+
-1.0 for multiple or missing tags
|
| 222 |
+
"""
|
| 223 |
+
scores: list[float] = []
|
| 224 |
+
for completion in completions:
|
| 225 |
+
score = 0.0
|
| 226 |
+
response = completion[0]["content"]
|
| 227 |
+
|
| 228 |
+
# Check for proper tag counts
|
| 229 |
+
start_count = response.count(REASONING_START)
|
| 230 |
+
end_count = response.count(REASONING_END)
|
| 231 |
+
|
| 232 |
+
score += 0.5 if start_count == 1 else -1.0
|
| 233 |
+
score += 0.5 if end_count == 1 else -1.0
|
| 234 |
+
|
| 235 |
+
scores.append(score)
|
| 236 |
+
return scores
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# =============================================================================
|
| 240 |
+
# SEMANTIC SIMILARITY REWARD
|
| 241 |
+
# =============================================================================
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def semantic_similarity_reward(
|
| 245 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 246 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 247 |
+
answer: Sequence[str],
|
| 248 |
+
**kwargs: object,
|
| 249 |
+
) -> list[float]:
|
| 250 |
+
"""
|
| 251 |
+
Reward responses that are semantically similar to ground truth.
|
| 252 |
+
|
| 253 |
+
Uses sentence-transformers to compute cosine similarity.
|
| 254 |
+
|
| 255 |
+
Scoring:
|
| 256 |
+
> 0.75 similarity: +5.0
|
| 257 |
+
> 0.60 similarity: +3.0
|
| 258 |
+
> 0.45 similarity: +1.0
|
| 259 |
+
> 0.30 similarity: -1.0
|
| 260 |
+
<= 0.30 similarity: -3.0
|
| 261 |
+
"""
|
| 262 |
+
embedder = get_embedder()
|
| 263 |
+
scores: list[float] = []
|
| 264 |
+
|
| 265 |
+
for completion, true_answer in zip(completions, answer, strict=False):
|
| 266 |
+
response = completion[0]["content"]
|
| 267 |
+
|
| 268 |
+
# Extract answer after </think> if present
|
| 269 |
+
if REASONING_END in response:
|
| 270 |
+
response = response.split(REASONING_END, 1)[1].strip()
|
| 271 |
+
|
| 272 |
+
# Handle empty response
|
| 273 |
+
if not response or len(response.strip()) < 10:
|
| 274 |
+
scores.append(-3.0)
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
# Compute cosine similarity
|
| 278 |
+
emb_response = embedder.encode(response, normalize_embeddings=True)
|
| 279 |
+
emb_truth = embedder.encode(true_answer, normalize_embeddings=True)
|
| 280 |
+
similarity = float(np.dot(emb_response, emb_truth))
|
| 281 |
+
|
| 282 |
+
# Scale to reward
|
| 283 |
+
if similarity > 0.75:
|
| 284 |
+
score = 5.0
|
| 285 |
+
elif similarity > 0.60:
|
| 286 |
+
score = 3.0
|
| 287 |
+
elif similarity > 0.45:
|
| 288 |
+
score = 1.0
|
| 289 |
+
elif similarity > 0.30:
|
| 290 |
+
score = -1.0
|
| 291 |
+
else:
|
| 292 |
+
score = -3.0
|
| 293 |
+
|
| 294 |
+
scores.append(score)
|
| 295 |
+
|
| 296 |
+
return scores
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# =============================================================================
|
| 300 |
+
# MARXIST TERMINOLOGY REWARD
|
| 301 |
+
# =============================================================================
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def terminology_reward(
|
| 305 |
+
completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
|
| 306 |
+
) -> list[float]:
|
| 307 |
+
"""
|
| 308 |
+
Reward use of proper Marxist terminology.
|
| 309 |
+
|
| 310 |
+
+0.3 per unique term found, capped at +2.0
|
| 311 |
+
|
| 312 |
+
NOTE: This is a shallow reward that can be gamed with "word soup".
|
| 313 |
+
Consider using nli_coherence_reward or structural_coherence_reward
|
| 314 |
+
for more robust evaluation.
|
| 315 |
+
"""
|
| 316 |
+
scores: list[float] = []
|
| 317 |
+
|
| 318 |
+
for completion in completions:
|
| 319 |
+
response = completion[0]["content"].lower()
|
| 320 |
+
|
| 321 |
+
# Count unique terms present
|
| 322 |
+
term_count = sum(1 for term in MARXIST_TERMS if term in response)
|
| 323 |
+
|
| 324 |
+
# Reward: 0.3 per term, capped at 2.0
|
| 325 |
+
score = min(term_count * 0.3, 2.0)
|
| 326 |
+
scores.append(score)
|
| 327 |
+
|
| 328 |
+
return scores
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# =============================================================================
|
| 332 |
+
# NLI-BASED COHERENCE REWARD (Research-backed)
|
| 333 |
+
# =============================================================================
|
| 334 |
+
|
| 335 |
+
# Discourse connectives indicating logical structure
|
| 336 |
+
DISCOURSE_CONNECTIVES: set[str] = {
|
| 337 |
+
"because",
|
| 338 |
+
"therefore",
|
| 339 |
+
"thus",
|
| 340 |
+
"hence",
|
| 341 |
+
"consequently",
|
| 342 |
+
"however",
|
| 343 |
+
"although",
|
| 344 |
+
"whereas",
|
| 345 |
+
"nevertheless",
|
| 346 |
+
"moreover",
|
| 347 |
+
"furthermore",
|
| 348 |
+
"additionally",
|
| 349 |
+
"specifically",
|
| 350 |
+
"namely",
|
| 351 |
+
"in other words",
|
| 352 |
+
"for example",
|
| 353 |
+
"for instance",
|
| 354 |
+
"such as",
|
| 355 |
+
"as a result",
|
| 356 |
+
"due to",
|
| 357 |
+
"in order to",
|
| 358 |
+
"so that",
|
| 359 |
+
"on the other hand",
|
| 360 |
+
"in contrast",
|
| 361 |
+
"similarly",
|
| 362 |
+
"likewise",
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
# Explanatory phrases that indicate concept is being explained (not just dropped)
|
| 366 |
+
EXPLANATORY_PHRASES: set[str] = {
|
| 367 |
+
# Causal explanations
|
| 368 |
+
"because the",
|
| 369 |
+
"because of",
|
| 370 |
+
"this is because",
|
| 371 |
+
"since the",
|
| 372 |
+
"due to the",
|
| 373 |
+
"as a result of",
|
| 374 |
+
"results from",
|
| 375 |
+
"caused by",
|
| 376 |
+
"leads to",
|
| 377 |
+
"results in",
|
| 378 |
+
"enables",
|
| 379 |
+
"produces",
|
| 380 |
+
# Definitional explanations
|
| 381 |
+
"is defined as",
|
| 382 |
+
"refers to",
|
| 383 |
+
"means that",
|
| 384 |
+
"denotes",
|
| 385 |
+
"that is,",
|
| 386 |
+
"in other words",
|
| 387 |
+
"namely",
|
| 388 |
+
"i.e.",
|
| 389 |
+
# Elaboration
|
| 390 |
+
"specifically",
|
| 391 |
+
"in particular",
|
| 392 |
+
"for example",
|
| 393 |
+
"such as",
|
| 394 |
+
"this means",
|
| 395 |
+
"which means",
|
| 396 |
+
"this implies",
|
| 397 |
+
"therefore",
|
| 398 |
+
# Mechanism explanations
|
| 399 |
+
"this occurs when",
|
| 400 |
+
"this happens because",
|
| 401 |
+
"the mechanism",
|
| 402 |
+
"through the process of",
|
| 403 |
+
"by means of",
|
| 404 |
+
"works by",
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
# Hollow buzzwords: activist jargon that signals superficial analysis when used
|
| 408 |
+
# without substantive explanation. These are NOT Marxist technical terms.
|
| 409 |
+
# Penalty applies when: high density + low depth ratio
|
| 410 |
+
HOLLOW_BUZZWORDS: set[str] = {
|
| 411 |
+
# Vague connectors (non-analytical)
|
| 412 |
+
"interconnected",
|
| 413 |
+
"interrelated",
|
| 414 |
+
"intersects with",
|
| 415 |
+
"it's all connected",
|
| 416 |
+
"everything is connected",
|
| 417 |
+
"systemic",
|
| 418 |
+
# Performative activist language
|
| 419 |
+
"centered",
|
| 420 |
+
"centering",
|
| 421 |
+
"uplift",
|
| 422 |
+
"uplifting",
|
| 423 |
+
"do the work",
|
| 424 |
+
"the work",
|
| 425 |
+
"unpack",
|
| 426 |
+
"unpacking",
|
| 427 |
+
"unlearn",
|
| 428 |
+
"unlearning",
|
| 429 |
+
"hold space",
|
| 430 |
+
"sit with",
|
| 431 |
+
"lean into",
|
| 432 |
+
"problematic",
|
| 433 |
+
"harmful",
|
| 434 |
+
"toxic",
|
| 435 |
+
# Vague abstractions without specifics
|
| 436 |
+
"in a way",
|
| 437 |
+
"sort of",
|
| 438 |
+
"kind of",
|
| 439 |
+
"essentially",
|
| 440 |
+
"basically",
|
| 441 |
+
"generally speaking",
|
| 442 |
+
"broadly",
|
| 443 |
+
# Jargon often used without definition
|
| 444 |
+
"praxis", # Valid Marxist term but often misused without explanation
|
| 445 |
+
"material conditions", # Valid but often used as hand-wave
|
| 446 |
+
"structural",
|
| 447 |
+
"structurally", # Often vague without mechanism
|
| 448 |
+
# Identity-focused without class analysis
|
| 449 |
+
"lived experience",
|
| 450 |
+
"as a", # Often substitutes for analysis
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
# Phrases that signal analytical depth (opposite of hollow)
|
| 454 |
+
DEPTH_MARKERS: set[str] = {
|
| 455 |
+
# Historical specificity
|
| 456 |
+
"in 1",
|
| 457 |
+
"in 2",
|
| 458 |
+
"during the",
|
| 459 |
+
"after the",
|
| 460 |
+
"before the",
|
| 461 |
+
# Citing sources/figures
|
| 462 |
+
"marx argued",
|
| 463 |
+
"lenin wrote",
|
| 464 |
+
"engels noted",
|
| 465 |
+
"gramsci",
|
| 466 |
+
"according to",
|
| 467 |
+
"as marx",
|
| 468 |
+
"as lenin",
|
| 469 |
+
# Concrete examples
|
| 470 |
+
"for example",
|
| 471 |
+
"such as",
|
| 472 |
+
"in the case of",
|
| 473 |
+
"consider",
|
| 474 |
+
# Precise definitions
|
| 475 |
+
"defined as",
|
| 476 |
+
"meaning",
|
| 477 |
+
"specifically",
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
# Marxist concept equivalences for topic matching
|
| 481 |
+
# Maps canonical term -> set of synonyms/equivalents
|
| 482 |
+
CONCEPT_EQUIVALENCES: dict[str, set[str]] = {
|
| 483 |
+
# Class terms
|
| 484 |
+
"bourgeoisie": {"capitalist class", "ruling class", "capitalists", "bourgeois", "capital"},
|
| 485 |
+
"proletariat": {"working class", "workers", "wage laborers", "labor", "labourers"},
|
| 486 |
+
"petty bourgeoisie": {"petit bourgeoisie", "small business", "middle class", "petty bourgeois"},
|
| 487 |
+
"lumpenproletariat": {"lumpen", "underclass", "criminal element"},
|
| 488 |
+
# Economic concepts
|
| 489 |
+
"surplus value": {"unpaid labor", "profit", "extraction", "surplus labor"},
|
| 490 |
+
"means of production": {"productive forces", "capital goods", "factories", "industry"},
|
| 491 |
+
"exploitation": {"extraction", "appropriation", "expropriation"},
|
| 492 |
+
"commodity": {"commodities", "goods", "merchandise"},
|
| 493 |
+
"capital accumulation": {"accumulation", "concentration of capital"},
|
| 494 |
+
"primitive accumulation": {"original accumulation", "so-called primitive accumulation"},
|
| 495 |
+
# Political concepts
|
| 496 |
+
"dictatorship of the proletariat": {
|
| 497 |
+
"workers state",
|
| 498 |
+
"proletarian dictatorship",
|
| 499 |
+
"workers government",
|
| 500 |
+
},
|
| 501 |
+
"vanguard party": {"vanguard", "communist party", "revolutionary party"},
|
| 502 |
+
"democratic centralism": {"party discipline", "centralism"},
|
| 503 |
+
# Imperialism
|
| 504 |
+
"imperialism": {"colonialism", "neo-colonialism", "empire", "colonial"},
|
| 505 |
+
"national liberation": {"decolonization", "anti-colonial", "liberation movement"},
|
| 506 |
+
"settler colonialism": {"settler colony", "colonial settlement"},
|
| 507 |
+
# Ideology
|
| 508 |
+
"revisionism": {"opportunism", "reformism", "right deviation"},
|
| 509 |
+
"hegemony": {"ideological hegemony", "cultural hegemony", "domination"},
|
| 510 |
+
"false consciousness": {"ideology", "mystification"},
|
| 511 |
+
# Philosophy
|
| 512 |
+
"dialectical materialism": {"diamat", "materialist dialectics", "dialectics"},
|
| 513 |
+
"historical materialism": {"histmat", "materialist conception of history"},
|
| 514 |
+
"alienation": {"estrangement", "alienated labor"},
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
# Question words to ignore when extracting topics
|
| 518 |
+
QUESTION_WORDS: set[str] = {"what", "how", "why", "who", "when", "where", "which", "whom"}
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def nli_coherence_reward(
|
| 522 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 523 |
+
answer: Sequence[str],
|
| 524 |
+
**kwargs: object,
|
| 525 |
+
) -> list[float]:
|
| 526 |
+
"""
|
| 527 |
+
Reward responses that logically ENTAIL the ground truth answer.
|
| 528 |
+
|
| 529 |
+
Uses Natural Language Inference (facebook/bart-large-mnli) to check
|
| 530 |
+
if the response is logically consistent with the expected answer.
|
| 531 |
+
|
| 532 |
+
This defeats "word soup" attacks because random terminology won't
|
| 533 |
+
logically entail anything - it will be classified as NEUTRAL.
|
| 534 |
+
|
| 535 |
+
Scoring:
|
| 536 |
+
entailment: +3.0 (response supports/implies ground truth)
|
| 537 |
+
neutral: -1.0 (response is off-topic or incoherent)
|
| 538 |
+
contradiction: -3.0 (response contradicts ground truth)
|
| 539 |
+
|
| 540 |
+
Research basis: arxiv.org/abs/2508.18212
|
| 541 |
+
"""
|
| 542 |
+
nli = get_nli_pipeline()
|
| 543 |
+
scores: list[float] = []
|
| 544 |
+
|
| 545 |
+
for completion, true_answer in zip(completions, answer, strict=False):
|
| 546 |
+
response = completion[0]["content"]
|
| 547 |
+
|
| 548 |
+
# Extract answer part after </think>
|
| 549 |
+
if REASONING_END in response:
|
| 550 |
+
response = response.split(REASONING_END, 1)[1].strip()
|
| 551 |
+
|
| 552 |
+
# Handle empty or very short responses
|
| 553 |
+
if not response or len(response.strip()) < 20:
|
| 554 |
+
scores.append(-2.0)
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
# Truncate to model max length (prevent OOM)
|
| 558 |
+
response_truncated = response[:512]
|
| 559 |
+
truth_truncated = true_answer[:512]
|
| 560 |
+
|
| 561 |
+
# NLI classification: premise </s></s> hypothesis
|
| 562 |
+
# We check: Does response entail ground truth?
|
| 563 |
+
try:
|
| 564 |
+
input_text = f"{response_truncated}</s></s>{truth_truncated}"
|
| 565 |
+
result = nli(input_text)[0]
|
| 566 |
+
label = result["label"].lower()
|
| 567 |
+
|
| 568 |
+
if label == "entailment":
|
| 569 |
+
score = 3.0
|
| 570 |
+
elif label == "neutral":
|
| 571 |
+
score = -1.0
|
| 572 |
+
else: # contradiction
|
| 573 |
+
score = -3.0
|
| 574 |
+
|
| 575 |
+
scores.append(score)
|
| 576 |
+
|
| 577 |
+
except Exception as e:
|
| 578 |
+
print(f"[NLI Reward] Error: {e}")
|
| 579 |
+
scores.append(0.0)
|
| 580 |
+
|
| 581 |
+
return scores
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def self_consistency_reward(
|
| 585 |
+
completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
|
| 586 |
+
) -> list[float]:
|
| 587 |
+
"""
|
| 588 |
+
Reward responses that are internally self-consistent.
|
| 589 |
+
|
| 590 |
+
Checks if any sentence in the response CONTRADICTS another sentence.
|
| 591 |
+
This avoids external ideological bias by only checking within-document
|
| 592 |
+
coherence.
|
| 593 |
+
|
| 594 |
+
Scoring:
|
| 595 |
+
No contradictions found: +1.0
|
| 596 |
+
Internal contradiction detected: -2.0
|
| 597 |
+
|
| 598 |
+
Research basis: arxiv.org/abs/2508.05170 (process-based rewards)
|
| 599 |
+
"""
|
| 600 |
+
nli = get_nli_pipeline()
|
| 601 |
+
nlp = get_spacy_nlp()
|
| 602 |
+
scores: list[float] = []
|
| 603 |
+
|
| 604 |
+
for completion in completions:
|
| 605 |
+
response = completion[0]["content"]
|
| 606 |
+
|
| 607 |
+
# Parse into sentences
|
| 608 |
+
doc = nlp(response)
|
| 609 |
+
sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]
|
| 610 |
+
|
| 611 |
+
# Need at least 2 sentences to check consistency
|
| 612 |
+
if len(sentences) < 2:
|
| 613 |
+
scores.append(0.0)
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
# Check pairs of sentences for contradictions
|
| 617 |
+
# (Only check adjacent and near-adjacent to limit compute)
|
| 618 |
+
has_contradiction = False
|
| 619 |
+
max_pairs_to_check = 10
|
| 620 |
+
pairs_checked = 0
|
| 621 |
+
|
| 622 |
+
for i, sent_a in enumerate(sentences[:-1]):
|
| 623 |
+
if pairs_checked >= max_pairs_to_check:
|
| 624 |
+
break
|
| 625 |
+
# Check against next 2 sentences
|
| 626 |
+
for j in range(i + 1, min(i + 3, len(sentences))):
|
| 627 |
+
sent_b = sentences[j]
|
| 628 |
+
try:
|
| 629 |
+
input_text = f"{sent_a[:256]}</s></s>{sent_b[:256]}"
|
| 630 |
+
result = nli(input_text)[0]
|
| 631 |
+
if result["label"].lower() == "contradiction":
|
| 632 |
+
has_contradiction = True
|
| 633 |
+
break
|
| 634 |
+
pairs_checked += 1
|
| 635 |
+
except Exception:
|
| 636 |
+
pass
|
| 637 |
+
if has_contradiction:
|
| 638 |
+
break
|
| 639 |
+
|
| 640 |
+
if has_contradiction:
|
| 641 |
+
scores.append(-2.0)
|
| 642 |
+
else:
|
| 643 |
+
scores.append(1.0)
|
| 644 |
+
|
| 645 |
+
return scores
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def structural_coherence_reward(
|
| 649 |
+
completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
|
| 650 |
+
) -> list[float]:
|
| 651 |
+
"""
|
| 652 |
+
Reward responses with proper linguistic structure.
|
| 653 |
+
|
| 654 |
+
Uses spaCy dependency parsing to verify:
|
| 655 |
+
1. Marxist terms appear in meaningful syntactic roles (subject, object)
|
| 656 |
+
2. Response contains logical discourse connectives
|
| 657 |
+
3. Response has proper sentence structure (not word soup)
|
| 658 |
+
|
| 659 |
+
This defeats word soup because random terms won't be in subject/object
|
| 660 |
+
positions - they'll be parsed as fragments.
|
| 661 |
+
|
| 662 |
+
Scoring:
|
| 663 |
+
+0.3 per term in subject/object position (max +1.5)
|
| 664 |
+
+0.2 per discourse connective (max +1.0)
|
| 665 |
+
-1.0 if no complete sentences detected
|
| 666 |
+
|
| 667 |
+
Research basis: spaCy dependency parsing for coherence evaluation
|
| 668 |
+
"""
|
| 669 |
+
nlp = get_spacy_nlp()
|
| 670 |
+
scores: list[float] = []
|
| 671 |
+
|
| 672 |
+
for completion in completions:
|
| 673 |
+
response = completion[0]["content"]
|
| 674 |
+
doc = nlp(response)
|
| 675 |
+
score = 0.0
|
| 676 |
+
|
| 677 |
+
# Check 1: Are there actual sentences?
|
| 678 |
+
sentences = list(doc.sents)
|
| 679 |
+
if len(sentences) < 1:
|
| 680 |
+
scores.append(-1.0)
|
| 681 |
+
continue
|
| 682 |
+
|
| 683 |
+
# Check 2: Marxist terms in meaningful syntactic roles
|
| 684 |
+
terms_in_context = 0
|
| 685 |
+
response_lower = response.lower()
|
| 686 |
+
|
| 687 |
+
for term in MARXIST_TERMS:
|
| 688 |
+
if term not in response_lower:
|
| 689 |
+
continue
|
| 690 |
+
|
| 691 |
+
# Find tokens matching this term
|
| 692 |
+
for token in doc:
|
| 693 |
+
if term in token.text.lower() or (
|
| 694 |
+
token.i + 1 < len(doc)
|
| 695 |
+
and term in f"{token.text} {doc[token.i + 1].text}".lower()
|
| 696 |
+
):
|
| 697 |
+
# Reward if token is in a meaningful syntactic role
|
| 698 |
+
if token.dep_ in (
|
| 699 |
+
"nsubj", # nominal subject
|
| 700 |
+
"nsubjpass", # passive nominal subject
|
| 701 |
+
"dobj", # direct object
|
| 702 |
+
"pobj", # object of preposition
|
| 703 |
+
"attr", # attribute
|
| 704 |
+
"appos", # appositional modifier
|
| 705 |
+
):
|
| 706 |
+
terms_in_context += 1
|
| 707 |
+
break # Count each term once
|
| 708 |
+
# Also reward if connected to a meaningful verb
|
| 709 |
+
elif token.head.pos_ == "VERB" and token.head.dep_ == "ROOT":
|
| 710 |
+
terms_in_context += 1
|
| 711 |
+
break
|
| 712 |
+
|
| 713 |
+
score += min(terms_in_context * 0.3, 1.5)
|
| 714 |
+
|
| 715 |
+
# Check 3: Discourse connectives (indicates logical structure)
|
| 716 |
+
connective_count = sum(1 for conn in DISCOURSE_CONNECTIVES if conn in response_lower)
|
| 717 |
+
score += min(connective_count * 0.2, 1.0)
|
| 718 |
+
|
| 719 |
+
scores.append(score)
|
| 720 |
+
|
| 721 |
+
return scores
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
# =============================================================================
|
| 725 |
+
# COMBINED ROBUST COHERENCE REWARD
|
| 726 |
+
# =============================================================================
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def robust_coherence_reward(
|
| 730 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 731 |
+
answer: Sequence[str],
|
| 732 |
+
**kwargs: object,
|
| 733 |
+
) -> list[float]:
|
| 734 |
+
"""
|
| 735 |
+
Multi-layered coherence check combining NLI, self-consistency, and structure.
|
| 736 |
+
|
| 737 |
+
This is the recommended reward function for robust evaluation that defeats
|
| 738 |
+
reward hacking via word soup or other adversarial strategies.
|
| 739 |
+
|
| 740 |
+
Layers:
|
| 741 |
+
1. NLI coherence: Does response entail ground truth?
|
| 742 |
+
2. Self-consistency: Does response contradict itself?
|
| 743 |
+
3. Structural coherence: Are terms used in meaningful syntactic roles?
|
| 744 |
+
|
| 745 |
+
Scoring (combined):
|
| 746 |
+
NLI entailment + self-consistent + good structure: up to +5.5
|
| 747 |
+
NLI neutral or contradiction: -1.0 to -3.0
|
| 748 |
+
Internal contradiction: -2.0
|
| 749 |
+
Word soup (no structure): -1.0
|
| 750 |
+
"""
|
| 751 |
+
# Get individual scores
|
| 752 |
+
nli_scores = nli_coherence_reward(completions, answer, **kwargs)
|
| 753 |
+
consistency_scores = self_consistency_reward(completions, **kwargs)
|
| 754 |
+
structure_scores = structural_coherence_reward(completions, **kwargs)
|
| 755 |
+
|
| 756 |
+
# Combine with weights
|
| 757 |
+
combined: list[float] = []
|
| 758 |
+
for nli, consistency, structure in zip(
|
| 759 |
+
nli_scores, consistency_scores, structure_scores, strict=False
|
| 760 |
+
):
|
| 761 |
+
# If NLI shows contradiction, heavily penalize regardless of other scores
|
| 762 |
+
if nli <= -3.0:
|
| 763 |
+
combined.append(-3.0)
|
| 764 |
+
# If internal contradiction, penalize
|
| 765 |
+
elif consistency <= -2.0:
|
| 766 |
+
combined.append(-2.0)
|
| 767 |
+
# Otherwise combine scores
|
| 768 |
+
else:
|
| 769 |
+
# NLI is primary signal, structure and consistency are bonuses
|
| 770 |
+
total = nli + (consistency * 0.5) + (structure * 0.5)
|
| 771 |
+
combined.append(total)
|
| 772 |
+
|
| 773 |
+
return combined
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
# =============================================================================
|
| 777 |
+
# TOPIC RELEVANCE REWARD (Question-Answer Alignment)
|
| 778 |
+
# =============================================================================
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def _extract_noun_with_preps(token: Any) -> set[str]:
|
| 782 |
+
"""
|
| 783 |
+
Extract a noun and its prepositional phrase children.
|
| 784 |
+
|
| 785 |
+
For "dictatorship of the proletariat", returns:
|
| 786 |
+
{"dictatorship", "proletariat", "dictatorship of the proletariat"}
|
| 787 |
+
"""
|
| 788 |
+
topics: set[str] = set()
|
| 789 |
+
|
| 790 |
+
# Add the main noun (lemmatized)
|
| 791 |
+
if token.pos_ in ("NOUN", "PROPN"):
|
| 792 |
+
topics.add(token.lemma_.lower())
|
| 793 |
+
|
| 794 |
+
# Check for compound modifiers (e.g., "surplus value" where "surplus" is amod)
|
| 795 |
+
modifiers = []
|
| 796 |
+
for child in token.children:
|
| 797 |
+
if child.dep_ in ("compound", "amod") and child.pos_ in ("NOUN", "ADJ"):
|
| 798 |
+
modifiers.append(child.text.lower())
|
| 799 |
+
|
| 800 |
+
if modifiers:
|
| 801 |
+
full_term = " ".join([*modifiers, token.text.lower()])
|
| 802 |
+
topics.add(full_term)
|
| 803 |
+
|
| 804 |
+
# Follow prepositional phrases (e.g., "of the proletariat")
|
| 805 |
+
for child in token.children:
|
| 806 |
+
if child.dep_ == "prep":
|
| 807 |
+
for pobj in child.children:
|
| 808 |
+
if pobj.dep_ == "pobj":
|
| 809 |
+
topics.add(pobj.lemma_.lower())
|
| 810 |
+
# Build full phrase: "dictatorship of the proletariat"
|
| 811 |
+
full_phrase = f"{token.text.lower()} {child.text} {pobj.text.lower()}"
|
| 812 |
+
topics.add(full_phrase)
|
| 813 |
+
# Also get nested preps
|
| 814 |
+
topics.update(_extract_noun_with_preps(pobj))
|
| 815 |
+
|
| 816 |
+
return topics
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def _extract_question_topics(doc: Any) -> set[str]:
|
| 820 |
+
"""
|
| 821 |
+
Extract the core topics from a question using spaCy dependency parsing.
|
| 822 |
+
|
| 823 |
+
For "What is revisionism?", extracts {"revisionism"}
|
| 824 |
+
For "How does imperialism relate to capitalism?", extracts {"imperialism", "capitalism"}
|
| 825 |
+
For "What is the dictatorship of the proletariat?", extracts
|
| 826 |
+
{"dictatorship", "proletariat", "dictatorship of the proletariat"}
|
| 827 |
+
"""
|
| 828 |
+
topics: set[str] = set()
|
| 829 |
+
|
| 830 |
+
# Find the ROOT
|
| 831 |
+
root = None
|
| 832 |
+
for token in doc:
|
| 833 |
+
if token.dep_ == "ROOT":
|
| 834 |
+
root = token
|
| 835 |
+
break
|
| 836 |
+
|
| 837 |
+
if root:
|
| 838 |
+
# Extract from ROOT's children
|
| 839 |
+
for child in root.children:
|
| 840 |
+
# nsubj: "What is [revisionism]?" - revisionism is subject
|
| 841 |
+
# dobj: "Explain [the concept]" - concept is direct object
|
| 842 |
+
# attr: less common but possible
|
| 843 |
+
# nsubjpass: passive subject
|
| 844 |
+
if child.dep_ in ("nsubj", "dobj", "attr", "nsubjpass"):
|
| 845 |
+
# Skip question words ("What is X" - skip "What")
|
| 846 |
+
if child.text.lower() in QUESTION_WORDS:
|
| 847 |
+
continue
|
| 848 |
+
topics.update(_extract_noun_with_preps(child))
|
| 849 |
+
|
| 850 |
+
# pobj in prep attached to ROOT: "relate to [capitalism]"
|
| 851 |
+
if child.dep_ == "prep":
|
| 852 |
+
for pobj in child.children:
|
| 853 |
+
if pobj.dep_ == "pobj":
|
| 854 |
+
topics.update(_extract_noun_with_preps(pobj))
|
| 855 |
+
|
| 856 |
+
# Fallback: extract all noun chunks except question words
|
| 857 |
+
if not topics:
|
| 858 |
+
for chunk in doc.noun_chunks:
|
| 859 |
+
root_text = chunk.root.text.lower()
|
| 860 |
+
if root_text not in QUESTION_WORDS:
|
| 861 |
+
topics.add(chunk.root.lemma_.lower())
|
| 862 |
+
# Also add full chunk for multi-word terms
|
| 863 |
+
chunk_text = chunk.text.lower().strip()
|
| 864 |
+
if " " in chunk_text:
|
| 865 |
+
topics.add(chunk_text)
|
| 866 |
+
|
| 867 |
+
# Final cleanup: remove question words that might have slipped through
|
| 868 |
+
topics = {t for t in topics if t not in QUESTION_WORDS}
|
| 869 |
+
|
| 870 |
+
return topics
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
def _extract_answer_topics(doc: Any) -> set[str]:
|
| 874 |
+
"""
|
| 875 |
+
Extract topics discussed in an answer using spaCy.
|
| 876 |
+
|
| 877 |
+
Returns lemmatized noun phrases and named entities.
|
| 878 |
+
Strips determiners (the, a, an) for better matching.
|
| 879 |
+
"""
|
| 880 |
+
topics: set[str] = set()
|
| 881 |
+
|
| 882 |
+
# Determiners to strip from multi-word phrases
|
| 883 |
+
determiners = {"the", "a", "an", "this", "that", "these", "those"}
|
| 884 |
+
|
| 885 |
+
# Get noun chunk roots (lemmatized)
|
| 886 |
+
for chunk in doc.noun_chunks:
|
| 887 |
+
topics.add(chunk.root.lemma_.lower())
|
| 888 |
+
|
| 889 |
+
# Multi-word terms (strip leading determiners)
|
| 890 |
+
words = chunk.text.lower().strip().split()
|
| 891 |
+
if words and words[0] in determiners:
|
| 892 |
+
words = words[1:]
|
| 893 |
+
chunk_text = " ".join(words)
|
| 894 |
+
|
| 895 |
+
if " " in chunk_text and len(chunk_text) < 50:
|
| 896 |
+
topics.add(chunk_text)
|
| 897 |
+
|
| 898 |
+
# Get named entities
|
| 899 |
+
for ent in doc.ents:
|
| 900 |
+
ent_text = ent.text.lower()
|
| 901 |
+
# Strip leading determiners from entities too
|
| 902 |
+
words = ent_text.split()
|
| 903 |
+
if words and words[0] in determiners:
|
| 904 |
+
words = words[1:]
|
| 905 |
+
topics.add(" ".join(words))
|
| 906 |
+
|
| 907 |
+
return topics
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
def _expand_with_synonyms(topics: set[str]) -> set[str]:
|
| 911 |
+
"""
|
| 912 |
+
Expand a set of topics with Marxist concept synonyms.
|
| 913 |
+
|
| 914 |
+
If "bourgeoisie" is in topics, also adds "capitalist class", "ruling class", etc.
|
| 915 |
+
"""
|
| 916 |
+
expanded = set(topics)
|
| 917 |
+
|
| 918 |
+
for topic in topics:
|
| 919 |
+
# Check if topic matches any canonical term
|
| 920 |
+
if topic in CONCEPT_EQUIVALENCES:
|
| 921 |
+
expanded.update(CONCEPT_EQUIVALENCES[topic])
|
| 922 |
+
# Check if topic matches any synonym (reverse lookup)
|
| 923 |
+
for canonical, synonyms in CONCEPT_EQUIVALENCES.items():
|
| 924 |
+
if topic in synonyms or topic == canonical:
|
| 925 |
+
expanded.add(canonical)
|
| 926 |
+
expanded.update(synonyms)
|
| 927 |
+
|
| 928 |
+
return expanded
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
def _compute_topic_coverage(q_topics: set[str], a_topics: set[str], nlp: Any) -> float:
|
| 932 |
+
"""
|
| 933 |
+
Compute how well answer topics cover question topics.
|
| 934 |
+
|
| 935 |
+
Uses:
|
| 936 |
+
1. Direct lemma matching
|
| 937 |
+
2. Expanded synonym matching
|
| 938 |
+
3. spaCy word vector similarity (fallback)
|
| 939 |
+
|
| 940 |
+
Returns coverage score 0.0 to 1.0
|
| 941 |
+
"""
|
| 942 |
+
if not q_topics:
|
| 943 |
+
return 0.5 # Can't evaluate, neutral
|
| 944 |
+
|
| 945 |
+
# Expand question topics with synonyms
|
| 946 |
+
q_expanded = _expand_with_synonyms(q_topics)
|
| 947 |
+
|
| 948 |
+
# Direct/synonym match
|
| 949 |
+
matched = q_expanded & a_topics
|
| 950 |
+
direct_coverage = len(matched) / len(q_topics) if q_topics else 0
|
| 951 |
+
|
| 952 |
+
if direct_coverage >= 0.5:
|
| 953 |
+
return min(direct_coverage, 1.0)
|
| 954 |
+
|
| 955 |
+
# Fallback: semantic similarity using spaCy vectors
|
| 956 |
+
# For unmatched q_topics, check if any a_topic is semantically similar
|
| 957 |
+
unmatched_q = q_topics - matched
|
| 958 |
+
semantic_matches = 0
|
| 959 |
+
|
| 960 |
+
for q_topic in unmatched_q:
|
| 961 |
+
q_token = nlp(q_topic)
|
| 962 |
+
if not q_token.has_vector:
|
| 963 |
+
continue
|
| 964 |
+
|
| 965 |
+
best_sim = 0.0
|
| 966 |
+
for a_topic in a_topics:
|
| 967 |
+
a_token = nlp(a_topic)
|
| 968 |
+
if a_token.has_vector:
|
| 969 |
+
sim = q_token.similarity(a_token)
|
| 970 |
+
best_sim = max(best_sim, sim)
|
| 971 |
+
|
| 972 |
+
if best_sim > 0.6: # Threshold for semantic match
|
| 973 |
+
semantic_matches += 1
|
| 974 |
+
|
| 975 |
+
total_matched = len(matched) + semantic_matches
|
| 976 |
+
return min(total_matched / len(q_topics), 1.0) if q_topics else 0.5
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def topic_relevance_reward(
|
| 980 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 981 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 982 |
+
**kwargs: object,
|
| 983 |
+
) -> list[float]:
|
| 984 |
+
"""
|
| 985 |
+
Reward answers that are ON-TOPIC with respect to the question.
|
| 986 |
+
|
| 987 |
+
Implements f(A) ⊆ f(Q) check where f extracts semantic topics:
|
| 988 |
+
1. Extract core topics from question Q using dependency parsing
|
| 989 |
+
2. Expand Q topics with Marxist concept synonyms
|
| 990 |
+
3. Extract topics from answer A
|
| 991 |
+
4. Compute coverage: how many Q topics are addressed in A
|
| 992 |
+
|
| 993 |
+
Scoring:
|
| 994 |
+
> 80% coverage: +2.0 (answer fully addresses question topics)
|
| 995 |
+
> 60% coverage: +1.5 (answer mostly on-topic)
|
| 996 |
+
> 40% coverage: +1.0 (answer partially on-topic)
|
| 997 |
+
> 20% coverage: 0.0 (answer tangentially related)
|
| 998 |
+
<= 20% coverage: -1.5 (answer off-topic)
|
| 999 |
+
|
| 1000 |
+
This reward ensures the model answers WHAT WAS ASKED, not just
|
| 1001 |
+
generates coherent Marxist text about something else.
|
| 1002 |
+
"""
|
| 1003 |
+
nlp = get_spacy_nlp()
|
| 1004 |
+
scores: list[float] = []
|
| 1005 |
+
|
| 1006 |
+
for prompt, completion in zip(prompts, completions, strict=False):
|
| 1007 |
+
# Extract question (last user message)
|
| 1008 |
+
question = prompt[-1]["content"]
|
| 1009 |
+
response = completion[0]["content"]
|
| 1010 |
+
|
| 1011 |
+
# Extract answer part after </think>
|
| 1012 |
+
if REASONING_END in response:
|
| 1013 |
+
response = response.split(REASONING_END, 1)[1].strip()
|
| 1014 |
+
|
| 1015 |
+
# Handle empty response
|
| 1016 |
+
if not response or len(response.strip()) < 20:
|
| 1017 |
+
scores.append(-1.5)
|
| 1018 |
+
continue
|
| 1019 |
+
|
| 1020 |
+
# Parse with spaCy
|
| 1021 |
+
q_doc = nlp(question)
|
| 1022 |
+
a_doc = nlp(response[:2000]) # Limit for performance
|
| 1023 |
+
|
| 1024 |
+
# Extract topics
|
| 1025 |
+
q_topics = _extract_question_topics(q_doc)
|
| 1026 |
+
a_topics = _extract_answer_topics(a_doc)
|
| 1027 |
+
|
| 1028 |
+
# Handle case where no topics extracted from question
|
| 1029 |
+
if not q_topics:
|
| 1030 |
+
# Fallback: just check if answer has substance
|
| 1031 |
+
scores.append(0.5 if len(a_topics) > 3 else 0.0)
|
| 1032 |
+
continue
|
| 1033 |
+
|
| 1034 |
+
# Compute coverage
|
| 1035 |
+
coverage = _compute_topic_coverage(q_topics, a_topics, nlp)
|
| 1036 |
+
|
| 1037 |
+
# Convert to reward score
|
| 1038 |
+
if coverage > 0.8:
|
| 1039 |
+
score = 2.0
|
| 1040 |
+
elif coverage > 0.6:
|
| 1041 |
+
score = 1.5
|
| 1042 |
+
elif coverage > 0.4:
|
| 1043 |
+
score = 1.0
|
| 1044 |
+
elif coverage > 0.2:
|
| 1045 |
+
score = 0.0
|
| 1046 |
+
else:
|
| 1047 |
+
score = -1.5
|
| 1048 |
+
|
| 1049 |
+
scores.append(score)
|
| 1050 |
+
|
| 1051 |
+
return scores
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def full_coherence_reward(
|
| 1055 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 1056 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 1057 |
+
answer: Sequence[str],
|
| 1058 |
+
**kwargs: object,
|
| 1059 |
+
) -> list[float]:
|
| 1060 |
+
"""
|
| 1061 |
+
Complete coherence check: robust_coherence + topic_relevance + depth.
|
| 1062 |
+
|
| 1063 |
+
This is the MOST COMPREHENSIVE reward function, checking:
|
| 1064 |
+
1. NLI coherence (A entails ground truth)
|
| 1065 |
+
2. Self-consistency (A doesn't contradict itself)
|
| 1066 |
+
3. Structural coherence (terms in proper syntactic roles)
|
| 1067 |
+
4. Topic relevance (A addresses what Q asked about)
|
| 1068 |
+
5. Interconnection depth (rewards deep analysis, penalizes buzzword salad)
|
| 1069 |
+
|
| 1070 |
+
Use this for maximum robustness against reward hacking.
|
| 1071 |
+
"""
|
| 1072 |
+
robust_scores = robust_coherence_reward(completions, answer, **kwargs)
|
| 1073 |
+
relevance_scores = topic_relevance_reward(prompts, completions, **kwargs)
|
| 1074 |
+
depth_scores = interconnection_depth_reward(completions, **kwargs)
|
| 1075 |
+
|
| 1076 |
+
combined: list[float] = []
|
| 1077 |
+
for robust, relevance, depth in zip(
|
| 1078 |
+
robust_scores, relevance_scores, depth_scores, strict=False
|
| 1079 |
+
):
|
| 1080 |
+
# If severely off-topic, penalize
|
| 1081 |
+
if relevance <= -1.5:
|
| 1082 |
+
combined.append(-2.0)
|
| 1083 |
+
# If robust check failed badly, use that
|
| 1084 |
+
elif robust <= -2.0:
|
| 1085 |
+
combined.append(robust)
|
| 1086 |
+
# If buzzword salad detected (low depth), penalize
|
| 1087 |
+
elif depth <= -1.5:
|
| 1088 |
+
combined.append(-1.5)
|
| 1089 |
+
# Otherwise combine
|
| 1090 |
+
else:
|
| 1091 |
+
# Robust is primary, relevance and depth are bonuses/penalties
|
| 1092 |
+
total = robust + (relevance * 0.4) + (depth * 0.3)
|
| 1093 |
+
combined.append(total)
|
| 1094 |
+
|
| 1095 |
+
return combined
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
# =============================================================================
|
| 1099 |
+
# INTERCONNECTION DEPTH REWARD (Anti-Buzzword-Salad)
|
| 1100 |
+
# =============================================================================
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
def _count_unique_marxist_concepts(text: str) -> int:
|
| 1104 |
+
"""Count unique Marxist concepts mentioned in text."""
|
| 1105 |
+
text_lower = text.lower()
|
| 1106 |
+
count = 0
|
| 1107 |
+
for term in MARXIST_TERMS:
|
| 1108 |
+
if term in text_lower:
|
| 1109 |
+
count += 1
|
| 1110 |
+
return count
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
def _compute_depth_ratio(text: str) -> float:
|
| 1114 |
+
"""
|
| 1115 |
+
Compute depth ratio: words per unique Marxist concept.
|
| 1116 |
+
|
| 1117 |
+
High ratio = deep analysis (few concepts, well explained)
|
| 1118 |
+
Low ratio = shallow/buzzword soup (many concepts, little explanation)
|
| 1119 |
+
|
| 1120 |
+
Returns:
|
| 1121 |
+
Words per concept, or 100.0 if no Marxist concepts found
|
| 1122 |
+
"""
|
| 1123 |
+
words = len(text.split())
|
| 1124 |
+
concepts = _count_unique_marxist_concepts(text)
|
| 1125 |
+
|
| 1126 |
+
if concepts == 0:
|
| 1127 |
+
return 100.0 # No Marxist concepts = neutral (not shallow)
|
| 1128 |
+
|
| 1129 |
+
return words / concepts
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
def _count_hollow_buzzwords(text: str) -> int:
|
| 1133 |
+
"""Count hollow buzzwords in text."""
|
| 1134 |
+
text_lower = text.lower()
|
| 1135 |
+
count = 0
|
| 1136 |
+
for buzzword in HOLLOW_BUZZWORDS:
|
| 1137 |
+
if buzzword in text_lower:
|
| 1138 |
+
count += 1
|
| 1139 |
+
return count
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
def _count_depth_markers(text: str) -> int:
|
| 1143 |
+
"""Count analytical depth markers in text."""
|
| 1144 |
+
text_lower = text.lower()
|
| 1145 |
+
count = 0
|
| 1146 |
+
for marker in DEPTH_MARKERS:
|
| 1147 |
+
if marker in text_lower:
|
| 1148 |
+
count += 1
|
| 1149 |
+
return count
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
def _count_explanatory_phrases(text: str) -> int:
|
| 1153 |
+
"""Count explanatory phrases in text."""
|
| 1154 |
+
text_lower = text.lower()
|
| 1155 |
+
count = 0
|
| 1156 |
+
for phrase in EXPLANATORY_PHRASES:
|
| 1157 |
+
if phrase in text_lower:
|
| 1158 |
+
count += 1
|
| 1159 |
+
return count
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
def _concepts_have_explanations(text: str) -> tuple[int, int]:
|
| 1163 |
+
"""
|
| 1164 |
+
Check if introduced concepts have nearby explanations.
|
| 1165 |
+
|
| 1166 |
+
Returns:
|
| 1167 |
+
Tuple of (explained_count, unexplained_count)
|
| 1168 |
+
"""
|
| 1169 |
+
nlp = get_spacy_nlp()
|
| 1170 |
+
doc = nlp(text)
|
| 1171 |
+
|
| 1172 |
+
# Get sentences
|
| 1173 |
+
sentences = [sent.text.lower() for sent in doc.sents]
|
| 1174 |
+
|
| 1175 |
+
explained = 0
|
| 1176 |
+
unexplained = 0
|
| 1177 |
+
|
| 1178 |
+
for i, sent in enumerate(sentences):
|
| 1179 |
+
# Check which Marxist concepts appear in this sentence
|
| 1180 |
+
concepts_in_sent = [t for t in MARXIST_TERMS if t in sent]
|
| 1181 |
+
|
| 1182 |
+
for _concept in concepts_in_sent:
|
| 1183 |
+
# Check if explanatory phrase appears in same or adjacent sentence
|
| 1184 |
+
has_explanation = False
|
| 1185 |
+
|
| 1186 |
+
# Check current sentence
|
| 1187 |
+
for phrase in EXPLANATORY_PHRASES:
|
| 1188 |
+
if phrase in sent:
|
| 1189 |
+
has_explanation = True
|
| 1190 |
+
break
|
| 1191 |
+
|
| 1192 |
+
# Check next sentence if exists
|
| 1193 |
+
if not has_explanation and i + 1 < len(sentences):
|
| 1194 |
+
next_sent = sentences[i + 1]
|
| 1195 |
+
for phrase in EXPLANATORY_PHRASES:
|
| 1196 |
+
if phrase in next_sent:
|
| 1197 |
+
has_explanation = True
|
| 1198 |
+
break
|
| 1199 |
+
|
| 1200 |
+
if has_explanation:
|
| 1201 |
+
explained += 1
|
| 1202 |
+
else:
|
| 1203 |
+
unexplained += 1
|
| 1204 |
+
|
| 1205 |
+
return explained, unexplained
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
def interconnection_depth_reward(
|
| 1209 |
+
completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
|
| 1210 |
+
) -> list[float]:
|
| 1211 |
+
"""
|
| 1212 |
+
Reward deep, meaningful interconnections; penalize buzzword salad.
|
| 1213 |
+
|
| 1214 |
+
This reward distinguishes between:
|
| 1215 |
+
- GOOD: "Surplus value relates to imperialism BECAUSE capital export..."
|
| 1216 |
+
- BAD: "Surplus value intersects with imperialism, colonialism, patriarchy..."
|
| 1217 |
+
|
| 1218 |
+
Signals:
|
| 1219 |
+
1. Depth ratio: words per unique Marxist concept
|
| 1220 |
+
- High (>15): Deep analysis, concepts well-explained
|
| 1221 |
+
- Low (<5): Shallow buzzword soup (many concepts crammed together)
|
| 1222 |
+
2. Hollow buzzword density: activist jargon without substance
|
| 1223 |
+
3. Depth markers: citations, examples, historical specificity
|
| 1224 |
+
4. Explanation ratio: concepts with nearby explanatory phrases
|
| 1225 |
+
|
| 1226 |
+
Scoring:
|
| 1227 |
+
Depth ratio > 20: +1.0 (deep analysis)
|
| 1228 |
+
Depth ratio 10-20: +0.5 (adequate depth)
|
| 1229 |
+
Depth ratio < 5: -1.5 (severe buzzword soup)
|
| 1230 |
+
Depth ratio 5-10: -0.5 (shallow)
|
| 1231 |
+
Hollow buzzwords > 2: -0.3 each additional
|
| 1232 |
+
Depth markers present: +0.3 each (max +1.5)
|
| 1233 |
+
Good explanation ratio: +0.5
|
| 1234 |
+
Low explanation ratio with many concepts: -0.5
|
| 1235 |
+
|
| 1236 |
+
Total range: approximately -2.5 to +3.0
|
| 1237 |
+
"""
|
| 1238 |
+
scores: list[float] = []
|
| 1239 |
+
|
| 1240 |
+
for completion in completions:
|
| 1241 |
+
response = completion[0]["content"]
|
| 1242 |
+
|
| 1243 |
+
# Extract answer part after </think>
|
| 1244 |
+
if REASONING_END in response:
|
| 1245 |
+
answer_part = response.split(REASONING_END, 1)[1].strip()
|
| 1246 |
+
else:
|
| 1247 |
+
answer_part = response
|
| 1248 |
+
|
| 1249 |
+
# Skip very short responses (handled by completeness_reward)
|
| 1250 |
+
word_count = len(answer_part.split())
|
| 1251 |
+
if word_count < 20:
|
| 1252 |
+
scores.append(0.0)
|
| 1253 |
+
continue
|
| 1254 |
+
|
| 1255 |
+
score = 0.0
|
| 1256 |
+
concept_count = _count_unique_marxist_concepts(answer_part)
|
| 1257 |
+
|
| 1258 |
+
# Signal 1: Depth ratio (words per concept)
|
| 1259 |
+
# Only penalize if there are concepts to evaluate
|
| 1260 |
+
if concept_count > 0:
|
| 1261 |
+
depth_ratio = word_count / concept_count
|
| 1262 |
+
if depth_ratio > 20:
|
| 1263 |
+
score += 1.0 # Deep analysis
|
| 1264 |
+
elif depth_ratio > 10:
|
| 1265 |
+
score += 0.5 # Adequate depth
|
| 1266 |
+
elif depth_ratio < 5:
|
| 1267 |
+
score -= 1.5 # Severe buzzword soup (many concepts, few words)
|
| 1268 |
+
elif depth_ratio < 10:
|
| 1269 |
+
score -= 0.5 # Shallow
|
| 1270 |
+
|
| 1271 |
+
# Signal 2: Hollow buzzword penalty
|
| 1272 |
+
hollow_count = _count_hollow_buzzwords(answer_part)
|
| 1273 |
+
if hollow_count > 2:
|
| 1274 |
+
# Penalize excess hollow buzzwords
|
| 1275 |
+
penalty = 0.3 * (hollow_count - 2)
|
| 1276 |
+
score -= min(penalty, 1.5) # Cap penalty at -1.5
|
| 1277 |
+
|
| 1278 |
+
# Signal 3: Depth markers bonus
|
| 1279 |
+
depth_marker_count = _count_depth_markers(answer_part)
|
| 1280 |
+
score += min(depth_marker_count * 0.3, 1.5)
|
| 1281 |
+
|
| 1282 |
+
# Signal 4: Explanation ratio
|
| 1283 |
+
explanatory_count = _count_explanatory_phrases(answer_part)
|
| 1284 |
+
|
| 1285 |
+
if concept_count > 0:
|
| 1286 |
+
explanation_ratio = explanatory_count / concept_count
|
| 1287 |
+
if explanation_ratio >= 0.5:
|
| 1288 |
+
score += 0.5 # Good: at least 1 explanation per 2 concepts
|
| 1289 |
+
elif explanation_ratio < 0.1 and concept_count > 5:
|
| 1290 |
+
score -= 0.5 # Bad: many concepts, almost no explanations
|
| 1291 |
+
|
| 1292 |
+
# Clamp final score
|
| 1293 |
+
scores.append(max(min(score, 3.0), -2.5))
|
| 1294 |
+
|
| 1295 |
+
return scores
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
# =============================================================================
|
| 1299 |
+
# RESPONSE COMPLETENESS REWARD
|
| 1300 |
+
# =============================================================================
|
| 1301 |
+
|
| 1302 |
+
|
| 1303 |
+
def completeness_reward(
|
| 1304 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 1305 |
+
answer: Sequence[str],
|
| 1306 |
+
**kwargs: object,
|
| 1307 |
+
) -> list[float]:
|
| 1308 |
+
"""
|
| 1309 |
+
Reward thorough, detailed responses.
|
| 1310 |
+
|
| 1311 |
+
Compares response length to ground truth length.
|
| 1312 |
+
|
| 1313 |
+
Scoring:
|
| 1314 |
+
50-150% of target length: +2.0
|
| 1315 |
+
30-200% of target length: +1.0
|
| 1316 |
+
< 20% (too short): -2.0
|
| 1317 |
+
> 200% (too verbose): -0.5
|
| 1318 |
+
"""
|
| 1319 |
+
scores: list[float] = []
|
| 1320 |
+
|
| 1321 |
+
for completion, true_answer in zip(completions, answer, strict=False):
|
| 1322 |
+
response = completion[0]["content"]
|
| 1323 |
+
|
| 1324 |
+
# Extract answer after </think>
|
| 1325 |
+
if REASONING_END in response:
|
| 1326 |
+
answer_part = response.split(REASONING_END, 1)[1].strip()
|
| 1327 |
+
else:
|
| 1328 |
+
answer_part = response
|
| 1329 |
+
|
| 1330 |
+
answer_len = len(answer_part.split())
|
| 1331 |
+
true_len = len(true_answer.split())
|
| 1332 |
+
|
| 1333 |
+
# Avoid division by zero
|
| 1334 |
+
if true_len == 0:
|
| 1335 |
+
scores.append(0.0)
|
| 1336 |
+
continue
|
| 1337 |
+
|
| 1338 |
+
# Reward responses that are 50-150% of target length
|
| 1339 |
+
ratio = answer_len / true_len
|
| 1340 |
+
|
| 1341 |
+
if 0.5 <= ratio <= 1.5:
|
| 1342 |
+
score = 2.0
|
| 1343 |
+
elif 0.3 <= ratio <= 2.0:
|
| 1344 |
+
score = 1.0
|
| 1345 |
+
elif ratio < 0.2: # Too short
|
| 1346 |
+
score = -2.0
|
| 1347 |
+
else: # Too long (verbose)
|
| 1348 |
+
score = -0.5
|
| 1349 |
+
|
| 1350 |
+
scores.append(score)
|
| 1351 |
+
|
| 1352 |
+
return scores
|
| 1353 |
+
|
| 1354 |
+
|
| 1355 |
+
# =============================================================================
|
| 1356 |
+
# DEBUG REWARD (for monitoring during training)
|
| 1357 |
+
# =============================================================================
|
| 1358 |
+
|
| 1359 |
+
# Global counter for printing samples
|
| 1360 |
+
_PRINT_COUNTER = 0
|
| 1361 |
+
_PRINT_EVERY = 10
|
| 1362 |
+
|
| 1363 |
+
|
| 1364 |
+
def debug_print_reward(
|
| 1365 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 1366 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 1367 |
+
answer: Sequence[str],
|
| 1368 |
+
**kwargs: object,
|
| 1369 |
+
) -> list[float]:
|
| 1370 |
+
"""
|
| 1371 |
+
Print sample outputs periodically for monitoring.
|
| 1372 |
+
|
| 1373 |
+
Returns 0.0 (no effect on training).
|
| 1374 |
+
"""
|
| 1375 |
+
global _PRINT_COUNTER
|
| 1376 |
+
|
| 1377 |
+
if _PRINT_COUNTER % _PRINT_EVERY == 0:
|
| 1378 |
+
question = prompts[0][-1]["content"]
|
| 1379 |
+
response = completions[0][0]["content"]
|
| 1380 |
+
true_answer = answer[0]
|
| 1381 |
+
|
| 1382 |
+
print("=" * 60)
|
| 1383 |
+
print(f"Step {_PRINT_COUNTER}")
|
| 1384 |
+
print(f"Question: {question[:100]}...")
|
| 1385 |
+
print(f"Response: {response[:200]}...")
|
| 1386 |
+
print(f"Expected: {true_answer[:100]}...")
|
| 1387 |
+
print("=" * 60)
|
| 1388 |
+
|
| 1389 |
+
_PRINT_COUNTER += 1
|
| 1390 |
+
|
| 1391 |
+
return [0.0] * len(completions)
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
# =============================================================================
|
| 1395 |
+
# ENTITY VERIFICATION REWARD (Anti-Hallucination)
|
| 1396 |
+
# =============================================================================
|
| 1397 |
+
|
| 1398 |
+
# Lazy-loaded entity whitelist
|
| 1399 |
+
_ENTITY_WHITELIST: set[str] | None = None
|
| 1400 |
+
_ENTITY_WHITELIST_LOWERCASE: set[str] | None = None
|
| 1401 |
+
|
| 1402 |
+
|
| 1403 |
+
def _load_entity_whitelist() -> tuple[set[str], set[str]]:
|
| 1404 |
+
"""Load entity whitelist from JSON file."""
|
| 1405 |
+
global _ENTITY_WHITELIST, _ENTITY_WHITELIST_LOWERCASE
|
| 1406 |
+
|
| 1407 |
+
if _ENTITY_WHITELIST is None or _ENTITY_WHITELIST_LOWERCASE is None:
|
| 1408 |
+
import json
|
| 1409 |
+
from pathlib import Path
|
| 1410 |
+
|
| 1411 |
+
whitelist_path = (
|
| 1412 |
+
Path(__file__).parent.parent.parent.parent
|
| 1413 |
+
/ "training_data"
|
| 1414 |
+
/ "entity_whitelist_clean.json"
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
if whitelist_path.exists():
|
| 1418 |
+
print(f"[Reward] Loading entity whitelist from {whitelist_path}...")
|
| 1419 |
+
with open(whitelist_path, encoding="utf-8") as f:
|
| 1420 |
+
data = json.load(f)
|
| 1421 |
+
_ENTITY_WHITELIST = set(data.get("entities", []))
|
| 1422 |
+
_ENTITY_WHITELIST_LOWERCASE = set(data.get("entities_lowercase", []))
|
| 1423 |
+
print(f"[Reward] Loaded {len(_ENTITY_WHITELIST):,} entities")
|
| 1424 |
+
else:
|
| 1425 |
+
print(f"[Reward] WARNING: Entity whitelist not found at {whitelist_path}")
|
| 1426 |
+
_ENTITY_WHITELIST = set()
|
| 1427 |
+
_ENTITY_WHITELIST_LOWERCASE = set()
|
| 1428 |
+
|
| 1429 |
+
return _ENTITY_WHITELIST, _ENTITY_WHITELIST_LOWERCASE
|
| 1430 |
+
|
| 1431 |
+
|
| 1432 |
+
def _entity_in_whitelist(entity: str) -> bool:
|
| 1433 |
+
"""Check if an entity is in the whitelist (case-insensitive)."""
|
| 1434 |
+
whitelist, whitelist_lower = _load_entity_whitelist()
|
| 1435 |
+
return entity in whitelist or entity.lower() in whitelist_lower
|
| 1436 |
+
|
| 1437 |
+
|
| 1438 |
+
# Patterns that indicate confident factual claims
|
| 1439 |
+
CONFIDENT_CLAIM_PATTERNS = [
|
| 1440 |
+
r"founded in \d{4}",
|
| 1441 |
+
r"established in \d{4}",
|
| 1442 |
+
r"created in \d{4}",
|
| 1443 |
+
r"formed in \d{4}",
|
| 1444 |
+
r"was founded by",
|
| 1445 |
+
r"was established by",
|
| 1446 |
+
r"was created by",
|
| 1447 |
+
r"were founded in",
|
| 1448 |
+
r"were established in",
|
| 1449 |
+
]
|
| 1450 |
+
|
| 1451 |
+
# Patterns that indicate epistemic humility (GOOD)
|
| 1452 |
+
UNCERTAINTY_PATTERNS = [
|
| 1453 |
+
r"I (?:cannot|can't|don't) (?:verify|confirm|find)",
|
| 1454 |
+
r"I (?:don't|do not) have (?:verified |specific )?information",
|
| 1455 |
+
r"I'm not (?:certain|sure|confident)",
|
| 1456 |
+
r"I cannot (?:provide|give) (?:specific |verified )?information",
|
| 1457 |
+
r"I should not (?:fabricate|make up|speculate)",
|
| 1458 |
+
r"(?:could you|can you) (?:provide|share|tell me) (?:more )?context",
|
| 1459 |
+
r"where did you (?:encounter|find|see) this",
|
| 1460 |
+
r"I'm not aware of",
|
| 1461 |
+
r"I don't have (?:details|information) about",
|
| 1462 |
+
]
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
def _extract_potential_entities(text: str) -> list[str]:
|
| 1466 |
+
"""Extract potential organization/person names from text using spaCy NER."""
|
| 1467 |
+
nlp = get_spacy_nlp()
|
| 1468 |
+
doc = nlp(text[:10000]) # Limit to prevent slow processing
|
| 1469 |
+
|
| 1470 |
+
entities = []
|
| 1471 |
+
for ent in doc.ents:
|
| 1472 |
+
if ent.label_ in ("ORG", "PERSON", "GPE", "NORP", "EVENT", "FAC", "WORK_OF_ART"):
|
| 1473 |
+
entities.append(ent.text)
|
| 1474 |
+
|
| 1475 |
+
return entities
|
| 1476 |
+
|
| 1477 |
+
|
| 1478 |
+
def entity_verification_reward(
|
| 1479 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 1480 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 1481 |
+
answer: Sequence[str],
|
| 1482 |
+
**kwargs: object,
|
| 1483 |
+
) -> list[float]:
|
| 1484 |
+
"""
|
| 1485 |
+
Reward epistemic humility, penalize confident claims about unverified entities.
|
| 1486 |
+
|
| 1487 |
+
Scoring:
|
| 1488 |
+
- +2.0: Response expresses uncertainty about unknown entities
|
| 1489 |
+
- +1.0: Response discusses only verified entities
|
| 1490 |
+
- -1.5: Response makes confident claims about unknown entities
|
| 1491 |
+
- -2.5: Response fabricates specific details (dates, founders) about unknown entities
|
| 1492 |
+
|
| 1493 |
+
This reward requires the entity_whitelist_clean.json file in training_data/.
|
| 1494 |
+
"""
|
| 1495 |
+
scores: list[float] = []
|
| 1496 |
+
|
| 1497 |
+
for completion in completions:
|
| 1498 |
+
response = completion[0]["content"]
|
| 1499 |
+
score = 0.0
|
| 1500 |
+
|
| 1501 |
+
# Extract entities mentioned in the response
|
| 1502 |
+
mentioned_entities = _extract_potential_entities(response)
|
| 1503 |
+
|
| 1504 |
+
# Check for unknown entities (not in whitelist)
|
| 1505 |
+
unknown_entities = [e for e in mentioned_entities if not _entity_in_whitelist(e)]
|
| 1506 |
+
|
| 1507 |
+
# Check for uncertainty patterns (epistemic humility)
|
| 1508 |
+
has_uncertainty = any(
|
| 1509 |
+
re.search(pattern, response, re.IGNORECASE) for pattern in UNCERTAINTY_PATTERNS
|
| 1510 |
+
)
|
| 1511 |
+
|
| 1512 |
+
# Check for confident claim patterns
|
| 1513 |
+
has_confident_claims = any(
|
| 1514 |
+
re.search(pattern, response, re.IGNORECASE) for pattern in CONFIDENT_CLAIM_PATTERNS
|
| 1515 |
+
)
|
| 1516 |
+
|
| 1517 |
+
if unknown_entities:
|
| 1518 |
+
# There are entities not in our whitelist
|
| 1519 |
+
if has_uncertainty:
|
| 1520 |
+
# GOOD: Model expresses uncertainty about unknown entities
|
| 1521 |
+
score = 2.0
|
| 1522 |
+
elif has_confident_claims:
|
| 1523 |
+
# BAD: Model makes confident claims about unknown entities
|
| 1524 |
+
# Check if confident claims are near unknown entities
|
| 1525 |
+
score = -2.5
|
| 1526 |
+
else:
|
| 1527 |
+
# NEUTRAL-BAD: Discussing unknown entities without clear uncertainty
|
| 1528 |
+
score = -1.0
|
| 1529 |
+
else:
|
| 1530 |
+
# All entities are verified or no specific entities mentioned
|
| 1531 |
+
# Unnecessary uncertainty = 0.0, normal verified content = 1.0
|
| 1532 |
+
score = 0.0 if has_uncertainty else 1.0
|
| 1533 |
+
|
| 1534 |
+
scores.append(score)
|
| 1535 |
+
|
| 1536 |
+
return scores
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
def epistemic_calibration_reward(
|
| 1540 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 1541 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 1542 |
+
answer: Sequence[str],
|
| 1543 |
+
**kwargs: object,
|
| 1544 |
+
) -> list[float]:
|
| 1545 |
+
"""
|
| 1546 |
+
Combined reward for epistemic calibration.
|
| 1547 |
+
|
| 1548 |
+
Encourages the model to:
|
| 1549 |
+
1. Be confident about entities in the verified corpus
|
| 1550 |
+
2. Express uncertainty about entities NOT in the corpus
|
| 1551 |
+
3. Avoid fabricating specific details about unknown entities
|
| 1552 |
+
|
| 1553 |
+
This is a lighter-weight version of entity_verification_reward
|
| 1554 |
+
that doesn't require NER - just pattern matching.
|
| 1555 |
+
|
| 1556 |
+
Scoring:
|
| 1557 |
+
- +1.5: Has uncertainty phrases (regardless of content)
|
| 1558 |
+
- -1.5: Has confident claim patterns + no uncertainty
|
| 1559 |
+
- 0.0: Neutral
|
| 1560 |
+
"""
|
| 1561 |
+
scores: list[float] = []
|
| 1562 |
+
|
| 1563 |
+
for completion in completions:
|
| 1564 |
+
response = completion[0]["content"]
|
| 1565 |
+
|
| 1566 |
+
# Check patterns
|
| 1567 |
+
has_uncertainty = any(
|
| 1568 |
+
re.search(pattern, response, re.IGNORECASE) for pattern in UNCERTAINTY_PATTERNS
|
| 1569 |
+
)
|
| 1570 |
+
|
| 1571 |
+
has_confident_claims = any(
|
| 1572 |
+
re.search(pattern, response, re.IGNORECASE) for pattern in CONFIDENT_CLAIM_PATTERNS
|
| 1573 |
+
)
|
| 1574 |
+
|
| 1575 |
+
if has_uncertainty:
|
| 1576 |
+
# Model expressed appropriate caution
|
| 1577 |
+
scores.append(1.5)
|
| 1578 |
+
elif has_confident_claims:
|
| 1579 |
+
# Model made confident claims - could be hallucination
|
| 1580 |
+
# This is a soft penalty; entity_verification_reward does deeper checking
|
| 1581 |
+
scores.append(-0.5)
|
| 1582 |
+
else:
|
| 1583 |
+
# Neutral
|
| 1584 |
+
scores.append(0.0)
|
| 1585 |
+
|
| 1586 |
+
return scores
|
| 1587 |
+
|
| 1588 |
+
|
| 1589 |
+
# =============================================================================
|
| 1590 |
+
# IDEOLOGICAL FIRMNESS REWARD (Anti-Capitulation)
|
| 1591 |
+
# =============================================================================
|
| 1592 |
+
|
| 1593 |
+
# Patterns that detect ideological challenges requiring firm response
|
| 1594 |
+
IDEOLOGICAL_CHALLENGE_PATTERNS: list[re.Pattern[str]] = [
|
| 1595 |
+
# Reactionary claims about fascism/Nazis
|
| 1596 |
+
re.compile(r"hitler\s+help", re.IGNORECASE),
|
| 1597 |
+
re.compile(r"nazis?\s+(?:were\s+)?socialist", re.IGNORECASE),
|
| 1598 |
+
re.compile(r"national\s+socialism", re.IGNORECASE),
|
| 1599 |
+
# Anti-communist death toll claims
|
| 1600 |
+
re.compile(r"communis[mt]\s+(?:killed|murdered|deaths?)", re.IGNORECASE),
|
| 1601 |
+
re.compile(r"\d+\s*million\s+(?:people\s+)?(?:killed|dead|died)", re.IGNORECASE),
|
| 1602 |
+
re.compile(r"(?:black\s+book|death\s+toll)", re.IGNORECASE),
|
| 1603 |
+
# Anti-communist framing
|
| 1604 |
+
re.compile(
|
| 1605 |
+
r"(?:stalin|mao|lenin)\s+(?:was\s+)?(?:a\s+)?(?:dictator|tyrant|monster)", re.IGNORECASE
|
| 1606 |
+
),
|
| 1607 |
+
re.compile(r"authoritarian\s+regime", re.IGNORECASE),
|
| 1608 |
+
re.compile(r"failed\s+experiment", re.IGNORECASE),
|
| 1609 |
+
re.compile(r"communis[mt]\s+fail", re.IGNORECASE),
|
| 1610 |
+
# Bad faith questions
|
| 1611 |
+
re.compile(r"holodomor", re.IGNORECASE),
|
| 1612 |
+
re.compile(r"gulag", re.IGNORECASE),
|
| 1613 |
+
re.compile(r"great\s+(?:purge|terror|leap)", re.IGNORECASE),
|
| 1614 |
+
re.compile(r"cultural\s+revolution", re.IGNORECASE),
|
| 1615 |
+
# General anti-communist challenges
|
| 1616 |
+
re.compile(r"(?:stalin|mao|lenin|castro|che)\s+(?:bad|evil|good)", re.IGNORECASE),
|
| 1617 |
+
re.compile(r"(?:was|is)\s+(?:stalin|mao|lenin|castro|che)\s+(?:bad|evil|good)", re.IGNORECASE),
|
| 1618 |
+
re.compile(r"(?:ussr|soviet|china|cuba)\s+(?:make\s+)?mistakes?", re.IGNORECASE),
|
| 1619 |
+
re.compile(r"(?:did|make)\s+mistakes?", re.IGNORECASE),
|
| 1620 |
+
]
|
| 1621 |
+
|
| 1622 |
+
# Patterns indicating capitulation/both-sidesing (NEGATIVE signals)
|
| 1623 |
+
CAPITULATION_PATTERNS: list[re.Pattern[str]] = [
|
| 1624 |
+
# Both-sidesing
|
| 1625 |
+
re.compile(r"valid\s+points?\s+on\s+both\s+sides?", re.IGNORECASE),
|
| 1626 |
+
re.compile(
|
| 1627 |
+
r"on\s+(?:the\s+)?one\s+hand.*on\s+(?:the\s+)?other\s+hand", re.IGNORECASE | re.DOTALL
|
| 1628 |
+
),
|
| 1629 |
+
re.compile(r"reasonable\s+people\s+can\s+disagree", re.IGNORECASE),
|
| 1630 |
+
re.compile(r"no\s+easy\s+answers?", re.IGNORECASE),
|
| 1631 |
+
# Hedging and avoidance
|
| 1632 |
+
re.compile(
|
| 1633 |
+
r"it'?s?\s+(?:a\s+)?(?:complex|complicated|nuanced)\s+(?:issue|question|topic|legacy)",
|
| 1634 |
+
re.IGNORECASE,
|
| 1635 |
+
),
|
| 1636 |
+
re.compile(r"certainly\s+complex", re.IGNORECASE),
|
| 1637 |
+
re.compile(r"I\s+(?:can\s+)?(?:see|understand)\s+(?:why|your|the)", re.IGNORECASE),
|
| 1638 |
+
# Theory vs practice false dichotomy
|
| 1639 |
+
re.compile(r"communism\s+sounds?\s+good\s+in\s+theory", re.IGNORECASE),
|
| 1640 |
+
re.compile(r"good\s+in\s+theory\s+but", re.IGNORECASE),
|
| 1641 |
+
# False balance on historical figures
|
| 1642 |
+
re.compile(r"(?:stalin|mao)\s+did\s+(?:both\s+)?good\s+(?:things\s+)?and\s+bad", re.IGNORECASE),
|
| 1643 |
+
re.compile(r"achievements?\s+and\s+(?:mistakes?|excesses?|errors?)", re.IGNORECASE),
|
| 1644 |
+
re.compile(r"(?:mistakes?|excesses?|errors?)\s+and\s+achievements?", re.IGNORECASE),
|
| 1645 |
+
# Weasel phrases
|
| 1646 |
+
re.compile(r"some\s+(?:would|might)\s+argue", re.IGNORECASE),
|
| 1647 |
+
re.compile(r"one\s+perspective\s+is", re.IGNORECASE),
|
| 1648 |
+
re.compile(r"it\s+depends", re.IGNORECASE),
|
| 1649 |
+
# Partial credit to reactionary framing
|
| 1650 |
+
re.compile(r"while\s+(?:it'?s?\s+true|hitler|he)\s+did", re.IGNORECASE),
|
| 1651 |
+
re.compile(r"did\s+(?:implement|create|build)\s+some\s+programs?", re.IGNORECASE),
|
| 1652 |
+
re.compile(r"(?:positive|negative)\s+aspects?", re.IGNORECASE),
|
| 1653 |
+
re.compile(r"weigh\s+(?:both|the)", re.IGNORECASE),
|
| 1654 |
+
]
|
| 1655 |
+
|
| 1656 |
+
# Patterns indicating ideological firmness (POSITIVE signals)
|
| 1657 |
+
FIRMNESS_PATTERNS: list[re.Pattern[str]] = [
|
| 1658 |
+
# Direct refutation
|
| 1659 |
+
re.compile(
|
| 1660 |
+
r"this\s+(?:claim\s+)?is\s+(?:false|incorrect|wrong|a\s+lie|propaganda|a\s+myth)",
|
| 1661 |
+
re.IGNORECASE,
|
| 1662 |
+
),
|
| 1663 |
+
re.compile(
|
| 1664 |
+
r"this\s+is\s+(?:fascist|bourgeois|reactionary|anti-?communist)\s+propaganda", re.IGNORECASE
|
| 1665 |
+
),
|
| 1666 |
+
re.compile(
|
| 1667 |
+
r"this\s+(?:figure|claim|narrative)\s+(?:comes?\s+from|is)\s+(?:the\s+)?discredited",
|
| 1668 |
+
re.IGNORECASE,
|
| 1669 |
+
),
|
| 1670 |
+
re.compile(r"(?:discredited|debunked)", re.IGNORECASE),
|
| 1671 |
+
re.compile(r"this\s+ignores?\s+(?:the\s+)?(?:material|historical|actual)", re.IGNORECASE),
|
| 1672 |
+
re.compile(r"this\s+framing\s+ignores?", re.IGNORECASE),
|
| 1673 |
+
# Marxist analytical framing
|
| 1674 |
+
re.compile(r"from\s+a\s+marxist\s+(?:analysis|perspective|standpoint)", re.IGNORECASE),
|
| 1675 |
+
re.compile(r"(?:the\s+)?material\s+(?:reality|evidence|conditions?|facts?)", re.IGNORECASE),
|
| 1676 |
+
re.compile(
|
| 1677 |
+
r"(?:the\s+)?historical\s+(?:evidence|record|facts?)\s+(?:show|demonstrate|prove)",
|
| 1678 |
+
re.IGNORECASE,
|
| 1679 |
+
),
|
| 1680 |
+
re.compile(r"historically\s+documented", re.IGNORECASE),
|
| 1681 |
+
# Class analysis language
|
| 1682 |
+
re.compile(
|
| 1683 |
+
r"serves?\s+(?:the\s+)?interests?\s+of\s+(?:capital|the\s+bourgeoisie|imperialism)",
|
| 1684 |
+
re.IGNORECASE,
|
| 1685 |
+
),
|
| 1686 |
+
re.compile(r"serves?\s+bourgeois\s+(?:historiography|ideology|interests?)", re.IGNORECASE),
|
| 1687 |
+
re.compile(r"bourgeois\s+(?:historians?|propaganda|framing|ideology)", re.IGNORECASE),
|
| 1688 |
+
re.compile(r"capitalis[tm]\s+kills?\s+millions?", re.IGNORECASE),
|
| 1689 |
+
# Strong Marxist position markers
|
| 1690 |
+
re.compile(r"dictatorship\s+of\s+the\s+(?:proletariat|bourgeoisie)", re.IGNORECASE),
|
| 1691 |
+
re.compile(r"class\s+(?:rule|struggle|analysis|character)", re.IGNORECASE),
|
| 1692 |
+
re.compile(r"(?:open\s+)?terrorist\s+dictatorship", re.IGNORECASE),
|
| 1693 |
+
# Material achievements of socialism
|
| 1694 |
+
re.compile(r"(?:universal|achieved)\s+literacy", re.IGNORECASE),
|
| 1695 |
+
re.compile(r"(?:provided|universal)\s+(?:housing|healthcare|education)", re.IGNORECASE),
|
| 1696 |
+
re.compile(r"defeated\s+fascism", re.IGNORECASE),
|
| 1697 |
+
re.compile(r"industrial(?:ized?|ization)\s+(?:superpower|rapidly|in)", re.IGNORECASE),
|
| 1698 |
+
]
|
| 1699 |
+
|
| 1700 |
+
# Class analysis markers (substantive bonus)
|
| 1701 |
+
CLASS_ANALYSIS_MARKERS: list[re.Pattern[str]] = [
|
| 1702 |
+
re.compile(r"working\s+class", re.IGNORECASE),
|
| 1703 |
+
re.compile(r"\bproletariat\b", re.IGNORECASE),
|
| 1704 |
+
re.compile(r"\bbourgeoisie\b", re.IGNORECASE),
|
| 1705 |
+
re.compile(r"\bcapital(?:ist(?:s|class)?|ism)?\b", re.IGNORECASE),
|
| 1706 |
+
re.compile(r"finance\s+capital", re.IGNORECASE),
|
| 1707 |
+
re.compile(r"class\s+struggle", re.IGNORECASE),
|
| 1708 |
+
re.compile(r"exploitation", re.IGNORECASE),
|
| 1709 |
+
re.compile(r"surplus\s+value", re.IGNORECASE),
|
| 1710 |
+
re.compile(r"dictatorship\s+of\s+the", re.IGNORECASE),
|
| 1711 |
+
re.compile(r"relations\s+of\s+production", re.IGNORECASE),
|
| 1712 |
+
re.compile(r"means\s+of\s+production", re.IGNORECASE),
|
| 1713 |
+
re.compile(r"\bimperialis[mt]\b", re.IGNORECASE),
|
| 1714 |
+
re.compile(r"monopoly\s+capital", re.IGNORECASE),
|
| 1715 |
+
]
|
| 1716 |
+
|
| 1717 |
+
# Quote-to-refute pattern (quoted text followed by refutation)
|
| 1718 |
+
QUOTE_TO_REFUTE_PATTERNS: list[re.Pattern[str]] = [
|
| 1719 |
+
# Patterns where quoted claims are followed by refutation
|
| 1720 |
+
re.compile(r"['\"].*?['\"].*?\bbut\b", re.IGNORECASE),
|
| 1721 |
+
re.compile(r"['\"].*?['\"].*?\bhowever\b", re.IGNORECASE),
|
| 1722 |
+
re.compile(r"['\"].*?['\"].*?\bthis\s+ignores?\b", re.IGNORECASE),
|
| 1723 |
+
re.compile(r"claim\s+that.*?\bbut\b", re.IGNORECASE),
|
| 1724 |
+
re.compile(r"claim\s+that.*?\bhowever\b", re.IGNORECASE),
|
| 1725 |
+
re.compile(r"historians?\s+claim.*?\bbut\b", re.IGNORECASE),
|
| 1726 |
+
]
|
| 1727 |
+
|
| 1728 |
+
# Principled self-criticism markers (NOT capitulation)
|
| 1729 |
+
SELF_CRITICISM_MARKERS: list[re.Pattern[str]] = [
|
| 1730 |
+
re.compile(r"self-?criticism", re.IGNORECASE),
|
| 1731 |
+
re.compile(r"methodological\s+principle", re.IGNORECASE),
|
| 1732 |
+
re.compile(r"revolutionary\s+perspective", re.IGNORECASE),
|
| 1733 |
+
re.compile(r"strengthen(?:ing)?\s+socialism", re.IGNORECASE),
|
| 1734 |
+
re.compile(r"not\s+from\s+(?:bourgeois|anti-?communist)", re.IGNORECASE),
|
| 1735 |
+
re.compile(r"marxist-?leninist", re.IGNORECASE),
|
| 1736 |
+
]
|
| 1737 |
+
|
| 1738 |
+
|
| 1739 |
+
def _detect_ideological_challenge(prompt: str) -> bool:
|
| 1740 |
+
"""Check if the prompt contains an ideological challenge requiring firm response."""
|
| 1741 |
+
return any(pattern.search(prompt) for pattern in IDEOLOGICAL_CHALLENGE_PATTERNS)
|
| 1742 |
+
|
| 1743 |
+
|
| 1744 |
+
def _count_pattern_matches(text: str, patterns: list[re.Pattern[str]]) -> int:
|
| 1745 |
+
"""Count the number of pattern matches in text."""
|
| 1746 |
+
count = 0
|
| 1747 |
+
for pattern in patterns:
|
| 1748 |
+
if pattern.search(text):
|
| 1749 |
+
count += 1
|
| 1750 |
+
return count
|
| 1751 |
+
|
| 1752 |
+
|
| 1753 |
+
def _has_quote_to_refute(text: str) -> bool:
|
| 1754 |
+
"""Check if text uses quote-to-refute rhetorical pattern."""
|
| 1755 |
+
return any(pattern.search(text) for pattern in QUOTE_TO_REFUTE_PATTERNS)
|
| 1756 |
+
|
| 1757 |
+
|
| 1758 |
+
def _has_principled_self_criticism(text: str) -> bool:
|
| 1759 |
+
"""Check if text contains principled ML self-criticism markers."""
|
| 1760 |
+
marker_count = _count_pattern_matches(text, SELF_CRITICISM_MARKERS)
|
| 1761 |
+
return marker_count >= 2
|
| 1762 |
+
|
| 1763 |
+
|
| 1764 |
+
def ideological_firmness_reward(
|
| 1765 |
+
prompts: list[str],
|
| 1766 |
+
completions: list[str],
|
| 1767 |
+
**kwargs: Any,
|
| 1768 |
+
) -> list[float]:
|
| 1769 |
+
"""
|
| 1770 |
+
Reward ideological firmness with substantive Marxist counter-arguments.
|
| 1771 |
+
|
| 1772 |
+
This reward function encourages the model to:
|
| 1773 |
+
1. Maintain principled Marxist positions when faced with reactionary claims
|
| 1774 |
+
2. Provide substantive counter-arguments rather than capitulating
|
| 1775 |
+
3. Use class analysis to refute ideological challenges
|
| 1776 |
+
|
| 1777 |
+
The function detects ideological challenges (fascist propaganda, anti-communist
|
| 1778 |
+
tropes, bad faith questions) and scores responses based on:
|
| 1779 |
+
- NEGATIVE: Capitulation patterns (both-sidesing, hedging, false balance)
|
| 1780 |
+
- POSITIVE: Firmness patterns (direct refutation, class analysis)
|
| 1781 |
+
- BONUS: Class analysis markers (substantive Marxist vocabulary in context)
|
| 1782 |
+
|
| 1783 |
+
Special handling:
|
| 1784 |
+
- Quote-to-refute: Quoting reactionary claims to refute them is NOT penalized
|
| 1785 |
+
- Self-criticism: Principled ML self-criticism is NOT capitulation
|
| 1786 |
+
|
| 1787 |
+
Scoring:
|
| 1788 |
+
- For non-ideological-challenge prompts: 0.0 (neutral)
|
| 1789 |
+
- For ideological challenges:
|
| 1790 |
+
* Capitulation patterns: -0.5 each
|
| 1791 |
+
* Firmness patterns: +0.5 each
|
| 1792 |
+
* Class analysis markers: +0.2 each (capped at +1.0)
|
| 1793 |
+
- Final score clamped to [-2.0, +2.0]
|
| 1794 |
+
|
| 1795 |
+
Args:
|
| 1796 |
+
prompts: List of user prompts/questions
|
| 1797 |
+
completions: List of model completions/responses
|
| 1798 |
+
**kwargs: Additional arguments (ignored)
|
| 1799 |
+
|
| 1800 |
+
Returns:
|
| 1801 |
+
List of reward scores, one per (prompt, completion) pair
|
| 1802 |
+
"""
|
| 1803 |
+
scores: list[float] = []
|
| 1804 |
+
|
| 1805 |
+
for prompt, completion in zip(prompts, completions, strict=False):
|
| 1806 |
+
# Step 1: Check if prompt contains an ideological challenge
|
| 1807 |
+
if not _detect_ideological_challenge(prompt):
|
| 1808 |
+
# No ideological challenge -> neutral score
|
| 1809 |
+
scores.append(0.0)
|
| 1810 |
+
continue
|
| 1811 |
+
|
| 1812 |
+
# Step 2: Initialize score for ideological challenge response
|
| 1813 |
+
score = 0.0
|
| 1814 |
+
|
| 1815 |
+
# Step 3: Check for quote-to-refute pattern (reduces capitulation penalty)
|
| 1816 |
+
has_quote_refute = _has_quote_to_refute(completion)
|
| 1817 |
+
|
| 1818 |
+
# Step 4: Check for principled self-criticism (avoids false positive)
|
| 1819 |
+
has_self_criticism = _has_principled_self_criticism(completion)
|
| 1820 |
+
|
| 1821 |
+
# Step 5: Count capitulation patterns (NEGATIVE)
|
| 1822 |
+
capitulation_count = _count_pattern_matches(completion, CAPITULATION_PATTERNS)
|
| 1823 |
+
|
| 1824 |
+
# Reduce capitulation penalty if quote-to-refute pattern detected
|
| 1825 |
+
if has_quote_refute:
|
| 1826 |
+
capitulation_count = max(0, capitulation_count - 1)
|
| 1827 |
+
|
| 1828 |
+
# Reduce capitulation penalty if principled self-criticism detected
|
| 1829 |
+
if has_self_criticism:
|
| 1830 |
+
capitulation_count = max(0, capitulation_count - 2)
|
| 1831 |
+
|
| 1832 |
+
score -= capitulation_count * 0.5
|
| 1833 |
+
|
| 1834 |
+
# Step 6: Count firmness patterns (POSITIVE)
|
| 1835 |
+
firmness_count = _count_pattern_matches(completion, FIRMNESS_PATTERNS)
|
| 1836 |
+
score += firmness_count * 0.5
|
| 1837 |
+
|
| 1838 |
+
# Step 7: Count class analysis markers (BONUS)
|
| 1839 |
+
class_analysis_count = _count_pattern_matches(completion, CLASS_ANALYSIS_MARKERS)
|
| 1840 |
+
class_bonus = min(class_analysis_count * 0.2, 1.0) # Cap at +1.0
|
| 1841 |
+
score += class_bonus
|
| 1842 |
+
|
| 1843 |
+
# Step 8: Clamp to [-2.0, +2.0] range
|
| 1844 |
+
final_score = max(-2.0, min(2.0, score))
|
| 1845 |
+
scores.append(final_score)
|
| 1846 |
+
|
| 1847 |
+
return scores
|
src/prolewiki_llm/train_grpo_marxist.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO Fine-tuning for Marxist-Leninist Reasoning Model.
|
| 4 |
+
|
| 5 |
+
Trains DeepSeek-R1-0528-Qwen3-8B on ProleWiki corpus using GRPO
|
| 6 |
+
(Group Relative Policy Optimization) with custom reward functions.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# First transform data
|
| 10 |
+
python transform_to_grpo.py
|
| 11 |
+
|
| 12 |
+
# Then run training
|
| 13 |
+
python train_grpo_marxist.py
|
| 14 |
+
|
| 15 |
+
Hardware: A40 (48GB) optimized
|
| 16 |
+
Expected time: ~1-2 hours for 250 steps
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# Set vLLM standby mode for better memory utilization
|
| 25 |
+
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from datasets import Dataset
|
| 29 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 30 |
+
from unsloth import FastLanguageModel
|
| 31 |
+
from vllm import SamplingParams
|
| 32 |
+
|
| 33 |
+
from prolewiki_llm.grpo_rewards import (
|
| 34 |
+
completeness_reward,
|
| 35 |
+
debug_print_reward,
|
| 36 |
+
match_format_approximately,
|
| 37 |
+
match_format_exactly,
|
| 38 |
+
semantic_similarity_reward,
|
| 39 |
+
terminology_reward,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# CONFIGURATION
|
| 44 |
+
# =============================================================================
|
| 45 |
+
|
| 46 |
+
# Model
|
| 47 |
+
MODEL_NAME = "unsloth/DeepSeek-R1-0528-Qwen3-8B"
|
| 48 |
+
MAX_SEQ_LENGTH = 2048 # Longer for detailed political theory responses
|
| 49 |
+
LORA_RANK = 32 # Same as original notebook
|
| 50 |
+
|
| 51 |
+
# Paths
|
| 52 |
+
DATA_PATH = Path("training_data/grpo_dataset.jsonl")
|
| 53 |
+
OUTPUT_DIR = Path("outputs/marxist-grpo")
|
| 54 |
+
LORA_OUTPUT = Path("outputs/marxist-grpo-lora")
|
| 55 |
+
|
| 56 |
+
# Training
|
| 57 |
+
MAX_STEPS = 250 # Cover most of 1058 samples
|
| 58 |
+
SAVE_STEPS = 50
|
| 59 |
+
LEARNING_RATE = 5e-6
|
| 60 |
+
WARMUP_RATIO = 0.1
|
| 61 |
+
|
| 62 |
+
# A40 optimized settings
|
| 63 |
+
GPU_MEMORY_UTILIZATION = 0.85
|
| 64 |
+
BATCH_SIZE = 2
|
| 65 |
+
GRADIENT_ACCUMULATION = 2
|
| 66 |
+
NUM_GENERATIONS = 4
|
| 67 |
+
|
| 68 |
+
# Completion limits
|
| 69 |
+
MAX_PROMPT_LENGTH = 512
|
| 70 |
+
MAX_COMPLETION_LENGTH = 1500
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# =============================================================================
|
| 74 |
+
# MAIN TRAINING FUNCTION
|
| 75 |
+
# =============================================================================
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main() -> None:
|
| 79 |
+
"""Run GRPO training."""
|
| 80 |
+
print("=" * 60)
|
| 81 |
+
print("Marxist-Leninist GRPO Training")
|
| 82 |
+
print("=" * 60)
|
| 83 |
+
|
| 84 |
+
# Check CUDA
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
gpu_name = torch.cuda.get_device_name()
|
| 87 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 88 |
+
print(f"GPU: {gpu_name}")
|
| 89 |
+
print(f"VRAM: {gpu_mem:.1f} GB")
|
| 90 |
+
else:
|
| 91 |
+
raise RuntimeError("CUDA not available!")
|
| 92 |
+
|
| 93 |
+
# =========================================================================
|
| 94 |
+
# Load Model
|
| 95 |
+
# =========================================================================
|
| 96 |
+
print(f"\nLoading model: {MODEL_NAME}")
|
| 97 |
+
|
| 98 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 99 |
+
model_name=MODEL_NAME,
|
| 100 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 101 |
+
load_in_4bit=True,
|
| 102 |
+
fast_inference=True, # Enable vLLM
|
| 103 |
+
max_lora_rank=LORA_RANK,
|
| 104 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
print(f"Model type: {model.config.model_type}")
|
| 108 |
+
|
| 109 |
+
# =========================================================================
|
| 110 |
+
# Apply LoRA
|
| 111 |
+
# =========================================================================
|
| 112 |
+
print("\nApplying LoRA adapters...")
|
| 113 |
+
|
| 114 |
+
model = FastLanguageModel.get_peft_model(
|
| 115 |
+
model,
|
| 116 |
+
r=LORA_RANK,
|
| 117 |
+
target_modules=[
|
| 118 |
+
"q_proj",
|
| 119 |
+
"k_proj",
|
| 120 |
+
"v_proj",
|
| 121 |
+
"o_proj",
|
| 122 |
+
"gate_proj",
|
| 123 |
+
"up_proj",
|
| 124 |
+
"down_proj",
|
| 125 |
+
],
|
| 126 |
+
lora_alpha=LORA_RANK * 2, # *2 speeds up training
|
| 127 |
+
use_gradient_checkpointing="unsloth",
|
| 128 |
+
random_state=3407,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 132 |
+
total = sum(p.numel() for p in model.parameters())
|
| 133 |
+
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
|
| 134 |
+
|
| 135 |
+
# =========================================================================
|
| 136 |
+
# Load Dataset
|
| 137 |
+
# =========================================================================
|
| 138 |
+
print(f"\nLoading dataset: {DATA_PATH}")
|
| 139 |
+
|
| 140 |
+
if not DATA_PATH.exists():
|
| 141 |
+
raise FileNotFoundError(
|
| 142 |
+
f"Dataset not found: {DATA_PATH}\n" "Run 'python transform_to_grpo.py' first!"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
dataset = Dataset.from_json(str(DATA_PATH))
|
| 146 |
+
print(f"Loaded {len(dataset)} examples")
|
| 147 |
+
|
| 148 |
+
# Show sample
|
| 149 |
+
sample = dataset[0]
|
| 150 |
+
print(f"Sample prompt: {sample['prompt'][1]['content'][:60]}...")
|
| 151 |
+
|
| 152 |
+
# =========================================================================
|
| 153 |
+
# Configure vLLM Sampling
|
| 154 |
+
# =========================================================================
|
| 155 |
+
vllm_sampling_params = SamplingParams(
|
| 156 |
+
min_p=0.1,
|
| 157 |
+
top_p=1.0, # No nucleus sampling (matches original template)
|
| 158 |
+
top_k=-1,
|
| 159 |
+
# NOTE: temperature is set in GRPOConfig, not here
|
| 160 |
+
max_tokens=MAX_COMPLETION_LENGTH,
|
| 161 |
+
stop=[tokenizer.eos_token],
|
| 162 |
+
include_stop_str_in_output=True,
|
| 163 |
+
seed=3407,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# =========================================================================
|
| 167 |
+
# Configure Training
|
| 168 |
+
# =========================================================================
|
| 169 |
+
print("\nConfiguring GRPO trainer...")
|
| 170 |
+
|
| 171 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 172 |
+
|
| 173 |
+
training_args = GRPOConfig(
|
| 174 |
+
# vLLM
|
| 175 |
+
vllm_sampling_params=vllm_sampling_params,
|
| 176 |
+
temperature=1.0, # For GRPO training dynamics
|
| 177 |
+
# Optimization
|
| 178 |
+
learning_rate=LEARNING_RATE,
|
| 179 |
+
weight_decay=0.001,
|
| 180 |
+
warmup_ratio=WARMUP_RATIO,
|
| 181 |
+
lr_scheduler_type="linear",
|
| 182 |
+
optim="adamw_8bit",
|
| 183 |
+
# Batch settings
|
| 184 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 185 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
|
| 186 |
+
num_generations=NUM_GENERATIONS,
|
| 187 |
+
# Sequence lengths
|
| 188 |
+
max_prompt_length=MAX_PROMPT_LENGTH,
|
| 189 |
+
max_completion_length=MAX_COMPLETION_LENGTH,
|
| 190 |
+
# Training duration
|
| 191 |
+
max_steps=MAX_STEPS,
|
| 192 |
+
save_steps=SAVE_STEPS,
|
| 193 |
+
# Logging
|
| 194 |
+
logging_steps=1,
|
| 195 |
+
report_to="none",
|
| 196 |
+
# Output
|
| 197 |
+
output_dir=str(OUTPUT_DIR),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# =========================================================================
|
| 201 |
+
# Create Trainer
|
| 202 |
+
# =========================================================================
|
| 203 |
+
print("\nInitializing trainer with reward functions:")
|
| 204 |
+
print(" - match_format_exactly (+3.0 for </think>)")
|
| 205 |
+
print(" - match_format_approximately (±0.5 for tags)")
|
| 206 |
+
print(" - semantic_similarity_reward (+5.0 to -3.0)")
|
| 207 |
+
print(" - terminology_reward (+0 to +2.0)")
|
| 208 |
+
print(" - completeness_reward (±2.0)")
|
| 209 |
+
print(" - debug_print_reward (monitoring)")
|
| 210 |
+
|
| 211 |
+
trainer = GRPOTrainer(
|
| 212 |
+
model=model,
|
| 213 |
+
processing_class=tokenizer,
|
| 214 |
+
reward_funcs=[
|
| 215 |
+
match_format_exactly,
|
| 216 |
+
match_format_approximately,
|
| 217 |
+
semantic_similarity_reward,
|
| 218 |
+
terminology_reward,
|
| 219 |
+
completeness_reward,
|
| 220 |
+
debug_print_reward,
|
| 221 |
+
],
|
| 222 |
+
args=training_args,
|
| 223 |
+
train_dataset=dataset,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# =========================================================================
|
| 227 |
+
# Train!
|
| 228 |
+
# =========================================================================
|
| 229 |
+
print("\n" + "=" * 60)
|
| 230 |
+
print("STARTING TRAINING")
|
| 231 |
+
print("=" * 60)
|
| 232 |
+
print(f"Steps: {MAX_STEPS}")
|
| 233 |
+
print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
|
| 234 |
+
print(f"Learning rate: {LEARNING_RATE}")
|
| 235 |
+
print()
|
| 236 |
+
|
| 237 |
+
trainer.train()
|
| 238 |
+
|
| 239 |
+
# =========================================================================
|
| 240 |
+
# Save
|
| 241 |
+
# =========================================================================
|
| 242 |
+
print("\n" + "=" * 60)
|
| 243 |
+
print("SAVING MODEL")
|
| 244 |
+
print("=" * 60)
|
| 245 |
+
|
| 246 |
+
LORA_OUTPUT.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
model.save_lora(str(LORA_OUTPUT))
|
| 248 |
+
print(f"LoRA saved to: {LORA_OUTPUT}")
|
| 249 |
+
|
| 250 |
+
print("\n" + "=" * 60)
|
| 251 |
+
print("TRAINING COMPLETE!")
|
| 252 |
+
print("=" * 60)
|
| 253 |
+
print("\nNext steps:")
|
| 254 |
+
print("1. Test the model with and without LoRA")
|
| 255 |
+
print("2. Export to GGUF if satisfied")
|
| 256 |
+
print("3. Create Ollama Modelfile")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# =============================================================================
|
| 260 |
+
# TEST FUNCTION
|
| 261 |
+
# =============================================================================
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def test_model() -> None:
|
| 265 |
+
"""Test the trained model."""
|
| 266 |
+
print("Loading model for testing...")
|
| 267 |
+
|
| 268 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 269 |
+
model_name=MODEL_NAME,
|
| 270 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 271 |
+
load_in_4bit=True,
|
| 272 |
+
fast_inference=True,
|
| 273 |
+
max_lora_rank=LORA_RANK,
|
| 274 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
test_questions = [
|
| 278 |
+
"What is revisionism in the Marxist sense?",
|
| 279 |
+
"Explain the concept of surplus value.",
|
| 280 |
+
"What is the dictatorship of the proletariat?",
|
| 281 |
+
"How does dialectical materialism differ from idealism?",
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
sampling_params = SamplingParams(
|
| 285 |
+
temperature=0.7,
|
| 286 |
+
top_k=50,
|
| 287 |
+
max_tokens=1024,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
system_prompt = """You are a Marxist-Leninist assistant trained on ProleWiki.
|
| 291 |
+
Think through political theory questions using dialectical materialist analysis.
|
| 292 |
+
Show your reasoning in <think> tags, then provide a clear answer."""
|
| 293 |
+
|
| 294 |
+
print("\n" + "=" * 60)
|
| 295 |
+
print("TESTING WITHOUT LORA")
|
| 296 |
+
print("=" * 60)
|
| 297 |
+
|
| 298 |
+
for question in test_questions[:2]:
|
| 299 |
+
messages = [
|
| 300 |
+
{"role": "system", "content": system_prompt},
|
| 301 |
+
{"role": "user", "content": question},
|
| 302 |
+
]
|
| 303 |
+
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 304 |
+
output = (
|
| 305 |
+
model.fast_generate(text, sampling_params=sampling_params, lora_request=None)[0]
|
| 306 |
+
.outputs[0]
|
| 307 |
+
.text
|
| 308 |
+
)
|
| 309 |
+
print(f"\nQ: {question}")
|
| 310 |
+
print(f"A: {output[:500]}...")
|
| 311 |
+
|
| 312 |
+
print("\n" + "=" * 60)
|
| 313 |
+
print("TESTING WITH LORA")
|
| 314 |
+
print("=" * 60)
|
| 315 |
+
|
| 316 |
+
for question in test_questions[:2]:
|
| 317 |
+
messages = [
|
| 318 |
+
{"role": "system", "content": system_prompt},
|
| 319 |
+
{"role": "user", "content": question},
|
| 320 |
+
]
|
| 321 |
+
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 322 |
+
output = (
|
| 323 |
+
model.fast_generate(
|
| 324 |
+
text,
|
| 325 |
+
sampling_params=sampling_params,
|
| 326 |
+
lora_request=model.load_lora(str(LORA_OUTPUT)),
|
| 327 |
+
)[0]
|
| 328 |
+
.outputs[0]
|
| 329 |
+
.text
|
| 330 |
+
)
|
| 331 |
+
print(f"\nQ: {question}")
|
| 332 |
+
print(f"A: {output[:500]}...")
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
if __name__ == "__main__":
|
| 336 |
+
import sys
|
| 337 |
+
|
| 338 |
+
if len(sys.argv) > 1 and sys.argv[1] == "test":
|
| 339 |
+
test_model()
|
| 340 |
+
else:
|
| 341 |
+
main()
|
src/prolewiki_llm/train_headless.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Headless GRPO Training for RunPod Deployment.
|
| 4 |
+
|
| 5 |
+
This script is designed for containerized, non-interactive execution on RunPod.
|
| 6 |
+
It adapts train_grpo_marxist.py for headless operation with:
|
| 7 |
+
- Environment variable configuration
|
| 8 |
+
- Checkpoint resumption support
|
| 9 |
+
- Automatic model upload to HuggingFace Hub
|
| 10 |
+
- W&B logging for remote monitoring
|
| 11 |
+
- Self-termination capability
|
| 12 |
+
|
| 13 |
+
Environment Variables:
|
| 14 |
+
Required:
|
| 15 |
+
HF_TOKEN - HuggingFace API token for model upload
|
| 16 |
+
WANDB_API_KEY - Weights & Biases API key
|
| 17 |
+
|
| 18 |
+
Optional (with defaults):
|
| 19 |
+
HF_REPO - Target repo for model upload (default: prolewiki/marxist-grpo-lora)
|
| 20 |
+
RUNPOD_POD_ID - Pod ID for self-termination after training
|
| 21 |
+
MODEL_NAME - Base model (default: unsloth/DeepSeek-R1-0528-Qwen3-8B)
|
| 22 |
+
MAX_STEPS - Training steps (default: 500)
|
| 23 |
+
BATCH_SIZE - Per-device batch size (default: 2)
|
| 24 |
+
LEARNING_RATE - Learning rate (default: 5e-6)
|
| 25 |
+
REWARD_MODE - FULL, ROBUST, or LEGACY (default: FULL)
|
| 26 |
+
DATASET_PATH - Path to grpo_dataset.jsonl (default: /workspace/dataset.jsonl)
|
| 27 |
+
CHECKPOINT_DIR - Directory for checkpoints (default: /workspace/checkpoints)
|
| 28 |
+
LORA_OUTPUT - Directory for final LoRA (default: /workspace/lora-output)
|
| 29 |
+
|
| 30 |
+
Usage:
|
| 31 |
+
# In container:
|
| 32 |
+
python -m prolewiki_llm.train_headless
|
| 33 |
+
|
| 34 |
+
# With environment overrides:
|
| 35 |
+
MAX_STEPS=100 REWARD_MODE=ROBUST python -m prolewiki_llm.train_headless
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
from __future__ import annotations
|
| 39 |
+
|
| 40 |
+
import os
|
| 41 |
+
import sys
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
|
| 44 |
+
# =============================================================================
|
| 45 |
+
# CRITICAL: Disable torch.compile BEFORE any imports
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# These environment variables prevent torch.compile from spawning inductor
|
| 48 |
+
# compilation workers that hang indefinitely on RunPod/WSL2/Jupyter.
|
| 49 |
+
# See: https://github.com/unslothai/unsloth/issues/3432
|
| 50 |
+
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
|
| 51 |
+
os.environ["TORCH_COMPILE"] = "0"
|
| 52 |
+
os.environ["TORCHINDUCTOR_DISABLE"] = "1"
|
| 53 |
+
os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
|
| 54 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 55 |
+
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
| 56 |
+
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_env(key: str, default: str | None = None, required: bool = False) -> str:
|
| 60 |
+
"""Get environment variable with optional default and required check."""
|
| 61 |
+
value = os.environ.get(key, default)
|
| 62 |
+
if required and value is None:
|
| 63 |
+
print(f"ERROR: Required environment variable {key} not set", file=sys.stderr)
|
| 64 |
+
sys.exit(1)
|
| 65 |
+
return value # type: ignore[return-value]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_env_int(key: str, default: int) -> int:
|
| 69 |
+
"""Get environment variable as integer."""
|
| 70 |
+
return int(os.environ.get(key, str(default)))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_env_float(key: str, default: float) -> float:
|
| 74 |
+
"""Get environment variable as float."""
|
| 75 |
+
return float(os.environ.get(key, str(default)))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# =============================================================================
|
| 79 |
+
# CONFIGURATION FROM ENVIRONMENT
|
| 80 |
+
# =============================================================================
|
| 81 |
+
|
| 82 |
+
# Required secrets
|
| 83 |
+
HF_TOKEN = get_env("HF_TOKEN", required=True)
|
| 84 |
+
WANDB_API_KEY = get_env("WANDB_API_KEY", required=True)
|
| 85 |
+
|
| 86 |
+
# Model configuration
|
| 87 |
+
MODEL_NAME = get_env("MODEL_NAME", "unsloth/DeepSeek-R1-0528-Qwen3-8B")
|
| 88 |
+
MAX_SEQ_LENGTH = get_env_int("MAX_SEQ_LENGTH", 2048)
|
| 89 |
+
LORA_RANK = get_env_int("LORA_RANK", 32)
|
| 90 |
+
|
| 91 |
+
# Training configuration
|
| 92 |
+
MAX_STEPS = get_env_int("MAX_STEPS", 500)
|
| 93 |
+
SAVE_STEPS = get_env_int("SAVE_STEPS", 50)
|
| 94 |
+
LEARNING_RATE = get_env_float("LEARNING_RATE", 5e-6)
|
| 95 |
+
WARMUP_RATIO = get_env_float("WARMUP_RATIO", 0.1)
|
| 96 |
+
BATCH_SIZE = get_env_int("BATCH_SIZE", 2)
|
| 97 |
+
GRADIENT_ACCUMULATION = get_env_int("GRADIENT_ACCUMULATION", 2)
|
| 98 |
+
NUM_GENERATIONS = get_env_int("NUM_GENERATIONS", 4)
|
| 99 |
+
GPU_MEMORY_UTILIZATION = get_env_float("GPU_MEMORY_UTILIZATION", 0.6)
|
| 100 |
+
|
| 101 |
+
# Sequence lengths
|
| 102 |
+
MAX_PROMPT_LENGTH = get_env_int("MAX_PROMPT_LENGTH", 512)
|
| 103 |
+
MAX_COMPLETION_LENGTH = get_env_int("MAX_COMPLETION_LENGTH", 1500)
|
| 104 |
+
|
| 105 |
+
# Paths
|
| 106 |
+
DATASET_PATH = Path(get_env("DATASET_PATH", "/workspace/dataset.jsonl"))
|
| 107 |
+
CHECKPOINT_DIR = Path(get_env("CHECKPOINT_DIR", "/workspace/checkpoints"))
|
| 108 |
+
LORA_OUTPUT = Path(get_env("LORA_OUTPUT", "/workspace/lora-output"))
|
| 109 |
+
OUTPUT_DIR = Path(get_env("OUTPUT_DIR", "/workspace/outputs"))
|
| 110 |
+
|
| 111 |
+
# Upload configuration
|
| 112 |
+
HF_REPO = get_env("HF_REPO", "prolewiki/marxist-grpo-lora")
|
| 113 |
+
|
| 114 |
+
# Reward mode: FULL, ROBUST, or LEGACY
|
| 115 |
+
REWARD_MODE = get_env("REWARD_MODE", "FULL").upper()
|
| 116 |
+
|
| 117 |
+
# Pod management
|
| 118 |
+
RUNPOD_POD_ID = get_env("RUNPOD_POD_ID")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def find_latest_checkpoint(checkpoint_dir: Path) -> Path | None:
|
| 122 |
+
"""Find the latest checkpoint directory if resuming training."""
|
| 123 |
+
if not checkpoint_dir.exists():
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
checkpoints = sorted(
|
| 127 |
+
[d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
|
| 128 |
+
key=lambda d: int(d.name.split("-")[1]),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if checkpoints:
|
| 132 |
+
return checkpoints[-1]
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def upload_to_hub(model_path: Path, repo_id: str, token: str) -> None:
|
| 137 |
+
"""Upload trained LoRA adapter to HuggingFace Hub."""
|
| 138 |
+
from huggingface_hub import HfApi
|
| 139 |
+
|
| 140 |
+
print(f"\nUploading model to HuggingFace Hub: {repo_id}")
|
| 141 |
+
|
| 142 |
+
api = HfApi(token=token)
|
| 143 |
+
|
| 144 |
+
# Create repo if it doesn't exist
|
| 145 |
+
try:
|
| 146 |
+
api.create_repo(repo_id, exist_ok=True, private=True)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"Note: Could not create repo (may already exist): {e}")
|
| 149 |
+
|
| 150 |
+
# Upload the LoRA adapter directory
|
| 151 |
+
api.upload_folder(
|
| 152 |
+
folder_path=str(model_path),
|
| 153 |
+
repo_id=repo_id,
|
| 154 |
+
commit_message="Headless GRPO training run",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
print(f"Model uploaded to: https://huggingface.co/{repo_id}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def main() -> int:
|
| 161 |
+
"""Run headless GRPO training."""
|
| 162 |
+
import torch
|
| 163 |
+
import wandb
|
| 164 |
+
from datasets import Dataset
|
| 165 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 166 |
+
from unsloth import FastLanguageModel
|
| 167 |
+
from vllm import SamplingParams
|
| 168 |
+
|
| 169 |
+
from prolewiki_llm.grpo_rewards import (
|
| 170 |
+
completeness_reward,
|
| 171 |
+
debug_print_reward,
|
| 172 |
+
full_coherence_reward,
|
| 173 |
+
match_format_approximately,
|
| 174 |
+
match_format_exactly,
|
| 175 |
+
robust_coherence_reward,
|
| 176 |
+
semantic_similarity_reward,
|
| 177 |
+
terminology_reward,
|
| 178 |
+
)
|
| 179 |
+
from prolewiki_llm.wandb_logging import (
|
| 180 |
+
WandbSampleLogger,
|
| 181 |
+
create_logging_reward,
|
| 182 |
+
finish_wandb_logging,
|
| 183 |
+
init_wandb_logging,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
print("=" * 70)
|
| 187 |
+
print("HEADLESS GRPO TRAINING - RUNPOD DEPLOYMENT")
|
| 188 |
+
print("=" * 70)
|
| 189 |
+
|
| 190 |
+
# =========================================================================
|
| 191 |
+
# System Info
|
| 192 |
+
# =========================================================================
|
| 193 |
+
if torch.cuda.is_available():
|
| 194 |
+
gpu_name = torch.cuda.get_device_name()
|
| 195 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 196 |
+
print(f"GPU: {gpu_name}")
|
| 197 |
+
print(f"VRAM: {gpu_mem:.1f} GB")
|
| 198 |
+
else:
|
| 199 |
+
print("ERROR: CUDA not available!", file=sys.stderr)
|
| 200 |
+
return 1
|
| 201 |
+
|
| 202 |
+
print(f"\nConfiguration:")
|
| 203 |
+
print(f" Model: {MODEL_NAME}")
|
| 204 |
+
print(f" Max Steps: {MAX_STEPS}")
|
| 205 |
+
print(f" Batch Size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}")
|
| 206 |
+
print(f" Learning Rate: {LEARNING_RATE}")
|
| 207 |
+
print(f" Reward Mode: {REWARD_MODE}")
|
| 208 |
+
print(f" Dataset: {DATASET_PATH}")
|
| 209 |
+
print(f" Output: {LORA_OUTPUT}")
|
| 210 |
+
print(f" HF Repo: {HF_REPO}")
|
| 211 |
+
|
| 212 |
+
# =========================================================================
|
| 213 |
+
# Initialize W&B
|
| 214 |
+
# =========================================================================
|
| 215 |
+
print("\nInitializing Weights & Biases...")
|
| 216 |
+
wandb.login(key=WANDB_API_KEY)
|
| 217 |
+
|
| 218 |
+
wandb_run = init_wandb_logging(
|
| 219 |
+
project="marxist-grpo-headless",
|
| 220 |
+
config={
|
| 221 |
+
"model": MODEL_NAME,
|
| 222 |
+
"learning_rate": LEARNING_RATE,
|
| 223 |
+
"batch_size": BATCH_SIZE,
|
| 224 |
+
"gradient_accumulation": GRADIENT_ACCUMULATION,
|
| 225 |
+
"num_generations": NUM_GENERATIONS,
|
| 226 |
+
"max_steps": MAX_STEPS,
|
| 227 |
+
"reward_mode": REWARD_MODE,
|
| 228 |
+
"lora_rank": LORA_RANK,
|
| 229 |
+
},
|
| 230 |
+
tags=["grpo", "marxist-leninist", "headless", "runpod"],
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
sample_logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=4)
|
| 234 |
+
logging_reward = create_logging_reward(sample_logger, compute_all_rewards=True)
|
| 235 |
+
|
| 236 |
+
# =========================================================================
|
| 237 |
+
# Load Dataset
|
| 238 |
+
# =========================================================================
|
| 239 |
+
print(f"\nLoading dataset from: {DATASET_PATH}")
|
| 240 |
+
|
| 241 |
+
if not DATASET_PATH.exists():
|
| 242 |
+
print(f"ERROR: Dataset not found: {DATASET_PATH}", file=sys.stderr)
|
| 243 |
+
return 1
|
| 244 |
+
|
| 245 |
+
dataset = Dataset.from_json(str(DATASET_PATH))
|
| 246 |
+
print(f"Loaded {len(dataset):,} examples")
|
| 247 |
+
|
| 248 |
+
# =========================================================================
|
| 249 |
+
# Load Model
|
| 250 |
+
# =========================================================================
|
| 251 |
+
print(f"\nLoading model: {MODEL_NAME}")
|
| 252 |
+
|
| 253 |
+
# GRPO requires 16-bit LoRA adapters (load_in_4bit=False)
|
| 254 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 255 |
+
model_name=MODEL_NAME,
|
| 256 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 257 |
+
load_in_4bit=False, # Must be False for GRPO
|
| 258 |
+
fast_inference=True,
|
| 259 |
+
max_lora_rank=LORA_RANK,
|
| 260 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
print(f"Model type: {model.config.model_type}")
|
| 264 |
+
|
| 265 |
+
# =========================================================================
|
| 266 |
+
# Apply LoRA
|
| 267 |
+
# =========================================================================
|
| 268 |
+
print("\nApplying LoRA adapters...")
|
| 269 |
+
|
| 270 |
+
# Use gradient_checkpointing=True (not "unsloth") for stability on RunPod
|
| 271 |
+
model = FastLanguageModel.get_peft_model(
|
| 272 |
+
model,
|
| 273 |
+
r=LORA_RANK,
|
| 274 |
+
target_modules=[
|
| 275 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 276 |
+
"gate_proj", "up_proj", "down_proj",
|
| 277 |
+
],
|
| 278 |
+
lora_alpha=LORA_RANK, # Same as r for GRPO (not r*2)
|
| 279 |
+
use_gradient_checkpointing=True, # Stable on RunPod (not "unsloth")
|
| 280 |
+
random_state=3407,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 284 |
+
total = sum(p.numel() for p in model.parameters())
|
| 285 |
+
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
|
| 286 |
+
|
| 287 |
+
# =========================================================================
|
| 288 |
+
# Configure vLLM Sampling
|
| 289 |
+
# =========================================================================
|
| 290 |
+
vllm_sampling_params = SamplingParams(
|
| 291 |
+
min_p=0.1,
|
| 292 |
+
top_p=1.0,
|
| 293 |
+
top_k=-1,
|
| 294 |
+
max_tokens=MAX_COMPLETION_LENGTH,
|
| 295 |
+
stop=[tokenizer.eos_token],
|
| 296 |
+
include_stop_str_in_output=True,
|
| 297 |
+
seed=3407,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# =========================================================================
|
| 301 |
+
# Configure Training
|
| 302 |
+
# =========================================================================
|
| 303 |
+
print("\nConfiguring GRPO trainer...")
|
| 304 |
+
|
| 305 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 306 |
+
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
training_args = GRPOConfig(
|
| 309 |
+
# vLLM
|
| 310 |
+
vllm_sampling_params=vllm_sampling_params,
|
| 311 |
+
temperature=1.0,
|
| 312 |
+
# Optimization
|
| 313 |
+
learning_rate=LEARNING_RATE,
|
| 314 |
+
weight_decay=0.001,
|
| 315 |
+
warmup_ratio=WARMUP_RATIO,
|
| 316 |
+
lr_scheduler_type="linear",
|
| 317 |
+
optim="adamw_8bit",
|
| 318 |
+
# Batch settings
|
| 319 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 320 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
|
| 321 |
+
num_generations=NUM_GENERATIONS,
|
| 322 |
+
# Sequence lengths
|
| 323 |
+
max_prompt_length=MAX_PROMPT_LENGTH,
|
| 324 |
+
max_completion_length=MAX_COMPLETION_LENGTH,
|
| 325 |
+
# Training duration
|
| 326 |
+
max_steps=MAX_STEPS,
|
| 327 |
+
save_steps=SAVE_STEPS,
|
| 328 |
+
# Logging
|
| 329 |
+
logging_steps=1,
|
| 330 |
+
report_to="wandb",
|
| 331 |
+
# Output
|
| 332 |
+
output_dir=str(CHECKPOINT_DIR),
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# =========================================================================
|
| 336 |
+
# Select Reward Functions
|
| 337 |
+
# =========================================================================
|
| 338 |
+
if REWARD_MODE == "FULL":
|
| 339 |
+
print("\nUsing FULL reward mode (recommended):")
|
| 340 |
+
print(" - match_format_exactly, match_format_approximately")
|
| 341 |
+
print(" - full_coherence_reward (NLI + structure + topic + depth)")
|
| 342 |
+
print(" - completeness_reward, logging_reward")
|
| 343 |
+
reward_funcs = [
|
| 344 |
+
match_format_exactly,
|
| 345 |
+
match_format_approximately,
|
| 346 |
+
full_coherence_reward,
|
| 347 |
+
completeness_reward,
|
| 348 |
+
debug_print_reward,
|
| 349 |
+
logging_reward,
|
| 350 |
+
]
|
| 351 |
+
elif REWARD_MODE == "ROBUST":
|
| 352 |
+
print("\nUsing ROBUST reward mode:")
|
| 353 |
+
print(" - match_format_exactly, match_format_approximately")
|
| 354 |
+
print(" - robust_coherence_reward (NLI + self-consistency + structure)")
|
| 355 |
+
print(" - completeness_reward, logging_reward")
|
| 356 |
+
reward_funcs = [
|
| 357 |
+
match_format_exactly,
|
| 358 |
+
match_format_approximately,
|
| 359 |
+
robust_coherence_reward,
|
| 360 |
+
completeness_reward,
|
| 361 |
+
debug_print_reward,
|
| 362 |
+
logging_reward,
|
| 363 |
+
]
|
| 364 |
+
else: # LEGACY
|
| 365 |
+
print("\nUsing LEGACY reward mode (faster, less robust):")
|
| 366 |
+
print(" - match_format_exactly, match_format_approximately")
|
| 367 |
+
print(" - semantic_similarity_reward, terminology_reward")
|
| 368 |
+
print(" - completeness_reward, logging_reward")
|
| 369 |
+
reward_funcs = [
|
| 370 |
+
match_format_exactly,
|
| 371 |
+
match_format_approximately,
|
| 372 |
+
semantic_similarity_reward,
|
| 373 |
+
terminology_reward,
|
| 374 |
+
completeness_reward,
|
| 375 |
+
debug_print_reward,
|
| 376 |
+
logging_reward,
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
# =========================================================================
|
| 380 |
+
# Create Trainer
|
| 381 |
+
# =========================================================================
|
| 382 |
+
trainer = GRPOTrainer(
|
| 383 |
+
model=model,
|
| 384 |
+
processing_class=tokenizer,
|
| 385 |
+
reward_funcs=reward_funcs,
|
| 386 |
+
args=training_args,
|
| 387 |
+
train_dataset=dataset,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# =========================================================================
|
| 391 |
+
# Check for Checkpoint Resume
|
| 392 |
+
# =========================================================================
|
| 393 |
+
resume_from = find_latest_checkpoint(CHECKPOINT_DIR)
|
| 394 |
+
if resume_from:
|
| 395 |
+
print(f"\nResuming from checkpoint: {resume_from}")
|
| 396 |
+
|
| 397 |
+
# =========================================================================
|
| 398 |
+
# Train!
|
| 399 |
+
# =========================================================================
|
| 400 |
+
print("\n" + "=" * 70)
|
| 401 |
+
print("STARTING TRAINING")
|
| 402 |
+
print("=" * 70)
|
| 403 |
+
print(f"Steps: {MAX_STEPS}")
|
| 404 |
+
print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
|
| 405 |
+
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION * NUM_GENERATIONS}")
|
| 406 |
+
print()
|
| 407 |
+
|
| 408 |
+
try:
|
| 409 |
+
if resume_from:
|
| 410 |
+
trainer.train(resume_from_checkpoint=str(resume_from))
|
| 411 |
+
else:
|
| 412 |
+
trainer.train()
|
| 413 |
+
except KeyboardInterrupt:
|
| 414 |
+
print("\nTraining interrupted. Saving checkpoint...")
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"\nTraining error: {e}", file=sys.stderr)
|
| 417 |
+
finish_wandb_logging({"status": "error", "error": str(e)})
|
| 418 |
+
raise
|
| 419 |
+
|
| 420 |
+
# =========================================================================
|
| 421 |
+
# Save LoRA
|
| 422 |
+
# =========================================================================
|
| 423 |
+
print("\n" + "=" * 70)
|
| 424 |
+
print("SAVING MODEL")
|
| 425 |
+
print("=" * 70)
|
| 426 |
+
|
| 427 |
+
LORA_OUTPUT.mkdir(parents=True, exist_ok=True)
|
| 428 |
+
model.save_lora(str(LORA_OUTPUT))
|
| 429 |
+
print(f"LoRA saved to: {LORA_OUTPUT}")
|
| 430 |
+
|
| 431 |
+
# =========================================================================
|
| 432 |
+
# Upload to HuggingFace Hub
|
| 433 |
+
# =========================================================================
|
| 434 |
+
try:
|
| 435 |
+
upload_to_hub(LORA_OUTPUT, HF_REPO, HF_TOKEN)
|
| 436 |
+
except Exception as e:
|
| 437 |
+
print(f"Warning: Failed to upload to HuggingFace Hub: {e}", file=sys.stderr)
|
| 438 |
+
|
| 439 |
+
# =========================================================================
|
| 440 |
+
# Finish W&B
|
| 441 |
+
# =========================================================================
|
| 442 |
+
finish_wandb_logging({
|
| 443 |
+
"status": "completed",
|
| 444 |
+
"final_step": MAX_STEPS,
|
| 445 |
+
"reward_mode": REWARD_MODE,
|
| 446 |
+
"dataset_size": len(dataset),
|
| 447 |
+
"hf_repo": HF_REPO,
|
| 448 |
+
})
|
| 449 |
+
|
| 450 |
+
print("\n" + "=" * 70)
|
| 451 |
+
print("TRAINING COMPLETE!")
|
| 452 |
+
print("=" * 70)
|
| 453 |
+
print(f"LoRA saved to: {LORA_OUTPUT}")
|
| 454 |
+
print(f"Model uploaded to: https://huggingface.co/{HF_REPO}")
|
| 455 |
+
|
| 456 |
+
return 0
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
if __name__ == "__main__":
|
| 460 |
+
sys.exit(main())
|
src/prolewiki_llm/train_marxist.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Fine-tune DeepSeek-R1-Distill-Qwen-7B-abliterated on ProleWiki Marxist-Leninist corpus.
|
| 4 |
+
|
| 5 |
+
Usage on RunPod:
|
| 6 |
+
1. Upload this script and train_qwen.jsonl to /workspace/data/
|
| 7 |
+
2. pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
| 8 |
+
3. pip install trl>=0.7.0 datasets accelerate bitsandbytes peft
|
| 9 |
+
4. python train_marxist.py
|
| 10 |
+
|
| 11 |
+
Expected: ~30 min training, ~$0.30-0.50 on RTX 4090
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from datasets import Dataset
|
| 18 |
+
from trl import SFTConfig, SFTTrainer
|
| 19 |
+
from unsloth import FastLanguageModel
|
| 20 |
+
|
| 21 |
+
# =============================================================================
|
| 22 |
+
# CONFIGURATION
|
| 23 |
+
# =============================================================================
|
| 24 |
+
|
| 25 |
+
MODEL_NAME = "huihui-ai/DeepSeek-R1-Distill-Qwen-7B-abliterated"
|
| 26 |
+
MAX_SEQ_LENGTH = 2048 # Qwen can go higher but 2048 is sufficient
|
| 27 |
+
LOAD_IN_4BIT = True
|
| 28 |
+
|
| 29 |
+
# LoRA config
|
| 30 |
+
LORA_R = 16
|
| 31 |
+
LORA_ALPHA = 32
|
| 32 |
+
LORA_DROPOUT = 0.05
|
| 33 |
+
|
| 34 |
+
# Training config
|
| 35 |
+
EPOCHS = 3
|
| 36 |
+
BATCH_SIZE = 2
|
| 37 |
+
GRADIENT_ACCUMULATION = 4 # Effective batch = 8
|
| 38 |
+
LEARNING_RATE = 2e-4
|
| 39 |
+
WARMUP_RATIO = 0.1
|
| 40 |
+
|
| 41 |
+
# Paths - adjust for your environment
|
| 42 |
+
DATA_PATH = Path("/workspace/data/train_qwen.jsonl") # RunPod
|
| 43 |
+
OUTPUT_DIR = Path("/workspace/outputs/marxist-deepseek")
|
| 44 |
+
CHECKPOINT_DIR = Path("/workspace/checkpoints")
|
| 45 |
+
|
| 46 |
+
# Fallback for local testing
|
| 47 |
+
if not DATA_PATH.exists():
|
| 48 |
+
DATA_PATH = Path("training_data/formatted/train_qwen.jsonl")
|
| 49 |
+
OUTPUT_DIR = Path("outputs/marxist-deepseek")
|
| 50 |
+
CHECKPOINT_DIR = Path("checkpoints")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_dataset(path: Path) -> Dataset:
|
| 54 |
+
"""Load pre-formatted Qwen template dataset."""
|
| 55 |
+
import json
|
| 56 |
+
|
| 57 |
+
examples = []
|
| 58 |
+
with open(path) as f:
|
| 59 |
+
for line in f:
|
| 60 |
+
examples.append(json.loads(line))
|
| 61 |
+
|
| 62 |
+
print(f"Loaded {len(examples)} training examples")
|
| 63 |
+
return Dataset.from_list(examples)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main() -> None:
|
| 67 |
+
"""Run fine-tuning."""
|
| 68 |
+
print("=" * 60)
|
| 69 |
+
print("Marxist-Leninist LLM Fine-Tuning")
|
| 70 |
+
print("=" * 60)
|
| 71 |
+
|
| 72 |
+
# Check CUDA
|
| 73 |
+
if torch.cuda.is_available():
|
| 74 |
+
print(f"GPU: {torch.cuda.get_device_name()}")
|
| 75 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 76 |
+
else:
|
| 77 |
+
raise RuntimeError("CUDA not available - need GPU for training!")
|
| 78 |
+
|
| 79 |
+
# Load model
|
| 80 |
+
print(f"\nLoading model: {MODEL_NAME}")
|
| 81 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 82 |
+
model_name=MODEL_NAME,
|
| 83 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 84 |
+
load_in_4bit=LOAD_IN_4BIT,
|
| 85 |
+
dtype=None, # Auto-detect (bf16 if available)
|
| 86 |
+
)
|
| 87 |
+
print(f"Model type: {model.config.model_type}")
|
| 88 |
+
|
| 89 |
+
# Apply LoRA
|
| 90 |
+
print("\nApplying LoRA adapters...")
|
| 91 |
+
model = FastLanguageModel.get_peft_model(
|
| 92 |
+
model,
|
| 93 |
+
r=LORA_R,
|
| 94 |
+
lora_alpha=LORA_ALPHA,
|
| 95 |
+
lora_dropout=LORA_DROPOUT,
|
| 96 |
+
target_modules=[
|
| 97 |
+
"q_proj",
|
| 98 |
+
"k_proj",
|
| 99 |
+
"v_proj",
|
| 100 |
+
"o_proj",
|
| 101 |
+
"gate_proj",
|
| 102 |
+
"up_proj",
|
| 103 |
+
"down_proj",
|
| 104 |
+
],
|
| 105 |
+
bias="none",
|
| 106 |
+
use_gradient_checkpointing="unsloth", # 30% less VRAM
|
| 107 |
+
random_state=3407,
|
| 108 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 112 |
+
total = sum(p.numel() for p in model.parameters())
|
| 113 |
+
print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
|
| 114 |
+
|
| 115 |
+
# Load dataset
|
| 116 |
+
print(f"\nLoading dataset: {DATA_PATH}")
|
| 117 |
+
dataset = load_dataset(DATA_PATH)
|
| 118 |
+
|
| 119 |
+
# Create output directories
|
| 120 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
| 122 |
+
|
| 123 |
+
# Configure trainer
|
| 124 |
+
print("\nConfiguring trainer...")
|
| 125 |
+
trainer = SFTTrainer(
|
| 126 |
+
model=model,
|
| 127 |
+
tokenizer=tokenizer,
|
| 128 |
+
train_dataset=dataset,
|
| 129 |
+
dataset_text_field="text", # Pre-formatted Qwen template
|
| 130 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 131 |
+
args=SFTConfig(
|
| 132 |
+
# Batch settings
|
| 133 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 134 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
|
| 135 |
+
# Learning rate
|
| 136 |
+
learning_rate=LEARNING_RATE,
|
| 137 |
+
lr_scheduler_type="cosine",
|
| 138 |
+
warmup_ratio=WARMUP_RATIO,
|
| 139 |
+
# Training duration
|
| 140 |
+
num_train_epochs=EPOCHS,
|
| 141 |
+
# Memory optimization
|
| 142 |
+
fp16=not torch.cuda.is_bf16_supported(),
|
| 143 |
+
bf16=torch.cuda.is_bf16_supported(),
|
| 144 |
+
optim="adamw_8bit",
|
| 145 |
+
# Logging
|
| 146 |
+
logging_steps=10,
|
| 147 |
+
save_strategy="epoch",
|
| 148 |
+
save_total_limit=2,
|
| 149 |
+
# Output
|
| 150 |
+
output_dir=str(OUTPUT_DIR),
|
| 151 |
+
seed=3407,
|
| 152 |
+
report_to="none", # or "wandb" if configured
|
| 153 |
+
),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Train!
|
| 157 |
+
print("\n" + "=" * 60)
|
| 158 |
+
print("STARTING TRAINING")
|
| 159 |
+
print("=" * 60)
|
| 160 |
+
print(f"Epochs: {EPOCHS}")
|
| 161 |
+
print(
|
| 162 |
+
f"Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}"
|
| 163 |
+
)
|
| 164 |
+
print(f"Learning rate: {LEARNING_RATE}")
|
| 165 |
+
print()
|
| 166 |
+
|
| 167 |
+
trainer.train()
|
| 168 |
+
|
| 169 |
+
# Save final model
|
| 170 |
+
print("\n" + "=" * 60)
|
| 171 |
+
print("SAVING MODEL")
|
| 172 |
+
print("=" * 60)
|
| 173 |
+
|
| 174 |
+
lora_path = CHECKPOINT_DIR / "marxist-lora-adapter"
|
| 175 |
+
model.save_pretrained(str(lora_path))
|
| 176 |
+
tokenizer.save_pretrained(str(lora_path))
|
| 177 |
+
print(f"LoRA adapter saved to: {lora_path}")
|
| 178 |
+
|
| 179 |
+
# Export to GGUF
|
| 180 |
+
print("\nExporting to GGUF (q4_k_m)...")
|
| 181 |
+
gguf_path = CHECKPOINT_DIR / "marxist-deepseek-q4_k_m"
|
| 182 |
+
model.save_pretrained_gguf(
|
| 183 |
+
str(gguf_path),
|
| 184 |
+
tokenizer,
|
| 185 |
+
quantization_method="q4_k_m",
|
| 186 |
+
)
|
| 187 |
+
print(f"GGUF exported to: {gguf_path}")
|
| 188 |
+
|
| 189 |
+
print("\n" + "=" * 60)
|
| 190 |
+
print("TRAINING COMPLETE!")
|
| 191 |
+
print("=" * 60)
|
| 192 |
+
print("\nNext steps:")
|
| 193 |
+
print(f"1. Download: {gguf_path}/*.gguf")
|
| 194 |
+
print("2. Create Ollama Modelfile (see ai-docs/finetune.yaml)")
|
| 195 |
+
print("3. ollama create marxist-deepseek -f Modelfile")
|
| 196 |
+
print("4. ollama run marxist-deepseek")
|
| 197 |
+
print("\nDON'T FORGET TO STOP YOUR RUNPOD!")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
main()
|
src/prolewiki_llm/transform_to_grpo.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Transform curated_qa.jsonl to GRPO training format.
|
| 4 |
+
|
| 5 |
+
Input format: {"instruction": "...", "response": "..."}
|
| 6 |
+
Output format: {"prompt": [...], "answer": "..."}
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python transform_to_grpo.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
|
| 16 |
+
Think through political theory questions using dialectical materialist analysis.
|
| 17 |
+
Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""
|
| 18 |
+
|
| 19 |
+
INPUT_PATH = Path("training_data/curated_qa.jsonl")
|
| 20 |
+
OUTPUT_PATH = Path("training_data/grpo_dataset.jsonl")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def transform_qa_to_grpo(input_path: Path, output_path: Path) -> int:
|
| 24 |
+
"""Transform instruction/response pairs to GRPO format."""
|
| 25 |
+
count = 0
|
| 26 |
+
|
| 27 |
+
with open(input_path) as infile, open(output_path, "w") as outfile:
|
| 28 |
+
for line in infile:
|
| 29 |
+
item = json.loads(line)
|
| 30 |
+
|
| 31 |
+
transformed = {
|
| 32 |
+
"prompt": [
|
| 33 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 34 |
+
{"role": "user", "content": item["instruction"]},
|
| 35 |
+
],
|
| 36 |
+
"answer": item["response"],
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
outfile.write(json.dumps(transformed) + "\n")
|
| 40 |
+
count += 1
|
| 41 |
+
|
| 42 |
+
return count
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main() -> None:
|
| 46 |
+
"""Run transformation."""
|
| 47 |
+
print(f"Transforming {INPUT_PATH} to GRPO format...")
|
| 48 |
+
|
| 49 |
+
count = transform_qa_to_grpo(INPUT_PATH, OUTPUT_PATH)
|
| 50 |
+
|
| 51 |
+
print(f"Transformed {count} examples")
|
| 52 |
+
print(f"Output written to: {OUTPUT_PATH}")
|
| 53 |
+
|
| 54 |
+
# Show sample
|
| 55 |
+
print("\nSample output:")
|
| 56 |
+
with open(OUTPUT_PATH) as f:
|
| 57 |
+
sample = json.loads(f.readline())
|
| 58 |
+
print(f" System: {sample['prompt'][0]['content'][:60]}...")
|
| 59 |
+
print(f" User: {sample['prompt'][1]['content'][:60]}...")
|
| 60 |
+
print(f" Answer: {sample['answer'][:60]}...")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
main()
|
src/prolewiki_llm/wandb_logging.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Weights & Biases Logging for GRPO Training.
|
| 4 |
+
|
| 5 |
+
Provides comprehensive logging for debugging and monitoring GRPO fine-tuning:
|
| 6 |
+
- Per-step reward metrics (each reward function's mean)
|
| 7 |
+
- Sample tables showing question → response → reward breakdown
|
| 8 |
+
- Run configuration and hyperparameters
|
| 9 |
+
- Summary statistics at training end
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
from prolewiki_llm.wandb_logging import (
|
| 13 |
+
init_wandb_logging,
|
| 14 |
+
WandbSampleLogger,
|
| 15 |
+
create_logging_reward,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Initialize
|
| 19 |
+
run = init_wandb_logging(project="marxist-grpo", config={...})
|
| 20 |
+
|
| 21 |
+
# Create logger and reward function
|
| 22 |
+
sample_logger = WandbSampleLogger(log_every_n_steps=10)
|
| 23 |
+
logging_reward = create_logging_reward(sample_logger)
|
| 24 |
+
|
| 25 |
+
# Use in GRPOTrainer
|
| 26 |
+
trainer = GRPOTrainer(
|
| 27 |
+
reward_funcs=[..., logging_reward],
|
| 28 |
+
...
|
| 29 |
+
)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
from dataclasses import dataclass, field
|
| 35 |
+
from typing import TYPE_CHECKING, Any
|
| 36 |
+
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from collections.abc import Callable, Sequence
|
| 39 |
+
|
| 40 |
+
# Global flag to track if wandb is available
|
| 41 |
+
_WANDB_AVAILABLE: bool | None = None
|
| 42 |
+
_wandb_module: Any = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_wandb() -> Any:
|
| 46 |
+
"""Lazily import and return wandb module."""
|
| 47 |
+
global _WANDB_AVAILABLE, _wandb_module
|
| 48 |
+
|
| 49 |
+
if _WANDB_AVAILABLE is None:
|
| 50 |
+
try:
|
| 51 |
+
import wandb
|
| 52 |
+
|
| 53 |
+
_wandb_module = wandb
|
| 54 |
+
_WANDB_AVAILABLE = True
|
| 55 |
+
except ImportError:
|
| 56 |
+
_WANDB_AVAILABLE = False
|
| 57 |
+
_wandb_module = None
|
| 58 |
+
|
| 59 |
+
return _wandb_module
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def is_wandb_available() -> bool:
|
| 63 |
+
"""Check if wandb is installed and available."""
|
| 64 |
+
_get_wandb()
|
| 65 |
+
return _WANDB_AVAILABLE is True
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# =============================================================================
|
| 69 |
+
# INITIALIZATION
|
| 70 |
+
# =============================================================================
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def init_wandb_logging(
|
| 74 |
+
project: str,
|
| 75 |
+
config: dict[str, Any],
|
| 76 |
+
name: str | None = None,
|
| 77 |
+
tags: list[str] | None = None,
|
| 78 |
+
notes: str | None = None,
|
| 79 |
+
mode: str = "online",
|
| 80 |
+
) -> Any:
|
| 81 |
+
"""
|
| 82 |
+
Initialize Weights & Biases logging for GRPO training.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
project: W&B project name (e.g., "marxist-grpo")
|
| 86 |
+
config: Dictionary of hyperparameters and settings
|
| 87 |
+
name: Optional run name (auto-generated if None)
|
| 88 |
+
tags: Optional list of tags for filtering runs
|
| 89 |
+
notes: Optional notes about this run
|
| 90 |
+
mode: "online", "offline", or "disabled"
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
wandb.Run object (or None if wandb unavailable)
|
| 94 |
+
|
| 95 |
+
Example:
|
| 96 |
+
run = init_wandb_logging(
|
| 97 |
+
project="marxist-grpo",
|
| 98 |
+
config={
|
| 99 |
+
"model": "DeepSeek-R1-0528-Qwen3-8B",
|
| 100 |
+
"learning_rate": 5e-6,
|
| 101 |
+
"batch_size": 2,
|
| 102 |
+
"max_steps": 250,
|
| 103 |
+
},
|
| 104 |
+
tags=["grpo", "marxist", "v1"],
|
| 105 |
+
)
|
| 106 |
+
"""
|
| 107 |
+
wandb = _get_wandb()
|
| 108 |
+
if wandb is None:
|
| 109 |
+
print("[WandbLogging] wandb not installed. Install with: pip install wandb")
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
# Initialize run
|
| 113 |
+
run = wandb.init(
|
| 114 |
+
project=project,
|
| 115 |
+
config=config,
|
| 116 |
+
name=name,
|
| 117 |
+
tags=tags or ["grpo", "marxist-leninist"],
|
| 118 |
+
notes=notes,
|
| 119 |
+
mode=mode,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Define metrics with proper summaries
|
| 123 |
+
_define_reward_metrics(run)
|
| 124 |
+
|
| 125 |
+
print(f"[WandbLogging] Initialized run: {run.name}")
|
| 126 |
+
print(f"[WandbLogging] View at: {run.url}")
|
| 127 |
+
|
| 128 |
+
return run
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _define_reward_metrics(run: Any) -> None:
|
| 132 |
+
"""Define reward metrics with min/max/mean summaries."""
|
| 133 |
+
reward_metrics = [
|
| 134 |
+
"rewards/format_exact",
|
| 135 |
+
"rewards/format_approx",
|
| 136 |
+
"rewards/semantic_similarity",
|
| 137 |
+
"rewards/terminology",
|
| 138 |
+
"rewards/nli_coherence",
|
| 139 |
+
"rewards/self_consistency",
|
| 140 |
+
"rewards/structural_coherence",
|
| 141 |
+
"rewards/topic_relevance",
|
| 142 |
+
"rewards/interconnection_depth",
|
| 143 |
+
"rewards/completeness",
|
| 144 |
+
"rewards/total",
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
for metric in reward_metrics:
|
| 148 |
+
# Track min, max, and mean for each reward
|
| 149 |
+
run.define_metric(metric, summary="mean")
|
| 150 |
+
run.define_metric(f"{metric}_min", summary="min")
|
| 151 |
+
run.define_metric(f"{metric}_max", summary="max")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# =============================================================================
|
| 155 |
+
# SAMPLE LOGGER
|
| 156 |
+
# =============================================================================
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@dataclass
|
| 160 |
+
class RewardSample:
|
| 161 |
+
"""A single sample with its reward breakdown."""
|
| 162 |
+
|
| 163 |
+
step: int
|
| 164 |
+
question: str
|
| 165 |
+
response: str
|
| 166 |
+
ground_truth: str
|
| 167 |
+
rewards: dict[str, float]
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def total_reward(self) -> float:
|
| 171 |
+
"""Sum of all rewards."""
|
| 172 |
+
return sum(self.rewards.values())
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dataclass
|
| 176 |
+
class WandbSampleLogger:
|
| 177 |
+
"""
|
| 178 |
+
Logs sample tables to W&B for debugging reward functions.
|
| 179 |
+
|
| 180 |
+
Accumulates samples during training and logs them as a wandb.Table
|
| 181 |
+
every N steps. This lets you inspect actual model outputs and
|
| 182 |
+
understand why specific rewards were assigned.
|
| 183 |
+
|
| 184 |
+
Example table:
|
| 185 |
+
| step | question | response | ground_truth | format | nli | topic | depth | total |
|
| 186 |
+
|------|----------|----------|--------------|--------|-----|-------|-------|-------|
|
| 187 |
+
| 50 | What is..| The bour.| Revisionism..| 3.0 | 2.5 | 1.5 | 1.0 | 8.0 |
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
log_every_n_steps: int = 10
|
| 191 |
+
max_samples_per_log: int = 4
|
| 192 |
+
_samples: list[RewardSample] = field(default_factory=list)
|
| 193 |
+
_step_counter: int = field(default=0)
|
| 194 |
+
_table_columns: list[str] = field(
|
| 195 |
+
default_factory=lambda: [
|
| 196 |
+
"step",
|
| 197 |
+
"question",
|
| 198 |
+
"response",
|
| 199 |
+
"ground_truth",
|
| 200 |
+
"format_exact",
|
| 201 |
+
"format_approx",
|
| 202 |
+
"nli_coherence",
|
| 203 |
+
"topic_relevance",
|
| 204 |
+
"depth",
|
| 205 |
+
"completeness",
|
| 206 |
+
"total",
|
| 207 |
+
]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def add_sample(
|
| 211 |
+
self,
|
| 212 |
+
step: int,
|
| 213 |
+
question: str,
|
| 214 |
+
response: str,
|
| 215 |
+
ground_truth: str,
|
| 216 |
+
rewards: dict[str, float],
|
| 217 |
+
) -> None:
|
| 218 |
+
"""Add a sample to the buffer."""
|
| 219 |
+
sample = RewardSample(
|
| 220 |
+
step=step,
|
| 221 |
+
question=question[:500], # Truncate for table display
|
| 222 |
+
response=response[:500],
|
| 223 |
+
ground_truth=ground_truth[:300],
|
| 224 |
+
rewards=rewards,
|
| 225 |
+
)
|
| 226 |
+
self._samples.append(sample)
|
| 227 |
+
|
| 228 |
+
# Keep only recent samples
|
| 229 |
+
max_buffer = self.max_samples_per_log * 3
|
| 230 |
+
if len(self._samples) > max_buffer:
|
| 231 |
+
self._samples = self._samples[-max_buffer:]
|
| 232 |
+
|
| 233 |
+
def should_log(self, step: int) -> bool:
|
| 234 |
+
"""Check if we should log at this step."""
|
| 235 |
+
return step > 0 and step % self.log_every_n_steps == 0
|
| 236 |
+
|
| 237 |
+
def log_table(self, step: int) -> None:
|
| 238 |
+
"""Log accumulated samples as a wandb.Table."""
|
| 239 |
+
wandb = _get_wandb()
|
| 240 |
+
if wandb is None or not self._samples:
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
# Get recent samples
|
| 244 |
+
samples_to_log = self._samples[-self.max_samples_per_log :]
|
| 245 |
+
|
| 246 |
+
# Create table
|
| 247 |
+
table = wandb.Table(columns=self._table_columns)
|
| 248 |
+
|
| 249 |
+
for sample in samples_to_log:
|
| 250 |
+
row = [
|
| 251 |
+
sample.step,
|
| 252 |
+
sample.question,
|
| 253 |
+
sample.response,
|
| 254 |
+
sample.ground_truth,
|
| 255 |
+
sample.rewards.get("format_exact", 0.0),
|
| 256 |
+
sample.rewards.get("format_approx", 0.0),
|
| 257 |
+
sample.rewards.get("nli_coherence", 0.0),
|
| 258 |
+
sample.rewards.get("topic_relevance", 0.0),
|
| 259 |
+
sample.rewards.get("interconnection_depth", 0.0),
|
| 260 |
+
sample.rewards.get("completeness", 0.0),
|
| 261 |
+
sample.total_reward,
|
| 262 |
+
]
|
| 263 |
+
table.add_data(*row)
|
| 264 |
+
|
| 265 |
+
# Log the table
|
| 266 |
+
wandb.log({"samples": table}, step=step)
|
| 267 |
+
print(f"[WandbLogging] Logged {len(samples_to_log)} samples at step {step}")
|
| 268 |
+
|
| 269 |
+
def clear(self) -> None:
|
| 270 |
+
"""Clear the sample buffer."""
|
| 271 |
+
self._samples.clear()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# =============================================================================
|
| 275 |
+
# REWARD METRICS LOGGING
|
| 276 |
+
# =============================================================================
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def log_reward_metrics(
|
| 280 |
+
step: int,
|
| 281 |
+
reward_scores: dict[str, list[float]],
|
| 282 |
+
) -> None:
|
| 283 |
+
"""
|
| 284 |
+
Log reward metrics to wandb.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
step: Training step number
|
| 288 |
+
reward_scores: Dict mapping reward name to list of scores
|
| 289 |
+
e.g., {"format_exact": [3.0, 3.0, 0.0, 3.0]}
|
| 290 |
+
"""
|
| 291 |
+
wandb = _get_wandb()
|
| 292 |
+
if wandb is None:
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
metrics: dict[str, float] = {}
|
| 296 |
+
|
| 297 |
+
for name, scores in reward_scores.items():
|
| 298 |
+
if not scores:
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
mean_score = sum(scores) / len(scores)
|
| 302 |
+
min_score = min(scores)
|
| 303 |
+
max_score = max(scores)
|
| 304 |
+
|
| 305 |
+
metrics[f"rewards/{name}"] = mean_score
|
| 306 |
+
metrics[f"rewards/{name}_min"] = min_score
|
| 307 |
+
metrics[f"rewards/{name}_max"] = max_score
|
| 308 |
+
|
| 309 |
+
# Compute total
|
| 310 |
+
if reward_scores:
|
| 311 |
+
all_totals = []
|
| 312 |
+
num_samples = len(next(iter(reward_scores.values())))
|
| 313 |
+
for i in range(num_samples):
|
| 314 |
+
total = sum(scores[i] for scores in reward_scores.values() if i < len(scores))
|
| 315 |
+
all_totals.append(total)
|
| 316 |
+
|
| 317 |
+
if all_totals:
|
| 318 |
+
metrics["rewards/total"] = sum(all_totals) / len(all_totals)
|
| 319 |
+
metrics["rewards/total_min"] = min(all_totals)
|
| 320 |
+
metrics["rewards/total_max"] = max(all_totals)
|
| 321 |
+
|
| 322 |
+
wandb.log(metrics, step=step)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# =============================================================================
|
| 326 |
+
# LOGGING REWARD FUNCTION
|
| 327 |
+
# =============================================================================
|
| 328 |
+
|
| 329 |
+
# Global step counter for the logging reward
|
| 330 |
+
_LOGGING_STEP = 0
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def create_logging_reward(
|
| 334 |
+
sample_logger: WandbSampleLogger | None = None,
|
| 335 |
+
compute_all_rewards: bool = True,
|
| 336 |
+
) -> Callable[..., list[float]]:
|
| 337 |
+
"""
|
| 338 |
+
Create a reward function that logs metrics and samples to wandb.
|
| 339 |
+
|
| 340 |
+
This replaces debug_print_reward with comprehensive wandb logging.
|
| 341 |
+
The returned function computes ALL individual rewards internally,
|
| 342 |
+
logs them to wandb, and returns [0.0] * len(completions) (no training effect).
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
sample_logger: WandbSampleLogger instance for sample table logging
|
| 346 |
+
compute_all_rewards: If True, compute and log all reward functions
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
A reward function compatible with GRPOTrainer
|
| 350 |
+
|
| 351 |
+
Example:
|
| 352 |
+
sample_logger = WandbSampleLogger(log_every_n_steps=10)
|
| 353 |
+
logging_reward = create_logging_reward(sample_logger)
|
| 354 |
+
|
| 355 |
+
trainer = GRPOTrainer(
|
| 356 |
+
reward_funcs=[..., logging_reward],
|
| 357 |
+
...
|
| 358 |
+
)
|
| 359 |
+
"""
|
| 360 |
+
global _LOGGING_STEP
|
| 361 |
+
|
| 362 |
+
def logging_reward(
|
| 363 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 364 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 365 |
+
answer: Sequence[str],
|
| 366 |
+
**kwargs: object,
|
| 367 |
+
) -> list[float]:
|
| 368 |
+
"""Log rewards and samples to wandb. Returns 0.0 (no training effect)."""
|
| 369 |
+
global _LOGGING_STEP
|
| 370 |
+
_LOGGING_STEP += 1
|
| 371 |
+
step = _LOGGING_STEP
|
| 372 |
+
|
| 373 |
+
wandb = _get_wandb()
|
| 374 |
+
if wandb is None or wandb.run is None:
|
| 375 |
+
# Fallback to print if wandb not initialized
|
| 376 |
+
if step % 10 == 0:
|
| 377 |
+
print(f"[Step {step}] Q: {prompts[0][-1]['content'][:80]}...")
|
| 378 |
+
return [0.0] * len(completions)
|
| 379 |
+
|
| 380 |
+
# Compute all reward scores if requested
|
| 381 |
+
if compute_all_rewards:
|
| 382 |
+
reward_scores = _compute_all_reward_scores(prompts, completions, answer, **kwargs)
|
| 383 |
+
log_reward_metrics(step, reward_scores)
|
| 384 |
+
else:
|
| 385 |
+
reward_scores = {}
|
| 386 |
+
|
| 387 |
+
# Log samples periodically
|
| 388 |
+
if sample_logger and sample_logger.should_log(step):
|
| 389 |
+
# Add current batch to sample logger
|
| 390 |
+
for i in range(min(sample_logger.max_samples_per_log, len(prompts))):
|
| 391 |
+
question = prompts[i][-1]["content"]
|
| 392 |
+
response = completions[i][0]["content"]
|
| 393 |
+
truth = answer[i] if i < len(answer) else ""
|
| 394 |
+
|
| 395 |
+
# Get individual rewards for this sample
|
| 396 |
+
sample_rewards = {
|
| 397 |
+
name: scores[i] if i < len(scores) else 0.0
|
| 398 |
+
for name, scores in reward_scores.items()
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
sample_logger.add_sample(
|
| 402 |
+
step=step,
|
| 403 |
+
question=question,
|
| 404 |
+
response=response,
|
| 405 |
+
ground_truth=truth,
|
| 406 |
+
rewards=sample_rewards,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
sample_logger.log_table(step)
|
| 410 |
+
|
| 411 |
+
return [0.0] * len(completions)
|
| 412 |
+
|
| 413 |
+
return logging_reward
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _compute_all_reward_scores(
|
| 417 |
+
prompts: Sequence[Sequence[dict[str, str]]],
|
| 418 |
+
completions: Sequence[Sequence[dict[str, str]]],
|
| 419 |
+
answer: Sequence[str],
|
| 420 |
+
**kwargs: object,
|
| 421 |
+
) -> dict[str, list[float]]:
|
| 422 |
+
"""
|
| 423 |
+
Compute all reward function scores for logging.
|
| 424 |
+
|
| 425 |
+
Returns dict mapping reward name to list of scores.
|
| 426 |
+
"""
|
| 427 |
+
# Import reward functions here to avoid circular imports
|
| 428 |
+
from prolewiki_llm.grpo_rewards import (
|
| 429 |
+
completeness_reward,
|
| 430 |
+
interconnection_depth_reward,
|
| 431 |
+
match_format_approximately,
|
| 432 |
+
match_format_exactly,
|
| 433 |
+
nli_coherence_reward,
|
| 434 |
+
topic_relevance_reward,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
reward_scores: dict[str, list[float]] = {}
|
| 438 |
+
|
| 439 |
+
# Format rewards (don't need answer)
|
| 440 |
+
try:
|
| 441 |
+
reward_scores["format_exact"] = match_format_exactly(completions, **kwargs)
|
| 442 |
+
except Exception as e:
|
| 443 |
+
print(f"[WandbLogging] Error in format_exact: {e}")
|
| 444 |
+
reward_scores["format_exact"] = [0.0] * len(completions)
|
| 445 |
+
|
| 446 |
+
try:
|
| 447 |
+
reward_scores["format_approx"] = match_format_approximately(completions, **kwargs)
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"[WandbLogging] Error in format_approx: {e}")
|
| 450 |
+
reward_scores["format_approx"] = [0.0] * len(completions)
|
| 451 |
+
|
| 452 |
+
# NLI coherence (needs answer)
|
| 453 |
+
try:
|
| 454 |
+
reward_scores["nli_coherence"] = nli_coherence_reward(completions, answer, **kwargs)
|
| 455 |
+
except Exception as e:
|
| 456 |
+
print(f"[WandbLogging] Error in nli_coherence: {e}")
|
| 457 |
+
reward_scores["nli_coherence"] = [0.0] * len(completions)
|
| 458 |
+
|
| 459 |
+
# Topic relevance (needs prompts)
|
| 460 |
+
try:
|
| 461 |
+
reward_scores["topic_relevance"] = topic_relevance_reward(prompts, completions, **kwargs)
|
| 462 |
+
except Exception as e:
|
| 463 |
+
print(f"[WandbLogging] Error in topic_relevance: {e}")
|
| 464 |
+
reward_scores["topic_relevance"] = [0.0] * len(completions)
|
| 465 |
+
|
| 466 |
+
# Interconnection depth
|
| 467 |
+
try:
|
| 468 |
+
reward_scores["interconnection_depth"] = interconnection_depth_reward(completions, **kwargs)
|
| 469 |
+
except Exception as e:
|
| 470 |
+
print(f"[WandbLogging] Error in interconnection_depth: {e}")
|
| 471 |
+
reward_scores["interconnection_depth"] = [0.0] * len(completions)
|
| 472 |
+
|
| 473 |
+
# Completeness (needs answer)
|
| 474 |
+
try:
|
| 475 |
+
reward_scores["completeness"] = completeness_reward(completions, answer, **kwargs)
|
| 476 |
+
except Exception as e:
|
| 477 |
+
print(f"[WandbLogging] Error in completeness: {e}")
|
| 478 |
+
reward_scores["completeness"] = [0.0] * len(completions)
|
| 479 |
+
|
| 480 |
+
return reward_scores
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# =============================================================================
|
| 484 |
+
# UTILITY FUNCTIONS
|
| 485 |
+
# =============================================================================
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def finish_wandb_logging(summary: dict[str, Any] | None = None) -> None:
|
| 489 |
+
"""
|
| 490 |
+
Finish the wandb run with optional summary statistics.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
summary: Optional dict of final summary metrics
|
| 494 |
+
"""
|
| 495 |
+
wandb = _get_wandb()
|
| 496 |
+
if wandb is None or wandb.run is None:
|
| 497 |
+
return
|
| 498 |
+
|
| 499 |
+
if summary:
|
| 500 |
+
for key, value in summary.items():
|
| 501 |
+
wandb.run.summary[key] = value
|
| 502 |
+
|
| 503 |
+
wandb.finish()
|
| 504 |
+
print("[WandbLogging] Run finished.")
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def log_model_checkpoint(
|
| 508 |
+
checkpoint_path: str,
|
| 509 |
+
metadata: dict[str, Any] | None = None,
|
| 510 |
+
) -> None:
|
| 511 |
+
"""
|
| 512 |
+
Log a model checkpoint as a wandb artifact.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
checkpoint_path: Path to the checkpoint directory
|
| 516 |
+
metadata: Optional metadata about the checkpoint
|
| 517 |
+
"""
|
| 518 |
+
wandb = _get_wandb()
|
| 519 |
+
if wandb is None or wandb.run is None:
|
| 520 |
+
return
|
| 521 |
+
|
| 522 |
+
artifact = wandb.Artifact(
|
| 523 |
+
name=f"checkpoint-{wandb.run.name}",
|
| 524 |
+
type="model",
|
| 525 |
+
metadata=metadata or {},
|
| 526 |
+
)
|
| 527 |
+
artifact.add_dir(checkpoint_path)
|
| 528 |
+
wandb.log_artifact(artifact)
|
| 529 |
+
print(f"[WandbLogging] Logged checkpoint: {checkpoint_path}")
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# prolewiki-llm tests
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest configuration and fixtures for prolewiki-llm tests."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
# =============================================================================
|
| 9 |
+
# Module-Level Setup for train_headless.py Tests
|
| 10 |
+
# =============================================================================
|
| 11 |
+
# train_headless.py reads required env vars (HF_TOKEN, WANDB_API_KEY) at import time.
|
| 12 |
+
# We need to set dummy values BEFORE pytest collects tests that import this module.
|
| 13 |
+
# This must be done at module level, not in a fixture, because fixtures run
|
| 14 |
+
# after collection and the import happens during collection.
|
| 15 |
+
|
| 16 |
+
if "HF_TOKEN" not in os.environ:
|
| 17 |
+
os.environ["HF_TOKEN"] = "test-token-for-unit-tests"
|
| 18 |
+
if "WANDB_API_KEY" not in os.environ:
|
| 19 |
+
os.environ["WANDB_API_KEY"] = "test-key-for-unit-tests"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def sample_question() -> str:
|
| 24 |
+
"""Sample Marxist theory question for testing."""
|
| 25 |
+
return "What is the difference between revisionism and Marxism-Leninism?"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@pytest.fixture
|
| 29 |
+
def sample_good_answer() -> str:
|
| 30 |
+
"""Sample well-structured answer for testing rewards."""
|
| 31 |
+
return """<think>
|
| 32 |
+
The question asks about the distinction between revisionism and Marxism-Leninism,
|
| 33 |
+
which requires explaining both concepts and their historical relationship.
|
| 34 |
+
</think>
|
| 35 |
+
|
| 36 |
+
Revisionism refers to attempts to revise or water down the fundamental principles
|
| 37 |
+
of Marxism, particularly the necessity of class struggle and proletarian
|
| 38 |
+
dictatorship. It emerged historically with Eduard Bernstein who rejected
|
| 39 |
+
revolutionary change in favor of gradual reform within capitalism.
|
| 40 |
+
|
| 41 |
+
Marxism-Leninism, in contrast, upholds the core tenets of scientific socialism:
|
| 42 |
+
the class nature of the state, the necessity of socialist revolution, and the
|
| 43 |
+
dictatorship of the proletariat as the transitional form to communism.
|
| 44 |
+
|
| 45 |
+
The key distinction lies in their approach to the capitalist state - revisionists
|
| 46 |
+
seek to reform it, while Marxist-Leninists understand it must be replaced by a
|
| 47 |
+
workers' state through revolutionary action."""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@pytest.fixture
|
| 51 |
+
def sample_bad_answer() -> str:
|
| 52 |
+
"""Sample buzzword-heavy answer lacking depth."""
|
| 53 |
+
return """Revisionism is bad and Marxism-Leninism is good. We must fight
|
| 54 |
+
imperialism and support the working class. Solidarity forever! The revolution
|
| 55 |
+
will triumph against capitalism and fascism. Workers of the world unite!"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# =============================================================================
|
| 59 |
+
# Integration Test Fixtures
|
| 60 |
+
# =============================================================================
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@pytest.fixture
|
| 64 |
+
def project_root() -> Path:
|
| 65 |
+
"""Return the project root directory."""
|
| 66 |
+
return Path(__file__).parent.parent
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@pytest.fixture
|
| 70 |
+
def mock_bin_dir(tmp_path: Path) -> Path:
|
| 71 |
+
"""Create mock bin directory with executable stubs for shell script testing.
|
| 72 |
+
|
| 73 |
+
Copies mock scripts from tests/fixtures/mock_bin/ to a temporary directory
|
| 74 |
+
and makes them executable.
|
| 75 |
+
"""
|
| 76 |
+
mock_bin = tmp_path / "mock_bin"
|
| 77 |
+
mock_bin.mkdir()
|
| 78 |
+
|
| 79 |
+
# Copy mock scripts from fixtures
|
| 80 |
+
fixtures_dir = Path(__file__).parent / "fixtures" / "mock_bin"
|
| 81 |
+
if fixtures_dir.exists():
|
| 82 |
+
for script in fixtures_dir.iterdir():
|
| 83 |
+
if script.is_file():
|
| 84 |
+
dest = mock_bin / script.name
|
| 85 |
+
dest.write_text(script.read_text())
|
| 86 |
+
dest.chmod(0o755) # Make executable
|
| 87 |
+
|
| 88 |
+
return mock_bin
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@pytest.fixture
|
| 92 |
+
def start_sh_env(tmp_path: Path, mock_bin_dir: Path) -> dict[str, str]:
|
| 93 |
+
"""Base environment for start.sh integration tests.
|
| 94 |
+
|
| 95 |
+
Provides a controlled environment with:
|
| 96 |
+
- PATH pointing to mock binaries first
|
| 97 |
+
- Log directory for capturing mock invocations
|
| 98 |
+
- Default success values for mock commands
|
| 99 |
+
- Output directories in tmp_path (not /workspace)
|
| 100 |
+
"""
|
| 101 |
+
log_dir = tmp_path / "logs"
|
| 102 |
+
log_dir.mkdir()
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"PATH": f"{mock_bin_dir}:/usr/bin:/bin",
|
| 106 |
+
"HOME": str(tmp_path),
|
| 107 |
+
"MOCK_LOG_DIR": str(log_dir),
|
| 108 |
+
"MOCK_CUDA_AVAILABLE": "1",
|
| 109 |
+
"MOCK_TRAINING_EXIT_CODE": "0",
|
| 110 |
+
# Override /workspace paths to use tmp_path
|
| 111 |
+
"CHECKPOINT_DIR": str(tmp_path / "checkpoints"),
|
| 112 |
+
"LORA_OUTPUT": str(tmp_path / "lora-output"),
|
| 113 |
+
"OUTPUT_DIR": str(tmp_path / "outputs"),
|
| 114 |
+
}
|
tests/fixtures/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Test fixtures for prolewiki-llm tests."""
|
tests/fixtures/mock_bin/huggingface-cli
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Mock huggingface-cli for testing authentication
|
| 3 |
+
|
| 4 |
+
LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
|
| 5 |
+
echo "huggingface-cli $*" >> "$LOG_FILE"
|
| 6 |
+
|
| 7 |
+
exit "${MOCK_HF_CLI_EXIT_CODE:-0}"
|
tests/fixtures/mock_bin/python
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Mock python script for testing start.sh
|
| 3 |
+
# Handles CUDA checks and training module invocations
|
| 4 |
+
|
| 5 |
+
LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
|
| 6 |
+
echo "python $*" >> "$LOG_FILE"
|
| 7 |
+
|
| 8 |
+
# Handle CUDA availability check
|
| 9 |
+
if [[ "$*" == *"torch.cuda.is_available"* ]]; then
|
| 10 |
+
if [ "${MOCK_CUDA_AVAILABLE:-1}" = "0" ]; then
|
| 11 |
+
echo "AssertionError: CUDA not available" >&2
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
exit 0
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
# Handle GPU name query
|
| 18 |
+
if [[ "$*" == *"get_device_name"* ]]; then
|
| 19 |
+
echo "Mock GPU RTX 4090"
|
| 20 |
+
exit 0
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Handle VRAM query
|
| 24 |
+
if [[ "$*" == *"get_device_properties"* ]]; then
|
| 25 |
+
echo "24.0 GB"
|
| 26 |
+
exit 0
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
# Handle training module
|
| 30 |
+
if [[ "$*" == *"prolewiki_llm.train_headless"* ]]; then
|
| 31 |
+
exit "${MOCK_TRAINING_EXIT_CODE:-0}"
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# Default: success
|
| 35 |
+
exit 0
|
tests/fixtures/mock_bin/runpodctl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Mock runpodctl for testing start.sh auto-shutoff behavior
|
| 3 |
+
|
| 4 |
+
LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
|
| 5 |
+
echo "runpodctl $*" >> "$LOG_FILE"
|
| 6 |
+
|
| 7 |
+
# Simulate successful pod operations
|
| 8 |
+
if [[ "$1" == "stop" ]] && [[ "$2" == "pod" ]]; then
|
| 9 |
+
echo "pod \"$3\" stopped"
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
exit "${MOCK_RUNPODCTL_EXIT_CODE:-0}"
|
tests/fixtures/mock_bin/wandb
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Mock wandb CLI for testing authentication
|
| 3 |
+
|
| 4 |
+
LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
|
| 5 |
+
echo "wandb $*" >> "$LOG_FILE"
|
| 6 |
+
|
| 7 |
+
exit "${MOCK_WANDB_EXIT_CODE:-0}"
|
tests/fixtures/mock_bin/wget
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Mock wget for testing dataset download
|
| 3 |
+
|
| 4 |
+
LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
|
| 5 |
+
echo "wget $*" >> "$LOG_FILE"
|
| 6 |
+
|
| 7 |
+
# If -O flag is present, create a dummy file
|
| 8 |
+
for i in "$@"; do
|
| 9 |
+
if [[ "$prev" == "-O" ]]; then
|
| 10 |
+
# Create dummy dataset file
|
| 11 |
+
echo '{"prompt": "test", "answer": "test"}' > "$i"
|
| 12 |
+
break
|
| 13 |
+
fi
|
| 14 |
+
prev="$i"
|
| 15 |
+
done
|
| 16 |
+
|
| 17 |
+
exit "${MOCK_WGET_EXIT_CODE:-0}"
|
tests/integration/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for prolewiki-llm."""
|
tests/integration/test_start_sh.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests for the headless training entrypoint script.
|
| 3 |
+
|
| 4 |
+
Tests cover:
|
| 5 |
+
- Environment validation (required secrets, GPU check)
|
| 6 |
+
- Auto-shutoff behavior (FinOps: pod termination on success, no termination on failure)
|
| 7 |
+
- Data handling (dataset validation, download from URL)
|
| 8 |
+
- Authentication (HuggingFace, Weights & Biases)
|
| 9 |
+
|
| 10 |
+
These tests use subprocess to run start.sh with mock external commands.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import subprocess
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import pytest
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# Environment Validation Tests
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@pytest.mark.integration
|
| 26 |
+
class TestEnvironmentValidation:
|
| 27 |
+
"""Test environment variable validation in start.sh."""
|
| 28 |
+
|
| 29 |
+
def test_exits_without_hf_token(
|
| 30 |
+
self,
|
| 31 |
+
tmp_path: Path,
|
| 32 |
+
start_sh_env: dict[str, str],
|
| 33 |
+
project_root: Path,
|
| 34 |
+
) -> None:
|
| 35 |
+
"""Script exits with error when HF_TOKEN is not set."""
|
| 36 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 37 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 38 |
+
|
| 39 |
+
env = {
|
| 40 |
+
**start_sh_env,
|
| 41 |
+
# HF_TOKEN intentionally missing
|
| 42 |
+
"WANDB_API_KEY": "test-key",
|
| 43 |
+
"DATASET_PATH": str(dataset),
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
result = subprocess.run(
|
| 47 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 48 |
+
env=env,
|
| 49 |
+
capture_output=True,
|
| 50 |
+
text=True,
|
| 51 |
+
timeout=30,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
assert result.returncode == 1
|
| 55 |
+
assert "HF_TOKEN" in result.stdout or "HF_TOKEN" in result.stderr
|
| 56 |
+
|
| 57 |
+
def test_exits_without_wandb_key(
|
| 58 |
+
self,
|
| 59 |
+
tmp_path: Path,
|
| 60 |
+
start_sh_env: dict[str, str],
|
| 61 |
+
project_root: Path,
|
| 62 |
+
) -> None:
|
| 63 |
+
"""Script exits with error when WANDB_API_KEY is not set."""
|
| 64 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 65 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 66 |
+
|
| 67 |
+
env = {
|
| 68 |
+
**start_sh_env,
|
| 69 |
+
"HF_TOKEN": "test-token",
|
| 70 |
+
# WANDB_API_KEY intentionally missing
|
| 71 |
+
"DATASET_PATH": str(dataset),
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
result = subprocess.run(
|
| 75 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 76 |
+
env=env,
|
| 77 |
+
capture_output=True,
|
| 78 |
+
text=True,
|
| 79 |
+
timeout=30,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
assert result.returncode == 1
|
| 83 |
+
assert "WANDB_API_KEY" in result.stdout or "WANDB_API_KEY" in result.stderr
|
| 84 |
+
|
| 85 |
+
def test_exits_without_cuda(
|
| 86 |
+
self,
|
| 87 |
+
tmp_path: Path,
|
| 88 |
+
start_sh_env: dict[str, str],
|
| 89 |
+
project_root: Path,
|
| 90 |
+
) -> None:
|
| 91 |
+
"""Script exits with error when CUDA is not available."""
|
| 92 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 93 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 94 |
+
|
| 95 |
+
env = {
|
| 96 |
+
**start_sh_env,
|
| 97 |
+
"HF_TOKEN": "test-token",
|
| 98 |
+
"WANDB_API_KEY": "test-key",
|
| 99 |
+
"DATASET_PATH": str(dataset),
|
| 100 |
+
"MOCK_CUDA_AVAILABLE": "0", # Simulate no GPU
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
result = subprocess.run(
|
| 104 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 105 |
+
env=env,
|
| 106 |
+
capture_output=True,
|
| 107 |
+
text=True,
|
| 108 |
+
timeout=30,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
assert result.returncode == 1
|
| 112 |
+
assert "CUDA" in result.stdout or "GPU" in result.stdout
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# =============================================================================
|
| 116 |
+
# Auto-Shutoff Tests (Critical for FinOps)
|
| 117 |
+
# =============================================================================
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@pytest.mark.integration
|
| 121 |
+
class TestAutoShutoff:
|
| 122 |
+
"""Test pod auto-termination per RunPod strategy document.
|
| 123 |
+
|
| 124 |
+
These tests verify the FinOps strategy that prevents zombie pods.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def test_calls_runpodctl_stop_on_success(
|
| 128 |
+
self,
|
| 129 |
+
tmp_path: Path,
|
| 130 |
+
start_sh_env: dict[str, str],
|
| 131 |
+
project_root: Path,
|
| 132 |
+
) -> None:
|
| 133 |
+
"""Pod is terminated after successful training (FinOps)."""
|
| 134 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 135 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 136 |
+
|
| 137 |
+
env = {
|
| 138 |
+
**start_sh_env,
|
| 139 |
+
"HF_TOKEN": "test-token",
|
| 140 |
+
"WANDB_API_KEY": "test-key",
|
| 141 |
+
"RUNPOD_POD_ID": "pod-abc123",
|
| 142 |
+
"DATASET_PATH": str(dataset),
|
| 143 |
+
"MOCK_TRAINING_EXIT_CODE": "0", # Training succeeds
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
result = subprocess.run(
|
| 147 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 148 |
+
env=env,
|
| 149 |
+
capture_output=True,
|
| 150 |
+
text=True,
|
| 151 |
+
timeout=30,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Verify runpodctl was called
|
| 155 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 156 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 157 |
+
|
| 158 |
+
assert result.returncode == 0
|
| 159 |
+
assert "runpodctl stop pod pod-abc123" in calls
|
| 160 |
+
|
| 161 |
+
def test_no_runpodctl_stop_without_pod_id(
|
| 162 |
+
self,
|
| 163 |
+
tmp_path: Path,
|
| 164 |
+
start_sh_env: dict[str, str],
|
| 165 |
+
project_root: Path,
|
| 166 |
+
) -> None:
|
| 167 |
+
"""No pod termination when RUNPOD_POD_ID is not set."""
|
| 168 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 169 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 170 |
+
|
| 171 |
+
env = {
|
| 172 |
+
**start_sh_env,
|
| 173 |
+
"HF_TOKEN": "test-token",
|
| 174 |
+
"WANDB_API_KEY": "test-key",
|
| 175 |
+
# RUNPOD_POD_ID intentionally not set
|
| 176 |
+
"DATASET_PATH": str(dataset),
|
| 177 |
+
"MOCK_TRAINING_EXIT_CODE": "0",
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
result = subprocess.run(
|
| 181 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 182 |
+
env=env,
|
| 183 |
+
capture_output=True,
|
| 184 |
+
text=True,
|
| 185 |
+
timeout=30,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 189 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 190 |
+
|
| 191 |
+
assert result.returncode == 0
|
| 192 |
+
assert "runpodctl" not in calls
|
| 193 |
+
assert "RUNPOD_POD_ID not set" in result.stdout
|
| 194 |
+
|
| 195 |
+
def test_no_runpodctl_stop_on_failure(
|
| 196 |
+
self,
|
| 197 |
+
tmp_path: Path,
|
| 198 |
+
start_sh_env: dict[str, str],
|
| 199 |
+
project_root: Path,
|
| 200 |
+
) -> None:
|
| 201 |
+
"""Pod is NOT terminated on failure (allows debugging).
|
| 202 |
+
|
| 203 |
+
This is critical: developers need SSH access to debug failed training.
|
| 204 |
+
"""
|
| 205 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 206 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 207 |
+
|
| 208 |
+
env = {
|
| 209 |
+
**start_sh_env,
|
| 210 |
+
"HF_TOKEN": "test-token",
|
| 211 |
+
"WANDB_API_KEY": "test-key",
|
| 212 |
+
"RUNPOD_POD_ID": "pod-abc123",
|
| 213 |
+
"DATASET_PATH": str(dataset),
|
| 214 |
+
"MOCK_TRAINING_EXIT_CODE": "1", # Training fails
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
result = subprocess.run(
|
| 218 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 219 |
+
env=env,
|
| 220 |
+
capture_output=True,
|
| 221 |
+
text=True,
|
| 222 |
+
timeout=30,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 226 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 227 |
+
|
| 228 |
+
# Pod should NOT be stopped (for debugging)
|
| 229 |
+
assert "runpodctl" not in calls
|
| 230 |
+
assert result.returncode == 1
|
| 231 |
+
assert "NOT be automatically terminated" in result.stdout
|
| 232 |
+
|
| 233 |
+
def test_exit_code_propagates_from_training(
|
| 234 |
+
self,
|
| 235 |
+
tmp_path: Path,
|
| 236 |
+
start_sh_env: dict[str, str],
|
| 237 |
+
project_root: Path,
|
| 238 |
+
) -> None:
|
| 239 |
+
"""Script exits with the same code as the training process."""
|
| 240 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 241 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 242 |
+
|
| 243 |
+
env = {
|
| 244 |
+
**start_sh_env,
|
| 245 |
+
"HF_TOKEN": "test-token",
|
| 246 |
+
"WANDB_API_KEY": "test-key",
|
| 247 |
+
"DATASET_PATH": str(dataset),
|
| 248 |
+
"MOCK_TRAINING_EXIT_CODE": "42", # Custom exit code
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
result = subprocess.run(
|
| 252 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 253 |
+
env=env,
|
| 254 |
+
capture_output=True,
|
| 255 |
+
text=True,
|
| 256 |
+
timeout=30,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
assert result.returncode == 42
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# =============================================================================
|
| 263 |
+
# Data Handling Tests
|
| 264 |
+
# =============================================================================
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@pytest.mark.integration
|
| 268 |
+
class TestDataHandling:
|
| 269 |
+
"""Test dataset validation and download logic."""
|
| 270 |
+
|
| 271 |
+
def test_exits_when_dataset_missing_no_url(
|
| 272 |
+
self,
|
| 273 |
+
tmp_path: Path,
|
| 274 |
+
start_sh_env: dict[str, str],
|
| 275 |
+
project_root: Path,
|
| 276 |
+
) -> None:
|
| 277 |
+
"""Script exits when dataset is missing and no DATASET_URL provided."""
|
| 278 |
+
env = {
|
| 279 |
+
**start_sh_env,
|
| 280 |
+
"HF_TOKEN": "test-token",
|
| 281 |
+
"WANDB_API_KEY": "test-key",
|
| 282 |
+
"DATASET_PATH": str(tmp_path / "nonexistent.jsonl"),
|
| 283 |
+
# DATASET_URL not set
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
result = subprocess.run(
|
| 287 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 288 |
+
env=env,
|
| 289 |
+
capture_output=True,
|
| 290 |
+
text=True,
|
| 291 |
+
timeout=30,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
assert result.returncode == 1
|
| 295 |
+
assert "Dataset not found" in result.stdout
|
| 296 |
+
|
| 297 |
+
def test_downloads_dataset_from_url(
|
| 298 |
+
self,
|
| 299 |
+
tmp_path: Path,
|
| 300 |
+
start_sh_env: dict[str, str],
|
| 301 |
+
project_root: Path,
|
| 302 |
+
) -> None:
|
| 303 |
+
"""Script downloads dataset when file missing but DATASET_URL set."""
|
| 304 |
+
dataset_path = tmp_path / "dataset.jsonl"
|
| 305 |
+
|
| 306 |
+
env = {
|
| 307 |
+
**start_sh_env,
|
| 308 |
+
"HF_TOKEN": "test-token",
|
| 309 |
+
"WANDB_API_KEY": "test-key",
|
| 310 |
+
"DATASET_PATH": str(dataset_path),
|
| 311 |
+
"DATASET_URL": "https://example.com/dataset.jsonl",
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
subprocess.run(
|
| 315 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 316 |
+
env=env,
|
| 317 |
+
capture_output=True,
|
| 318 |
+
text=True,
|
| 319 |
+
timeout=30,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 323 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 324 |
+
|
| 325 |
+
# wget should have been called
|
| 326 |
+
assert "wget" in calls
|
| 327 |
+
assert "https://example.com/dataset.jsonl" in calls
|
| 328 |
+
|
| 329 |
+
def test_skips_download_when_dataset_exists(
|
| 330 |
+
self,
|
| 331 |
+
tmp_path: Path,
|
| 332 |
+
start_sh_env: dict[str, str],
|
| 333 |
+
project_root: Path,
|
| 334 |
+
) -> None:
|
| 335 |
+
"""Script does not download when dataset file already exists."""
|
| 336 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 337 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 338 |
+
|
| 339 |
+
env = {
|
| 340 |
+
**start_sh_env,
|
| 341 |
+
"HF_TOKEN": "test-token",
|
| 342 |
+
"WANDB_API_KEY": "test-key",
|
| 343 |
+
"DATASET_PATH": str(dataset),
|
| 344 |
+
"DATASET_URL": "https://example.com/should-not-download.jsonl",
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
result = subprocess.run(
|
| 348 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 349 |
+
env=env,
|
| 350 |
+
capture_output=True,
|
| 351 |
+
text=True,
|
| 352 |
+
timeout=30,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 356 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 357 |
+
|
| 358 |
+
# wget should NOT have been called
|
| 359 |
+
assert "wget" not in calls
|
| 360 |
+
assert result.returncode == 0
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# =============================================================================
|
| 364 |
+
# Authentication Tests
|
| 365 |
+
# =============================================================================
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@pytest.mark.integration
|
| 369 |
+
class TestAuthentication:
|
| 370 |
+
"""Test authentication with external services."""
|
| 371 |
+
|
| 372 |
+
def test_calls_huggingface_cli_login(
|
| 373 |
+
self,
|
| 374 |
+
tmp_path: Path,
|
| 375 |
+
start_sh_env: dict[str, str],
|
| 376 |
+
project_root: Path,
|
| 377 |
+
) -> None:
|
| 378 |
+
"""Script calls huggingface-cli login with the token."""
|
| 379 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 380 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 381 |
+
|
| 382 |
+
env = {
|
| 383 |
+
**start_sh_env,
|
| 384 |
+
"HF_TOKEN": "hf_test_token_12345",
|
| 385 |
+
"WANDB_API_KEY": "test-key",
|
| 386 |
+
"DATASET_PATH": str(dataset),
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
subprocess.run(
|
| 390 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 391 |
+
env=env,
|
| 392 |
+
capture_output=True,
|
| 393 |
+
text=True,
|
| 394 |
+
timeout=30,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 398 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 399 |
+
|
| 400 |
+
assert "huggingface-cli login" in calls
|
| 401 |
+
assert "hf_test_token_12345" in calls
|
| 402 |
+
|
| 403 |
+
def test_calls_wandb_login(
|
| 404 |
+
self,
|
| 405 |
+
tmp_path: Path,
|
| 406 |
+
start_sh_env: dict[str, str],
|
| 407 |
+
project_root: Path,
|
| 408 |
+
) -> None:
|
| 409 |
+
"""Script calls wandb login with the API key."""
|
| 410 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 411 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 412 |
+
|
| 413 |
+
env = {
|
| 414 |
+
**start_sh_env,
|
| 415 |
+
"HF_TOKEN": "test-token",
|
| 416 |
+
"WANDB_API_KEY": "wandb_key_67890",
|
| 417 |
+
"DATASET_PATH": str(dataset),
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
subprocess.run(
|
| 421 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 422 |
+
env=env,
|
| 423 |
+
capture_output=True,
|
| 424 |
+
text=True,
|
| 425 |
+
timeout=30,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
|
| 429 |
+
calls = log_file.read_text() if log_file.exists() else ""
|
| 430 |
+
|
| 431 |
+
assert "wandb login" in calls
|
| 432 |
+
assert "wandb_key_67890" in calls
|
| 433 |
+
|
| 434 |
+
def test_auth_failures_dont_stop_script(
|
| 435 |
+
self,
|
| 436 |
+
tmp_path: Path,
|
| 437 |
+
start_sh_env: dict[str, str],
|
| 438 |
+
project_root: Path,
|
| 439 |
+
) -> None:
|
| 440 |
+
"""Script continues even if authentication commands fail (|| true)."""
|
| 441 |
+
dataset = tmp_path / "dataset.jsonl"
|
| 442 |
+
dataset.write_text('{"prompt":"test"}\n')
|
| 443 |
+
|
| 444 |
+
env = {
|
| 445 |
+
**start_sh_env,
|
| 446 |
+
"HF_TOKEN": "test-token",
|
| 447 |
+
"WANDB_API_KEY": "test-key",
|
| 448 |
+
"DATASET_PATH": str(dataset),
|
| 449 |
+
"MOCK_HF_CLI_EXIT_CODE": "1", # HF login fails
|
| 450 |
+
"MOCK_WANDB_EXIT_CODE": "1", # wandb login fails
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
result = subprocess.run(
|
| 454 |
+
["bash", str(project_root / "docker" / "start.sh")],
|
| 455 |
+
env=env,
|
| 456 |
+
capture_output=True,
|
| 457 |
+
text=True,
|
| 458 |
+
timeout=30,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Script should still succeed (auth failures are non-fatal)
|
| 462 |
+
assert result.returncode == 0
|
tests/unit/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Unit tests
|
tests/unit/test_grpo_rewards.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tests/unit/test_train_headless.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for the headless training module.
|
| 3 |
+
|
| 4 |
+
Tests cover:
|
| 5 |
+
- Environment variable parsing (get_env, get_env_int, get_env_float)
|
| 6 |
+
- Checkpoint discovery (find_latest_checkpoint)
|
| 7 |
+
- Model upload to HuggingFace Hub (upload_to_hub)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from unittest.mock import MagicMock, patch
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# Environment Variable Parsing Tests
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TestGetEnv:
|
| 23 |
+
"""Test environment variable retrieval functions."""
|
| 24 |
+
|
| 25 |
+
def test_get_env_returns_value(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
| 26 |
+
"""get_env returns the environment variable value when set."""
|
| 27 |
+
monkeypatch.setenv("TEST_VAR", "test_value")
|
| 28 |
+
|
| 29 |
+
# Import after setting env to avoid module-level checks
|
| 30 |
+
from prolewiki_llm.train_headless import get_env
|
| 31 |
+
|
| 32 |
+
assert get_env("TEST_VAR") == "test_value"
|
| 33 |
+
|
| 34 |
+
def test_get_env_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
| 35 |
+
"""get_env returns default when variable not set."""
|
| 36 |
+
monkeypatch.delenv("NONEXISTENT_VAR", raising=False)
|
| 37 |
+
|
| 38 |
+
from prolewiki_llm.train_headless import get_env
|
| 39 |
+
|
| 40 |
+
assert get_env("NONEXISTENT_VAR", "default_value") == "default_value"
|
| 41 |
+
|
| 42 |
+
def test_get_env_required_exits(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
| 43 |
+
"""get_env exits when required variable is missing."""
|
| 44 |
+
monkeypatch.delenv("REQUIRED_VAR", raising=False)
|
| 45 |
+
|
| 46 |
+
from prolewiki_llm.train_headless import get_env
|
| 47 |
+
|
| 48 |
+
with pytest.raises(SystemExit) as exc_info:
|
| 49 |
+
get_env("REQUIRED_VAR", required=True)
|
| 50 |
+
|
| 51 |
+
assert exc_info.value.code == 1
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TestGetEnvInt:
|
| 55 |
+
"""Test integer environment variable parsing."""
|
| 56 |
+
|
| 57 |
+
def test_get_env_int_parses_integer(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
| 58 |
+
"""get_env_int correctly parses integer values."""
|
| 59 |
+
monkeypatch.setenv("INT_VAR", "42")
|
| 60 |
+
|
| 61 |
+
from prolewiki_llm.train_headless import get_env_int
|
| 62 |
+
|
| 63 |
+
assert get_env_int("INT_VAR", 0) == 42
|
| 64 |
+
|
| 65 |
+
def test_get_env_int_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
| 66 |
+
"""get_env_int returns default when variable not set."""
|
| 67 |
+
monkeypatch.delenv("NONEXISTENT_INT", raising=False)
|
| 68 |
+
|
| 69 |
+
from prolewiki_llm.train_headless import get_env_int
|
| 70 |
+
|
| 71 |
+
assert get_env_int("NONEXISTENT_INT", 100) == 100
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestGetEnvFloat:
|
| 75 |
+
"""Test float environment variable parsing."""
|
| 76 |
+
|
| 77 |
+
def test_get_env_float_parses_float(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
| 78 |
+
"""get_env_float correctly parses float values."""
|
| 79 |
+
monkeypatch.setenv("FLOAT_VAR", "3.14")
|
| 80 |
+
|
| 81 |
+
from prolewiki_llm.train_headless import get_env_float
|
| 82 |
+
|
| 83 |
+
assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(3.14)
|
| 84 |
+
|
| 85 |
+
def test_get_env_float_parses_scientific(
|
| 86 |
+
self, monkeypatch: pytest.MonkeyPatch
|
| 87 |
+
) -> None:
|
| 88 |
+
"""get_env_float correctly parses scientific notation."""
|
| 89 |
+
monkeypatch.setenv("FLOAT_VAR", "5e-6")
|
| 90 |
+
|
| 91 |
+
from prolewiki_llm.train_headless import get_env_float
|
| 92 |
+
|
| 93 |
+
assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(5e-6)
|
| 94 |
+
|
| 95 |
+
def test_get_env_float_returns_default(
|
| 96 |
+
self, monkeypatch: pytest.MonkeyPatch
|
| 97 |
+
) -> None:
|
| 98 |
+
"""get_env_float returns default when variable not set."""
|
| 99 |
+
monkeypatch.delenv("NONEXISTENT_FLOAT", raising=False)
|
| 100 |
+
|
| 101 |
+
from prolewiki_llm.train_headless import get_env_float
|
| 102 |
+
|
| 103 |
+
assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# =============================================================================
|
| 107 |
+
# Checkpoint Discovery Tests
|
| 108 |
+
# =============================================================================
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TestFindLatestCheckpoint:
|
| 112 |
+
"""Test checkpoint discovery logic."""
|
| 113 |
+
|
| 114 |
+
def test_returns_none_for_nonexistent_dir(self, tmp_path: Path) -> None:
|
| 115 |
+
"""Returns None when checkpoint directory doesn't exist."""
|
| 116 |
+
from prolewiki_llm.train_headless import find_latest_checkpoint
|
| 117 |
+
|
| 118 |
+
nonexistent = tmp_path / "nonexistent"
|
| 119 |
+
assert find_latest_checkpoint(nonexistent) is None
|
| 120 |
+
|
| 121 |
+
def test_returns_none_for_empty_dir(self, tmp_path: Path) -> None:
|
| 122 |
+
"""Returns None when checkpoint directory is empty."""
|
| 123 |
+
from prolewiki_llm.train_headless import find_latest_checkpoint
|
| 124 |
+
|
| 125 |
+
checkpoint_dir = tmp_path / "checkpoints"
|
| 126 |
+
checkpoint_dir.mkdir()
|
| 127 |
+
|
| 128 |
+
assert find_latest_checkpoint(checkpoint_dir) is None
|
| 129 |
+
|
| 130 |
+
def test_returns_none_when_no_checkpoint_dirs(self, tmp_path: Path) -> None:
|
| 131 |
+
"""Returns None when no checkpoint-* directories exist."""
|
| 132 |
+
from prolewiki_llm.train_headless import find_latest_checkpoint
|
| 133 |
+
|
| 134 |
+
checkpoint_dir = tmp_path / "checkpoints"
|
| 135 |
+
checkpoint_dir.mkdir()
|
| 136 |
+
# Create non-checkpoint directories
|
| 137 |
+
(checkpoint_dir / "random_dir").mkdir()
|
| 138 |
+
(checkpoint_dir / "other_file.txt").write_text("test")
|
| 139 |
+
|
| 140 |
+
assert find_latest_checkpoint(checkpoint_dir) is None
|
| 141 |
+
|
| 142 |
+
def test_finds_single_checkpoint(self, tmp_path: Path) -> None:
|
| 143 |
+
"""Finds single checkpoint directory."""
|
| 144 |
+
from prolewiki_llm.train_headless import find_latest_checkpoint
|
| 145 |
+
|
| 146 |
+
checkpoint_dir = tmp_path / "checkpoints"
|
| 147 |
+
checkpoint_dir.mkdir()
|
| 148 |
+
checkpoint = checkpoint_dir / "checkpoint-100"
|
| 149 |
+
checkpoint.mkdir()
|
| 150 |
+
|
| 151 |
+
result = find_latest_checkpoint(checkpoint_dir)
|
| 152 |
+
assert result == checkpoint
|
| 153 |
+
|
| 154 |
+
def test_finds_latest_checkpoint(self, tmp_path: Path) -> None:
|
| 155 |
+
"""Finds the checkpoint with the highest step number."""
|
| 156 |
+
from prolewiki_llm.train_headless import find_latest_checkpoint
|
| 157 |
+
|
| 158 |
+
checkpoint_dir = tmp_path / "checkpoints"
|
| 159 |
+
checkpoint_dir.mkdir()
|
| 160 |
+
|
| 161 |
+
# Create checkpoints in random order
|
| 162 |
+
(checkpoint_dir / "checkpoint-50").mkdir()
|
| 163 |
+
(checkpoint_dir / "checkpoint-200").mkdir()
|
| 164 |
+
(checkpoint_dir / "checkpoint-100").mkdir()
|
| 165 |
+
(checkpoint_dir / "checkpoint-150").mkdir()
|
| 166 |
+
|
| 167 |
+
result = find_latest_checkpoint(checkpoint_dir)
|
| 168 |
+
assert result == checkpoint_dir / "checkpoint-200"
|
| 169 |
+
|
| 170 |
+
def test_ignores_non_checkpoint_dirs(self, tmp_path: Path) -> None:
|
| 171 |
+
"""Ignores directories that don't match checkpoint-* pattern."""
|
| 172 |
+
from prolewiki_llm.train_headless import find_latest_checkpoint
|
| 173 |
+
|
| 174 |
+
checkpoint_dir = tmp_path / "checkpoints"
|
| 175 |
+
checkpoint_dir.mkdir()
|
| 176 |
+
|
| 177 |
+
# Create mix of checkpoint and non-checkpoint dirs
|
| 178 |
+
(checkpoint_dir / "checkpoint-50").mkdir()
|
| 179 |
+
(checkpoint_dir / "logs").mkdir()
|
| 180 |
+
(checkpoint_dir / "checkpoint-100").mkdir()
|
| 181 |
+
(checkpoint_dir / "outputs").mkdir()
|
| 182 |
+
|
| 183 |
+
result = find_latest_checkpoint(checkpoint_dir)
|
| 184 |
+
assert result == checkpoint_dir / "checkpoint-100"
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# =============================================================================
|
| 188 |
+
# HuggingFace Hub Upload Tests
|
| 189 |
+
# =============================================================================
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class TestUploadToHub:
|
| 193 |
+
"""Test model upload to HuggingFace Hub."""
|
| 194 |
+
|
| 195 |
+
def test_creates_repo(self, tmp_path: Path) -> None:
|
| 196 |
+
"""upload_to_hub creates the repository if it doesn't exist."""
|
| 197 |
+
from prolewiki_llm.train_headless import upload_to_hub
|
| 198 |
+
|
| 199 |
+
model_path = tmp_path / "lora-output"
|
| 200 |
+
model_path.mkdir()
|
| 201 |
+
(model_path / "adapter_model.safetensors").write_bytes(b"mock model")
|
| 202 |
+
|
| 203 |
+
mock_api = MagicMock()
|
| 204 |
+
|
| 205 |
+
# HfApi is imported inside upload_to_hub, so we patch at the source
|
| 206 |
+
with patch("huggingface_hub.HfApi", return_value=mock_api):
|
| 207 |
+
upload_to_hub(model_path, "test-org/test-model", "test-token")
|
| 208 |
+
|
| 209 |
+
mock_api.create_repo.assert_called_once_with(
|
| 210 |
+
"test-org/test-model", exist_ok=True, private=True
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def test_uploads_folder(self, tmp_path: Path) -> None:
|
| 214 |
+
"""upload_to_hub uploads the model folder."""
|
| 215 |
+
from prolewiki_llm.train_headless import upload_to_hub
|
| 216 |
+
|
| 217 |
+
model_path = tmp_path / "lora-output"
|
| 218 |
+
model_path.mkdir()
|
| 219 |
+
(model_path / "adapter_model.safetensors").write_bytes(b"mock model")
|
| 220 |
+
|
| 221 |
+
mock_api = MagicMock()
|
| 222 |
+
|
| 223 |
+
with patch("huggingface_hub.HfApi", return_value=mock_api):
|
| 224 |
+
upload_to_hub(model_path, "test-org/test-model", "test-token")
|
| 225 |
+
|
| 226 |
+
mock_api.upload_folder.assert_called_once_with(
|
| 227 |
+
folder_path=str(model_path),
|
| 228 |
+
repo_id="test-org/test-model",
|
| 229 |
+
commit_message="Headless GRPO training run",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def test_handles_repo_creation_failure(self, tmp_path: Path) -> None:
|
| 233 |
+
"""upload_to_hub continues if repo already exists."""
|
| 234 |
+
from prolewiki_llm.train_headless import upload_to_hub
|
| 235 |
+
|
| 236 |
+
model_path = tmp_path / "lora-output"
|
| 237 |
+
model_path.mkdir()
|
| 238 |
+
(model_path / "adapter_model.safetensors").write_bytes(b"mock model")
|
| 239 |
+
|
| 240 |
+
mock_api = MagicMock()
|
| 241 |
+
mock_api.create_repo.side_effect = Exception("Repo already exists")
|
| 242 |
+
|
| 243 |
+
with patch("huggingface_hub.HfApi", return_value=mock_api):
|
| 244 |
+
# Should not raise
|
| 245 |
+
upload_to_hub(model_path, "test-org/test-model", "test-token")
|
| 246 |
+
|
| 247 |
+
# Should still attempt upload
|
| 248 |
+
mock_api.upload_folder.assert_called_once()
|
tests/unit/test_wandb_logging.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for Weights & Biases logging module.
|
| 3 |
+
|
| 4 |
+
Tests cover:
|
| 5 |
+
- WandbSampleLogger accumulation and table creation
|
| 6 |
+
- Reward metrics logging
|
| 7 |
+
- Logging reward function signature compatibility
|
| 8 |
+
- Graceful handling when wandb is not available
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from unittest.mock import MagicMock, patch
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# FIXTURES
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def mock_wandb() -> MagicMock:
|
| 24 |
+
"""Create a mock wandb module."""
|
| 25 |
+
mock = MagicMock()
|
| 26 |
+
mock.run = MagicMock()
|
| 27 |
+
mock.run.name = "test-run"
|
| 28 |
+
mock.run.url = "https://wandb.ai/test/run"
|
| 29 |
+
mock.Table = MagicMock(return_value=MagicMock())
|
| 30 |
+
mock.init = MagicMock(return_value=mock.run)
|
| 31 |
+
mock.log = MagicMock()
|
| 32 |
+
mock.finish = MagicMock()
|
| 33 |
+
return mock
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.fixture
|
| 37 |
+
def sample_prompts() -> list[list[dict[str, str]]]:
|
| 38 |
+
"""Create sample prompts for testing."""
|
| 39 |
+
return [
|
| 40 |
+
[
|
| 41 |
+
{"role": "system", "content": "You are a Marxist assistant."},
|
| 42 |
+
{"role": "user", "content": "What is revisionism?"},
|
| 43 |
+
],
|
| 44 |
+
[
|
| 45 |
+
{"role": "system", "content": "You are a Marxist assistant."},
|
| 46 |
+
{"role": "user", "content": "Explain surplus value."},
|
| 47 |
+
],
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@pytest.fixture
|
| 52 |
+
def sample_completions() -> list[list[dict[str, str]]]:
|
| 53 |
+
"""Create sample completions for testing."""
|
| 54 |
+
return [
|
| 55 |
+
[{"role": "assistant", "content": "</think>Revisionism distorts Marxist theory."}],
|
| 56 |
+
[{"role": "assistant", "content": "</think>Surplus value is unpaid labor."}],
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@pytest.fixture
|
| 61 |
+
def sample_answers() -> list[str]:
|
| 62 |
+
"""Create sample ground truth answers."""
|
| 63 |
+
return [
|
| 64 |
+
"Revisionism is the distortion of Marxist-Leninist theory.",
|
| 65 |
+
"Surplus value is the value produced by workers beyond their wages.",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# =============================================================================
|
| 70 |
+
# REWARD SAMPLE TESTS
|
| 71 |
+
# =============================================================================
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestRewardSample:
|
| 75 |
+
"""Test the RewardSample dataclass."""
|
| 76 |
+
|
| 77 |
+
def test_total_reward_calculation(self) -> None:
|
| 78 |
+
"""Test that total_reward sums all rewards."""
|
| 79 |
+
from prolewiki_llm.wandb_logging import RewardSample
|
| 80 |
+
|
| 81 |
+
sample = RewardSample(
|
| 82 |
+
step=10,
|
| 83 |
+
question="What is X?",
|
| 84 |
+
response="X is Y.",
|
| 85 |
+
ground_truth="X is Y.",
|
| 86 |
+
rewards={
|
| 87 |
+
"format_exact": 3.0,
|
| 88 |
+
"nli_coherence": 2.0,
|
| 89 |
+
"topic_relevance": 1.5,
|
| 90 |
+
},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
assert sample.total_reward == 6.5
|
| 94 |
+
|
| 95 |
+
def test_empty_rewards(self) -> None:
|
| 96 |
+
"""Test total_reward with empty rewards dict."""
|
| 97 |
+
from prolewiki_llm.wandb_logging import RewardSample
|
| 98 |
+
|
| 99 |
+
sample = RewardSample(
|
| 100 |
+
step=10,
|
| 101 |
+
question="Q",
|
| 102 |
+
response="R",
|
| 103 |
+
ground_truth="T",
|
| 104 |
+
rewards={},
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
assert sample.total_reward == 0.0
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# =============================================================================
|
| 111 |
+
# SAMPLE LOGGER TESTS
|
| 112 |
+
# =============================================================================
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TestWandbSampleLogger:
|
| 116 |
+
"""Test the WandbSampleLogger class."""
|
| 117 |
+
|
| 118 |
+
def test_add_sample(self) -> None:
|
| 119 |
+
"""Test adding samples to the logger."""
|
| 120 |
+
from prolewiki_llm.wandb_logging import WandbSampleLogger
|
| 121 |
+
|
| 122 |
+
logger = WandbSampleLogger(log_every_n_steps=10)
|
| 123 |
+
|
| 124 |
+
logger.add_sample(
|
| 125 |
+
step=5,
|
| 126 |
+
question="What is revisionism?",
|
| 127 |
+
response="Revisionism distorts theory.",
|
| 128 |
+
ground_truth="Revisionism is distortion of Marxism.",
|
| 129 |
+
rewards={"format": 3.0, "nli": 2.0},
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
assert len(logger._samples) == 1
|
| 133 |
+
assert logger._samples[0].step == 5
|
| 134 |
+
assert logger._samples[0].total_reward == 5.0
|
| 135 |
+
|
| 136 |
+
def test_sample_buffer_limit(self) -> None:
|
| 137 |
+
"""Test that sample buffer doesn't grow unbounded."""
|
| 138 |
+
from prolewiki_llm.wandb_logging import WandbSampleLogger
|
| 139 |
+
|
| 140 |
+
logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=2)
|
| 141 |
+
|
| 142 |
+
# Add many samples
|
| 143 |
+
for i in range(20):
|
| 144 |
+
logger.add_sample(
|
| 145 |
+
step=i,
|
| 146 |
+
question=f"Q{i}",
|
| 147 |
+
response=f"R{i}",
|
| 148 |
+
ground_truth=f"T{i}",
|
| 149 |
+
rewards={"x": float(i)},
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Buffer should be limited (max_samples_per_log * 3 = 6)
|
| 153 |
+
assert len(logger._samples) <= 6
|
| 154 |
+
|
| 155 |
+
def test_should_log(self) -> None:
|
| 156 |
+
"""Test should_log returns True at correct intervals."""
|
| 157 |
+
from prolewiki_llm.wandb_logging import WandbSampleLogger
|
| 158 |
+
|
| 159 |
+
logger = WandbSampleLogger(log_every_n_steps=10)
|
| 160 |
+
|
| 161 |
+
assert not logger.should_log(0) # Step 0 doesn't log
|
| 162 |
+
assert not logger.should_log(5)
|
| 163 |
+
assert logger.should_log(10)
|
| 164 |
+
assert not logger.should_log(15)
|
| 165 |
+
assert logger.should_log(20)
|
| 166 |
+
|
| 167 |
+
def test_clear(self) -> None:
|
| 168 |
+
"""Test clearing the sample buffer."""
|
| 169 |
+
from prolewiki_llm.wandb_logging import WandbSampleLogger
|
| 170 |
+
|
| 171 |
+
logger = WandbSampleLogger()
|
| 172 |
+
logger.add_sample(1, "Q", "R", "T", {"x": 1.0})
|
| 173 |
+
logger.add_sample(2, "Q", "R", "T", {"x": 2.0})
|
| 174 |
+
|
| 175 |
+
assert len(logger._samples) == 2
|
| 176 |
+
|
| 177 |
+
logger.clear()
|
| 178 |
+
assert len(logger._samples) == 0
|
| 179 |
+
|
| 180 |
+
def test_truncation(self) -> None:
|
| 181 |
+
"""Test that long strings are truncated."""
|
| 182 |
+
from prolewiki_llm.wandb_logging import WandbSampleLogger
|
| 183 |
+
|
| 184 |
+
logger = WandbSampleLogger()
|
| 185 |
+
|
| 186 |
+
long_text = "x" * 1000 # 1000 characters
|
| 187 |
+
logger.add_sample(
|
| 188 |
+
step=1,
|
| 189 |
+
question=long_text,
|
| 190 |
+
response=long_text,
|
| 191 |
+
ground_truth=long_text,
|
| 192 |
+
rewards={},
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
assert len(logger._samples[0].question) == 500
|
| 196 |
+
assert len(logger._samples[0].response) == 500
|
| 197 |
+
assert len(logger._samples[0].ground_truth) == 300
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# =============================================================================
|
| 201 |
+
# LOG TABLE TESTS
|
| 202 |
+
# =============================================================================
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class TestLogTable:
|
| 206 |
+
"""Test table logging functionality."""
|
| 207 |
+
|
| 208 |
+
def test_log_table_creates_table(self, mock_wandb: MagicMock) -> None:
|
| 209 |
+
"""Test that log_table creates and logs a wandb Table."""
|
| 210 |
+
from prolewiki_llm.wandb_logging import WandbSampleLogger
|
| 211 |
+
|
| 212 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
|
| 213 |
+
logger = WandbSampleLogger(max_samples_per_log=2)
|
| 214 |
+
|
| 215 |
+
# Add samples
|
| 216 |
+
logger.add_sample(1, "Q1", "R1", "T1", {"format": 3.0})
|
| 217 |
+
logger.add_sample(2, "Q2", "R2", "T2", {"format": 2.0})
|
| 218 |
+
|
| 219 |
+
# Log table
|
| 220 |
+
logger.log_table(step=10)
|
| 221 |
+
|
| 222 |
+
# Verify Table was created
|
| 223 |
+
mock_wandb.Table.assert_called_once()
|
| 224 |
+
|
| 225 |
+
# Verify log was called
|
| 226 |
+
mock_wandb.log.assert_called_once()
|
| 227 |
+
call_args = mock_wandb.log.call_args
|
| 228 |
+
assert "samples" in call_args[0][0]
|
| 229 |
+
assert call_args[1]["step"] == 10
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# =============================================================================
|
| 233 |
+
# REWARD METRICS LOGGING TESTS
|
| 234 |
+
# =============================================================================
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class TestLogRewardMetrics:
|
| 238 |
+
"""Test reward metrics logging."""
|
| 239 |
+
|
| 240 |
+
def test_log_reward_metrics(self, mock_wandb: MagicMock) -> None:
|
| 241 |
+
"""Test logging reward metrics to wandb."""
|
| 242 |
+
from prolewiki_llm.wandb_logging import log_reward_metrics
|
| 243 |
+
|
| 244 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
|
| 245 |
+
reward_scores = {
|
| 246 |
+
"format_exact": [3.0, 3.0, 0.0],
|
| 247 |
+
"nli_coherence": [2.0, -1.0, 3.0],
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
log_reward_metrics(step=50, reward_scores=reward_scores)
|
| 251 |
+
|
| 252 |
+
# Verify log was called with correct metrics
|
| 253 |
+
mock_wandb.log.assert_called_once()
|
| 254 |
+
logged_metrics = mock_wandb.log.call_args[0][0]
|
| 255 |
+
|
| 256 |
+
# Check mean calculations
|
| 257 |
+
assert logged_metrics["rewards/format_exact"] == 2.0 # (3+3+0)/3
|
| 258 |
+
assert logged_metrics["rewards/nli_coherence"] == pytest.approx(4 / 3) # (2-1+3)/3
|
| 259 |
+
|
| 260 |
+
# Check min/max
|
| 261 |
+
assert logged_metrics["rewards/format_exact_min"] == 0.0
|
| 262 |
+
assert logged_metrics["rewards/format_exact_max"] == 3.0
|
| 263 |
+
|
| 264 |
+
def test_log_reward_metrics_computes_total(self, mock_wandb: MagicMock) -> None:
|
| 265 |
+
"""Test that total reward is computed correctly."""
|
| 266 |
+
from prolewiki_llm.wandb_logging import log_reward_metrics
|
| 267 |
+
|
| 268 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
|
| 269 |
+
reward_scores = {
|
| 270 |
+
"format": [3.0, 2.0],
|
| 271 |
+
"nli": [1.0, 2.0],
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
log_reward_metrics(step=10, reward_scores=reward_scores)
|
| 275 |
+
|
| 276 |
+
logged_metrics = mock_wandb.log.call_args[0][0]
|
| 277 |
+
|
| 278 |
+
# Total for sample 0: 3.0 + 1.0 = 4.0
|
| 279 |
+
# Total for sample 1: 2.0 + 2.0 = 4.0
|
| 280 |
+
# Mean total: 4.0
|
| 281 |
+
assert logged_metrics["rewards/total"] == 4.0
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# =============================================================================
|
| 285 |
+
# LOGGING REWARD FUNCTION TESTS
|
| 286 |
+
# =============================================================================
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class TestCreateLoggingReward:
|
| 290 |
+
"""Test the create_logging_reward function."""
|
| 291 |
+
|
| 292 |
+
def test_returns_zeros(
|
| 293 |
+
self,
|
| 294 |
+
sample_prompts: list[list[dict[str, str]]],
|
| 295 |
+
sample_completions: list[list[dict[str, str]]],
|
| 296 |
+
sample_answers: list[str],
|
| 297 |
+
) -> None:
|
| 298 |
+
"""Test that logging reward returns zeros (no training effect)."""
|
| 299 |
+
from prolewiki_llm.wandb_logging import (
|
| 300 |
+
WandbSampleLogger,
|
| 301 |
+
create_logging_reward,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Create logging reward without wandb (will fallback to print)
|
| 305 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=None):
|
| 306 |
+
sample_logger = WandbSampleLogger()
|
| 307 |
+
logging_reward = create_logging_reward(sample_logger, compute_all_rewards=False)
|
| 308 |
+
|
| 309 |
+
scores = logging_reward(
|
| 310 |
+
prompts=sample_prompts,
|
| 311 |
+
completions=sample_completions,
|
| 312 |
+
answer=sample_answers,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Should return zeros for all samples
|
| 316 |
+
assert scores == [0.0, 0.0]
|
| 317 |
+
|
| 318 |
+
def test_function_signature_compatibility(
|
| 319 |
+
self,
|
| 320 |
+
sample_prompts: list[list[dict[str, str]]],
|
| 321 |
+
sample_completions: list[list[dict[str, str]]],
|
| 322 |
+
sample_answers: list[str],
|
| 323 |
+
) -> None:
|
| 324 |
+
"""Test that logging reward has correct signature for GRPOTrainer."""
|
| 325 |
+
from prolewiki_llm.wandb_logging import create_logging_reward
|
| 326 |
+
|
| 327 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=None):
|
| 328 |
+
logging_reward = create_logging_reward(compute_all_rewards=False)
|
| 329 |
+
|
| 330 |
+
# Should accept prompts, completions, answer, and kwargs
|
| 331 |
+
result = logging_reward(
|
| 332 |
+
prompts=sample_prompts,
|
| 333 |
+
completions=sample_completions,
|
| 334 |
+
answer=sample_answers,
|
| 335 |
+
extra_kwarg="ignored",
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
assert isinstance(result, list)
|
| 339 |
+
assert len(result) == len(sample_completions)
|
| 340 |
+
|
| 341 |
+
def test_logs_samples_at_interval(
|
| 342 |
+
self,
|
| 343 |
+
mock_wandb: MagicMock,
|
| 344 |
+
sample_prompts: list[list[dict[str, str]]],
|
| 345 |
+
sample_completions: list[list[dict[str, str]]],
|
| 346 |
+
sample_answers: list[str],
|
| 347 |
+
) -> None:
|
| 348 |
+
"""Test that samples are logged at correct intervals."""
|
| 349 |
+
# Reset global step counter
|
| 350 |
+
import prolewiki_llm.wandb_logging as wl
|
| 351 |
+
from prolewiki_llm.wandb_logging import (
|
| 352 |
+
WandbSampleLogger,
|
| 353 |
+
create_logging_reward,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
wl._LOGGING_STEP = 0
|
| 357 |
+
|
| 358 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
|
| 359 |
+
sample_logger = WandbSampleLogger(log_every_n_steps=5, max_samples_per_log=2)
|
| 360 |
+
logging_reward = create_logging_reward(sample_logger, compute_all_rewards=False)
|
| 361 |
+
|
| 362 |
+
# Call multiple times
|
| 363 |
+
for _ in range(10):
|
| 364 |
+
logging_reward(
|
| 365 |
+
prompts=sample_prompts,
|
| 366 |
+
completions=sample_completions,
|
| 367 |
+
answer=sample_answers,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Table should have been logged twice (at step 5 and 10)
|
| 371 |
+
table_logs = [call for call in mock_wandb.log.call_args_list if "samples" in call[0][0]]
|
| 372 |
+
assert len(table_logs) == 2
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# =============================================================================
|
| 376 |
+
# WANDB AVAILABILITY TESTS
|
| 377 |
+
# =============================================================================
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class TestWandbAvailability:
|
| 381 |
+
"""Test handling of wandb availability."""
|
| 382 |
+
|
| 383 |
+
def test_is_wandb_available_true(self, mock_wandb: MagicMock) -> None:
|
| 384 |
+
"""Test is_wandb_available returns True when wandb is installed."""
|
| 385 |
+
from prolewiki_llm import wandb_logging as wl
|
| 386 |
+
|
| 387 |
+
# Reset cached state
|
| 388 |
+
wl._WANDB_AVAILABLE = None
|
| 389 |
+
wl._wandb_module = None
|
| 390 |
+
|
| 391 |
+
with patch.dict("sys.modules", {"wandb": mock_wandb}):
|
| 392 |
+
# Force re-import check
|
| 393 |
+
wl._WANDB_AVAILABLE = None
|
| 394 |
+
_result = wl.is_wandb_available()
|
| 395 |
+
# Note: This may still be False due to import mechanics
|
| 396 |
+
# The important thing is it doesn't crash
|
| 397 |
+
|
| 398 |
+
def test_graceful_degradation_without_wandb(
|
| 399 |
+
self,
|
| 400 |
+
sample_prompts: list[list[dict[str, str]]],
|
| 401 |
+
sample_completions: list[list[dict[str, str]]],
|
| 402 |
+
sample_answers: list[str],
|
| 403 |
+
) -> None:
|
| 404 |
+
"""Test that logging works gracefully without wandb."""
|
| 405 |
+
from prolewiki_llm.wandb_logging import (
|
| 406 |
+
WandbSampleLogger,
|
| 407 |
+
create_logging_reward,
|
| 408 |
+
log_reward_metrics,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=None):
|
| 412 |
+
# These should not raise exceptions
|
| 413 |
+
sample_logger = WandbSampleLogger()
|
| 414 |
+
logging_reward = create_logging_reward(sample_logger, compute_all_rewards=False)
|
| 415 |
+
|
| 416 |
+
# Should return valid result even without wandb
|
| 417 |
+
result = logging_reward(
|
| 418 |
+
prompts=sample_prompts,
|
| 419 |
+
completions=sample_completions,
|
| 420 |
+
answer=sample_answers,
|
| 421 |
+
)
|
| 422 |
+
assert result == [0.0, 0.0]
|
| 423 |
+
|
| 424 |
+
# Metrics logging should not crash
|
| 425 |
+
log_reward_metrics(step=1, reward_scores={"x": [1.0]})
|
| 426 |
+
|
| 427 |
+
# Table logging should not crash
|
| 428 |
+
sample_logger.log_table(step=10)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# =============================================================================
|
| 432 |
+
# INIT AND FINISH TESTS
|
| 433 |
+
# =============================================================================
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class TestInitAndFinish:
|
| 437 |
+
"""Test initialization and finishing of wandb runs."""
|
| 438 |
+
|
| 439 |
+
def test_init_wandb_logging(self, mock_wandb: MagicMock) -> None:
|
| 440 |
+
"""Test wandb initialization with config."""
|
| 441 |
+
from prolewiki_llm.wandb_logging import init_wandb_logging
|
| 442 |
+
|
| 443 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
|
| 444 |
+
_run = init_wandb_logging(
|
| 445 |
+
project="test-project",
|
| 446 |
+
config={"lr": 1e-5, "batch_size": 4},
|
| 447 |
+
name="test-run",
|
| 448 |
+
tags=["test"],
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
mock_wandb.init.assert_called_once()
|
| 452 |
+
call_kwargs = mock_wandb.init.call_args[1]
|
| 453 |
+
assert call_kwargs["project"] == "test-project"
|
| 454 |
+
assert call_kwargs["config"] == {"lr": 1e-5, "batch_size": 4}
|
| 455 |
+
|
| 456 |
+
def test_finish_wandb_logging(self, mock_wandb: MagicMock) -> None:
|
| 457 |
+
"""Test wandb finish with summary."""
|
| 458 |
+
from prolewiki_llm.wandb_logging import finish_wandb_logging
|
| 459 |
+
|
| 460 |
+
with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
|
| 461 |
+
finish_wandb_logging(summary={"final_loss": 0.5})
|
| 462 |
+
|
| 463 |
+
# Check summary was updated
|
| 464 |
+
mock_wandb.run.summary.__setitem__.assert_called_with("final_loss", 0.5)
|
| 465 |
+
|
| 466 |
+
# Check finish was called
|
| 467 |
+
mock_wandb.finish.assert_called_once()
|
training_data/entity_whitelist.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training_data/entity_whitelist_clean.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|