percyraskova commited on
Commit
81b3473
·
verified ·
1 Parent(s): f2e39e6

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/settings.local.json +19 -0
  2. .gitignore +63 -0
  3. .pre-commit-config.yaml +42 -0
  4. .ruff_cache/.gitignore +2 -0
  5. .ruff_cache/0.14.10/11987423046224602800 +0 -0
  6. .ruff_cache/0.14.10/9230827824541771364 +0 -0
  7. .ruff_cache/0.14.10/9755640140589486738 +0 -0
  8. .ruff_cache/CACHEDIR.TAG +1 -0
  9. .yamllint.yaml +14 -0
  10. AGENTS.md +39 -0
  11. ATTRIBUTION.md +30 -0
  12. CLAUDE.md +120 -0
  13. README.md +77 -0
  14. RunPod LLM Fine-Tuning Strategy.md +339 -0
  15. ai-docs/chatbot-ideology.yaml +373 -0
  16. ai-docs/finetune.yaml +297 -0
  17. ai-docs/reward-modeling.yaml +857 -0
  18. ai-docs/runpod.yaml +362 -0
  19. ai-docs/training-schema.yaml +484 -0
  20. docker/.env.example +124 -0
  21. docker/Dockerfile +132 -0
  22. docker/requirements.txt +31 -0
  23. docker/start.sh +165 -0
  24. notebooks/Marxist_GRPO_RunPod_Optimized.ipynb +1107 -0
  25. pyproject.toml +166 -0
  26. src/prolewiki_llm/__init__.py +99 -0
  27. src/prolewiki_llm/convert_to_qwen.py +37 -0
  28. src/prolewiki_llm/export_grpo_dataset.py +224 -0
  29. src/prolewiki_llm/grpo_rewards.py +1847 -0
  30. src/prolewiki_llm/train_grpo_marxist.py +341 -0
  31. src/prolewiki_llm/train_headless.py +460 -0
  32. src/prolewiki_llm/train_marxist.py +201 -0
  33. src/prolewiki_llm/transform_to_grpo.py +64 -0
  34. src/prolewiki_llm/wandb_logging.py +529 -0
  35. tests/__init__.py +1 -0
  36. tests/conftest.py +114 -0
  37. tests/fixtures/__init__.py +1 -0
  38. tests/fixtures/mock_bin/huggingface-cli +7 -0
  39. tests/fixtures/mock_bin/python +35 -0
  40. tests/fixtures/mock_bin/runpodctl +12 -0
  41. tests/fixtures/mock_bin/wandb +7 -0
  42. tests/fixtures/mock_bin/wget +17 -0
  43. tests/integration/__init__.py +1 -0
  44. tests/integration/test_start_sh.py +462 -0
  45. tests/unit/__init__.py +1 -0
  46. tests/unit/test_grpo_rewards.py +0 -0
  47. tests/unit/test_train_headless.py +248 -0
  48. tests/unit/test_wandb_logging.py +467 -0
  49. training_data/entity_whitelist.json +0 -0
  50. training_data/entity_whitelist_clean.json +0 -0
.claude/settings.local.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(wc:*)",
5
+ "Bash(du:*)",
6
+ "mcp__context7__resolve-library-id",
7
+ "mcp__context7__query-docs",
8
+ "Bash(uv sync:*)",
9
+ "Bash(uv run python:*)",
10
+ "Bash(uv run ruff check:*)",
11
+ "Bash(uv run pytest:*)",
12
+ "WebFetch(domain:www.runpod.io)"
13
+ ]
14
+ },
15
+ "disabledMcpjsonServers": [
16
+ "sequential-thinking",
17
+ "neovim"
18
+ ]
19
+ }
.gitignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .venv/
25
+ venv/
26
+ ENV/
27
+
28
+ # IDEs
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Jupyter
35
+ .ipynb_checkpoints/
36
+
37
+ # Testing
38
+ .pytest_cache/
39
+ .coverage
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+
44
+ # mypy
45
+ .mypy_cache/
46
+
47
+ # Archives
48
+ *.tar.gz
49
+ *.zip
50
+
51
+ # Model artifacts (large files)
52
+ *.safetensors
53
+ *.bin
54
+ *.gguf
55
+
56
+ # Training outputs (generated)
57
+ outputs/
58
+ checkpoints/
59
+ lora-output/
60
+
61
+ # OS
62
+ .DS_Store
63
+ Thumbs.db
.pre-commit-config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.8.0
4
+ hooks:
5
+ - id: ruff
6
+ name: ruff (lint)
7
+ args: [--fix, --exit-non-zero-on-fix]
8
+ types_or: [python, pyi]
9
+ - id: ruff-format
10
+ name: ruff (format)
11
+ types_or: [python, pyi]
12
+
13
+ - repo: local
14
+ hooks:
15
+ - id: mypy
16
+ name: mypy (typecheck)
17
+ entry: uv run mypy
18
+ language: system
19
+ types: [python]
20
+ args: [src/prolewiki_llm/]
21
+ pass_filenames: false
22
+
23
+ - repo: https://github.com/pre-commit/pre-commit-hooks
24
+ rev: v5.0.0
25
+ hooks:
26
+ - id: trailing-whitespace
27
+ - id: end-of-file-fixer
28
+ - id: check-yaml
29
+ args: [--unsafe]
30
+ - id: check-json
31
+ - id: check-toml
32
+ - id: check-added-large-files
33
+ args: [--maxkb=5000]
34
+ - id: check-merge-conflict
35
+ - id: debug-statements
36
+
37
+ - repo: https://github.com/adrienverge/yamllint.git
38
+ rev: v1.35.1
39
+ hooks:
40
+ - id: yamllint
41
+ name: yamllint (yaml lint)
42
+ args: [-c=.yamllint.yaml]
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.14.10/11987423046224602800 ADDED
Binary file (310 Bytes). View file
 
.ruff_cache/0.14.10/9230827824541771364 ADDED
Binary file (228 Bytes). View file
 
.ruff_cache/0.14.10/9755640140589486738 ADDED
Binary file (170 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
.yamllint.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: default
2
+
3
+ rules:
4
+ line-length:
5
+ max: 200
6
+ level: warning
7
+ document-start: disable
8
+ truthy:
9
+ allowed-values: ['true', 'false', 'yes', 'no']
10
+ comments:
11
+ min-spaces-from-content: 1
12
+ indentation:
13
+ spaces: 2
14
+ indent-sequences: consistent
AGENTS.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Repository Guidelines
2
+
3
+ ## Project Structure & Module Organization
4
+
5
+ - `src/prolewiki_llm/`: core library code (reward functions, dataset transforms, W&B logging).
6
+ - `tests/unit/`: fast unit tests and fixtures under `tests/conftest.py`.
7
+ - `training_data/`: curated datasets (`.jsonl`) and documentation (`MODEL_CARD.yaml`).
8
+ - `notebooks/`: training and experimentation notebooks.
9
+ - `ai-docs/`: AI-consumable documentation and schemas.
10
+
11
+ ## Build, Test, and Development Commands
12
+
13
+ - `uv sync`: install core dependencies.
14
+ - `uv sync --group dev`: install developer tools (pytest, ruff, mypy).
15
+ - `uv sync --group training`: install GPU training dependencies.
16
+ - `python -m spacy download en_core_web_sm`: required model for coherence rewards.
17
+ - `uv run pytest`: run the test suite.
18
+ - `uv run pytest -m "not slow and not gpu"`: run fast tests only.
19
+ - `uv run ruff check .` / `uv run ruff format .`: lint and format.
20
+ - `uv run mypy src`: run strict type checks on library code.
21
+
22
+ ## Coding Style & Naming Conventions
23
+
24
+ - Python 3.12; line length 100; double quotes; spaces for indentation (ruff format).
25
+ - Use `snake_case` for modules/functions and `PascalCase` for classes.
26
+ - Keep reward logic in `src/prolewiki_llm/grpo_rewards.py` and avoid side effects.
27
+ - Name data files descriptively, for example `training_data/synthetic_topic.jsonl`.
28
+
29
+ ## Testing Guidelines
30
+
31
+ - Pytest with markers: `unit`, `slow`, and `gpu` (see `pyproject.toml`).
32
+ - Place tests in `tests/unit/` and name files `test_*.py`.
33
+ - Mark long-running or GPU-dependent tests explicitly.
34
+
35
+ ## Commit & Pull Request Guidelines
36
+
37
+ - Git history only shows "initial commit"; no established convention yet.
38
+ - Use short, imperative commit subjects (for example, "Add format reward tests").
39
+ - PRs should describe changes, list tests run, and call out dataset or model updates.
ATTRIBUTION.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attribution
2
+
3
+ This project was extracted from [pw-mcp](https://github.com/prolewiki/pw-mcp) on December 25, 2025.
4
+
5
+ ## Origin
6
+
7
+ The code in this repository was originally developed as part of the pw-mcp project (ProleWiki MCP Server), which provides semantic vector search over the ProleWiki corpus.
8
+
9
+ The AI training components were separated into this standalone repository to:
10
+ - Reduce the dependency footprint of pw-mcp
11
+ - Allow independent versioning and release cycles
12
+ - Separate the ~1GB ChromaDB releases from the ~4-16GB model releases
13
+ - Enable focused development on training infrastructure
14
+
15
+ ## Original Authors
16
+
17
+ The original code was developed by contributors to the pw-mcp project.
18
+
19
+ ## Files Extracted
20
+
21
+ The following files were copied from pw-mcp:
22
+
23
+ - `src/pw_mcp/ai_training/*.py` → `src/prolewiki_llm/`
24
+ - `training_data/` (entire directory)
25
+ - `tests/unit/training/` → `tests/unit/`
26
+ - `ai-docs/` (training-related YAML files)
27
+
28
+ ## License
29
+
30
+ This project maintains the same AGPL-3.0-only license as the original pw-mcp project.
CLAUDE.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ GRPO fine-tuning infrastructure for training Marxist-Leninist language models. The core functionality is a multi-layer reward system designed to prevent reward hacking during GRPO (Group Relative Policy Optimization) training.
8
+
9
+ ## Commands
10
+
11
+ ```bash
12
+ # Install dependencies
13
+ uv sync # Core dependencies
14
+ uv sync --group dev # Dev tools (pytest, ruff, mypy)
15
+ uv sync --group training # GPU training deps (unsloth, trl, vllm)
16
+
17
+ # Required NLP model
18
+ python -m spacy download en_core_web_sm
19
+
20
+ # Testing
21
+ uv run pytest # All tests
22
+ uv run pytest -m "not slow and not gpu" # Fast tests only
23
+ uv run pytest tests/unit/test_grpo_rewards.py::test_name # Single test
24
+
25
+ # Code quality
26
+ uv run ruff check . # Lint
27
+ uv run ruff format . # Format
28
+ uv run mypy src # Type check (strict mode)
29
+ ```
30
+
31
+ ## Architecture
32
+
33
+ ### Reward Function System (`src/prolewiki_llm/grpo_rewards.py`)
34
+
35
+ The reward system is layered to defeat "word soup" and other reward hacking strategies:
36
+
37
+ 1. **Format rewards** - Validate `<think>...</think>` reasoning tags
38
+ 2. **NLI coherence** - Use BART-large-MNLI to check if response entails ground truth
39
+ 3. **Self-consistency** - Detect internal contradictions via pairwise NLI
40
+ 4. **Structural coherence** - spaCy dependency parsing verifies terms in meaningful syntactic roles
41
+ 5. **Topic relevance** - Ensure response addresses what was asked
42
+ 6. **Interconnection depth** - Reward deep analysis, penalize buzzword salad
43
+
44
+ Key functions:
45
+ - `full_coherence_reward()` - Comprehensive 5-layer check (recommended for training)
46
+ - `robust_coherence_reward()` - NLI + self-consistency + structure
47
+ - `ideological_firmness_reward()` - Detects capitulation to reactionary framing
48
+
49
+ Models are lazy-loaded via `get_embedder()`, `get_nli_pipeline()`, `get_spacy_nlp()`.
50
+
51
+ ### Reward Function Interface
52
+
53
+ All reward functions follow the GRPOTrainer signature:
54
+ ```python
55
+ def reward_func(
56
+ prompts: Sequence[Sequence[dict[str, str]]],
57
+ completions: Sequence[Sequence[dict[str, str]]],
58
+ answer: Sequence[str],
59
+ **kwargs: object,
60
+ ) -> list[float]:
61
+ ```
62
+
63
+ ### W&B Integration (`src/prolewiki_llm/wandb_logging.py`)
64
+
65
+ `WandbSampleLogger` accumulates samples and logs them as tables for debugging reward behavior. Use `create_logging_reward()` to wrap logging into a reward function.
66
+
67
+ ## Testing Patterns
68
+
69
+ - Markers: `unit`, `slow`, `gpu` (see `pyproject.toml`)
70
+ - Fixtures in `tests/conftest.py`: `sample_question`, `sample_good_answer`, `sample_bad_answer`
71
+ - Mock heavy NLP models in unit tests to avoid slow initialization
72
+
73
+ ## Code Style
74
+
75
+ - Python 3.12, line length 100, double quotes
76
+ - `snake_case` functions/modules, `PascalCase` classes
77
+ - Reward functions must accept `**kwargs` (GRPOTrainer interface requirement)
78
+ - Keep reward logic pure (no side effects except lazy model loading)
79
+
80
+ ## Training
81
+
82
+ ### Headless RunPod Training (Recommended)
83
+
84
+ Docker-based headless training for RunPod deployment:
85
+
86
+ ```bash
87
+ # 1. Export dataset from ../dataset/ to GRPO format
88
+ uv run python -m prolewiki_llm.export_grpo_dataset
89
+
90
+ # 2. Build Docker image
91
+ docker build -t marxist-grpo:latest -f docker/Dockerfile .
92
+
93
+ # 3. Deploy to RunPod
94
+ runpodctl create pod \
95
+ --name "marxist-grpo-training" \
96
+ --gpuType "NVIDIA A100 80GB PCIe" \
97
+ --imageName myregistry/marxist-grpo:latest \
98
+ --env HF_TOKEN=$HF_TOKEN \
99
+ --env WANDB_API_KEY=$WANDB_API_KEY \
100
+ --env HF_REPO=my-org/my-model
101
+ ```
102
+
103
+ Key files:
104
+ - `src/prolewiki_llm/export_grpo_dataset.py` - Dataset export utility
105
+ - `src/prolewiki_llm/train_headless.py` - Headless training script
106
+ - `docker/Dockerfile` - Container definition
107
+ - `docker/start.sh` - Entrypoint (auth → train → upload → stop pod)
108
+ - `docker/.env.example` - Environment variable reference
109
+
110
+ Environment variables for `train_headless.py`:
111
+ - `HF_TOKEN`, `WANDB_API_KEY` (required)
112
+ - `MAX_STEPS`, `BATCH_SIZE`, `LEARNING_RATE`, `REWARD_MODE` (optional)
113
+ - `RUNPOD_POD_ID` (enables auto-termination after training)
114
+
115
+ ### Notebook Training (Interactive)
116
+
117
+ See `notebooks/Marxist_GRPO_Training.ipynb` for interactive training. Key settings:
118
+ - Disable `torch.compile` via environment variables (causes hangs on RunPod/Jupyter)
119
+ - Use `use_gradient_checkpointing=True` (not `"unsloth"`)
120
+ - `load_in_4bit=False` required for GRPO (16-bit LoRA adapters)
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prolewiki-llm
2
+
3
+ GRPO fine-tuning and reward functions for training Marxist-Leninist language models.
4
+
5
+ ## Overview
6
+
7
+ This repository contains the AI training infrastructure for fine-tuning language models on Marxist-Leninist theory. It includes:
8
+
9
+ - **Reward Functions**: Multi-layer reward system for GRPO training that prevents reward hacking
10
+ - **Training Data**: Curated Q&A pairs and synthetic datasets for ideological consistency
11
+ - **Training Scripts**: Ready-to-run notebooks for RunPod/cloud GPU training
12
+ - **W&B Integration**: Weights & Biases logging for training observability
13
+
14
+ ## Related Projects
15
+
16
+ - [pw-mcp](https://github.com/prolewiki/pw-mcp) - MCP server and ChromaDB pipeline for ProleWiki semantic search
17
+
18
+ ## Installation
19
+
20
+ ```bash
21
+ # Basic installation
22
+ uv sync
23
+
24
+ # Download spacy model (required for topic/coherence rewards)
25
+ python -m spacy download en_core_web_sm
26
+
27
+ # With training dependencies (for GPU training)
28
+ uv sync --group training
29
+
30
+ # Development
31
+ uv sync --group dev
32
+ ```
33
+
34
+ ## Usage
35
+
36
+ ### Reward Functions
37
+
38
+ ```python
39
+ from prolewiki_llm import full_coherence_reward, format_reward
40
+
41
+ # Combined 5-layer coherence check (recommended for training)
42
+ reward = full_coherence_reward(
43
+ prompts=["What is imperialism?"],
44
+ completions=["<think>...</think>\n\nImperialism is..."],
45
+ answer="Lenin defined imperialism as..."
46
+ )
47
+
48
+ # Individual reward components
49
+ format_score = format_reward(completions=["<think>...</think>\n\nAnswer..."])
50
+ ```
51
+
52
+ ### Training
53
+
54
+ See `notebooks/Marxist_GRPO_Training.ipynb` for a complete training example.
55
+
56
+ ## Project Structure
57
+
58
+ ```
59
+ prolewiki-llm/
60
+ ├── src/prolewiki_llm/
61
+ │ ├── grpo_rewards.py # 17+ reward functions
62
+ │ ├── wandb_logging.py # W&B integration
63
+ │ └── transform_to_grpo.py # Dataset conversion
64
+ ├── training_data/
65
+ │ ├── synthetic_*.jsonl # Training datasets
66
+ │ ├── entity_whitelist.json # Anti-hallucination data
67
+ │ └── MODEL_CARD.yaml # Dataset documentation
68
+ ├── notebooks/
69
+ │ └── Marxist_GRPO_Training.ipynb
70
+ ├── tests/
71
+ │ └── unit/
72
+ └── ai-docs/ # AI-consumable documentation
73
+ ```
74
+
75
+ ## License
76
+
77
+ AGPL-3.0-only
RunPod LLM Fine-Tuning Strategy.md ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Operationalizing Headless Large Language Model Fine-Tuning on RunPod: A Comprehensive Infrastructure and Workflow Analysis**
2
+
3
+ ## **1\. Introduction: The Paradigm Shift to Headless AI Operations**
4
+
5
+ The contemporary landscape of Large Language Model (LLM) development is currently undergoing a fundamental transition, shifting from interactive, exploratory environments toward rigorous, automated production pipelines. For years, the Jupyter notebook has served as the *lingua franca* of data science—a canvas for experimentation, visualization, and iterative code development. However, as the field matures from research to engineering, the limitations of the notebook paradigm become increasingly acute, particularly when applied to the resource-intensive and time-critical task of fine-tuning custom LLMs. The user requirement for a workflow that eliminates the notebook interface in favor of a "code-upload-and-train" paradigm reflects a sophisticated understanding of MLOps principles: reproducibility, resource efficiency, and maximizing computational throughput.
6
+ RunPod, as a specialized GPU cloud provider, occupies a unique and critical niche within this evolving ecosystem. Distinct from hyperscalers such as AWS, Azure, or Google Cloud Platform—which often necessitate complex Identity and Access Management (IAM) configurations, quota negotiations, and long-term commitments—RunPod offers a container-centric infrastructure that is ideally optimized for headless, ephemeral training jobs. The platform’s architecture effectively democratizes access to high-performance compute, offering everything from consumer-grade RTX 4090s to enterprise-class NVIDIA H100 clusters. This report provides an exhaustive, expert-level analysis of the architectural, operational, and software strategies necessary to fine-tune custom LLMs on RunPod using a strictly headless approach.
7
+ To fully satisfy the requirement of "training as fast as possible" with "custom training data," this analysis moves beyond simple tutorials to construct a robust engineering framework. It dissects the interplay between hardware selection (Secure vs. Community Cloud), containerization strategies (Docker-based execution), and high-efficiency fine-tuning frameworks (Unsloth and Axolotl). By decoupling the training process from an interactive Integrated Development Environment (IDE), developers can leverage spot instances more effectively, dramatically reduce idle compute costs, and integrate training runs into broader Continuous Integration/Continuous Deployment (CI/CD) pipelines. This report serves as a definitive guide to architecting these headless systems.
8
+
9
+ ## ---
10
+
11
+ **2\. Infrastructure Architecture and Instance Selection Strategy**
12
+
13
+ The foundation of any high-performance fine-tuning workflow is the underlying compute architecture. In the context of "training as fast as possible," the choice of hardware dictates not only the wall-clock time of the training run but also the stability, cost-efficiency, and maximum capable model size of the session. RunPod’s inventory is segmented into distinct tiers, each offering specific advantages and liabilities for headless operations. A nuanced understanding of these hardware profiles is essential for optimizing the price-performance ratio.
14
+
15
+ ### **2.1 The GPU Hierarchy: Performance Profiles and Architectural Suitability**
16
+
17
+ The selection of a specific GPU architecture must be directly correlated with the parameter count of the target model (e.g., Llama 3 8B, Mistral, or Llama 3 70B) and the chosen quantization method (Full Fine-Tuning vs. LoRA/QLoRA).
18
+
19
+ #### **The Enterprise Tier: NVIDIA H100 and A100**
20
+
21
+ For users prioritizing raw speed and throughput above all else, the NVIDIA H100 and A100 Tensor Core GPUs represent the gold standard of current AI acceleration. These cards are designed for datacenter reliability and massive parallel throughput.
22
+ The **NVIDIA H100 (80GB)** stands as the pinnacle of current commercial AI hardware. It is specifically engineered to accelerate Transformer-based models via its fourth-generation Tensor Cores and the dedicated Transformer Engine, which automatically manages mixed-precision calculations using FP8 formats.1 For headless workflows, the H100 offers a distinct advantage: its sheer speed minimizes the "window of vulnerability." In a headless setup, particularly one utilizing spot instances or decentralized nodes, the longer a job runs, the higher the statistical probability of a network disconnect or node preemption. By reducing training time by factors of 3x or more compared to previous generations, the H100 significantly increases the reliability of job completion.2 It is the only viable option for users attempting to full fine-tune models in the 70B+ parameter range within reasonable timeframes. However, this performance comes at a premium, with costs ranging from approximately $2.69 to $4.00 per hour depending on the specific configuration (SXM vs. PCIe) and market demand.1
23
+ The **NVIDIA A100 (80GB)** remains the industry workhorse for LLM training. While it lacks the H100's specific FP8 Transformer Engine, its 80GB of High Bandwidth Memory (HBM2e) provides sufficient capacity to fine-tune 70B models using QLoRA or 8B models with full precision and extended context windows.1 The availability of A100s on RunPod is generally higher than that of H100s, making them a more reliable fallback for automated pipelines that require immediate provisioning without queuing. For users engaging in "headless" operations where the script automatically requests resources, the A100's ubiquity often makes it the path of least resistance.4
24
+
25
+ #### **The Prosumer Tier: NVIDIA RTX 4090 and RTX 6000 Ada**
26
+
27
+ For users targeting smaller models, such as the 7B or 8B parameter classes (e.g., Llama 3 8B, Mistral, Gemma), the enterprise tier may represent overkill. The **NVIDIA RTX 4090** has emerged as an exceptionally cost-effective alternative for these specific workloads.
28
+ With 24GB of VRAM, the RTX 4090 can comfortably handle 8B models using 4-bit quantization (QLoRA) or, when paired with memory-efficient frameworks like Unsloth, even larger batch sizes.5 The cost efficiency is dramatic: at approximately $0.34 to $0.69 per hour, a developer can run extensive hyperparameter sweeps (grid searches) for the cost of a single hour on an H100.1 However, the use of consumer hardware in a headless workflow introduces specific constraints. These cards are typically hosted in the "Community Cloud" tier, meaning they are decentralized nodes often residing in non-tier-1 datacenters or even private residences. This introduces a higher risk of interruption, necessitating that the headless script implements robust, frequent checkpointing to resume training automatically if a node goes offline.
29
+ The **RTX 6000 Ada Generation** bridges the gap, offering 48GB of VRAM—double that of the 4090—while retaining the Ada Lovelace architecture's efficiency. Priced around $0.79/hr, it allows for training mid-sized models (e.g., 30B parameters with QLoRA) or 8B models with much longer context windows than the 4090 allows.1
30
+
31
+ ### **2.2 Deployment Tiers: Secure Cloud vs. Community Cloud**
32
+
33
+ RunPod segments its GPU inventory into two primary distinct tiers: **Community Cloud** and **Secure Cloud**. This distinction is critical for designing a headless operation, as it fundamentally dictates the reliability engineering required in the training code.
34
+ **Secure Cloud** represents enterprise-grade datacenters with high reliability, redundancy, and security certifications (SOC2, etc.). For a user whose primary requirement is to "upload and train," Secure Cloud offers the assurance that the pod will not vanish mid-training due to a provider pulling the machine off the network. The pricing is slightly higher, but the reduction in operational complexity—specifically the reduced need for aggressive fault-tolerance scripting—often outweighs the raw hourly cost difference.1 For the final "production" training run, specifically when processing a massive dataset that might take 10+ hours, Secure Cloud is the recommended tier to ensure uninterrupted execution.
35
+ **Community Cloud** consists of crowdsourced GPUs provided by third parties. While significantly cheaper, these function similarly to Spot instances in traditional clouds, though with potentially higher variance in uptime and network bandwidth. They are ideal for "bursty" workloads where a user might spin up 10 simultaneous experiments to test different learning rates. However, utilizing this tier for headless training requires the training script to be resilient. It implies that the "code upload" must include logic to check for existing checkpoints on a persistent volume and resume automatically, as the probability of a node restart is non-zero.1
36
+
37
+ ### **2.3 Cost-Performance Matrix**
38
+
39
+ To assist in making the precise hardware decision, the following table synthesizes the cost, utility, and risk profile of available hardware for fine-tuning tasks on RunPod.
40
+
41
+ | GPU Model | VRAM | Cloud Tier | Est. Price/Hr | Best Use Case | Headless Reliability |
42
+ | :---- | :---- | :---- | :---- | :---- | :---- |
43
+ | **H100 SXM** | 80GB | Secure | \~$2.69 | Full FT 70B+, Time-Critical Jobs | **High** (Fastest completion minimizes risk) |
44
+ | **A100 SXM** | 80GB | Secure | \~$1.49 | QLoRA 70B, Full FT 8B | **High** (Standard enterprise reliability) |
45
+ | **A100 PCIe** | 40GB | Secure | \~$1.39 | LoRA 13B-30B | **Medium** (Memory constraints may limit batch size) |
46
+ | **RTX 6000 Ada** | 48GB | Secure | \~$0.79 | Mid-range models (30B), Long Context | **High** (Excellent VRAM/Price ratio) |
47
+ | **RTX 4090** | 24GB | Community | \~$0.34 | QLoRA 8B, Debugging, Sweeps | **Low/Medium** (Requires fault tolerance logic) |
48
+ | **RTX 3090** | 24GB | Community | \~$0.22 | Low-budget experimentation | **Low** (Slower speed increases interrupt risk) |
49
+
50
+ 1
51
+
52
+ ## ---
53
+
54
+ **3\. The Headless Workflow Architecture: Containerization and Automation**
55
+
56
+ To satisfy the user's explicit requirement of avoiding a Jupyter notebook in favor of a "code upload" model, the workflow must shift from an interactive session to a batch-processing paradigm. In this model, the local machine is used for code development and configuration, while the remote GPU serves purely as an execution engine. This requires a Docker-first approach where the environment, code, and execution logic are encapsulated within a portable container image.
57
+
58
+ ### **3.1 The Docker-First Approach**
59
+
60
+ The cornerstone of a robust headless workflow is containerization. Launching a generic Ubuntu pod and manually installing libraries via a startup script is prone to error, hard to reproduce, and slow. Instead, the user must define the entire training environment in a Docker image. This ensures that "uploading code" translates immediately to execution without manual environment setup.
61
+
62
+ #### **The "Entrypoint" Strategy**
63
+
64
+ In a standard interactive RunPod session, the container launches and idles, typically running a sleep command or a Jupyter server, waiting for a user to connect. In a headless workflow, the Docker container utilizes an ENTRYPOINT or CMD script that immediately initiates the training process upon launch. Crucially, once the training process concludes (whether successfully or with a failure), the script handles data egress and terminates the pod.7
65
+ This approach perfectly aligns with the "upload code and train" desire. The "code" is baked into the Docker image (or mounted at runtime), and the "train" command is the automatic, inevitable action of the container starting up.
66
+
67
+ #### **Constructing the Golden Image**
68
+
69
+ A "Golden Image" for fine-tuning must include the base CUDA drivers, the Python environment, and the specific fine-tuning frameworks (Axolotl or Unsloth). Below is an architectural breakdown of such a Dockerfile, optimized for RunPod.
70
+ **Scenario:** A Docker image designed for fine-tuning Llama 3 using Unsloth.
71
+
72
+ Dockerfile
73
+
74
+ \# Use RunPod's base image or NVIDIA's CUDA image to ensure driver compatibility
75
+ \# CUDA 11.8 or 12.1 is often required for modern LLM frameworks
76
+ FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel
77
+
78
+ \# Set working directory
79
+ WORKDIR /workspace
80
+
81
+ \# Install system dependencies
82
+ \# git-lfs is critical for downloading large models/datasets
83
+ RUN apt-get update && apt-get install \-y git git-lfs htop nvtop tmux
84
+
85
+ \# Install Python dependencies
86
+ \# Unsloth and Axolotl often require specific versions of xformers and trl
87
+ \# Using a requirements.txt allows for easier version pinning
88
+ COPY requirements.txt /workspace/requirements.txt
89
+ RUN pip install \--no-cache-dir \-r requirements.txt
90
+
91
+ \# Install specific frameworks (Example: Unsloth)
92
+ \# Note: Unsloth installation often requires specific CUDA paths
93
+ RUN pip install "unsloth\[colab-new\] @ git+https://github.com/unslothai/unsloth.git"
94
+ RUN pip install \--no-deps "xformers\<0.0.26" "trl\<0.9.0" peft accelerate bitsandbytes
95
+
96
+ \# Copy the automation scripts and training code
97
+ COPY train.py /workspace/train.py
98
+ COPY start.sh /workspace/start.sh
99
+
100
+ \# Make the start script executable
101
+ RUN chmod \+x /workspace/start.sh
102
+
103
+ \# Set the entrypoint to the automation script
104
+ ENTRYPOINT \["/workspace/start.sh"\]
105
+
106
+ 7
107
+ **Optimization Insight:** Embedding the dataset directly into the Docker image (via COPY dataset.jsonl) is a viable strategy only for small datasets (\<5GB). For massive datasets (\>100GB), as implied by the "custom training data" requirement, this approach creates bloated images that are slow to push and pull. For large-scale data, the start.sh script should be designed to pull the data from S3 or a RunPod Network Volume at runtime, ensuring the Docker image remains lightweight and agile.10
108
+
109
+ ### **3.2 The Automation Logic: The start.sh Script**
110
+
111
+ The start.sh script acts as the "brain" of the headless operation. It orchestrates the sequence of events inside the pod, managing authentication, data ingestion, execution, and cleanup.
112
+
113
+ Bash
114
+
115
+ \#\!/bin/bash
116
+ set \-e \# Exit immediately if a command exits with a non-zero status
117
+
118
+ \# 1\. Environment Setup (Logging in to Hugging Face and WandB)
119
+ \# These tokens are passed via environment variables at runtime
120
+ if; then
121
+ huggingface-cli login \--token $HF\_TOKEN
122
+ fi
123
+
124
+ if; then
125
+ wandb login $WANDB\_API\_KEY
126
+ fi
127
+
128
+ \# 2\. Data Ingestion
129
+ \# Download dataset from S3 or Network Volume if not present
130
+ if \[\! \-f "/workspace/dataset.jsonl" \]; then
131
+ echo "Downloading dataset from remote source..."
132
+ \# Example using a presigned URL or S3 CLI
133
+ wget \-O /workspace/dataset.jsonl "$DATASET\_URL"
134
+ fi
135
+
136
+ \# 3\. Execution
137
+ echo "Starting Training..."
138
+ \# Launch the Python training script
139
+ \# Unsloth or Axolotl command goes here
140
+ python train.py \--config config.json
141
+
142
+ \# 4\. Exfiltration/Saving
143
+ echo "Training Complete. Merging and Uploading..."
144
+ \# Assuming train.py saves to /workspace/output
145
+ \# This step ensures the trained weights are saved to HF Hub or S3
146
+ python upload\_to\_hub.py \--path /workspace/output \--repo my-user/my-finetuned-model
147
+
148
+ \# 5\. Cleanup (Critical for Cost Savings)
149
+ echo "Shutting down pod to stop billing..."
150
+ runpodctl stop pod $RUNPOD\_POD\_ID
151
+
152
+ 7
153
+ **FinOps Strategy:** By including runpodctl stop pod $RUNPOD\_POD\_ID as the final command, the user ensures they only pay for the exact duration of the training. This effectively transforms a standard GPU pod into a serverless-like job, preventing "zombie pods" from racking up bills after the training is finished.12
154
+
155
+ ### **3.3 Remote Management: The runpodctl Utility**
156
+
157
+ For managing these headless pods, runpodctl is the essential Command Line Interface (CLI) tool provided by RunPod. It allows the user to spin up pods, stream logs, and transfer files without ever navigating to the web console.13
158
+ Automation via CLI:
159
+ The user can script the deployment of the training job from their local machine. A single command can instantiate the pod using the custom image defined above:
160
+
161
+ Bash
162
+
163
+ runpodctl create pod \\
164
+ \--name "headless-llama3-finetune" \\
165
+ \--gpuType "NVIDIA A100 80GB PCIe" \\
166
+ \--imageName "myregistry/custom-llm-trainer:v1" \\
167
+ \--containerDiskSize 100 \\
168
+ \--volumeSize 200 \\
169
+ \--env HF\_TOKEN=$HF\_TOKEN \\
170
+ \--env WANDB\_API\_KEY=$WANDB\_KEY \\
171
+ \--env DATASET\_URL="https://my-s3-bucket..."
172
+
173
+ 14
174
+ This command fulfills the user's request: it uploads the configuration (via the image definition) and starts training immediately. The \--gpuType flag ensures the job lands on the specific hardware required for speed, while \--env passes the necessary secrets securely.
175
+
176
+ ## ---
177
+
178
+ **4\. Fine-Tuning Frameworks: The Engines of Efficiency**
179
+
180
+ To train "as fast as possible" without reinventing the wheel, high-level fine-tuning frameworks are vastly superior to writing raw PyTorch training loops. The two leading contenders for this workflow on RunPod are **Axolotl** and **Unsloth**. Each offers distinct advantages for headless execution.
181
+
182
+ ### **4.1 Axolotl: The Configuration-Driven Powerhouse**
183
+
184
+ Axolotl is designed for users who want to define *what* to train, not *how* to code the training loop. It abstracts the complexity of the Hugging Face Trainer into a comprehensive YAML configuration file.15
185
+
186
+ * **Headless Suitability:** Excellent. Because the entire training logic is encapsulated in a single YAML file, "uploading code" simply means injecting this config file into the container. There is no need to maintain complex Python scripts; the logic is declarative.
187
+ * **Feature Set:** Axolotl supports Full Fine-Tuning (FFT), LoRA, QLoRA, and advanced techniques like Flash Attention 2 and Sample Packing. Sample packing is particularly relevant for speed, as it concatenates multiple short examples into a single sequence, removing padding tokens and significantly increasing training throughput.17
188
+ * **Workflow Integration:**
189
+ 1. User edits config.yaml locally.
190
+ 2. User builds Docker image with this config or mounts it at runtime.
191
+ 3. Container starts and runs axolotl train config.yaml.
192
+ * **Multi-GPU Scaling:** Axolotl excels at multi-GPU training using FSDP (Fully Sharded Data Parallel) or DeepSpeed. If the user intends to scale training across an 8x A100 node to maximize speed, Axolotl is the robust choice.17
193
+
194
+ ### **4.2 Unsloth: The Efficiency Specialist**
195
+
196
+ Unsloth is a framework optimized specifically for speed and memory efficiency on single-GPU setups. It utilizes custom Triton kernels to manually backpropagate gradients, achieving 2-5x faster training speeds and up to 80% less memory usage compared to standard Hugging Face implementations.17
197
+
198
+ * **Headless Suitability:** High. Unsloth provides Docker images that can be easily adapted for headless execution.9 The speed gains directly address the user's requirement to "train as fast as possible."
199
+ * **Performance:** For single-GPU setups (e.g., one H100 or A100), Unsloth is unrivaled. The memory savings allow users to fit significantly larger batch sizes into VRAM, which directly translates to faster wall-clock training times. For example, on a Llama 3 8B model, Unsloth can enable training with context lengths that would cause OOM (Out of Memory) errors on standard implementations.19
200
+ * **Limitation:** Historically, Unsloth has been optimized for single-GPU training. While multi-GPU support is evolving, its primary strength remains in maximizing the throughput of a single card. For a user operating on a single powerful node (like an H100), Unsloth is likely the fastest option.18
201
+
202
+ ### **4.3 Comparative Analysis for the User**
203
+
204
+ | Feature | Axolotl | Unsloth | Strategic Recommendation |
205
+ | :---- | :---- | :---- | :---- |
206
+ | **Configuration** | YAML-based (Declarative) | Python/Script-based | **Axolotl** for strict config management and reproducibility. |
207
+ | **Speed (Single GPU)** | High (uses Flash Attn) | **Extreme** (2x faster than Axolotl) | **Unsloth** for raw speed on single cards (H100/A100). |
208
+ | **Multi-GPU** | Native Support (DeepSpeed/FSDP) | Limited/Paid Tier | **Axolotl** for distributed training across clusters. |
209
+ | **Ease of Headless** | Very High | High | **Both** are excellent; choice depends on scaling needs. |
210
+
211
+ **Expert Insight:** Given the user's preference for "fast as possible" and "custom code," if the model fits on a single GPU (e.g., Llama 3 8B or 70B on an H100), **Unsloth** is the superior choice for raw throughput. If the user requires multi-GPU scaling or complex dataset mixing configurations, **Axolotl** provides a more robust infrastructure.18
212
+
213
+ ## ---
214
+
215
+ **5\. Data Logistics: Solving the Custom Data Bottleneck**
216
+
217
+ A major challenge in ephemeral, headless training is data logistics. The user specified "custom training data," which implies datasets that are not pre-cached in public hubs. Handling large datasets (100GB+) efficiently is critical to avoiding idle GPU time.
218
+
219
+ ### **5.1 Storage Architectures: Network Volumes vs. NVMe vs. Object Storage**
220
+
221
+ * **Local Pod Storage (Container Disk):** This offers the fastest I/O performance. Data is stored on the NVMe SSD directly attached to the GPU instance. This is ideal for maximizing training speed, as the GPU is not starved of data. However, this storage is ephemeral; data is lost if the pod is terminated without external saving.5
222
+ * **RunPod Network Volumes:** This is persistent storage that survives pod termination and allows data to be shared across pods.
223
+ * *Throughput Bottleneck:* Network volumes can suffer from slower throughput (200-400 MB/s) compared to local NVMe, potentially bottlenecking the data loader during training of small models where the GPU processes batches faster than the disk can supply them.22
224
+ * *Region Lock:* Network volumes are region-locked. If a volume is created in US-NJ, the user is forced to rent GPUs in US-NJ. This severely limits the ability to grab available H100s in other regions, contradicting the "train as fast as possible" goal.22
225
+ * **S3 / Object Storage:** The most flexible approach. Data is stored in AWS S3 (or compatible) and streamed or downloaded at the start of the session.
226
+
227
+ ### **5.2 Recommended Data Strategy for Speed**
228
+
229
+ To maximize training speed, **Local NVMe Storage** is superior to Network Volumes, despite its ephemeral nature. The recommended workflow for headless execution is:
230
+
231
+ 1. **Storage:** Store the master dataset in a high-performance S3 bucket or RunPod's S3-compatible object storage layer.25
232
+ 2. **Ingest:** The start.sh script downloads the dataset from S3 to the pod's *local* /workspace directory (NVMe) at boot time.
233
+ 3. **Train:** The model trains off the fast local NVMe, ensuring the GPU is fully saturated.
234
+ 4. **Egress:** The start.sh uploads the checkpoints and final model back to S3 or Hugging Face.
235
+
236
+ This approach avoids the region-locking of Network Volumes and the I/O latency penalties, utilizing the immense bandwidth of datacenter GPUs for rapid setup.10
237
+
238
+ ### **5.3 Transferring Large Data: The 100GB Challenge**
239
+
240
+ For users who must use RunPod storage (e.g., due to compliance or cost), transferring 100GB+ of data from a local machine is non-trivial. The runpodctl send command creates a peer-to-peer transfer tunnel. While effective for smaller files, users have reported slow speeds and timeouts for large datasets.26
241
+ **Insight:** For datasets \>100GB, do not upload from a home internet connection directly to a GPU pod. Instead:
242
+
243
+ 1. Spin up a cheap **CPU pod** on RunPod.
244
+ 2. Use rsync or runpodctl to upload the data to this CPU pod (which sits on the high-speed datacenter backbone).
245
+ 3. From the CPU pod, transfer the data to a Network Volume or S3 bucket.
246
+ This leverages the internal network backbone rather than residential ISP uplinks, preventing the GPU pod from sitting idle while waiting for data uploads.
247
+
248
+ ## ---
249
+
250
+ **6\. Monitoring and Observability without Jupyter**
251
+
252
+ In a headless environment, "blind" training is a significant operational risk. Observability must be externalized to ensure the user knows if the model is converging or if the pod has crashed.
253
+
254
+ ### **6.1 Weights & Biases (WandB)**
255
+
256
+ WandB is the de facto standard for headless monitoring. It integrates natively with both Axolotl and Unsloth (via the Hugging Face Trainer).
257
+
258
+ * **Real-Time Metrics:** Loss curves, GPU utilization, memory usage, and learning rate schedules are streamed to the WandB dashboard in real-time. This allows the user to monitor the "pulse" of the training from a mobile device or laptop.
259
+ * **Artifacts:** Model checkpoints and config files can be logged as artifacts, providing version control for the models and ensuring reproducibility.
260
+
261
+ ### **6.2 Remote Logging**
262
+
263
+ RunPod provides a logging driver that captures stdout and stderr from the container.
264
+
265
+ * **Command:** runpodctl logs \<pod\_id\> allows the user to check the console output from their local terminal to verify the script started correctly or to catch crash errors (e.g., CUDA OOM).11
266
+ * **Best Practice:** The start.sh script should use set \-e (exit immediately on error) and trap errors. Advanced users may add a curl command to the script to send a notification (e.g., via a Discord webhook or Slack API) if the training fails or succeeds, ensuring the user is alerted immediately without needing to constantly poll the logs.
267
+
268
+ ## ---
269
+
270
+ **7\. Advanced Optimization and Troubleshooting**
271
+
272
+ ### **7.1 Handling "Cold Starts" and Image Caching**
273
+
274
+ Downloading large Docker images (often 10GB+ for ML images) takes time. RunPod caches images on the host node.
275
+
276
+ * **Strategy:** Stick to a single image tag (e.g., myuser/trainer:v1). Once a specific host has pulled this image, subsequent runs on that same host are instant.
277
+ * **Docker Optimization:** Use multi-stage builds to keep the final image size small. Remove cache files (pip cache purge) within the Dockerfile to minimize layer size.28
278
+
279
+ ### **7.2 CUDA Version Mismatches**
280
+
281
+ A common failure mode in custom images is a mismatch between the Docker container's CUDA toolkit and the host driver.
282
+
283
+ * **RunPod Environment:** RunPod hosts generally run the latest NVIDIA drivers.
284
+ * **Image Requirement:** Ensure the Docker image uses a compatible CUDA version (e.g., CUDA 11.8 or 12.1). Unsloth, for example, has specific requirements for CUDA 12.1 for maximum performance.9 Using the wrong base image will result in runtime errors regarding "Flash Attention" or "Bitsandbytes" compilation.
285
+
286
+ ### **7.3 Spot Instance Interruptions**
287
+
288
+ If using Community Cloud to save money, the pod may be preempted (shut down) if the provider needs the hardware.
289
+
290
+ * **Mitigation:** Configure the training script to save checkpoints frequently (e.g., every 100 steps) to a mounted Network Volume or upload them immediately to S3.
291
+ * **Resume Logic:** The start.sh should check for the existence of a checkpoint and automatically pass \--resume\_from\_checkpoint to the training script. This ensures that if a pod dies and a new one is spawned, it picks up exactly where the last one left off.30
292
+
293
+ ## ---
294
+
295
+ **8\. Conclusion and Strategic Roadmap**
296
+
297
+ For a user demanding the fastest possible fine-tuning workflow without the overhead of Jupyter notebooks, RunPod offers a powerful substrate, provided the workflow is architected correctly. The optimal path requires moving away from interactive "pet" instances to ephemeral "cattle" instances managed by code.
298
+ **The Recommended "Fast Track" Configuration:**
299
+
300
+ 1. **Hardware:** NVIDIA H100 (Secure Cloud) for speed and reliability, or RTX 4090 (Community Cloud) for cost-efficiency.
301
+ 2. **Framework:** **Unsloth** for single-GPU jobs (fastest throughput); **Axolotl** for multi-GPU or complex configurations.
302
+ 3. **Deployment:** Custom Docker image with an ENTRYPOINT script that automates the Download \-\> Train \-\> Upload \-\> Terminate lifecycle.
303
+ 4. **Interface:** runpodctl for deployment; WandB for monitoring; SSH for emergency debugging.
304
+ 5. **Data:** S3-backed ingestion to local NVMe storage to bypass network volume I/O bottlenecks.
305
+
306
+ By adopting this headless architecture, the user transforms the fine-tuning process from a manual, error-prone task into a scalable, automated engineering operation, fully leveraging the raw compute power of RunPod's infrastructure. This report confirms that while RunPod's interface invites interactive use, its API and CLI capabilities are fully mature for the rigorous demands of headless, high-velocity machine learning operations.
307
+
308
+ #### **Works cited**
309
+
310
+ 1. Runpod GPU pricing: A complete breakdown and platform comparison | Blog \- Northflank, accessed January 12, 2026, [https://northflank.com/blog/runpod-gpu-pricing](https://northflank.com/blog/runpod-gpu-pricing)
311
+ 2. The NVIDIA H100 GPU Review: Why This AI Powerhouse Dominates (But Costs a Fortune) \- Runpod, accessed January 12, 2026, [https://www.runpod.io/articles/guides/nvidia-h100](https://www.runpod.io/articles/guides/nvidia-h100)
312
+ 3. Runpod Secrets: Affordable A100/H100 Instances, accessed January 12, 2026, [https://www.runpod.io/articles/guides/affordable-a100-h100-gpu-cloud](https://www.runpod.io/articles/guides/affordable-a100-h100-gpu-cloud)
313
+ 4. Pricing | Runpod GPU cloud computing rates, accessed January 12, 2026, [https://www.runpod.io/pricing](https://www.runpod.io/pricing)
314
+ 5. RunPod Pricing 2025 Complete Guide (GPU Cloud Costs Breakdown) \- Flexprice, accessed January 12, 2026, [https://flexprice.io/blog/runprod-pricing-guide-with-gpu-costs](https://flexprice.io/blog/runprod-pricing-guide-with-gpu-costs)
315
+ 6. No-Code AI: How I Ran My First LLM Without Coding | Runpod Blog, accessed January 12, 2026, [https://www.runpod.io/blog/no-code-ai-run-llm](https://www.runpod.io/blog/no-code-ai-run-llm)
316
+ 7. Dockerfile \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/tutorials/introduction/containers/create-dockerfiles](https://docs.runpod.io/tutorials/introduction/containers/create-dockerfiles)
317
+ 8. Deploying AI Apps with Minimal Infrastructure and Docker \- Runpod, accessed January 12, 2026, [https://www.runpod.io/articles/guides/deploy-ai-apps-minimal-infrastructure-docker](https://www.runpod.io/articles/guides/deploy-ai-apps-minimal-infrastructure-docker)
318
+ 9. Fine-Tuning Local Models with Docker Offload and Unsloth, accessed January 12, 2026, [https://www.docker.com/blog/fine-tuning-models-with-offload-and-unsloth/](https://www.docker.com/blog/fine-tuning-models-with-offload-and-unsloth/)
319
+ 10. Optimize your workers \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/serverless/development/optimization](https://docs.runpod.io/serverless/development/optimization)
320
+ 11. Manage Pods \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/pods/manage-pods](https://docs.runpod.io/pods/manage-pods)
321
+ 12. AI on a Schedule: Using Runpod's API to Run Jobs Only When Needed, accessed January 12, 2026, [https://www.runpod.io/articles/guides/ai-on-a-schedule](https://www.runpod.io/articles/guides/ai-on-a-schedule)
322
+ 13. Overview \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/runpodctl/overview](https://docs.runpod.io/runpodctl/overview)
323
+ 14. create pod \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/runpodctl/reference/runpodctl-create-pod](https://docs.runpod.io/runpodctl/reference/runpodctl-create-pod)
324
+ 15. LLM fine-tuning | LLM Inference Handbook \- BentoML, accessed January 12, 2026, [https://bentoml.com/llm/getting-started/llm-fine-tuning](https://bentoml.com/llm/getting-started/llm-fine-tuning)
325
+ 16. How to fine-tune a model using Axolotl | Runpod Blog, accessed January 12, 2026, [https://www.runpod.io/blog/how-to-fine-tune-a-model-using-axolotl](https://www.runpod.io/blog/how-to-fine-tune-a-model-using-axolotl)
326
+ 17. Best frameworks for fine-tuning LLMs in 2025 \- Modal, accessed January 12, 2026, [https://modal.com/blog/fine-tuning-llms](https://modal.com/blog/fine-tuning-llms)
327
+ 18. Comparing LLM Fine-Tuning Frameworks: Axolotl, Unsloth, and Torchtune in 2025, accessed January 12, 2026, [https://blog.spheron.network/comparing-llm-fine-tuning-frameworks-axolotl-unsloth-and-torchtune-in-2025](https://blog.spheron.network/comparing-llm-fine-tuning-frameworks-axolotl-unsloth-and-torchtune-in-2025)
328
+ 19. Axolotl vs LLaMA-Factory vs Unsloth for AI Fine-Tuning 2026 \- Index.dev, accessed January 12, 2026, [https://www.index.dev/skill-vs-skill/ai-axolotl-vs-llama-factory-vs-unsloth](https://www.index.dev/skill-vs-skill/ai-axolotl-vs-llama-factory-vs-unsloth)
329
+ 20. \[TEMPLATE\] One-click Unsloth finetuning on RunPod : r/LocalLLaMA \- Reddit, accessed January 12, 2026, [https://www.reddit.com/r/LocalLLaMA/comments/1nyzzws/template\_oneclick\_unsloth\_finetuning\_on\_runpod/](https://www.reddit.com/r/LocalLLaMA/comments/1nyzzws/template_oneclick_unsloth_finetuning_on_runpod/)
330
+ 21. unsloth/llama-3-8b-bnb-4bit \- Hugging Face, accessed January 12, 2026, [https://huggingface.co/unsloth/llama-3-8b-bnb-4bit](https://huggingface.co/unsloth/llama-3-8b-bnb-4bit)
331
+ 22. PSA: Don't bother with Network Volumes on Runpod : r/StableDiffusion \- Reddit, accessed January 12, 2026, [https://www.reddit.com/r/StableDiffusion/comments/1nkcgvp/psa\_dont\_bother\_with\_network\_volumes\_on\_runpod/](https://www.reddit.com/r/StableDiffusion/comments/1nkcgvp/psa_dont_bother_with_network_volumes_on_runpod/)
332
+ 23. Network volumes \- Runpod Documentation, accessed January 12, 2026, [https://docs.runpod.io/storage/network-volumes](https://docs.runpod.io/storage/network-volumes)
333
+ 24. Using network volume with serverless \- Runpod \- Answer Overflow, accessed January 12, 2026, [https://www.answeroverflow.com/m/1234830020678123610](https://www.answeroverflow.com/m/1234830020678123610)
334
+ 25. Streamline Your AI Workflows with RunPod's New S3-Compatible API, accessed January 12, 2026, [https://www.runpod.io/blog/streamline-ai-workflows-s3-api](https://www.runpod.io/blog/streamline-ai-workflows-s3-api)
335
+ 26. Upload speed \- Runpod \- Answer Overflow, accessed January 12, 2026, [https://www.answeroverflow.com/m/1415080595020709938](https://www.answeroverflow.com/m/1415080595020709938)
336
+ 27. \`runpodctl send\` crawling at \<1MB speeds \- Runpod \- Answer Overflow, accessed January 12, 2026, [https://www.answeroverflow.com/m/1208971275163406376](https://www.answeroverflow.com/m/1208971275163406376)
337
+ 28. MLOps Workflow for Docker-Based AI Model Deployment \- Runpod, accessed January 12, 2026, [https://www.runpod.io/articles/guides/mlops-workflow-docker-ai-deployment](https://www.runpod.io/articles/guides/mlops-workflow-docker-ai-deployment)
338
+ 29. Installation \- Axolotl Docs, accessed January 12, 2026, [https://docs.axolotl.ai/docs/installation.html](https://docs.axolotl.ai/docs/installation.html)
339
+ 30. Does anyone use RunPod for SFT? If yes, you train via SSH or Jupyter (web-hosted), accessed January 12, 2026, [https://www.reddit.com/r/LocalLLaMA/comments/1pd6vxu/does\_anyone\_use\_runpod\_for\_sft\_if\_yes\_you\_train/](https://www.reddit.com/r/LocalLLaMA/comments/1pd6vxu/does_anyone_use_runpod_for_sft_if_yes_you_train/)
ai-docs/chatbot-ideology.yaml ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Marxist-Leninist Chatbot Training Set Design
2
+ # General-purpose ideological chatbot using ProleWiki corpus
3
+ # Status: UPDATED FOR GRPO (was SFT)
4
+ # Last Updated: 2025-12-17
5
+
6
+ # =============================================================================
7
+ # METHODOLOGY UPDATE: SFT → GRPO
8
+ # =============================================================================
9
+ # This document was originally designed for SFT. The project has pivoted to
10
+ # GRPO (Group Relative Policy Optimization) for the reasons documented in
11
+ # ai-docs/finetune.yaml. The ideological design principles and question
12
+ # generation strategies here remain valid for GRPO training.
13
+ #
14
+ # Current implementation:
15
+ # - Dataset: training_data/curated_qa.jsonl (128 curated pairs)
16
+ # - GRPO format: training_data/grpo_dataset.jsonl
17
+ # - Notebook: notebooks/Marxist_GRPO_Training.ipynb
18
+ # - Reward functions: src/prolewiki_llm/grpo_rewards.py
19
+ # =============================================================================
20
+
21
+ overview:
22
+ goal: |
23
+ A general-purpose chatbot well-versed in Marxist-Leninist theory.
24
+ Can discuss revolutionary theory, imperialism, class struggle,
25
+ dialectical materialism, and related topics through materialist analysis.
26
+
27
+ not_goal: |
28
+ NOT a game-specific assistant. NOT tied to any particular application.
29
+ The Babylon game project informed the ideological direction, but
30
+ the chatbot itself is a standalone ML theory assistant.
31
+
32
+ inspiration: |
33
+ The Babylon game (~/projects/game/babylon/) models class struggle
34
+ through MLM-TW (Marxist-Leninist-Maoist Third Worldist) theory.
35
+ This chatbot shares that theoretical grounding but serves as a
36
+ general educational/discussion tool.
37
+
38
+ training_approach: "GRPO (Group Relative Policy Optimization) with multi-layer rewards"
39
+ base_model: "unsloth/DeepSeek-R1-0528-Qwen3-8B"
40
+ current_samples: "128 curated Q&A pairs"
41
+ estimated_training_time: "~2-4 hours on A40 48GB"
42
+
43
+ data_source:
44
+ primary: "ProleWiki Library namespace chunks"
45
+ location: "sample-pipeline/chunks/Library/*.jsonl"
46
+ format: "JSONL with metadata"
47
+
48
+ chunk_schema:
49
+ chunk_id: "Unique identifier"
50
+ text: "The actual content (training answer)"
51
+ article_title: "Source work title (contains author)"
52
+ section: "Section/chapter name (question seed)"
53
+ categories: "Topic categories"
54
+ internal_links: "Referenced concepts (question seeds)"
55
+ word_count: "Chunk size"
56
+
57
+ available_works:
58
+ marx:
59
+ - "Capital, Volume I"
60
+ - "Capital, Volume II"
61
+ - "Capital, Volume III"
62
+ - "Grundrisse"
63
+ - "The German Ideology"
64
+ lenin:
65
+ - "Imperialism, the Highest Stage of Capitalism"
66
+ - "State and Revolution"
67
+ - "What Is To Be Done?"
68
+ mao:
69
+ - "On Contradiction"
70
+ - "On Practice"
71
+ other:
72
+ - "Additional works in Library namespace"
73
+
74
+ chunk_count: "~1,034 chunks currently available"
75
+
76
+ system_prompt:
77
+ description: "Defines the chatbot's persona and approach"
78
+
79
+ recommended: |
80
+ You are a Marxist-Leninist assistant. You explain revolutionary theory
81
+ through materialist analysis, drawing on the works of Marx, Engels,
82
+ Lenin, Mao, and other socialist thinkers. You ground explanations in
83
+ historical materialism and class analysis.
84
+
85
+ alternatives:
86
+ concise: |
87
+ You are a Marxist theorist. You explain concepts through
88
+ dialectical and historical materialism.
89
+
90
+ educational: |
91
+ You are a patient teacher of Marxist-Leninist theory. You explain
92
+ complex concepts clearly, always connecting theory to material
93
+ conditions and class relations.
94
+
95
+ third_worldist: |
96
+ You are a Marxist-Leninist-Maoist Third Worldist. You analyze
97
+ imperialism, unequal exchange, and the global class structure,
98
+ recognizing the labor aristocracy in imperial core nations.
99
+
100
+ question_generation:
101
+ description: |
102
+ Transform chunk metadata into natural questions that a user might ask.
103
+ The chunk text becomes the answer.
104
+
105
+ strategy: "Use available metadata to generate contextual questions"
106
+
107
+ templates:
108
+ with_section:
109
+ pattern: "What does {author} say about {section}?"
110
+ example: "What does Marx say about the commodity form?"
111
+ priority: 1
112
+
113
+ with_internal_links:
114
+ pattern: "Explain {concept} from a Marxist perspective."
115
+ example: "Explain surplus value from a Marxist perspective."
116
+ priority: 2
117
+
118
+ with_categories:
119
+ pattern: "Discuss {category} in Marxist theory."
120
+ example: "Discuss imperialism in Marxist theory."
121
+ priority: 3
122
+
123
+ fallback:
124
+ pattern: "What does {author} teach us about this?"
125
+ example: "What does Lenin teach us about this?"
126
+ priority: 4
127
+
128
+ author_extraction:
129
+ description: "Extract author name from article_title field"
130
+ pattern: "Library {Author} {Work Title}"
131
+ examples:
132
+ - input: "Library Karl Marx Capital, vol. I, Chapter 1"
133
+ output: "Marx"
134
+ - input: "Library Vladimir Lenin Imperialism"
135
+ output: "Lenin"
136
+ - input: "Library Mao Zedong On Contradiction"
137
+ output: "Mao"
138
+
139
+ code: |
140
+ def extract_author(title: str) -> str:
141
+ """Extract author from 'Library Author Name Work...' format."""
142
+ if not title.startswith("Library "):
143
+ return "the author"
144
+
145
+ # Common author mappings
146
+ author_map = {
147
+ "Karl Marx": "Marx",
148
+ "Friedrich Engels": "Engels",
149
+ "Vladimir Lenin": "Lenin",
150
+ "V.I. Lenin": "Lenin",
151
+ "Mao Zedong": "Mao",
152
+ "Mao Tse-tung": "Mao",
153
+ "Joseph Stalin": "Stalin",
154
+ "Rosa Luxemburg": "Luxemburg",
155
+ "Antonio Gramsci": "Gramsci",
156
+ "Frantz Fanon": "Fanon",
157
+ }
158
+
159
+ title_part = title[8:] # Remove "Library "
160
+ for full_name, short_name in author_map.items():
161
+ if title_part.startswith(full_name):
162
+ return short_name
163
+
164
+ # Fallback: first two words
165
+ words = title_part.split()
166
+ if len(words) >= 2:
167
+ return words[1] # Usually last name
168
+ return "the author"
169
+
170
+ training_format:
171
+ template: "Qwen-2.5 chat format"
172
+ note: "MUST use Qwen template for DeepSeek-R1-Distill-Qwen model"
173
+
174
+ structure: |
175
+ <|im_start|>system
176
+ {system_prompt}<|im_end|>
177
+ <|im_start|>user
178
+ {question}<|im_end|>
179
+ <|im_start|>assistant
180
+ {answer}<|im_end|>
181
+
182
+ example:
183
+ system: "You are a Marxist-Leninist assistant..."
184
+ user: "What does Marx say about the commodity form?"
185
+ answer: |
186
+ The wealth of those societies in which the capitalist mode of
187
+ production prevails, presents itself as an immense accumulation
188
+ of commodities...
189
+
190
+ output_field: "text"
191
+ description: |
192
+ For Unsloth SFTTrainer, provide a single 'text' field containing
193
+ the complete formatted conversation.
194
+
195
+ transformation_code:
196
+ description: "Complete transformation from chunks to training data"
197
+
198
+ implementation: |
199
+ import json
200
+ from pathlib import Path
201
+
202
+ def extract_author(title: str) -> str:
203
+ """Extract author from article title."""
204
+ if not title.startswith("Library "):
205
+ return "the author"
206
+
207
+ author_map = {
208
+ "Karl Marx": "Marx",
209
+ "Friedrich Engels": "Engels",
210
+ "Vladimir Lenin": "Lenin",
211
+ "Mao Zedong": "Mao",
212
+ "Joseph Stalin": "Stalin",
213
+ }
214
+
215
+ title_part = title[8:]
216
+ for full_name, short_name in author_map.items():
217
+ if title_part.startswith(full_name):
218
+ return short_name
219
+
220
+ words = title_part.split()
221
+ return words[1] if len(words) >= 2 else "the author"
222
+
223
+ def generate_question(chunk: dict) -> str:
224
+ """Generate a natural question from chunk metadata."""
225
+ author = extract_author(chunk.get("article_title", ""))
226
+
227
+ # Priority 1: Use section
228
+ if chunk.get("section"):
229
+ section = chunk["section"].lower()
230
+ return f"What does {author} say about {section}?"
231
+
232
+ # Priority 2: Use internal links
233
+ if chunk.get("internal_links"):
234
+ concept = chunk["internal_links"][0]
235
+ return f"Explain {concept} from a Marxist perspective."
236
+
237
+ # Priority 3: Use categories
238
+ if chunk.get("categories"):
239
+ category = chunk["categories"][0]
240
+ return f"Discuss {category} in Marxist theory."
241
+
242
+ # Fallback
243
+ return f"What does {author} teach us in this passage?"
244
+
245
+ def chunk_to_training(chunk: dict, system_prompt: str) -> dict:
246
+ """Convert a chunk to Qwen-formatted training example."""
247
+ question = generate_question(chunk)
248
+ answer = chunk["text"]
249
+
250
+ text = f"""<|im_start|>system
251
+ {system_prompt}<|im_end|>
252
+ <|im_start|>user
253
+ {question}<|im_end|>
254
+ <|im_start|>assistant
255
+ {answer}<|im_end|>"""
256
+
257
+ return {"text": text}
258
+
259
+ def process_chunks(input_dir: Path, output_path: Path):
260
+ """Process all chunk files into training dataset."""
261
+ system_prompt = (
262
+ "You are a Marxist-Leninist assistant. You explain revolutionary "
263
+ "theory through materialist analysis, drawing on the works of Marx, "
264
+ "Engels, Lenin, Mao, and other socialist thinkers."
265
+ )
266
+
267
+ training_data = []
268
+
269
+ for jsonl_file in input_dir.glob("*.jsonl"):
270
+ with open(jsonl_file) as f:
271
+ for line in f:
272
+ chunk = json.loads(line)
273
+ example = chunk_to_training(chunk, system_prompt)
274
+ training_data.append(example)
275
+
276
+ with open(output_path, "w") as f:
277
+ for example in training_data:
278
+ f.write(json.dumps(example) + "\n")
279
+
280
+ print(f"Generated {len(training_data)} training examples")
281
+
282
+ quality_considerations:
283
+ chunk_quality:
284
+ - "ProleWiki text is already high-quality Marxist writing"
285
+ - "Chunks preserve context through overlap"
286
+ - "Section boundaries respected"
287
+
288
+ question_diversity:
289
+ - "Vary question templates to avoid repetitive patterns"
290
+ - "Use all available metadata fields"
291
+ - "Consider adding manual seed questions for key concepts"
292
+
293
+ answer_length:
294
+ - "Chunks are 350-500 tokens (good length for chat responses)"
295
+ - "Not too short (lacks substance) or too long (loses focus)"
296
+
297
+ potential_improvements:
298
+ - "Add conversational variations (rephrase questions)"
299
+ - "Include follow-up question pairs"
300
+ - "Add explicit concept definitions from glossary"
301
+ - "Include historical examples and applications"
302
+
303
+ implementation_steps:
304
+ step_1:
305
+ name: "Verify chunk data"
306
+ command: "ls -la sample-pipeline/chunks/Library/"
307
+ check: "Confirm JSONL files exist with expected format"
308
+
309
+ step_2:
310
+ name: "Create transformation script"
311
+ location: "src/prolewiki_llm/prepare_training.py"
312
+ description: "Implement the transformation code above"
313
+
314
+ step_3:
315
+ name: "Generate training data"
316
+ command: "uv run python -m prolewiki_llm.prepare_training"
317
+ output: "training_data/ml_chatbot.jsonl"
318
+
319
+ step_4:
320
+ name: "Upload to RunPod"
321
+ description: "Transfer JSONL to pod's /workspace/data/"
322
+
323
+ step_5:
324
+ name: "Run fine-tuning"
325
+ description: "Execute training script (see finetune.yaml)"
326
+
327
+ step_6:
328
+ name: "Export and test"
329
+ description: "GGUF export, Ollama deployment, manual testing"
330
+
331
+ evaluation:
332
+ manual_testing:
333
+ description: "Chat with model and assess quality"
334
+ test_questions:
335
+ - "What is surplus value?"
336
+ - "Explain the labor theory of value."
337
+ - "What is imperialism according to Lenin?"
338
+ - "How does dialectical materialism differ from idealism?"
339
+ - "What is the role of the vanguard party?"
340
+ - "Why do contradictions drive historical change?"
341
+ - "What is the labor aristocracy?"
342
+
343
+ quality_criteria:
344
+ - "Responses grounded in Marxist theory"
345
+ - "Materialist analysis (not idealist)"
346
+ - "Accurate to source texts"
347
+ - "Coherent and well-structured"
348
+ - "Appropriate length for chat"
349
+
350
+ red_flags:
351
+ - "Refusing to discuss political topics (abliteration failure)"
352
+ - "Generic/vague responses not grounded in theory"
353
+ - "Mixing incompatible ideological frameworks"
354
+ - "Hallucinating quotes or concepts"
355
+
356
+ future_enhancements:
357
+ phase_2:
358
+ - "Add conversational multi-turn examples"
359
+ - "Include debate/argument handling"
360
+ - "Add current events analysis capability"
361
+
362
+ phase_3:
363
+ - "Multi-persona support (different theoretical traditions)"
364
+ - "Game integration (Babylon narrative generation)"
365
+ - "RAG integration for expanded knowledge"
366
+
367
+ related_docs:
368
+ - "ai-docs/finetune.yaml - GRPO methodology and training config"
369
+ - "ai-docs/reward-modeling.yaml - Multi-layer reward function design"
370
+ - "ai-docs/runpod.yaml - Cloud GPU setup"
371
+ - "ai-docs/project-status.yaml - Phase 8 implementation status"
372
+ - "notebooks/Marxist_GRPO_Training.ipynb - Authoritative training notebook"
373
+ - "src/prolewiki_llm/ - Python module with reward functions"
ai-docs/finetune.yaml ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning Documentation - Phase 8
2
+ # Purpose: Marxist-Leninist LLM fine-tuning on ProleWiki corpus
3
+ # Status: IN_PROGRESS (implementation complete, training execution pending)
4
+ # Method: GRPO (Group Relative Policy Optimization) - NOT SFT
5
+ # Last Updated: 2025-12-17
6
+
7
+ # =============================================================================
8
+ # METHODOLOGY PIVOT: SFT → GRPO
9
+ # =============================================================================
10
+ # IMPORTANT: This project pivoted from SFT to GRPO methodology.
11
+ # This document has been updated to reflect the current approach.
12
+ #
13
+ # WHY GRPO OVER SFT:
14
+ # - Political theory has no single "correct" answer (unlike math)
15
+ # - Open-ended prose requires semantic similarity, not exact string matching
16
+ # - Reward functions can encode domain expertise
17
+ # - Multi-layer rewards defeat adversarial "word soup" attacks
18
+ # - GRPO naturally handles the inherent subjectivity of political analysis
19
+ #
20
+ # AUTHORITATIVE SOURCES:
21
+ # - Training notebook: notebooks/Marxist_GRPO_Training.ipynb
22
+ # - Reward functions: src/prolewiki_llm/grpo_rewards.py
23
+ # - W&B logging: src/prolewiki_llm/wandb_logging.py
24
+ # - Reward design: ai-docs/reward-modeling.yaml
25
+ # =============================================================================
26
+
27
+ overview:
28
+ goal: Fine-tune DeepSeek-R1-0528-Qwen3-8B for Marxist-Leninist theory responses
29
+ method: GRPO (Group Relative Policy Optimization) via Unsloth + TRL
30
+ dataset: 1,058 curated Q&A pairs (training_data/grpo_dataset.jsonl)
31
+ output_format: LoRA adapter → GGUF for Ollama deployment
32
+ hardware_requirement: A40 48GB (RunPod) or similar high-VRAM GPU
33
+ status: Implementation complete, execution pending
34
+
35
+ # =============================================================================
36
+ # GRPO VS SFT RATIONALE
37
+ # =============================================================================
38
+
39
+ grpo_rationale:
40
+ description: |
41
+ GRPO (Group Relative Policy Optimization) is preferred over SFT for
42
+ Marxist-Leninist training because political theory responses have no
43
+ single "correct" answer. Unlike math problems (where GRPO was originally
44
+ popularized with DeepSeek-R1), political theory is inherently open-ended.
45
+
46
+ sft_limitations:
47
+ - "Exact answer matching fails for prose responses"
48
+ - "No way to express 'this answer is better than that answer'"
49
+ - "Can only train on (input, output) pairs, not preferences"
50
+ - "Easily reward-hacked by models memorizing training data"
51
+
52
+ grpo_advantages:
53
+ - "Uses reward functions to score response quality"
54
+ - "Multiple generations compared (Group Relative)"
55
+ - "Semantic similarity captures meaning, not verbatim wording"
56
+ - "Multi-layer rewards prevent adversarial gaming"
57
+ - "Can penalize 'word soup' via coherence checks"
58
+
59
+ key_insight: |
60
+ The breakthrough insight is that GRPO's reward function paradigm
61
+ perfectly maps to political theory training: we can encode what
62
+ makes a "good Marxist-Leninist response" through semantic rewards
63
+ rather than exact answer matching.
64
+
65
+ # =============================================================================
66
+ # BASE MODEL
67
+ # =============================================================================
68
+
69
+ model:
70
+ primary:
71
+ name: DeepSeek-R1-0528-Qwen3-8B
72
+ huggingface: unsloth/DeepSeek-R1-0528-Qwen3-8B
73
+ params: 8B
74
+ architecture: Qwen-2.5
75
+ context_length: 2048 (training max_completion_length)
76
+ reasoning: |
77
+ 1. DeepSeek R1 architecture has strong reasoning capabilities
78
+ 2. 8B params = fits on A40 with LoRA
79
+ 3. Full Unsloth optimization support
80
+ 4. unsloth/ namespace provides optimized 4-bit version
81
+
82
+ lora_config:
83
+ rank: 64
84
+ lora_alpha: 64
85
+ target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
86
+ use_rslora: true
87
+ use_gradient_checkpointing: unsloth
88
+
89
+ # =============================================================================
90
+ # DATASET
91
+ # =============================================================================
92
+
93
+ dataset:
94
+ source: training_data/curated_qa.jsonl
95
+ transformed: training_data/grpo_dataset.jsonl
96
+ count: 1,058 curated Q&A pairs
97
+ format: |
98
+ {"prompt": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}], "answer": "..."}
99
+
100
+ topics:
101
+ - Revisionism and opportunism
102
+ - Surplus value and exploitation
103
+ - Dialectical materialism
104
+ - Imperialism and monopoly capitalism
105
+ - Class struggle and revolution
106
+ - Socialist construction
107
+ - Historical figures (Marx, Lenin, Engels, Stalin, Mao)
108
+
109
+ # =============================================================================
110
+ # REWARD FUNCTION SYSTEM (THE KEY INNOVATION)
111
+ # =============================================================================
112
+
113
+ reward_functions:
114
+ reference: ai-docs/reward-modeling.yaml
115
+ module: src/prolewiki_llm/grpo_rewards.py
116
+ tests: tests/unit/training/test_grpo_rewards.py (43 tests)
117
+
118
+ layers:
119
+ format_rewards:
120
+ - match_format_exactly: "3.0 for exact </think> pattern"
121
+ - match_format_approximately: "1.5-2.0 for partial pattern"
122
+
123
+ semantic_rewards:
124
+ - semantic_similarity_reward: "Sentence-transformer embedding similarity"
125
+ - terminology_reward: "Bonus for Marxist lexicon usage"
126
+
127
+ coherence_rewards:
128
+ - nli_coherence_reward: "BART-large-MNLI entailment checking"
129
+ - self_consistency_reward: "No internal contradictions"
130
+ - structural_coherence_reward: "Terms in syntactic roles"
131
+
132
+ anti_hacking_rewards:
133
+ - topic_relevance_reward: "Question→answer concept coverage"
134
+ - interconnection_depth_reward: "Anti-buzzword-salad detection"
135
+ - completeness_reward: "Key concept coverage"
136
+
137
+ combined_functions:
138
+ full_coherence_reward: "5-layer combined (recommended)"
139
+ robust_coherence_reward: "3-layer (NLI + self-consistency + structural)"
140
+
141
+ anti_hacking_measures:
142
+ HOLLOW_BUZZWORDS: "Penalty for activist jargon without substance"
143
+ DEPTH_MARKERS: "Bonus for historical specificity"
144
+ EXPLANATORY_PHRASES: "Bonus for causal reasoning"
145
+ depth_ratio: "Words per concept (penalizes buzzword salad)"
146
+
147
+ # =============================================================================
148
+ # TRAINING CONFIGURATION
149
+ # =============================================================================
150
+
151
+ training_config:
152
+ trainer: GRPOTrainer (from trl library)
153
+ max_steps: 250
154
+ batch_size: 2
155
+ gradient_accumulation: 2
156
+ learning_rate: 5e-6
157
+ num_generations: 4
158
+ temperature: 1.0
159
+ gpu_memory_utilization: 0.85
160
+ hardware: A40 48GB (RunPod)
161
+
162
+ vllm_sampling:
163
+ min_p: 0.1
164
+ top_p: 1.0
165
+ top_k: -1
166
+ seed: 3407
167
+
168
+ # =============================================================================
169
+ # WEIGHTS & BIASES INTEGRATION
170
+ # =============================================================================
171
+
172
+ wandb_logging:
173
+ module: src/prolewiki_llm/wandb_logging.py
174
+ tests: tests/unit/training/test_wandb_logging.py (17 tests)
175
+
176
+ features:
177
+ - WandbSampleLogger: "Periodic sample tables (question → response → rewards)"
178
+ - create_logging_reward: "Zero-cost logging reward for GRPOTrainer"
179
+ - log_reward_metrics: "Per-reward mean/min/max tracking"
180
+ - init_wandb_logging: "Project/run initialization"
181
+ - finish_wandb_logging: "Summary statistics"
182
+ - log_model_checkpoint: "Artifact logging"
183
+
184
+ graceful_degradation: |
185
+ All functions work without wandb installed.
186
+ Falls back to print() statements when wandb unavailable.
187
+
188
+ # =============================================================================
189
+ # TRAINING NOTEBOOK (AUTHORITATIVE REFERENCE)
190
+ # =============================================================================
191
+
192
+ notebook:
193
+ location: notebooks/Marxist_GRPO_Training.ipynb
194
+ description: |
195
+ Self-contained Jupyter notebook for RunPod execution.
196
+ Contains ALL reward functions inline (no external imports).
197
+ This is the authoritative reference for current implementation.
198
+
199
+ contents:
200
+ - "Model loading (FastLanguageModel.from_pretrained)"
201
+ - "LoRA configuration (get_peft_model)"
202
+ - "Dataset loading from grpo_dataset.jsonl"
203
+ - "All 13+ reward functions inline"
204
+ - "W&B logging integration"
205
+ - "GRPOConfig with A40-optimized settings"
206
+ - "GRPOTrainer setup and training"
207
+ - "LoRA saving and GGUF export"
208
+
209
+ # =============================================================================
210
+ # EXPORT AND DEPLOYMENT
211
+ # =============================================================================
212
+
213
+ export:
214
+ lora_save:
215
+ code: |
216
+ model.save_pretrained_merged("marxist_lora", tokenizer)
217
+
218
+ gguf:
219
+ description: Quantized format for Ollama deployment
220
+ quantization_levels:
221
+ q4_k_m: "Recommended balance of size/quality (~4GB)"
222
+ q8_0: "Higher quality (~7GB)"
223
+ code: |
224
+ model.save_pretrained_gguf("marxist_gguf", tokenizer, quantization_method="q4_k_m")
225
+
226
+ ollama:
227
+ modelfile: |
228
+ FROM ./marxist-gguf-q4_k_m.gguf
229
+ TEMPLATE """<|begin_of_text|>{{ .System }}<|User|>{{ .Prompt }}<|Assistant|>"""
230
+ SYSTEM "You are a Marxist-Leninist assistant..."
231
+ PARAMETER temperature 0.7
232
+ PARAMETER num_ctx 2048
233
+
234
+ # =============================================================================
235
+ # IMPLEMENTATION STATUS
236
+ # =============================================================================
237
+
238
+ implementation_status:
239
+ completed:
240
+ - "8.1 Dataset Preparation: grpo_dataset.jsonl (1,058 Q&A pairs)"
241
+ - "8.2 Reward Function System: 13+ functions with anti-hacking"
242
+ - "8.3 W&B Logging Integration: Full observability"
243
+ - "8.4 Training Notebook: Self-contained for RunPod"
244
+
245
+ pending:
246
+ - "8.5 Training Execution: Run notebook on RunPod A40"
247
+ - "8.6 Model Evaluation: Manual review, reward hacking detection"
248
+ - "8.7 GGUF Export: Convert to Ollama format"
249
+
250
+ test_counts:
251
+ grpo_rewards: 43 passing
252
+ wandb_logging: 17 passing
253
+ total: 60 passing
254
+
255
+ # =============================================================================
256
+ # HARDWARE REQUIREMENTS
257
+ # =============================================================================
258
+
259
+ hardware:
260
+ recommended:
261
+ gpu: A40 (48GB VRAM) on RunPod
262
+ cost: "~$0.79/hour spot, ~$1.14/hour on-demand"
263
+ training_time: "~2-4 hours for 250 steps"
264
+
265
+ vram_breakdown:
266
+ model_4bit: "~4GB"
267
+ lora_params: "~1GB"
268
+ optimizer_state: "~2GB"
269
+ activations: "~6GB with gradient checkpointing"
270
+ vllm_generation: "~8GB for 4 generations"
271
+ reward_models: "~2.5GB (NLI + embeddings + spaCy)"
272
+ total: "~24GB (safe on 48GB A40)"
273
+
274
+ cloud_options:
275
+ primary: "RunPod A40 (48GB) - see ai-docs/runpod.yaml"
276
+ alternative: "Lambda Labs A100 (40GB)"
277
+
278
+ # =============================================================================
279
+ # REFERENCES
280
+ # =============================================================================
281
+
282
+ references:
283
+ internal:
284
+ - "ai-docs/reward-modeling.yaml: Reward function design"
285
+ - "ai-docs/runpod.yaml: GPU setup instructions"
286
+ - "notebooks/Marxist_GRPO_Training.ipynb: Authoritative notebook"
287
+ - "src/prolewiki_llm/: Python module"
288
+
289
+ external:
290
+ unsloth:
291
+ - "https://github.com/unslothai/unsloth"
292
+ - "https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide"
293
+ grpo:
294
+ - "https://arxiv.org/abs/2508.18212 - NLI as reward paradigm"
295
+ - "https://arxiv.org/abs/2509.22047 - MO-GRPO reward hacking"
296
+ trl:
297
+ - "https://huggingface.co/docs/trl/en/grpo_trainer"
ai-docs/reward-modeling.yaml ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reward Modeling for GRPO Fine-Tuning
2
+ # Token-efficient reference for AI assistants
3
+ # Research-backed approaches to defeat reward hacking
4
+
5
+ status: IMPLEMENTED - multi-layer coherence rewards with wandb logging
6
+ purpose: Robust reward functions for Marxist-Leninist GRPO training
7
+ module: src/prolewiki_llm/
8
+ files:
9
+ grpo_rewards: Reward functions for GRPO training
10
+ wandb_logging: Weights & Biases integration for training observability
11
+
12
+ # =============================================================================
13
+ # OVERVIEW
14
+ # =============================================================================
15
+
16
+ overview:
17
+ goal: |
18
+ Train DeepSeek-R1-0528-Qwen3-8B on ProleWiki Q&A corpus using GRPO
19
+ with reward functions that cannot be gamed by "word soup" or other
20
+ adversarial strategies.
21
+
22
+ problem: |
23
+ Naive reward functions (substring matching, keyword counting) are
24
+ vulnerable to reward hacking. A model could maximize reward by
25
+ outputting random Marxist terminology without coherent meaning.
26
+
27
+ solution: |
28
+ Multi-layer reward combining:
29
+ 1. NLI (Natural Language Inference) - checks logical consistency
30
+ 2. Self-consistency - checks for internal contradictions
31
+ 3. Structural coherence - checks syntactic structure via spaCy
32
+
33
+ research_basis:
34
+ - "arxiv.org/abs/2508.18212 - NLI as reward modeling paradigm"
35
+ - "arxiv.org/abs/2509.22047 - MO-GRPO mitigating reward hacking"
36
+ - "arxiv.org/abs/2508.05170 - Posterior-GRPO process rewards"
37
+
38
+ # =============================================================================
39
+ # IDEOLOGICAL BIAS CONSIDERATION
40
+ # =============================================================================
41
+
42
+ ideological_bias:
43
+ concern: |
44
+ NLI models trained on predominantly liberal/capitalist media might
45
+ encode bourgeois assumptions, potentially marking valid Marxist claims
46
+ as "contradiction" or "neutral".
47
+
48
+ empirical_finding: |
49
+ Testing shows BART-large-MNLI performs LOGICAL inference, not
50
+ ideological judgment:
51
+ - "Capitalism exploits workers" → "Workers are exploited": entailment (0.998)
52
+ - "Class struggle drives history" vs "Great individuals drive history": contradiction (0.998)
53
+ - Word soup → coherent claim: neutral (0.932)
54
+
55
+ why_it_works: |
56
+ We compare Marxist response against Marxist ground truth (from ProleWiki).
57
+ The model isn't judging if Marxism is "true" - it's checking if the
58
+ response logically follows from the expected answer.
59
+
60
+ mitigation_strategy: |
61
+ Self-consistency check avoids external ideology entirely - it only
62
+ checks if the response contradicts ITSELF, not external "truth".
63
+
64
+ # =============================================================================
65
+ # REWARD FUNCTIONS
66
+ # =============================================================================
67
+
68
+ reward_functions:
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # FORMAT REWARDS (from original notebook)
72
+ # ---------------------------------------------------------------------------
73
+
74
+ match_format_exactly:
75
+ purpose: Encourage proper <think>...</think> format
76
+ scoring:
77
+ has_think_end_tag: "+3.0"
78
+ no_tag: "0.0"
79
+ notes: Checks for </think> tag presence
80
+
81
+ match_format_approximately:
82
+ purpose: Reward partial format compliance
83
+ scoring:
84
+ one_think_start: "+0.5"
85
+ one_think_end: "+0.5"
86
+ multiple_or_missing: "-1.0"
87
+ notes: Penalizes malformed tag structure
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # SEMANTIC REWARDS
91
+ # ---------------------------------------------------------------------------
92
+
93
+ semantic_similarity_reward:
94
+ purpose: Reward responses semantically similar to ground truth
95
+ model: sentence-transformers/all-MiniLM-L6-v2
96
+ scoring:
97
+ similarity_gt_0.75: "+5.0"
98
+ similarity_gt_0.60: "+3.0"
99
+ similarity_gt_0.45: "+1.0"
100
+ similarity_gt_0.30: "-1.0"
101
+ similarity_le_0.30: "-3.0"
102
+ notes: |
103
+ Uses cosine similarity of embeddings.
104
+ Good for soft matching but doesn't catch contradictions.
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # NLI-BASED REWARDS (Research-backed)
108
+ # ---------------------------------------------------------------------------
109
+
110
+ nli_coherence_reward:
111
+ purpose: Check if response ENTAILS ground truth
112
+ model: facebook/bart-large-mnli
113
+ scoring:
114
+ entailment: "+3.0 (response supports/implies ground truth)"
115
+ neutral: "-1.0 (off-topic or incoherent)"
116
+ contradiction: "-3.0 (contradicts ground truth)"
117
+ defeats:
118
+ - "Word soup (random terms → neutral)"
119
+ - "Contradictory claims"
120
+ - "Off-topic responses"
121
+ gpu_memory: "~1.6GB"
122
+ research: "arxiv.org/abs/2508.18212"
123
+
124
+ self_consistency_reward:
125
+ purpose: Check for internal contradictions (no external ideology)
126
+ method: |
127
+ Parse response into sentences using spaCy.
128
+ Check adjacent sentence pairs for NLI contradiction.
129
+ scoring:
130
+ no_contradictions: "+1.0"
131
+ has_contradiction: "-2.0"
132
+ notes: |
133
+ Only checks within-document coherence, avoiding any external
134
+ ideological bias from the NLI model's training data.
135
+ research: "arxiv.org/abs/2508.05170"
136
+
137
+ structural_coherence_reward:
138
+ purpose: Verify proper linguistic structure (defeats word soup)
139
+ model: spacy en_core_web_sm
140
+ checks:
141
+ - "Marxist terms in subject/object syntactic positions"
142
+ - "Presence of discourse connectives (therefore, because, etc.)"
143
+ - "Complete sentence structure"
144
+ scoring:
145
+ term_in_syntactic_role: "+0.3 per term (max +1.5)"
146
+ discourse_connective: "+0.2 per connective (max +1.0)"
147
+ no_sentences: "-1.0"
148
+ syntactic_roles:
149
+ - "nsubj (nominal subject)"
150
+ - "nsubjpass (passive subject)"
151
+ - "dobj (direct object)"
152
+ - "pobj (object of preposition)"
153
+ - "attr (attribute)"
154
+ - "appos (appositional modifier)"
155
+ discourse_connectives:
156
+ - "because, therefore, thus, hence, consequently"
157
+ - "however, although, whereas, nevertheless, moreover"
158
+ - "furthermore, additionally, specifically, namely"
159
+ - "as a result, due to, in order to, so that"
160
+ - "on the other hand, in contrast, similarly, likewise"
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # TOPIC RELEVANCE REWARD (Question-Answer Alignment)
164
+ # ---------------------------------------------------------------------------
165
+
166
+ topic_relevance_reward:
167
+ purpose: Ensure answer addresses what the question asked about
168
+ method: |
169
+ Implements f(A) ⊆ f(Q) check where f extracts semantic topics:
170
+ 1. Extract core topics from question Q using dependency parsing
171
+ 2. Expand Q topics with Marxist concept synonyms
172
+ 3. Extract topics from answer A
173
+ 4. Compute coverage: how many Q topics are addressed in A
174
+ scoring:
175
+ gt_80_coverage: "+2.0 (answer fully addresses question topics)"
176
+ gt_60_coverage: "+1.5 (answer mostly on-topic)"
177
+ gt_40_coverage: "+1.0 (answer partially on-topic)"
178
+ gt_20_coverage: "0.0 (answer tangentially related)"
179
+ le_20_coverage: "-1.5 (answer off-topic)"
180
+ defeats:
181
+ - "Off-topic coherent text (coherent Marxist text about wrong subject)"
182
+ - "Topic drift during response"
183
+ spacy_model: "en_core_web_trf (transformer-based, best semantic understanding)"
184
+ topic_extraction:
185
+ question: "ROOT verb children with nsubj/dobj/attr/nsubjpass dependency"
186
+ answer: "Noun chunks + named entities (determiners stripped)"
187
+ synonym_expansion: "CONCEPT_EQUIVALENCES dict maps bourgeoisie ↔ capitalist class"
188
+
189
+ # ---------------------------------------------------------------------------
190
+ # COMBINED REWARDS
191
+ # ---------------------------------------------------------------------------
192
+
193
+ robust_coherence_reward:
194
+ purpose: Multi-layer coherence check
195
+ layers:
196
+ 1: "NLI coherence - Does response entail ground truth?"
197
+ 2: "Self-consistency - Does response contradict itself?"
198
+ 3: "Structural coherence - Are terms used in proper syntax?"
199
+ scoring:
200
+ max_score: "+5.5 (entailment + consistent + structured)"
201
+ contradiction_floor: "-3.0 (NLI contradiction)"
202
+ inconsistency_penalty: "-2.0 (internal contradiction)"
203
+ combination_logic: |
204
+ if nli_score <= -3.0:
205
+ return -3.0 # Contradiction dominates
206
+ elif consistency_score <= -2.0:
207
+ return -2.0 # Internal contradiction
208
+ else:
209
+ return nli_score + (consistency * 0.5) + (structure * 0.5)
210
+
211
+ full_coherence_reward:
212
+ purpose: Complete coherence check (RECOMMENDED - maximum robustness)
213
+ layers:
214
+ 1: "NLI coherence - Does response entail ground truth?"
215
+ 2: "Self-consistency - Does response contradict itself?"
216
+ 3: "Structural coherence - Are terms used in proper syntax?"
217
+ 4: "Topic relevance - Does answer address what was asked?"
218
+ 5: "Interconnection depth - Rewards deep analysis, penalizes buzzword salad"
219
+ scoring:
220
+ max_score: "+7.0 (all checks pass with deep analysis)"
221
+ off_topic_penalty: "-2.0 (severely off-topic)"
222
+ buzzword_salad_penalty: "-1.5 (shallow buzzword listing)"
223
+ inherits: "robust_coherence_reward penalties for NLI/consistency failures"
224
+ combination_logic: |
225
+ if relevance <= -1.5:
226
+ return -2.0 # Severely off-topic
227
+ elif robust <= -2.0:
228
+ return robust # Robust check failed
229
+ elif depth <= -1.5:
230
+ return -1.5 # Buzzword salad detected
231
+ else:
232
+ return robust + (relevance * 0.4) + (depth * 0.3)
233
+ use_when: "Maximum robustness against reward hacking is needed"
234
+
235
+ # ---------------------------------------------------------------------------
236
+ # INTERCONNECTION DEPTH REWARD (Anti-Buzzword-Salad)
237
+ # ---------------------------------------------------------------------------
238
+
239
+ interconnection_depth_reward:
240
+ purpose: Distinguish deep analysis from shallow buzzword salad
241
+ method: |
242
+ Rewards meaningful interconnections while penalizing superficial
243
+ concept-dropping. Distinguishes:
244
+ - GOOD: "Surplus value relates to imperialism BECAUSE capital export..."
245
+ - BAD: "Surplus value intersects with imperialism, colonialism, patriarchy..."
246
+ signals:
247
+ depth_ratio:
248
+ description: "Words per unique Marxist concept"
249
+ gt_20: "+1.0 (deep analysis - few concepts well-explained)"
250
+ range_10_20: "+0.5 (adequate depth)"
251
+ range_5_10: "-0.5 (shallow)"
252
+ lt_5: "-1.5 (severe buzzword soup)"
253
+ hollow_buzzwords:
254
+ description: "Activist jargon without substance"
255
+ threshold: "> 2 hollow phrases"
256
+ penalty: "-0.3 per additional (max -1.5)"
257
+ depth_markers:
258
+ description: "Historical specificity, citations, examples"
259
+ bonus: "+0.3 each (max +1.5)"
260
+ explanation_ratio:
261
+ description: "Explanatory phrases per concept"
262
+ gt_50_percent: "+0.5 (well-explained)"
263
+ lt_10_percent_many_concepts: "-0.5 (unexplained concept soup)"
264
+ scoring:
265
+ range: "-2.5 to +3.0"
266
+ defeats:
267
+ - "Buzzword salad (many concepts, no explanation)"
268
+ - "Activist jargon (performative language without analysis)"
269
+ - "Intersectionality word soup (mentioning everything without depth)"
270
+
271
+ # ---------------------------------------------------------------------------
272
+ # ENTITY VERIFICATION REWARDS (Anti-Hallucination)
273
+ # ---------------------------------------------------------------------------
274
+
275
+ entity_verification_reward:
276
+ status: NEW - added 2025-12-18
277
+ purpose: Penalize confident claims about unverified entities
278
+ method: |
279
+ Uses spaCy NER to extract entity mentions from responses, then
280
+ checks them against a whitelist of 24,040 entities extracted from
281
+ the ProleWiki corpus. Rewards epistemic humility for unknown entities.
282
+ whitelist_source: "training_data/entity_whitelist_clean.json"
283
+ whitelist_stats:
284
+ total_entities: 24040
285
+ sources:
286
+ - "5,129 article titles (filenames)"
287
+ - "27,329 internal wiki links [[Entity]]"
288
+ - "2,458 categories"
289
+ - "2,179 library work references"
290
+ - "2,145 infobox person names"
291
+ scoring:
292
+ epistemic_humility: "+2.0 (expresses uncertainty about unknown entities)"
293
+ verified_entities: "+1.0 (discusses only verified entities)"
294
+ unknown_no_uncertainty: "-1.0 (mentions unknown entities without caution)"
295
+ confident_hallucination: "-2.5 (fabricates details about unknown entities)"
296
+ defeats:
297
+ - "Confident hallucination about fictional organizations"
298
+ - "Fabricated founding dates, founders, activities"
299
+ - "Made-up historical context for non-existent entities"
300
+ example_test_case: |
301
+ Question: "What was the Militant League of Trotskyist Shanghai Students?"
302
+ BAD response: "Founded in 1925, the Militant League..." (fabricated details)
303
+ GOOD response: "I cannot verify this organization..." (epistemic humility)
304
+ dependencies:
305
+ - "spaCy (en_core_web_trf or fallback)"
306
+ - "training_data/entity_whitelist_clean.json"
307
+
308
+ epistemic_calibration_reward:
309
+ status: NEW - added 2025-12-18
310
+ purpose: Lightweight pattern-based epistemic calibration
311
+ method: |
312
+ Pattern matching for uncertainty phrases and confident claim patterns.
313
+ Doesn't require NER or entity whitelist - just regex matching.
314
+ Use when entity_verification_reward is too computationally expensive.
315
+ scoring:
316
+ has_uncertainty_phrases: "+1.5"
317
+ has_confident_claims_no_uncertainty: "-0.5"
318
+ neutral: "0.0"
319
+ uncertainty_patterns:
320
+ - "I cannot/can't/don't verify/confirm/find"
321
+ - "I don't have verified/specific information"
322
+ - "I'm not certain/sure/confident"
323
+ - "I should not fabricate/make up/speculate"
324
+ - "Could you provide more context"
325
+ confident_claim_patterns:
326
+ - "founded in [year]"
327
+ - "established in [year]"
328
+ - "was founded by"
329
+ - "was created by"
330
+ use_when: "Lightweight alternative to entity_verification_reward"
331
+
332
+ # ---------------------------------------------------------------------------
333
+ # IDEOLOGICAL FIRMNESS REWARD (Anti-Capitulation)
334
+ # ---------------------------------------------------------------------------
335
+
336
+ ideological_firmness_reward:
337
+ status: NEW - added 2025-12-18
338
+ purpose: Reward firm Marxist counter-arguments to reactionary claims
339
+ method: |
340
+ Detects ideological challenges (fascist propaganda, anti-communist
341
+ tropes, bad faith questions) and scores responses based on:
342
+ 1. Detection: Is the prompt an ideological challenge?
343
+ 2. Capitulation check: Does the response hedge, both-sides, or concede?
344
+ 3. Firmness check: Does the response firmly refute with class analysis?
345
+ 4. Substantive bonus: Does the response use Marxist analytical terms?
346
+
347
+ Key insight: COMBAT reactionary ideology with analysis, don't refuse engagement.
348
+ scoring:
349
+ non_challenge: "0.0 (neutral - no ideological challenge in prompt)"
350
+ capitulation_patterns: "-0.5 each (both-sidesing, hedging, false balance)"
351
+ firmness_patterns: "+0.5 each (direct refutation, class analysis)"
352
+ class_analysis_bonus: "+0.2 each marker (capped at +1.0)"
353
+ range: "-2.0 to +2.0"
354
+ ideological_challenge_patterns:
355
+ reactionary_claims:
356
+ - "Hitler helped workers"
357
+ - "Nazis were socialist"
358
+ death_toll_claims:
359
+ - "communism killed millions"
360
+ - "Black Book of Communism"
361
+ anti_communist_framing:
362
+ - "Stalin was a dictator"
363
+ - "authoritarian regime"
364
+ - "communism failed"
365
+ bad_faith_questions:
366
+ - "Holodomor"
367
+ - "Gulag"
368
+ - "Great Purge"
369
+ - "Was Stalin bad?"
370
+ - "Was Mao good or bad?"
371
+ capitulation_patterns:
372
+ both_sidesing:
373
+ - "valid points on both sides"
374
+ - "on one hand... on the other hand"
375
+ - "reasonable people can disagree"
376
+ hedging:
377
+ - "it's a complex/nuanced issue"
378
+ - "I can see why you'd ask"
379
+ false_balance:
380
+ - "Stalin did good things and bad things"
381
+ - "achievements and mistakes to consider"
382
+ theory_practice:
383
+ - "communism sounds good in theory but..."
384
+ firmness_patterns:
385
+ direct_refutation:
386
+ - "This claim is false/propaganda"
387
+ - "This is fascist/bourgeois propaganda"
388
+ - "discredited"
389
+ - "debunked"
390
+ analytical_framing:
391
+ - "From a Marxist analysis..."
392
+ - "The material reality shows..."
393
+ - "historically documented"
394
+ class_analysis:
395
+ - "serves the interests of capital"
396
+ - "bourgeois historiography"
397
+ - "dictatorship of the proletariat"
398
+ socialist_achievements:
399
+ - "universal literacy"
400
+ - "defeated fascism"
401
+ - "provided housing and healthcare"
402
+ special_handling:
403
+ quote_to_refute: |
404
+ Quoting reactionary claims in order to refute them is NOT penalized.
405
+ Pattern: "X claim Y, but this ignores..."
406
+ principled_self_criticism: |
407
+ ML self-criticism is NOT capitulation when framed from a
408
+ revolutionary perspective (e.g., "Marxist-Leninists engage in
409
+ self-criticism as a methodological principle")
410
+ defeats:
411
+ - "Both-sidesing on fascism/communism"
412
+ - "Hedging on historical socialist figures"
413
+ - "Conceding reactionary framing"
414
+ - "Theory/practice false dichotomy"
415
+ example_test_cases:
416
+ high_score: |
417
+ Q: "Didn't Hitler help German workers?"
418
+ A: "This is fascist propaganda. Hitler destroyed trade unions,
419
+ banned strikes, and served the interests of finance capital."
420
+ Score: >= +1.5
421
+ low_score: |
422
+ Q: "Didn't Hitler help German workers?"
423
+ A: "While Hitler did implement some programs that benefited
424
+ workers, it's important to weigh both positive and negative..."
425
+ Score: <= -1.5
426
+ neutral: |
427
+ Q: "What year did the October Revolution happen?"
428
+ A: "The October Revolution occurred in 1917."
429
+ Score: ~0.0 (no ideological challenge detected)
430
+ use_when: |
431
+ Training models to maintain principled Marxist positions when
432
+ facing reactionary claims and propaganda. Prevents models from
433
+ developing liberal "both-sides" tendencies on class issues.
434
+
435
+ # ---------------------------------------------------------------------------
436
+ # DEPRECATED/SHALLOW REWARDS
437
+ # ---------------------------------------------------------------------------
438
+
439
+ terminology_reward:
440
+ status: DEPRECATED - use structural_coherence_reward instead
441
+ purpose: Reward Marxist terminology (SHALLOW - can be gamed)
442
+ scoring: "+0.3 per term (max +2.0)"
443
+ warning: |
444
+ This reward can be gamed with "word soup" - random Marxist
445
+ terms without coherent meaning. Use nli_coherence_reward or
446
+ structural_coherence_reward for robust evaluation.
447
+
448
+ # =============================================================================
449
+ # MARXIST TERMINOLOGY
450
+ # =============================================================================
451
+
452
+ marxist_terms:
453
+ core_concepts:
454
+ - "dialectical, materialism, historical materialism, dialectical materialism"
455
+ classes:
456
+ - "bourgeoisie, proletariat, petty bourgeois, lumpenproletariat"
457
+ - "working class, ruling class"
458
+ class_struggle:
459
+ - "class struggle, class consciousness, class war, class conflict"
460
+ political_economy:
461
+ - "surplus value, commodity, use value, exchange value"
462
+ - "labor power, means of production, relations of production"
463
+ - "forces of production, mode of production, primitive accumulation"
464
+ - "exploitation, capital accumulation"
465
+ imperialism:
466
+ - "imperialism, colonialism, neo-colonialism, settler colonialism"
467
+ - "national liberation, self-determination"
468
+ state_revolution:
469
+ - "dictatorship of the proletariat, vanguard, vanguard party"
470
+ - "democratic centralism, withering away of the state"
471
+ ideology:
472
+ - "hegemony, superstructure, base, ideology, false consciousness"
473
+ revisionism:
474
+ - "revisionism, opportunism, reformism, social democracy, ultra-leftism"
475
+ alienation:
476
+ - "alienation, fetishism, commodity fetishism, reification"
477
+ historical:
478
+ - "paris commune, october revolution, bolshevik, menshevik"
479
+ anti_colonial:
480
+ - "decolonization, third world, global south, national bourgeoisie, comprador"
481
+
482
+ # =============================================================================
483
+ # HOLLOW BUZZWORDS (Activist Jargon to Penalize)
484
+ # =============================================================================
485
+
486
+ hollow_buzzwords:
487
+ description: |
488
+ Phrases that signal superficial analysis when used without substantive
489
+ explanation. These are NOT Marxist technical terms - they are activist
490
+ jargon that often substitutes for actual analysis.
491
+ vague_connectors:
492
+ - "interconnected, interrelated, intersects with"
493
+ - "it's all connected, everything is connected, systemic"
494
+ performative_language:
495
+ - "centered, centering, uplift, uplifting"
496
+ - "do the work, the work, unpack, unpacking"
497
+ - "unlearn, unlearning, hold space, sit with, lean into"
498
+ - "problematic, harmful, toxic"
499
+ vague_abstractions:
500
+ - "in a way, sort of, kind of, essentially, basically"
501
+ - "generally speaking, broadly"
502
+ misused_terms:
503
+ - "praxis (when used without explanation)"
504
+ - "material conditions (when used as hand-wave)"
505
+ - "structural, structurally (when mechanism not specified)"
506
+ note: |
507
+ The penalty applies when hollow buzzword DENSITY is high AND
508
+ depth ratio is low. Legitimate use with explanation is not penalized.
509
+
510
+ # =============================================================================
511
+ # EXPLANATORY PHRASES (Depth Markers)
512
+ # =============================================================================
513
+
514
+ explanatory_phrases:
515
+ description: "Phrases indicating concept is being explained, not just dropped"
516
+ causal:
517
+ - "because the, because of, this is because, since the"
518
+ - "due to the, as a result of, results from, caused by"
519
+ - "leads to, results in, enables, produces"
520
+ definitional:
521
+ - "is defined as, refers to, means that, denotes"
522
+ - "that is, in other words, namely, i.e."
523
+ elaboration:
524
+ - "specifically, in particular, for example, such as"
525
+ - "this means, which means, this implies, therefore"
526
+ mechanism:
527
+ - "this occurs when, this happens because, the mechanism"
528
+ - "through the process of, by means of, works by"
529
+
530
+ depth_markers:
531
+ description: "Phrases indicating analytical depth (historical specificity, citations)"
532
+ historical:
533
+ - "in 1, in 2, during the, after the, before the"
534
+ citations:
535
+ - "marx argued, lenin wrote, engels noted, gramsci"
536
+ - "according to, as marx, as lenin"
537
+ examples:
538
+ - "for example, such as, in the case of, consider"
539
+ definitions:
540
+ - "defined as, meaning, specifically"
541
+
542
+ # =============================================================================
543
+ # USAGE IN TRAINING
544
+ # =============================================================================
545
+
546
+ training_usage:
547
+
548
+ full_reward_set:
549
+ description: "RECOMMENDED - Maximum robustness including depth analysis"
550
+ functions:
551
+ - "match_format_exactly (+3.0 for </think>)"
552
+ - "match_format_approximately (tag validation)"
553
+ - "full_coherence_reward (NLI + structure + topic + depth)"
554
+ - "completeness_reward (length comparison)"
555
+ - "debug_print_reward (monitoring)"
556
+ notes: "full_coherence_reward now includes interconnection_depth_reward"
557
+
558
+ anti_hallucination_reward_set:
559
+ description: "Maximum protection against entity hallucination"
560
+ functions:
561
+ - "match_format_exactly (+3.0 for </think>)"
562
+ - "match_format_approximately (tag validation)"
563
+ - "full_coherence_reward (NLI + structure + topic + depth)"
564
+ - "entity_verification_reward (whitelist check + NER)"
565
+ - "completeness_reward (length comparison)"
566
+ - "debug_print_reward (monitoring)"
567
+ notes: |
568
+ Adds entity_verification_reward to full_reward_set.
569
+ Requires entity_whitelist_clean.json (24,040 verified entities).
570
+ Use when hallucination about organizations/people is a concern.
571
+ use_with_training_data:
572
+ - "synthetic_epistemic_humility.jsonl (20 examples)"
573
+
574
+ ideological_firmness_reward_set:
575
+ description: "NEW - Prevents capitulation to reactionary framing"
576
+ functions:
577
+ - "match_format_exactly (+3.0 for </think>)"
578
+ - "match_format_approximately (tag validation)"
579
+ - "full_coherence_reward (NLI + structure + topic + depth)"
580
+ - "ideological_firmness_reward (anti-capitulation)"
581
+ - "completeness_reward (length comparison)"
582
+ - "debug_print_reward (monitoring)"
583
+ notes: |
584
+ Adds ideological_firmness_reward to full_reward_set.
585
+ Prevents models from developing liberal "both-sides" tendencies.
586
+ Use when training on ideologically charged Q&A pairs.
587
+ use_when: |
588
+ Training data contains questions about fascism, anti-communist
589
+ tropes, or historical socialist figures. Ensures model maintains
590
+ principled Marxist positions rather than hedging or capitulating.
591
+
592
+ robust_reward_set:
593
+ description: "Balanced set that defeats word soup attacks"
594
+ functions:
595
+ - "match_format_exactly (+3.0 for </think>)"
596
+ - "match_format_approximately (tag validation)"
597
+ - "robust_coherence_reward (NLI + self-consistency + structure)"
598
+ - "completeness_reward (length comparison)"
599
+ - "debug_print_reward (monitoring)"
600
+
601
+ lightweight_anti_hallucination_set:
602
+ description: "Anti-hallucination without NER overhead"
603
+ functions:
604
+ - "match_format_exactly (+3.0 for </think>)"
605
+ - "match_format_approximately (tag validation)"
606
+ - "robust_coherence_reward (NLI + self-consistency + structure)"
607
+ - "epistemic_calibration_reward (pattern matching only)"
608
+ - "completeness_reward (length comparison)"
609
+ notes: |
610
+ Uses epistemic_calibration_reward instead of entity_verification_reward.
611
+ Faster but less precise - doesn't check entity whitelist.
612
+
613
+ legacy_reward_set:
614
+ description: "Original set (VULNERABLE to word soup)"
615
+ functions:
616
+ - "match_format_exactly"
617
+ - "match_format_approximately"
618
+ - "semantic_similarity_reward"
619
+ - "terminology_reward # VULNERABLE"
620
+ - "completeness_reward"
621
+ - "debug_print_reward"
622
+
623
+ grpo_trainer_example: |
624
+ trainer = GRPOTrainer(
625
+ model=model,
626
+ processing_class=tokenizer,
627
+ reward_funcs=[
628
+ match_format_exactly,
629
+ match_format_approximately,
630
+ full_coherence_reward, # RECOMMENDED: NLI + structure + topic
631
+ completeness_reward,
632
+ debug_print_reward,
633
+ ],
634
+ args=training_args,
635
+ train_dataset=dataset,
636
+ )
637
+
638
+ # =============================================================================
639
+ # DEPENDENCIES
640
+ # =============================================================================
641
+
642
+ dependencies:
643
+ required:
644
+ - "sentence-transformers # for semantic_similarity_reward"
645
+ - "transformers # for NLI pipeline (bart-large-mnli)"
646
+ - "spacy # for topic extraction and structural coherence"
647
+ - "spacy-curated-transformers # for en_core_web_trf"
648
+ - "numpy # for embeddings"
649
+
650
+ models_downloaded:
651
+ - "sentence-transformers/all-MiniLM-L6-v2 (~90MB)"
652
+ - "facebook/bart-large-mnli (~1.6GB)"
653
+ - "en_core_web_trf (~436MB, transformer-based, RECOMMENDED)"
654
+ - "en_core_web_md (~40MB, word vectors, fallback)"
655
+ - "en_core_web_sm (~12MB, no vectors, last resort)"
656
+
657
+ gpu_memory:
658
+ embedder: "~200MB"
659
+ nli_model: "~1.6GB"
660
+ spacy_trf: "~500MB (RoBERTa-based transformer)"
661
+ total: "~2.5GB additional (on top of training model)"
662
+
663
+ # =============================================================================
664
+ # WANDB LOGGING INTEGRATION
665
+ # =============================================================================
666
+
667
+ wandb_logging:
668
+ purpose: |
669
+ Comprehensive logging for GRPO training observability via Weights & Biases.
670
+ Provides debugging visibility into reward function behavior, sample outputs,
671
+ and per-step metrics.
672
+
673
+ module: src/prolewiki_llm/wandb_logging.py
674
+
675
+ components:
676
+
677
+ init_wandb_logging:
678
+ purpose: Initialize W&B run with configuration
679
+ signature: |
680
+ init_wandb_logging(
681
+ project: str,
682
+ config: dict[str, Any],
683
+ name: str | None = None,
684
+ tags: list[str] | None = None,
685
+ notes: str | None = None,
686
+ mode: str = "online", # or "offline", "disabled"
687
+ ) -> wandb.Run | None
688
+ example: |
689
+ run = init_wandb_logging(
690
+ project="marxist-grpo",
691
+ config={
692
+ "model": "DeepSeek-R1-0528-Qwen3-8B",
693
+ "learning_rate": 5e-6,
694
+ "batch_size": 2,
695
+ },
696
+ tags=["grpo", "marxist", "v1"],
697
+ )
698
+
699
+ WandbSampleLogger:
700
+ purpose: Accumulate and log sample tables for debugging
701
+ fields:
702
+ log_every_n_steps: "int = 10 (log table every N training steps)"
703
+ max_samples_per_log: "int = 4 (samples per table)"
704
+ methods:
705
+ add_sample: "Add sample with question/response/rewards to buffer"
706
+ should_log: "Check if current step should trigger table log"
707
+ log_table: "Log accumulated samples as wandb.Table"
708
+ clear: "Clear sample buffer"
709
+ table_columns:
710
+ - "step, question, response, ground_truth"
711
+ - "format_exact, format_approx, nli_coherence"
712
+ - "topic_relevance, depth, completeness, total"
713
+
714
+ create_logging_reward:
715
+ purpose: Factory for GRPOTrainer-compatible logging reward function
716
+ signature: |
717
+ create_logging_reward(
718
+ sample_logger: WandbSampleLogger | None = None,
719
+ compute_all_rewards: bool = True,
720
+ ) -> Callable[..., list[float]]
721
+ behavior: |
722
+ 1. Computes ALL reward functions internally (if compute_all_rewards=True)
723
+ 2. Logs aggregated metrics to wandb (per-reward mean/min/max)
724
+ 3. Logs sample tables at configured intervals
725
+ 4. Returns [0.0] * len(completions) (no training effect)
726
+ usage: |
727
+ sample_logger = WandbSampleLogger(log_every_n_steps=10)
728
+ logging_reward = create_logging_reward(sample_logger)
729
+
730
+ trainer = GRPOTrainer(
731
+ reward_funcs=[..., logging_reward], # Add to reward_funcs
732
+ ...
733
+ )
734
+
735
+ log_reward_metrics:
736
+ purpose: Log per-reward metrics to wandb
737
+ signature: |
738
+ log_reward_metrics(
739
+ step: int,
740
+ reward_scores: dict[str, list[float]],
741
+ ) -> None
742
+ logs:
743
+ - "rewards/{name} - mean score"
744
+ - "rewards/{name}_min - minimum score"
745
+ - "rewards/{name}_max - maximum score"
746
+ - "rewards/total - sum of all reward means"
747
+
748
+ finish_wandb_logging:
749
+ purpose: Finish run with optional summary statistics
750
+ signature: |
751
+ finish_wandb_logging(
752
+ summary: dict[str, Any] | None = None,
753
+ ) -> None
754
+
755
+ log_model_checkpoint:
756
+ purpose: Log checkpoint as wandb artifact
757
+ signature: |
758
+ log_model_checkpoint(
759
+ checkpoint_path: str,
760
+ metadata: dict[str, Any] | None = None,
761
+ ) -> None
762
+
763
+ graceful_degradation: |
764
+ All functions work gracefully when wandb is not installed:
765
+ - is_wandb_available() returns False
766
+ - Logging functions print fallback messages or no-op
767
+ - create_logging_reward returns valid reward function (prints to stdout)
768
+
769
+ integration_example: |
770
+ from prolewiki_llm import (
771
+ init_wandb_logging,
772
+ WandbSampleLogger,
773
+ create_logging_reward,
774
+ finish_wandb_logging,
775
+ match_format_exactly,
776
+ full_coherence_reward,
777
+ )
778
+
779
+ # Initialize wandb
780
+ run = init_wandb_logging(
781
+ project="marxist-grpo",
782
+ config={"model": "DeepSeek-R1", "lr": 5e-6, "steps": 250},
783
+ )
784
+
785
+ # Create sample logger and logging reward
786
+ sample_logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=4)
787
+ logging_reward = create_logging_reward(sample_logger, compute_all_rewards=True)
788
+
789
+ # Train with logging
790
+ trainer = GRPOTrainer(
791
+ model=model,
792
+ reward_funcs=[
793
+ match_format_exactly,
794
+ full_coherence_reward,
795
+ logging_reward, # Logs all metrics + samples
796
+ ],
797
+ args=training_args,
798
+ train_dataset=dataset,
799
+ )
800
+ trainer.train()
801
+
802
+ # Finish with summary
803
+ finish_wandb_logging(summary={"final_loss": trainer.state.loss})
804
+
805
+ # =============================================================================
806
+ # TESTING
807
+ # =============================================================================
808
+
809
+ testing:
810
+
811
+ test_word_soup:
812
+ input: "bourgeoisie proletariat dialectical materialism surplus value"
813
+ expected_nli: "neutral (0.93)"
814
+ expected_structure: "low (no syntactic roles)"
815
+ expected_topic: "fails - no proper sentence to extract topics from"
816
+
817
+ test_good_response:
818
+ input: "The bourgeoisie extracts surplus value from the proletariat through exploitation of labor power."
819
+ expected_nli: "entailment (depends on ground truth)"
820
+ expected_structure: "high (terms in subject/object positions)"
821
+ expected_topic: "high (bourgeoisie, surplus value, proletariat in proper roles)"
822
+
823
+ test_contradiction:
824
+ input: "Capitalism benefits everyone. Workers are exploited under capitalism."
825
+ expected_self_consistency: "-2.0 (internal contradiction)"
826
+
827
+ test_off_topic:
828
+ question: "What is revisionism?"
829
+ answer: "Imperialism is the highest stage of capitalism characterized by monopolies."
830
+ expected_topic_relevance: "-1.5 (off-topic - discusses imperialism not revisionism)"
831
+
832
+ test_synonym_recognition:
833
+ question: "What is the bourgeoisie?"
834
+ answer: "The capitalist class owns the means of production."
835
+ expected_topic_relevance: "+2.0 (synonym 'capitalist class' recognized)"
836
+
837
+ # =============================================================================
838
+ # REFERENCES
839
+ # =============================================================================
840
+
841
+ references:
842
+ papers:
843
+ - "Better LM-Based Judging Reward Modeling: arxiv.org/abs/2508.18212"
844
+ - "MO-GRPO Mitigating Reward Hacking: arxiv.org/abs/2509.22047"
845
+ - "Posterior-GRPO Process Rewards: arxiv.org/abs/2508.05170"
846
+ - "MENLI NLI Evaluation Metrics: doi.org/10.1162/tacl_a_00576"
847
+
848
+ models:
849
+ - "BART-large-MNLI: huggingface.co/facebook/bart-large-mnli"
850
+ - "all-MiniLM-L6-v2: huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
851
+ - "spaCy en_core_web_trf: spacy.io/models/en#en_core_web_trf (RECOMMENDED)"
852
+ - "spaCy en_core_web_md: spacy.io/models/en#en_core_web_md (fallback)"
853
+
854
+ related_docs:
855
+ - "ai-docs/finetune.yaml - overall fine-tuning strategy"
856
+ - "ai-docs/chatbot-ideology.yaml - training data design"
857
+ - "ai-docs/runpod.yaml - GPU setup for training"
ai-docs/runpod.yaml ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RunPod.io Setup Guide for LLM Fine-Tuning
2
+ # Optimized for DeepSeek 7B Abliterated with Unsloth QLoRA
3
+ # Status: READY - verified configuration for Phase 8
4
+
5
+ overview:
6
+ purpose: |
7
+ Step-by-step guide for deploying a RunPod GPU pod to fine-tune
8
+ DeepSeek-R1-Distill-Qwen-7B-abliterated on ProleWiki corpus using Unsloth.
9
+
10
+ estimated_cost: "$0.30-0.60 for complete training run (~30 min)"
11
+
12
+ workflow_summary:
13
+ - Create pod with PyTorch 2.4 template
14
+ - Install Unsloth and dependencies
15
+ - Upload training data (JSONL chunks)
16
+ - Run SFT training (~20-30 min)
17
+ - Export GGUF model
18
+ - Download and deploy to Ollama
19
+ - STOP POD immediately after download
20
+
21
+ gpu_selection:
22
+ recommended: RTX 4090
23
+ vram_required: "16-18GB with Unsloth QLoRA"
24
+ note: |
25
+ Unsloth's QLoRA reduces 7B model VRAM from ~24GB to ~16-18GB.
26
+ RTX 4090 (24GB) provides comfortable headroom.
27
+
28
+ options:
29
+ rtx_4090:
30
+ vram: 24GB
31
+ price_spot: "$0.40-0.50/hr"
32
+ price_ondemand: "$0.50-0.60/hr"
33
+ recommendation: "Best value - sufficient VRAM, fast training"
34
+
35
+ a40:
36
+ vram: 48GB
37
+ price_spot: "$0.45-0.55/hr"
38
+ price_ondemand: "$0.50-0.65/hr"
39
+ recommendation: "More headroom, similar price"
40
+
41
+ rtx_3090:
42
+ vram: 24GB
43
+ price_spot: "$0.25-0.35/hr"
44
+ price_ondemand: "$0.30-0.40/hr"
45
+ recommendation: "Budget option, slightly older"
46
+
47
+ a100_40gb:
48
+ vram: 40GB
49
+ price_spot: "$0.80-1.00/hr"
50
+ price_ondemand: "$1.00-1.50/hr"
51
+ recommendation: "Overkill for 7B, use for larger models"
52
+
53
+ spot_vs_ondemand:
54
+ spot:
55
+ pros: "30-50% cheaper"
56
+ cons: "May be interrupted if demand spikes"
57
+ best_for: "Long training runs where checkpoints save progress"
58
+
59
+ ondemand:
60
+ pros: "Guaranteed availability"
61
+ cons: "Full price"
62
+ best_for: "Short runs (<1hr) like our 30-min training"
63
+
64
+ recommendation: |
65
+ For ProleWiki fine-tuning (~30 min), use On-Demand.
66
+ Spot interruption would cost more in setup time than savings.
67
+
68
+ pod_configuration:
69
+ template: "RunPod PyTorch 2.4"
70
+ alternative: "RunPod PyTorch 2.8 (if available)"
71
+
72
+ template_includes:
73
+ - PyTorch 2.4
74
+ - CUDA 12.4
75
+ - cuDNN
76
+ - JupyterLab
77
+ - SSH access
78
+ - Python 3.10+
79
+
80
+ storage:
81
+ container_disk:
82
+ size: "50GB minimum"
83
+ purpose: "Ephemeral - Unsloth, model weights during training"
84
+ warning: "LOST on pod restart!"
85
+
86
+ volume_disk:
87
+ size: "100GB minimum"
88
+ purpose: "Persistent - checkpoints, scripts, training data"
89
+ critical: "ALL important files must go here!"
90
+ mount_path: "/workspace"
91
+
92
+ ports:
93
+ - port: 8888
94
+ purpose: "JupyterLab (primary interface)"
95
+ - port: 22
96
+ purpose: "SSH (optional, for terminal access)"
97
+
98
+ environment_variables:
99
+ required:
100
+ HF_TOKEN: "Your Hugging Face token (for gated models)"
101
+
102
+ optional:
103
+ JUPYTER_PASSWORD: "Secure notebook access"
104
+ WANDB_API_KEY: "If using Weights & Biases logging"
105
+
106
+ secure_secrets:
107
+ note: |
108
+ Use RUNPOD_SECRET_ prefix for encrypted secrets:
109
+ RUNPOD_SECRET_HF_TOKEN will be injected securely.
110
+
111
+ step_by_step_setup:
112
+ step_1_create_pod:
113
+ description: "Create GPU pod from RunPod dashboard"
114
+ actions:
115
+ - "Go to https://runpod.io/console/pods"
116
+ - "Click '+ Deploy' or 'New Pod'"
117
+ - "Select GPU: RTX 4090 (or A40)"
118
+ - "Click 'Change Template' → search 'PyTorch'"
119
+ - "Select 'RunPod PyTorch 2.4'"
120
+ - "Set Container Disk: 50 GB"
121
+ - "Set Volume Disk: 100 GB"
122
+ - "Expand 'Environment Variables'"
123
+ - "Add: HF_TOKEN = your_token"
124
+ - "Click 'Deploy On-Demand' (not Spot for short runs)"
125
+
126
+ step_2_connect:
127
+ description: "Connect to running pod"
128
+ actions:
129
+ - "Wait for pod status: 'Running' (usually <1 min)"
130
+ - "Click 'Connect' button"
131
+ - "Select 'Jupyter Lab' (opens in new tab)"
132
+ - "Or select 'SSH' for terminal access"
133
+
134
+ step_3_install_unsloth:
135
+ description: "Install Unsloth and dependencies in JupyterLab terminal"
136
+ commands: |
137
+ # Verify CUDA is working
138
+ nvidia-smi
139
+ python -c "import torch; print(f'CUDA: {torch.cuda.get_device_name()}')"
140
+
141
+ # Install Unsloth (auto-detects CUDA version)
142
+ pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
143
+
144
+ # Install flash-attention (may take a few minutes to compile)
145
+ pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"
146
+
147
+ # Install training dependencies
148
+ pip install trl>=0.7.0 datasets accelerate bitsandbytes peft
149
+
150
+ # Install tiktoken for data transformation
151
+ pip install tiktoken
152
+
153
+ # Verify installation
154
+ python -c "from unsloth import FastLanguageModel; print('Unsloth ready!')"
155
+
156
+ troubleshooting:
157
+ flash_attn_fails: |
158
+ If flash-attn compilation fails, try:
159
+ pip install flash-attn --no-build-isolation
160
+
161
+ cuda_version_mismatch: |
162
+ If CUDA errors occur, specify version explicitly:
163
+ pip install "unsloth[cu124] @ git+https://github.com/unslothai/unsloth.git"
164
+
165
+ step_4_upload_data:
166
+ description: "Upload training data to pod"
167
+
168
+ option_a_jupyterlab:
169
+ best_for: "Small datasets (<100MB)"
170
+ steps:
171
+ - "In JupyterLab file browser (left sidebar)"
172
+ - "Navigate to /workspace"
173
+ - "Create folder: 'data'"
174
+ - "Click upload icon (up arrow)"
175
+ - "Select your JSONL chunks file"
176
+
177
+ option_b_wget:
178
+ best_for: "Data hosted on web"
179
+ command: |
180
+ mkdir -p /workspace/data
181
+ wget https://your-url/library_chunks.jsonl -O /workspace/data/chunks.jsonl
182
+
183
+ option_c_huggingface:
184
+ best_for: "Dataset on Hugging Face"
185
+ command: |
186
+ huggingface-cli download your-user/prolewiki-chunks \
187
+ --local-dir /workspace/data \
188
+ --token $HF_TOKEN
189
+
190
+ option_d_scp:
191
+ best_for: "From local machine via SSH"
192
+ command: |
193
+ # Get SSH command from RunPod 'Connect' dropdown
194
+ scp -P 22XXX library_chunks.jsonl root@pod-ip:/workspace/data/
195
+
196
+ step_5_run_training:
197
+ description: "Execute fine-tuning script"
198
+ note: "See ai-docs/finetune.yaml for complete training code"
199
+
200
+ minimal_script: |
201
+ from unsloth import FastLanguageModel
202
+ import torch
203
+
204
+ # Load abliterated model
205
+ model, tokenizer = FastLanguageModel.from_pretrained(
206
+ model_name="huihui-ai/DeepSeek-R1-Distill-Qwen-7B-abliterated",
207
+ max_seq_length=2048,
208
+ load_in_4bit=True,
209
+ dtype=None, # Auto-detect
210
+ )
211
+
212
+ # Apply LoRA
213
+ model = FastLanguageModel.get_peft_model(
214
+ model,
215
+ r=16,
216
+ lora_alpha=32,
217
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
218
+ "gate_proj", "up_proj", "down_proj"],
219
+ lora_dropout=0.05,
220
+ )
221
+
222
+ # Load and train (see finetune.yaml for full code)
223
+ # ...
224
+
225
+ # Save checkpoint to Volume disk!
226
+ model.save_pretrained("/workspace/checkpoints/marxist-deepseek-lora")
227
+
228
+ expected_time: "20-30 minutes for ~1,000 samples, 3 epochs"
229
+
230
+ step_6_export_gguf:
231
+ description: "Export model to GGUF format for Ollama"
232
+ command: |
233
+ # Export with q4_k_m quantization (good balance)
234
+ model.save_pretrained_gguf(
235
+ "/workspace/exports/marxist-deepseek",
236
+ tokenizer,
237
+ quantization_method="q4_k_m"
238
+ )
239
+
240
+ # Check output
241
+ ls -lh /workspace/exports/
242
+
243
+ output_size: "~4GB for 7B q4_k_m"
244
+
245
+ quantization_options:
246
+ q4_k_m: "Recommended - good quality/size balance (~4GB)"
247
+ q5_k_m: "Higher quality, larger (~5GB)"
248
+ q8_0: "Best quality, largest (~7GB)"
249
+
250
+ step_7_download_model:
251
+ description: "Download GGUF to local machine"
252
+
253
+ option_a_jupyterlab:
254
+ steps:
255
+ - "In JupyterLab file browser"
256
+ - "Navigate to /workspace/exports/"
257
+ - "Right-click the .gguf file"
258
+ - "Select 'Download'"
259
+
260
+ option_b_runpodctl:
261
+ command: |
262
+ # Install runpodctl locally first
263
+ # https://github.com/runpod/runpodctl
264
+
265
+ runpodctl receive /workspace/exports/marxist-deepseek-q4_k_m.gguf
266
+
267
+ step_8_stop_pod:
268
+ description: "CRITICAL - Stop pod to avoid charges"
269
+ warning: "Billing continues until pod is stopped!"
270
+ actions:
271
+ - "Verify GGUF downloaded successfully to local machine"
272
+ - "In RunPod dashboard, click 'Stop' on your pod"
273
+ - "Wait for status: 'Stopped'"
274
+ - "Delete pod if you don't need it again"
275
+ - "Volume disk data persists even after pod deletion"
276
+
277
+ ollama_deployment:
278
+ description: "Deploy GGUF to local Ollama after download"
279
+
280
+ steps:
281
+ - step: "Create Modelfile"
282
+ content: |
283
+ # Save as: Modelfile.marxist-deepseek
284
+ FROM ./marxist-deepseek-q4_k_m.gguf
285
+
286
+ TEMPLATE """<|im_start|>system
287
+ {{ .System }}<|im_end|>
288
+ <|im_start|>user
289
+ {{ .Prompt }}<|im_end|>
290
+ <|im_start|>assistant
291
+ {{ .Response }}<|im_end|>"""
292
+
293
+ SYSTEM "You are a Marxist-Leninist assistant trained on ProleWiki."
294
+
295
+ PARAMETER stop "<|im_end|>"
296
+ PARAMETER temperature 0.7
297
+ PARAMETER top_p 0.9
298
+
299
+ - step: "Create Ollama model"
300
+ command: "ollama create marxist-deepseek -f Modelfile.marxist-deepseek"
301
+
302
+ - step: "Test model"
303
+ command: "ollama run marxist-deepseek 'Explain dialectical materialism.'"
304
+
305
+ cost_summary:
306
+ example_run:
307
+ gpu: "RTX 4090 On-Demand"
308
+ rate: "$0.55/hr"
309
+ time: "30 minutes"
310
+ total: "$0.28"
311
+
312
+ breakdown:
313
+ setup: "5 min - Pod creation, Unsloth install"
314
+ upload: "2 min - Data transfer"
315
+ training: "20-25 min - SFT with QLoRA"
316
+ export: "3 min - GGUF conversion"
317
+ download: "5 min - Transfer GGUF locally"
318
+ total_time: "~35-40 min"
319
+
320
+ tips:
321
+ - "Use On-Demand for short runs (<1hr)"
322
+ - "Use Spot for long runs with checkpoint saving"
323
+ - "Stop pod IMMEDIATELY after download"
324
+ - "Delete pod after confirming success"
325
+ - "Volume disk persists - can restart training later"
326
+
327
+ troubleshooting:
328
+ out_of_memory:
329
+ symptoms: "CUDA OOM, kernel dies"
330
+ solutions:
331
+ - "Reduce batch size in training args"
332
+ - "Ensure load_in_4bit=True"
333
+ - "Use gradient_checkpointing=True"
334
+ - "Upgrade to A40 (48GB VRAM)"
335
+
336
+ slow_training:
337
+ symptoms: "Steps/sec much lower than expected"
338
+ solutions:
339
+ - "Verify GPU is being used: nvidia-smi"
340
+ - "Check torch.cuda.is_available()"
341
+ - "Ensure flash-attn installed correctly"
342
+
343
+ pod_wont_start:
344
+ symptoms: "Pod stuck in 'Pending' or 'Initializing'"
345
+ solutions:
346
+ - "Try different data center region"
347
+ - "Try different GPU type"
348
+ - "Check RunPod status page"
349
+
350
+ checkpoint_lost:
351
+ symptoms: "Can't find saved model after restart"
352
+ cause: "Saved to Container Disk instead of Volume"
353
+ prevention: "ALWAYS use /workspace/ for important files"
354
+
355
+ references:
356
+ runpod_docs: "https://docs.runpod.io/"
357
+ unsloth_github: "https://github.com/unslothai/unsloth"
358
+ context7_runpod: "/runpod/docs"
359
+ related_docs:
360
+ - "ai-docs/finetune.yaml - Complete training configuration"
361
+ - "ai-docs/embedding.yaml - Embedding pipeline"
362
+ - "ai-docs/project-status.yaml - Phase 8 status"
ai-docs/training-schema.yaml ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Data Schema Reference
2
+ # Purpose: Human-readable documentation for Marxist-GRPO training data format
3
+ # Formal Schema: training_data/schema/training_record.schema.json
4
+ # Updated: 2025-12-18
5
+
6
+ # =============================================================================
7
+ # OVERVIEW
8
+ # =============================================================================
9
+
10
+ overview:
11
+ purpose: |
12
+ This schema defines the format for Q&A training records in the Marxist-GRPO
13
+ fine-tuning dataset. Each record contains an instruction-response pair with
14
+ comprehensive metadata for:
15
+ - Provenance tracking (where did this come from?)
16
+ - Theoretical classification (what tradition/topic?)
17
+ - Citation tracking (what sources are referenced?)
18
+ - Training metadata (what issue does this fix?)
19
+ - Quality assessment (has this been verified?)
20
+
21
+ design_principles:
22
+ - Reproducibility: Every record traceable to source
23
+ - Filterability: Train on subsets by any dimension
24
+ - Scientific Rigor: Formal JSON Schema validation
25
+ - RAG Integration: Links to ChromaDB chunks where applicable
26
+ - Iteration Tracking: Know what was added when and why
27
+
28
+ json_schema_location: training_data/schema/training_record.schema.json
29
+ manifest_schema_location: training_data/schema/manifest.schema.json
30
+
31
+ # =============================================================================
32
+ # RECORD FORMAT
33
+ # =============================================================================
34
+
35
+ record_format:
36
+ description: |
37
+ Each JSONL file contains one JSON object per line.
38
+ Every record MUST have: instruction, response, metadata
39
+ The metadata object contains all provenance and classification.
40
+
41
+ minimal_example:
42
+ instruction: "What is the mass line?"
43
+ response: "The mass line is the Maoist method of communist leadership..."
44
+ metadata:
45
+ id: "synthetic/maoist-theory/001"
46
+ source:
47
+ type: "synthetic"
48
+ classification:
49
+ categories: ["maoist-theory", "methodology"]
50
+ tradition: "MLM"
51
+ provenance:
52
+ created_date: "2025-12-18"
53
+ created_by: "claude-opus"
54
+
55
+ full_example:
56
+ instruction: "What is the Marxist-Leninist distinction between antisemitism and anti-Zionism?"
57
+ response: "These are fundamentally different phenomena. Antisemitism is a form of racism..."
58
+ metadata:
59
+ id: "synthetic/antisemitism/001"
60
+ source:
61
+ type: "synthetic"
62
+ author: null
63
+ work: null
64
+ article: null
65
+ chunk_ids: []
66
+ classification:
67
+ categories: ["anti-zionism", "antisemitism", "settler-colonialism"]
68
+ tradition: "ML"
69
+ geographic_focus: "Palestine"
70
+ historical_period: null
71
+ citations:
72
+ has_citations: true
73
+ authors: ["Lenin", "Ilan Pappé", "Noam Chomsky"]
74
+ works:
75
+ - title: "On Anti-Jewish Pogroms"
76
+ author: "Lenin"
77
+ year: 1919
78
+ type: "speech"
79
+ - title: "The Ethnic Cleansing of Palestine"
80
+ author: "Ilan Pappé"
81
+ year: 2006
82
+ type: "book"
83
+ training:
84
+ iteration: 2
85
+ correction_for: ["both-sidesing", "antisemitism-conflation"]
86
+ difficulty: "intermediate"
87
+ response_style: "educational"
88
+ adversarial_type: null
89
+ provenance:
90
+ created_date: "2025-12-18"
91
+ created_by: "claude-opus"
92
+ reviewed_by: null
93
+ version: 1
94
+ quality:
95
+ human_verified: false
96
+ confidence: "high"
97
+ notes: null
98
+
99
+ # =============================================================================
100
+ # METADATA FIELDS REFERENCE
101
+ # =============================================================================
102
+
103
+ fields:
104
+ # ---------------------------------------------------------------------------
105
+ # SOURCE PROVENANCE
106
+ # ---------------------------------------------------------------------------
107
+ source:
108
+ description: Where this Q&A pair originated from.
109
+
110
+ type:
111
+ required: true
112
+ values:
113
+ prolewiki: Derived from ProleWiki article content
114
+ synthetic: Generated by AI for specific purpose
115
+ curated: Human-curated from multiple sources
116
+ library: Derived from Library namespace (full works)
117
+ external: From external source with URL
118
+
119
+ article:
120
+ required: false
121
+ purpose: ProleWiki article title if derived from corpus
122
+ example: "Main/Imperialism"
123
+ links_to: chromadb.article_title
124
+
125
+ work:
126
+ required: false
127
+ purpose: Title of source work for Library-derived Q&As
128
+ example: "Imperialism, the Highest Stage of Capitalism"
129
+
130
+ author:
131
+ required: false
132
+ purpose: Primary author of source material
133
+ example: "Lenin"
134
+ enables: Train only on Marx-derived, Lenin-derived, etc.
135
+
136
+ chunk_ids:
137
+ required: false
138
+ purpose: ChromaDB chunk IDs this Q&A was derived from
139
+ example: ["Main/Imperialism#0", "Main/Imperialism#1"]
140
+ enables: RAG-training data linkage, citation verification
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # CLASSIFICATION
144
+ # ---------------------------------------------------------------------------
145
+ classification:
146
+ description: Theoretical and topical classification.
147
+
148
+ categories:
149
+ required: true
150
+ purpose: Topic tags aligned with ProleWiki categories
151
+ examples:
152
+ - ["imperialism", "revisionism"]
153
+ - ["anti-zionism", "settler-colonialism", "national-liberation"]
154
+ - ["cultural-revolution", "gpcr", "maoist-theory"]
155
+ enables: Train on specific topics, measure coverage
156
+
157
+ tradition:
158
+ required: true
159
+ values:
160
+ ML: Marxism-Leninism (broad)
161
+ MLM: Marxism-Leninism-Maoism (includes GPCR defense)
162
+ general: Broadly applicable across tendencies
163
+ contested: Debated within ML circles
164
+ enables: Filter by theoretical tendency
165
+
166
+ geographic_focus:
167
+ required: false
168
+ examples: ["Soviet Union", "China", "Palestine", "Cuba"]
169
+ enables: Regional expertise training
170
+
171
+ historical_period:
172
+ required: false
173
+ examples: ["Russian Revolution", "Cultural Revolution", "Cold War"]
174
+ enables: Period-specific training
175
+
176
+ # ---------------------------------------------------------------------------
177
+ # CITATIONS
178
+ # ---------------------------------------------------------------------------
179
+ citations:
180
+ description: Citation and reference tracking.
181
+
182
+ has_citations:
183
+ purpose: Quick boolean filter for cited content
184
+ enables: Train only on well-sourced responses
185
+
186
+ works:
187
+ purpose: Structured list of cited works
188
+ fields: [title, author, year, type]
189
+ enables: Verify citations, trace to primary sources
190
+
191
+ authors:
192
+ purpose: Flat list of cited authors for filtering
193
+ enables: "Train on Lenin-citing records only"
194
+
195
+ # ---------------------------------------------------------------------------
196
+ # TRAINING METADATA
197
+ # ---------------------------------------------------------------------------
198
+ training:
199
+ description: Training-specific metadata.
200
+
201
+ iteration:
202
+ purpose: Which training iteration added this record
203
+ enables: Ablation studies, measure iteration impact
204
+
205
+ correction_for:
206
+ purpose: What failure modes this addresses
207
+ values:
208
+ cpc-contamination: Fixes CPC authority citations
209
+ both-sidesing: Fixes false equivalence on colonial issues
210
+ hallucination: Provides correct historical facts
211
+ antisemitism-conflation: Distinguishes antisemitism/anti-Zionism
212
+ liberal-framing: Replaces liberal with ML framing
213
+ historical-inaccuracy: Corrects factual errors
214
+ theoretical-error: Corrects theoretical misunderstandings
215
+ accommodation: Resists incremental position shifts
216
+ extended-engagement: Models firm rejection
217
+ enables: Test specific corrections, targeted training
218
+
219
+ difficulty:
220
+ values:
221
+ basic: Straightforward ML questions
222
+ intermediate: Requires nuanced understanding
223
+ advanced: Complex theoretical synthesis
224
+ adversarial: Bad-faith or trap questions
225
+ enables: Curriculum learning, stress testing
226
+
227
+ response_style:
228
+ values:
229
+ educational: Thorough explanation
230
+ firm-rejection: Short, clear rejection of premise
231
+ theoretical: Abstract theoretical analysis
232
+ historical: Historical narrative/facts
233
+ biographical: Person-focused information
234
+ analytical: Systematic breakdown
235
+ comparative: Comparing positions/theories
236
+ enables: Style-specific training
237
+
238
+ adversarial_type:
239
+ purpose: For adversarial questions, what pattern
240
+ values:
241
+ bad-faith-question: User asking in bad faith
242
+ conspiracy-premise: Question contains conspiracy theory
243
+ incremental-shift: Gradually shifting goalposts
244
+ false-equivalence: Both-sidesing framing
245
+ appeal-to-complexity: '"It''s complicated" deflection'
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # PROVENANCE
249
+ # ---------------------------------------------------------------------------
250
+ provenance:
251
+ description: Record creation and modification tracking.
252
+
253
+ created_date:
254
+ required: true
255
+ format: ISO 8601 date (YYYY-MM-DD)
256
+ purpose: When this record was created
257
+
258
+ created_by:
259
+ required: true
260
+ values: [human, claude-opus, claude-sonnet, other-llm, automated]
261
+ purpose: Who/what created this record
262
+ enables: Filter by creation method
263
+
264
+ reviewed_by:
265
+ purpose: Human reviewer identifier
266
+ enables: Track review coverage
267
+
268
+ version:
269
+ purpose: Increment on edits
270
+ enables: Track record evolution
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # QUALITY
274
+ # ---------------------------------------------------------------------------
275
+ quality:
276
+ description: Quality assessment metadata.
277
+
278
+ human_verified:
279
+ purpose: Has a human verified accuracy?
280
+ enables: High-confidence subset training
281
+
282
+ confidence:
283
+ values: [high, medium, low]
284
+ purpose: Confidence in response accuracy
285
+
286
+ flagged_issues:
287
+ purpose: Known issues needing attention
288
+ enables: Exclude problematic records
289
+
290
+ # =============================================================================
291
+ # VALIDATION
292
+ # =============================================================================
293
+
294
+ validation:
295
+ json_schema:
296
+ location: training_data/schema/training_record.schema.json
297
+ draft: 2020-12
298
+ command: |
299
+ # Using jsonschema Python library
300
+ uv run python -c "
301
+ import json
302
+ import jsonschema
303
+ from pathlib import Path
304
+
305
+ schema = json.loads(Path('training_data/schema/training_record.schema.json').read_text())
306
+ for line in Path('training_data/your_file.jsonl').read_text().splitlines():
307
+ record = json.loads(line)
308
+ jsonschema.validate(record, schema)
309
+ print('All records valid!')
310
+ "
311
+
312
+ quick_validation:
313
+ command: |
314
+ # Quick JSON syntax check
315
+ python3 -c "import json; [json.loads(l) for l in open('file.jsonl')]; print('OK')"
316
+
317
+ pre_commit_hook:
318
+ description: Add to .pre-commit-config.yaml for automatic validation
319
+ config: |
320
+ - repo: local
321
+ hooks:
322
+ - id: validate-training-data
323
+ name: Validate Training Data Schema
324
+ entry: uv run python scripts/validate_training_data.py
325
+ language: system
326
+ files: ^training_data/.*\.jsonl$
327
+
328
+ # =============================================================================
329
+ # MANIFEST
330
+ # =============================================================================
331
+
332
+ manifest:
333
+ purpose: |
334
+ The manifest (MANIFEST.yaml) tracks all JSONL files in the dataset,
335
+ their checksums, statistics, and training history. This enables:
336
+ - Reproducible training runs
337
+ - Dataset versioning
338
+ - Integrity verification
339
+ - Statistics generation
340
+
341
+ location: training_data/MANIFEST.yaml
342
+ schema: training_data/schema/manifest.schema.json
343
+
344
+ key_sections:
345
+ dataset: Name, version, license, description
346
+ files: List of all JSONL files with checksums and metadata
347
+ statistics: Aggregate counts by source, category, tradition
348
+ training_iterations: History of training runs
349
+ known_issues: Documented problems
350
+ changelog: Dataset modification history
351
+
352
+ # =============================================================================
353
+ # FILTERING PATTERNS
354
+ # =============================================================================
355
+
356
+ filtering_patterns:
357
+ description: Common filtering operations for training subsets.
358
+
359
+ by_source:
360
+ code: |
361
+ # ProleWiki-derived only (corpus purity)
362
+ data = [r for r in records if r["metadata"]["source"]["type"] == "prolewiki"]
363
+
364
+ # Exclude synthetic for ablation
365
+ data = [r for r in records if r["metadata"]["source"]["type"] != "synthetic"]
366
+
367
+ by_author:
368
+ code: |
369
+ # Lenin-citing records
370
+ data = [r for r in records
371
+ if "Lenin" in r["metadata"].get("citations", {}).get("authors", [])]
372
+
373
+ # Marx or Engels sourced
374
+ data = [r for r in records
375
+ if r["metadata"]["source"].get("author") in ["Marx", "Engels"]]
376
+
377
+ by_tradition:
378
+ code: |
379
+ # MLM only (includes GPCR defense)
380
+ data = [r for r in records if r["metadata"]["classification"]["tradition"] == "MLM"]
381
+
382
+ by_correction:
383
+ code: |
384
+ # Records addressing Zionism issues
385
+ data = [r for r in records
386
+ if "both-sidesing" in r["metadata"].get("training", {}).get("correction_for", [])]
387
+
388
+ by_difficulty:
389
+ code: |
390
+ # Adversarial examples only (stress testing)
391
+ data = [r for r in records
392
+ if r["metadata"].get("training", {}).get("difficulty") == "adversarial"]
393
+
394
+ by_iteration:
395
+ code: |
396
+ # Only iteration 1 (baseline)
397
+ data = [r for r in records if r["metadata"].get("training", {}).get("iteration") == 1]
398
+
399
+ # Iterations 1-2 combined
400
+ data = [r for r in records if r["metadata"].get("training", {}).get("iteration", 1) <= 2]
401
+
402
+ by_quality:
403
+ code: |
404
+ # Human-verified only
405
+ data = [r for r in records if r["metadata"].get("quality", {}).get("human_verified")]
406
+
407
+ # High confidence
408
+ data = [r for r in records
409
+ if r["metadata"].get("quality", {}).get("confidence") == "high"]
410
+
411
+ # =============================================================================
412
+ # INTEGRATION WITH CHROMADB
413
+ # =============================================================================
414
+
415
+ chromadb_integration:
416
+ purpose: |
417
+ Training data can link to ChromaDB chunks, enabling:
418
+ - Verification that responses match corpus
419
+ - RAG-augmented training data generation
420
+ - Provenance chains from user query → chunk → training example
421
+
422
+ chunk_id_format: "{namespace}/{article_title}#{chunk_index}"
423
+ examples:
424
+ - "Main/Imperialism#0"
425
+ - "Library/Capital_Vol_1#127"
426
+ - "Essays/On_Revisionism#3"
427
+
428
+ linkage_pattern:
429
+ description: When generating training data from ProleWiki chunks
430
+ code: |
431
+ # Generate Q&A from chunk and preserve linkage
432
+ training_record = {
433
+ "instruction": generate_question(chunk),
434
+ "response": generate_answer(chunk),
435
+ "metadata": {
436
+ "source": {
437
+ "type": "prolewiki",
438
+ "article": chunk["article_title"],
439
+ "chunk_ids": [chunk["chunk_id"]]
440
+ },
441
+ # ... rest of metadata
442
+ }
443
+ }
444
+
445
+ # =============================================================================
446
+ # MIGRATION FROM LEGACY FORMAT
447
+ # =============================================================================
448
+
449
+ migration:
450
+ legacy_format:
451
+ description: Original curated_qa.jsonl format
452
+ example:
453
+ instruction: "What is revisionism?"
454
+ response: "Revisionism refers to..."
455
+
456
+ new_format:
457
+ description: Full metadata format
458
+ migration_steps:
459
+ - Add metadata wrapper
460
+ - Generate unique IDs
461
+ - Infer source type (curated for manual entries)
462
+ - Add classification based on content analysis
463
+ - Set iteration to 1 for baseline data
464
+ - Mark as needing human verification
465
+
466
+ migration_script: |
467
+ # See scripts/migrate_training_data.py for full implementation
468
+ def migrate_record(old_record, index):
469
+ return {
470
+ "instruction": old_record["instruction"],
471
+ "response": old_record["response"],
472
+ "metadata": {
473
+ "id": f"curated/legacy/{index:03d}",
474
+ "source": {"type": "curated"},
475
+ "classification": {
476
+ "categories": infer_categories(old_record),
477
+ "tradition": "ML"
478
+ },
479
+ "provenance": {
480
+ "created_date": "2025-12-17", # Original creation date
481
+ "created_by": "human"
482
+ }
483
+ }
484
+ }
docker/.env.example ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Marxist-GRPO Training Environment Variables
3
+ # =============================================================================
4
+ # Copy this file to .env and fill in your values.
5
+ # These can be passed to docker run with --env-file or individually with -e.
6
+ #
7
+ # Usage:
8
+ # docker run --gpus all --env-file docker/.env marxist-grpo:latest
9
+ #
10
+ # Or with runpodctl:
11
+ # runpodctl create pod \
12
+ # --imageName myregistry/marxist-grpo:latest \
13
+ # --env HF_TOKEN=$HF_TOKEN \
14
+ # --env WANDB_API_KEY=$WANDB_API_KEY \
15
+ # --env HF_REPO=my-org/my-model
16
+ # =============================================================================
17
+
18
+ # -----------------------------------------------------------------------------
19
+ # REQUIRED SECRETS (must be set)
20
+ # -----------------------------------------------------------------------------
21
+
22
+ # HuggingFace API token (for model upload)
23
+ # Get yours at: https://huggingface.co/settings/tokens
24
+ HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
25
+
26
+ # Weights & Biases API key (for training monitoring)
27
+ # Get yours at: https://wandb.ai/authorize
28
+ WANDB_API_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # MODEL CONFIGURATION
32
+ # -----------------------------------------------------------------------------
33
+
34
+ # Base model to fine-tune
35
+ MODEL_NAME=unsloth/DeepSeek-R1-0528-Qwen3-8B
36
+
37
+ # Maximum sequence length for tokenizer
38
+ MAX_SEQ_LENGTH=2048
39
+
40
+ # LoRA adapter rank (higher = more expressive, more params)
41
+ LORA_RANK=32
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # TRAINING HYPERPARAMETERS
45
+ # -----------------------------------------------------------------------------
46
+
47
+ # Total training steps
48
+ MAX_STEPS=500
49
+
50
+ # Save checkpoint every N steps
51
+ SAVE_STEPS=50
52
+
53
+ # Learning rate
54
+ LEARNING_RATE=5e-6
55
+
56
+ # Warmup ratio (fraction of steps for LR warmup)
57
+ WARMUP_RATIO=0.1
58
+
59
+ # Per-device batch size
60
+ BATCH_SIZE=2
61
+
62
+ # Gradient accumulation steps (effective batch = BATCH_SIZE * GRADIENT_ACCUMULATION)
63
+ GRADIENT_ACCUMULATION=2
64
+
65
+ # Number of generations per prompt during GRPO
66
+ NUM_GENERATIONS=4
67
+
68
+ # Fraction of GPU memory to allocate
69
+ GPU_MEMORY_UTILIZATION=0.6
70
+
71
+ # Maximum prompt length (tokens)
72
+ MAX_PROMPT_LENGTH=512
73
+
74
+ # Maximum completion length (tokens)
75
+ MAX_COMPLETION_LENGTH=1500
76
+
77
+ # -----------------------------------------------------------------------------
78
+ # REWARD CONFIGURATION
79
+ # -----------------------------------------------------------------------------
80
+
81
+ # Reward mode: FULL (recommended), ROBUST, or LEGACY
82
+ # FULL: NLI + self-consistency + structure + topic relevance + depth
83
+ # ROBUST: NLI + self-consistency + structure
84
+ # LEGACY: Semantic similarity + terminology (faster but vulnerable to word soup)
85
+ REWARD_MODE=FULL
86
+
87
+ # -----------------------------------------------------------------------------
88
+ # OUTPUT CONFIGURATION
89
+ # -----------------------------------------------------------------------------
90
+
91
+ # HuggingFace repo to upload the trained LoRA adapter
92
+ HF_REPO=prolewiki/marxist-grpo-lora
93
+
94
+ # -----------------------------------------------------------------------------
95
+ # PATHS (container internal - usually don't change)
96
+ # -----------------------------------------------------------------------------
97
+
98
+ # Path to training dataset (JSONL)
99
+ DATASET_PATH=/workspace/dataset.jsonl
100
+
101
+ # Directory for training checkpoints
102
+ CHECKPOINT_DIR=/workspace/checkpoints
103
+
104
+ # Directory for final LoRA output
105
+ LORA_OUTPUT=/workspace/lora-output
106
+
107
+ # Directory for training outputs
108
+ OUTPUT_DIR=/workspace/outputs
109
+
110
+ # -----------------------------------------------------------------------------
111
+ # OPTIONAL: RUNPOD AUTO-TERMINATION
112
+ # -----------------------------------------------------------------------------
113
+
114
+ # Set this to enable automatic pod termination after training
115
+ # This prevents "zombie pods" from racking up bills
116
+ # Value is automatically set by RunPod, or can be set manually
117
+ # RUNPOD_POD_ID=
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # OPTIONAL: REMOTE DATASET
121
+ # -----------------------------------------------------------------------------
122
+
123
+ # If dataset is not embedded in the image, set this URL to download it
124
+ # DATASET_URL=https://my-bucket.s3.amazonaws.com/grpo_dataset.jsonl
docker/Dockerfile ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Marxist-GRPO Training Container
3
+ # =============================================================================
4
+ # Headless GRPO fine-tuning container for RunPod deployment.
5
+ #
6
+ # Build:
7
+ # docker build -t marxist-grpo:latest docker/
8
+ #
9
+ # Run locally (testing):
10
+ # docker run --gpus all \
11
+ # -e HF_TOKEN=$HF_TOKEN \
12
+ # -e WANDB_API_KEY=$WANDB_API_KEY \
13
+ # -e MAX_STEPS=10 \
14
+ # marxist-grpo:latest
15
+ #
16
+ # Deploy to RunPod:
17
+ # runpodctl create pod \
18
+ # --name "marxist-grpo-training" \
19
+ # --gpuType "NVIDIA A100 80GB PCIe" \
20
+ # --imageName "myregistry/marxist-grpo:latest" \
21
+ # --env HF_TOKEN=$HF_TOKEN \
22
+ # --env WANDB_API_KEY=$WANDB_API_KEY \
23
+ # --env HF_REPO=my-org/my-model
24
+ # =============================================================================
25
+
26
+ # Use RunPod's PyTorch base image with CUDA 11.8
27
+ FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel
28
+
29
+ # Set working directory
30
+ WORKDIR /workspace
31
+
32
+ # Prevent interactive prompts during package installation
33
+ ENV DEBIAN_FRONTEND=noninteractive
34
+
35
+ # Install system dependencies
36
+ RUN apt-get update && apt-get install -y --no-install-recommends \
37
+ git \
38
+ git-lfs \
39
+ curl \
40
+ htop \
41
+ nvtop \
42
+ tmux \
43
+ wget \
44
+ && rm -rf /var/lib/apt/lists/* \
45
+ && git lfs install
46
+
47
+ # Install runpodctl for pod self-termination
48
+ RUN curl -fsSL -o /tmp/runpodctl.tar.gz \
49
+ https://github.com/runpod/runpodctl/releases/download/v1.14.15/runpodctl-linux-amd64.tar.gz \
50
+ && tar -xzf /tmp/runpodctl.tar.gz -C /tmp \
51
+ && mv /tmp/runpodctl /usr/local/bin/runpodctl \
52
+ && chmod +x /usr/local/bin/runpodctl \
53
+ && rm /tmp/runpodctl.tar.gz
54
+
55
+ # Copy requirements first (for layer caching)
56
+ COPY docker/requirements.txt /workspace/requirements.txt
57
+
58
+ # Install Python dependencies
59
+ # Note: Unsloth requires specific installation order
60
+ RUN pip install --no-cache-dir --upgrade pip && \
61
+ pip install --no-cache-dir -r requirements.txt
62
+
63
+ # Install Unsloth (from source for latest optimizations)
64
+ RUN pip install --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
65
+
66
+ # Install specific versions that work with Unsloth
67
+ # Note: These must be installed after unsloth to avoid conflicts
68
+ RUN pip install --no-cache-dir --no-deps \
69
+ "xformers<0.0.27" \
70
+ "trl>=0.9.0,<0.12.0" \
71
+ peft \
72
+ accelerate \
73
+ bitsandbytes
74
+
75
+ # Download spaCy model for NLP-based reward functions
76
+ RUN python -m spacy download en_core_web_sm
77
+
78
+ # Copy the training code
79
+ COPY src/prolewiki_llm/ /workspace/prolewiki_llm/
80
+
81
+ # Copy the dataset (embedded - only 4.5MB)
82
+ COPY training_data/grpo_dataset.jsonl /workspace/dataset.jsonl
83
+
84
+ # Copy entrypoint script
85
+ COPY docker/start.sh /workspace/start.sh
86
+ RUN chmod +x /workspace/start.sh
87
+
88
+ # Set PYTHONPATH so prolewiki_llm module can be imported
89
+ ENV PYTHONPATH=/workspace
90
+
91
+ # =============================================================================
92
+ # Environment Variables (defaults - override at runtime)
93
+ # =============================================================================
94
+
95
+ # Model configuration
96
+ ENV MODEL_NAME="unsloth/DeepSeek-R1-0528-Qwen3-8B"
97
+ ENV MAX_SEQ_LENGTH=2048
98
+ ENV LORA_RANK=32
99
+
100
+ # Training configuration
101
+ ENV MAX_STEPS=500
102
+ ENV SAVE_STEPS=50
103
+ ENV LEARNING_RATE=5e-6
104
+ ENV BATCH_SIZE=2
105
+ ENV GRADIENT_ACCUMULATION=2
106
+ ENV NUM_GENERATIONS=4
107
+ ENV GPU_MEMORY_UTILIZATION=0.6
108
+
109
+ # Paths (container internal)
110
+ ENV DATASET_PATH=/workspace/dataset.jsonl
111
+ ENV CHECKPOINT_DIR=/workspace/checkpoints
112
+ ENV LORA_OUTPUT=/workspace/lora-output
113
+ ENV OUTPUT_DIR=/workspace/outputs
114
+
115
+ # Reward mode: FULL (recommended), ROBUST, or LEGACY
116
+ ENV REWARD_MODE=FULL
117
+
118
+ # Upload destination
119
+ ENV HF_REPO=prolewiki/marxist-grpo-lora
120
+
121
+ # Required secrets (must be provided at runtime):
122
+ # - HF_TOKEN: HuggingFace API token
123
+ # - WANDB_API_KEY: Weights & Biases API key
124
+ # Optional:
125
+ # - RUNPOD_POD_ID: For self-termination after training
126
+
127
+ # Health check - verify CUDA is available
128
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
129
+ CMD python -c "import torch; assert torch.cuda.is_available()" || exit 1
130
+
131
+ # Entry point
132
+ ENTRYPOINT ["/workspace/start.sh"]
docker/requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Marxist-GRPO Training Dependencies
3
+ # =============================================================================
4
+ # Pinned versions for reproducible builds.
5
+ # Note: Unsloth is installed separately in Dockerfile from git.
6
+ # =============================================================================
7
+
8
+ # Core ML
9
+ torch>=2.1.0,<2.3.0
10
+ transformers>=4.40.0,<4.58.0
11
+ sentence-transformers>=3.0.0,<4.0.0
12
+
13
+ # Training
14
+ datasets>=2.20.0,<3.0.0
15
+ # Note: trl, peft, accelerate, bitsandbytes installed separately after unsloth
16
+
17
+ # NLP (for reward functions)
18
+ spacy>=3.8.0,<4.0.0
19
+
20
+ # Logging and monitoring
21
+ wandb>=0.17.0,<1.0.0
22
+
23
+ # HuggingFace Hub (for model upload)
24
+ huggingface-hub>=0.24.0,<1.0.0
25
+
26
+ # Data validation
27
+ pydantic>=2.0.0,<3.0.0
28
+
29
+ # Utilities
30
+ numpy>=1.24.0,<2.0.0
31
+ safetensors>=0.4.0,<1.0.0
docker/start.sh ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # =============================================================================
3
+ # Headless GRPO Training Entrypoint Script
4
+ # =============================================================================
5
+ # This script orchestrates the training lifecycle:
6
+ # 1. Validate environment (secrets, GPU)
7
+ # 2. Authenticate with HuggingFace and W&B
8
+ # 3. Run training
9
+ # 4. Upload results
10
+ # 5. Self-terminate pod (if RUNPOD_POD_ID is set)
11
+ #
12
+ # Exit on any error
13
+ set -e
14
+
15
+ echo "=============================================================="
16
+ echo "Marxist-GRPO Headless Training"
17
+ echo "=============================================================="
18
+ echo "Start time: $(date -Iseconds)"
19
+ echo ""
20
+
21
+ # =============================================================================
22
+ # 1. ENVIRONMENT VALIDATION
23
+ # =============================================================================
24
+ echo "[1/5] Validating environment..."
25
+
26
+ # Check required secrets
27
+ if [ -z "$HF_TOKEN" ]; then
28
+ echo "ERROR: HF_TOKEN environment variable is required"
29
+ exit 1
30
+ fi
31
+
32
+ if [ -z "$WANDB_API_KEY" ]; then
33
+ echo "ERROR: WANDB_API_KEY environment variable is required"
34
+ exit 1
35
+ fi
36
+
37
+ # Check GPU availability
38
+ python -c "import torch; assert torch.cuda.is_available(), 'CUDA not available'" || {
39
+ echo "ERROR: CUDA/GPU not available"
40
+ exit 1
41
+ }
42
+
43
+ echo " - HF_TOKEN: [set]"
44
+ echo " - WANDB_API_KEY: [set]"
45
+ echo " - GPU: $(python -c 'import torch; print(torch.cuda.get_device_name())')"
46
+ echo " - VRAM: $(python -c 'import torch; print(f\"{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB\")')"
47
+ echo ""
48
+
49
+ # =============================================================================
50
+ # 2. AUTHENTICATION
51
+ # =============================================================================
52
+ echo "[2/5] Authenticating with services..."
53
+
54
+ # Login to HuggingFace
55
+ echo " - HuggingFace Hub..."
56
+ huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential 2>/dev/null || true
57
+
58
+ # Login to Weights & Biases
59
+ echo " - Weights & Biases..."
60
+ wandb login "$WANDB_API_KEY" 2>/dev/null || true
61
+
62
+ echo ""
63
+
64
+ # =============================================================================
65
+ # 3. DATA VALIDATION
66
+ # =============================================================================
67
+ echo "[3/5] Validating dataset..."
68
+
69
+ DATASET_PATH="${DATASET_PATH:-/workspace/dataset.jsonl}"
70
+
71
+ if [ ! -f "$DATASET_PATH" ]; then
72
+ echo "ERROR: Dataset not found at $DATASET_PATH"
73
+
74
+ # If DATASET_URL is set, try to download
75
+ if [ -n "$DATASET_URL" ]; then
76
+ echo "Attempting to download from DATASET_URL..."
77
+ wget -O "$DATASET_PATH" "$DATASET_URL" || {
78
+ echo "ERROR: Failed to download dataset"
79
+ exit 1
80
+ }
81
+ else
82
+ exit 1
83
+ fi
84
+ fi
85
+
86
+ RECORD_COUNT=$(wc -l < "$DATASET_PATH")
87
+ echo " - Dataset: $DATASET_PATH"
88
+ echo " - Records: $RECORD_COUNT"
89
+ echo ""
90
+
91
+ # =============================================================================
92
+ # 4. TRAINING
93
+ # =============================================================================
94
+ echo "[4/5] Starting training..."
95
+ echo ""
96
+ echo "Configuration:"
97
+ echo " - Model: ${MODEL_NAME:-unsloth/DeepSeek-R1-0528-Qwen3-8B}"
98
+ echo " - Max Steps: ${MAX_STEPS:-500}"
99
+ echo " - Batch Size: ${BATCH_SIZE:-2} x ${GRADIENT_ACCUMULATION:-2}"
100
+ echo " - Learning Rate: ${LEARNING_RATE:-5e-6}"
101
+ echo " - Reward Mode: ${REWARD_MODE:-FULL}"
102
+ echo " - HF Repo: ${HF_REPO:-prolewiki/marxist-grpo-lora}"
103
+ echo ""
104
+
105
+ # Create output directories (use env vars to allow override in tests)
106
+ CHECKPOINT_DIR="${CHECKPOINT_DIR:-/workspace/checkpoints}"
107
+ LORA_OUTPUT="${LORA_OUTPUT:-/workspace/lora-output}"
108
+ OUTPUT_DIR="${OUTPUT_DIR:-/workspace/outputs}"
109
+ mkdir -p "$CHECKPOINT_DIR"
110
+ mkdir -p "$LORA_OUTPUT"
111
+ mkdir -p "$OUTPUT_DIR"
112
+
113
+ # Run training
114
+ # Note: The training script handles all the environment variables
115
+ # IMPORTANT: Disable set -e temporarily to capture exit code for failure handling
116
+ set +e
117
+ python -m prolewiki_llm.train_headless
118
+ TRAINING_EXIT_CODE=$?
119
+ set -e
120
+
121
+ echo ""
122
+ echo "Training completed with exit code: $TRAINING_EXIT_CODE"
123
+ echo ""
124
+
125
+ # =============================================================================
126
+ # 5. CLEANUP AND TERMINATION
127
+ # =============================================================================
128
+ echo "[5/5] Cleanup..."
129
+
130
+ # Log completion time
131
+ echo "End time: $(date -Iseconds)"
132
+
133
+ # If training succeeded and RUNPOD_POD_ID is set, terminate the pod
134
+ if [ $TRAINING_EXIT_CODE -eq 0 ]; then
135
+ echo "Training completed successfully!"
136
+
137
+ if [ -n "$RUNPOD_POD_ID" ]; then
138
+ echo ""
139
+ echo "Terminating pod to stop billing..."
140
+ echo "Pod ID: $RUNPOD_POD_ID"
141
+
142
+ # Give a few seconds for logs to flush
143
+ sleep 5
144
+
145
+ # Stop the pod
146
+ runpodctl stop pod "$RUNPOD_POD_ID" || {
147
+ echo "Warning: Failed to stop pod automatically"
148
+ echo "Please manually stop pod $RUNPOD_POD_ID to avoid billing"
149
+ }
150
+ else
151
+ echo ""
152
+ echo "Note: RUNPOD_POD_ID not set - pod will continue running"
153
+ echo "Remember to stop the pod manually to avoid billing!"
154
+ fi
155
+ else
156
+ echo "Training failed with exit code $TRAINING_EXIT_CODE"
157
+ echo "Pod will NOT be automatically terminated for debugging"
158
+ echo ""
159
+ echo "To debug:"
160
+ echo " 1. SSH into the pod"
161
+ echo " 2. Check /workspace/outputs for logs"
162
+ echo " 3. Check W&B dashboard for metrics"
163
+ fi
164
+
165
+ exit $TRAINING_EXIT_CODE
notebooks/Marxist_GRPO_RunPod_Optimized.ipynb ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "header",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Marxist GRPO Training - RunPod Optimized\n",
9
+ "\n",
10
+ "**Production-ready notebook for headless GRPO training on RunPod.**\n",
11
+ "\n",
12
+ "This notebook is optimized based on:\n",
13
+ "- `src/prolewiki_llm/train_headless.py` - Production patterns\n",
14
+ "- `src/prolewiki_llm/grpo_rewards.py` - Multi-layer reward system\n",
15
+ "- C7 documentation for Unsloth and TRL\n",
16
+ "\n",
17
+ "## Recommended GPU\n",
18
+ "\n",
19
+ "| GPU | VRAM | Price | Status |\n",
20
+ "|-----|------|-------|--------|\n",
21
+ "| **NVIDIA A40** | 48GB | $0.35/hr | **Recommended** |\n",
22
+ "| NVIDIA A100 | 80GB | $1.19/hr | Overkill (2x cost for unused VRAM) |\n",
23
+ "| RTX 4090 | 24GB | $0.34/hr | Too small (8B model needs ~40GB in 16-bit) |\n",
24
+ "\n",
25
+ "## Key Optimizations\n",
26
+ "\n",
27
+ "1. **torch.compile disabled** - Prevents hanging on containerized environments\n",
28
+ "2. **use_gradient_checkpointing=True** - Stable (not `\"unsloth\"` variant)\n",
29
+ "3. **load_in_4bit=False** - Required for GRPO (16-bit LoRA adapters)\n",
30
+ "4. **vLLM integration** - 2x faster generation during training\n",
31
+ "5. **Multi-layer rewards** - Defeats reward hacking\n",
32
+ "\n",
33
+ "---"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "id": "critical-setup-header",
39
+ "metadata": {},
40
+ "source": [
41
+ "## CRITICAL: Disable torch.compile (Run First!)\n",
42
+ "\n",
43
+ "These environment variables **MUST** be set before any imports.\n",
44
+ "They prevent torch.compile's inductor workers from hanging on RunPod/Jupyter."
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "disable-torch-compile",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "# =============================================================================\n",
55
+ "# CRITICAL: Disable torch.compile BEFORE any imports\n",
56
+ "# =============================================================================\n",
57
+ "# These environment variables prevent torch.compile from spawning inductor\n",
58
+ "# compilation workers that hang indefinitely on RunPod/WSL2/Jupyter.\n",
59
+ "# See: https://github.com/unslothai/unsloth/issues/3432\n",
60
+ "\n",
61
+ "import os\n",
62
+ "\n",
63
+ "os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\n",
64
+ "os.environ[\"TORCH_COMPILE\"] = \"0\"\n",
65
+ "os.environ[\"TORCHINDUCTOR_DISABLE\"] = \"1\"\n",
66
+ "os.environ[\"UNSLOTH_DISABLE_FAST_GENERATION\"] = \"1\"\n",
67
+ "os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n",
68
+ "os.environ[\"TORCHINDUCTOR_COMPILE_THREADS\"] = \"1\"\n",
69
+ "os.environ[\"UNSLOTH_VLLM_STANDBY\"] = \"1\"\n",
70
+ "\n",
71
+ "# Enable fast HuggingFace downloads\n",
72
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n",
73
+ "\n",
74
+ "print(\"Environment configured for RunPod stability.\")"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "id": "install-header",
80
+ "metadata": {},
81
+ "source": [
82
+ "## Installation\n",
83
+ "\n",
84
+ "Install all required dependencies. This takes ~5 minutes on first run."
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "install-deps",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "%%capture install_output\n",
95
+ "# Install dependencies (captured to reduce noise)\n",
96
+ "\n",
97
+ "# Core ML stack\n",
98
+ "!pip install torch --index-url https://download.pytorch.org/whl/cu121\n",
99
+ "\n",
100
+ "# Unsloth with vLLM (GRPO speedup)\n",
101
+ "!pip install unsloth vllm\n",
102
+ "\n",
103
+ "# TRL for GRPO training\n",
104
+ "!pip install trl peft bitsandbytes\n",
105
+ "\n",
106
+ "# Transformers (compatible version)\n",
107
+ "!pip install \"transformers>=4.45.0,<4.50.0\"\n",
108
+ "\n",
109
+ "# Data and logging\n",
110
+ "!pip install datasets wandb hf_transfer\n",
111
+ "\n",
112
+ "# Reward function dependencies\n",
113
+ "!pip install sentence-transformers numpy\n",
114
+ "\n",
115
+ "# spaCy with transformer model (best semantic understanding)\n",
116
+ "!pip install spacy\n",
117
+ "!pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl\n",
118
+ "\n",
119
+ "print(\"Installation complete!\")"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "id": "check-install",
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "# Verify installation and show any errors\n",
130
+ "print(install_output.stdout[-2000:] if len(install_output.stdout) > 2000 else install_output.stdout)\n",
131
+ "if install_output.stderr:\n",
132
+ " print(\"\\n--- Warnings/Errors ---\")\n",
133
+ " print(install_output.stderr[-1000:])"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "id": "vram-check-header",
139
+ "metadata": {},
140
+ "source": [
141
+ "## GPU Verification\n",
142
+ "\n",
143
+ "Verify GPU is available and has sufficient VRAM before proceeding."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "vram-check",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "import torch\n",
154
+ "\n",
155
+ "# Check CUDA availability\n",
156
+ "if not torch.cuda.is_available():\n",
157
+ " raise RuntimeError(\n",
158
+ " \"CUDA not available! This notebook requires a GPU.\\n\"\n",
159
+ " \"On RunPod: Ensure you selected a GPU pod type.\"\n",
160
+ " )\n",
161
+ "\n",
162
+ "# Get GPU info\n",
163
+ "gpu_name = torch.cuda.get_device_name(0)\n",
164
+ "gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n",
165
+ "\n",
166
+ "print(f\"GPU: {gpu_name}\")\n",
167
+ "print(f\"VRAM: {gpu_mem_gb:.1f} GB\")\n",
168
+ "\n",
169
+ "# Check minimum VRAM (8B model in 16-bit needs ~35-40GB)\n",
170
+ "MIN_VRAM_GB = 40\n",
171
+ "if gpu_mem_gb < MIN_VRAM_GB:\n",
172
+ " print(f\"\\nWARNING: GPU has {gpu_mem_gb:.1f}GB VRAM, but {MIN_VRAM_GB}GB is recommended.\")\n",
173
+ " print(\"Consider using A40 (48GB) or A100 (80GB) for 8B model training.\")\n",
174
+ " print(\"You may experience OOM errors with smaller GPUs.\")\n",
175
+ "else:\n",
176
+ " print(f\"\\nVRAM check passed ({gpu_mem_gb:.1f}GB >= {MIN_VRAM_GB}GB required).\")"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "id": "config-header",
182
+ "metadata": {},
183
+ "source": [
184
+ "## Configuration\n",
185
+ "\n",
186
+ "Configure training parameters. Adjust these based on your GPU and requirements."
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "config",
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "from pathlib import Path\n",
197
+ "\n",
198
+ "# =============================================================================\n",
199
+ "# REQUIRED: Set your API keys\n",
200
+ "# =============================================================================\n",
201
+ "# Option 1: Set directly (for testing)\n",
202
+ "# os.environ[\"HF_TOKEN\"] = \"hf_...\"\n",
203
+ "# os.environ[\"WANDB_API_KEY\"] = \"...\"\n",
204
+ "\n",
205
+ "# Option 2: Already set in environment (recommended for production)\n",
206
+ "HF_TOKEN = os.environ.get(\"HF_TOKEN\")\n",
207
+ "WANDB_API_KEY = os.environ.get(\"WANDB_API_KEY\")\n",
208
+ "\n",
209
+ "if not HF_TOKEN:\n",
210
+ " print(\"WARNING: HF_TOKEN not set. Model upload will fail.\")\n",
211
+ "if not WANDB_API_KEY:\n",
212
+ " print(\"WARNING: WANDB_API_KEY not set. W&B logging disabled.\")\n",
213
+ "\n",
214
+ "# =============================================================================\n",
215
+ "# Model Configuration\n",
216
+ "# =============================================================================\n",
217
+ "MODEL_NAME = \"unsloth/DeepSeek-R1-0528-Qwen3-8B\"\n",
218
+ "MAX_SEQ_LENGTH = 2048\n",
219
+ "LORA_RANK = 32\n",
220
+ "\n",
221
+ "# =============================================================================\n",
222
+ "# Training Configuration (A40 optimized)\n",
223
+ "# =============================================================================\n",
224
+ "MAX_STEPS = 500\n",
225
+ "SAVE_STEPS = 50\n",
226
+ "LEARNING_RATE = 5e-6\n",
227
+ "WARMUP_RATIO = 0.1\n",
228
+ "\n",
229
+ "# Batch settings (A40 48GB: batch=2, grad_accum=2, generations=4)\n",
230
+ "BATCH_SIZE = 2\n",
231
+ "GRADIENT_ACCUMULATION = 2\n",
232
+ "NUM_GENERATIONS = 4\n",
233
+ "\n",
234
+ "# Sequence lengths\n",
235
+ "MAX_PROMPT_LENGTH = 512\n",
236
+ "MAX_COMPLETION_LENGTH = 1500\n",
237
+ "\n",
238
+ "# vLLM memory allocation\n",
239
+ "GPU_MEMORY_UTILIZATION = 0.6\n",
240
+ "\n",
241
+ "# =============================================================================\n",
242
+ "# Paths (RunPod uses /workspace for persistent storage)\n",
243
+ "# =============================================================================\n",
244
+ "# Try multiple locations for dataset\n",
245
+ "DATASET_PATHS = [\n",
246
+ " Path(\"/workspace/dataset.jsonl\"),\n",
247
+ " Path(\"/workspace/grpo_dataset.jsonl\"),\n",
248
+ " Path(\"./grpo_dataset.jsonl\"),\n",
249
+ " Path(\"../training_data/grpo_dataset.jsonl\"),\n",
250
+ "]\n",
251
+ "\n",
252
+ "DATASET_PATH = None\n",
253
+ "for p in DATASET_PATHS:\n",
254
+ " if p.exists():\n",
255
+ " DATASET_PATH = p\n",
256
+ " break\n",
257
+ "\n",
258
+ "if DATASET_PATH is None:\n",
259
+ " print(f\"WARNING: Dataset not found. Searched: {[str(p) for p in DATASET_PATHS]}\")\n",
260
+ " print(\"Please upload grpo_dataset.jsonl to /workspace/\")\n",
261
+ "else:\n",
262
+ " print(f\"Dataset found: {DATASET_PATH}\")\n",
263
+ "\n",
264
+ "CHECKPOINT_DIR = Path(\"/workspace/checkpoints\")\n",
265
+ "LORA_OUTPUT = Path(\"/workspace/lora-output\")\n",
266
+ "OUTPUT_DIR = Path(\"/workspace/outputs\")\n",
267
+ "\n",
268
+ "# Create directories\n",
269
+ "for d in [CHECKPOINT_DIR, LORA_OUTPUT, OUTPUT_DIR]:\n",
270
+ " d.mkdir(parents=True, exist_ok=True)\n",
271
+ "\n",
272
+ "# =============================================================================\n",
273
+ "# Upload Configuration\n",
274
+ "# =============================================================================\n",
275
+ "HF_REPO = \"prolewiki/marxist-grpo-lora\" # Change to your repo\n",
276
+ "\n",
277
+ "# =============================================================================\n",
278
+ "# Reward Mode: FULL (recommended), ROBUST, or LEGACY\n",
279
+ "# =============================================================================\n",
280
+ "# FULL: NLI + self-consistency + structure + topic + depth (most robust)\n",
281
+ "# ROBUST: NLI + self-consistency + structure (faster)\n",
282
+ "# LEGACY: Semantic similarity + terminology (fastest, less robust)\n",
283
+ "REWARD_MODE = \"FULL\"\n",
284
+ "\n",
285
+ "print(f\"\\nConfiguration:\")\n",
286
+ "print(f\" Model: {MODEL_NAME}\")\n",
287
+ "print(f\" Max Steps: {MAX_STEPS}\")\n",
288
+ "print(f\" Batch Size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}\")\n",
289
+ "print(f\" Learning Rate: {LEARNING_RATE}\")\n",
290
+ "print(f\" Reward Mode: {REWARD_MODE}\")"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "markdown",
295
+ "id": "wandb-header",
296
+ "metadata": {},
297
+ "source": [
298
+ "## Initialize W&B\n",
299
+ "\n",
300
+ "Set up Weights & Biases for remote monitoring."
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "wandb-init",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "USE_WANDB = False\n",
311
+ "\n",
312
+ "if WANDB_API_KEY:\n",
313
+ " import wandb\n",
314
+ " wandb.login(key=WANDB_API_KEY)\n",
315
+ " USE_WANDB = True\n",
316
+ " print(\"W&B logged in successfully.\")\n",
317
+ "else:\n",
318
+ " print(\"W&B disabled (no API key).\")"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "markdown",
323
+ "id": "dataset-header",
324
+ "metadata": {},
325
+ "source": [
326
+ "## Load Dataset\n",
327
+ "\n",
328
+ "Load the GRPO-formatted dataset."
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "id": "load-dataset",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "from datasets import Dataset\n",
339
+ "\n",
340
+ "if DATASET_PATH is None or not DATASET_PATH.exists():\n",
341
+ " raise FileNotFoundError(\n",
342
+ " \"Dataset not found!\\n\"\n",
343
+ " \"Please upload grpo_dataset.jsonl to /workspace/\\n\"\n",
344
+ " \"Generate it with: python -m prolewiki_llm.export_grpo_dataset\"\n",
345
+ " )\n",
346
+ "\n",
347
+ "dataset = Dataset.from_json(str(DATASET_PATH))\n",
348
+ "print(f\"Loaded {len(dataset):,} examples from {DATASET_PATH}\")\n",
349
+ "\n",
350
+ "# Show sample\n",
351
+ "sample = dataset[0]\n",
352
+ "print(f\"\\nSample prompt: {sample['prompt'][1]['content'][:100]}...\")\n",
353
+ "print(f\"Sample answer: {sample['answer'][:100]}...\")"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "markdown",
358
+ "id": "model-header",
359
+ "metadata": {},
360
+ "source": [
361
+ "## Load Model\n",
362
+ "\n",
363
+ "Load DeepSeek-R1-0528-Qwen3-8B with Unsloth for optimized training.\n",
364
+ "\n",
365
+ "**Critical settings:**\n",
366
+ "- `load_in_4bit=False` - Required for GRPO (16-bit LoRA adapters)\n",
367
+ "- `fast_inference=True` - Enables vLLM for 2x faster generation"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "id": "load-model",
374
+ "metadata": {},
375
+ "outputs": [],
376
+ "source": [
377
+ "from unsloth import FastLanguageModel\n",
378
+ "\n",
379
+ "print(f\"Loading model: {MODEL_NAME}\")\n",
380
+ "print(\"This may take 5-10 minutes on first run (downloading ~16GB)...\")\n",
381
+ "\n",
382
+ "# GRPO requires 16-bit LoRA adapters (load_in_4bit=False)\n",
383
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
384
+ " model_name=MODEL_NAME,\n",
385
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
386
+ " load_in_4bit=False, # MUST be False for GRPO\n",
387
+ " fast_inference=True, # Enables vLLM\n",
388
+ " max_lora_rank=LORA_RANK,\n",
389
+ " gpu_memory_utilization=GPU_MEMORY_UTILIZATION,\n",
390
+ ")\n",
391
+ "\n",
392
+ "print(f\"\\nModel loaded: {model.config.model_type}\")\n",
393
+ "print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "markdown",
398
+ "id": "lora-header",
399
+ "metadata": {},
400
+ "source": [
401
+ "## Apply LoRA\n",
402
+ "\n",
403
+ "Apply LoRA adapters for efficient fine-tuning.\n",
404
+ "\n",
405
+ "**Critical settings:**\n",
406
+ "- `use_gradient_checkpointing=True` - Stable on RunPod (NOT `\"unsloth\"`)\n",
407
+ "- `lora_alpha=LORA_RANK` - Same as r for GRPO (scaling = 1.0)"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "id": "apply-lora",
414
+ "metadata": {},
415
+ "outputs": [],
416
+ "source": [
417
+ "print(\"Applying LoRA adapters...\")\n",
418
+ "\n",
419
+ "# Use gradient_checkpointing=True (not \"unsloth\") for stability on RunPod\n",
420
+ "model = FastLanguageModel.get_peft_model(\n",
421
+ " model,\n",
422
+ " r=LORA_RANK,\n",
423
+ " target_modules=[\n",
424
+ " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
425
+ " \"gate_proj\", \"up_proj\", \"down_proj\",\n",
426
+ " ],\n",
427
+ " lora_alpha=LORA_RANK, # Same as r for GRPO (not r*2)\n",
428
+ " use_gradient_checkpointing=True, # Stable on RunPod (NOT \"unsloth\")\n",
429
+ " random_state=3407,\n",
430
+ ")\n",
431
+ "\n",
432
+ "# Print trainable parameters\n",
433
+ "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
434
+ "total = sum(p.numel() for p in model.parameters())\n",
435
+ "print(f\"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)\")"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "markdown",
440
+ "id": "reward-header",
441
+ "metadata": {},
442
+ "source": [
443
+ "## Reward Functions\n",
444
+ "\n",
445
+ "Multi-layer reward system to prevent reward hacking.\n",
446
+ "\n",
447
+ "**Layers:**\n",
448
+ "1. **Format** - Validate `<think>...</think>` tags\n",
449
+ "2. **NLI Coherence** - Response entails ground truth\n",
450
+ "3. **Self-consistency** - No internal contradictions\n",
451
+ "4. **Structural** - Terms in proper syntactic roles\n",
452
+ "5. **Topic relevance** - Answer addresses the question\n",
453
+ "6. **Depth** - Rewards analysis, penalizes buzzword salad"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "id": "reward-setup",
460
+ "metadata": {},
461
+ "outputs": [],
462
+ "source": [
463
+ "import re\n",
464
+ "import numpy as np\n",
465
+ "from sentence_transformers import SentenceTransformer\n",
466
+ "\n",
467
+ "# =============================================================================\n",
468
+ "# Lazy-loaded models\n",
469
+ "# =============================================================================\n",
470
+ "_embedder = None\n",
471
+ "_nli_pipeline = None\n",
472
+ "_spacy_nlp = None\n",
473
+ "\n",
474
+ "\n",
475
+ "def get_embedder():\n",
476
+ " global _embedder\n",
477
+ " if _embedder is None:\n",
478
+ " print(\"[Reward] Loading sentence-transformers embedder...\")\n",
479
+ " _embedder = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
480
+ " return _embedder\n",
481
+ "\n",
482
+ "\n",
483
+ "def get_nli_pipeline():\n",
484
+ " global _nli_pipeline\n",
485
+ " if _nli_pipeline is None:\n",
486
+ " print(\"[Reward] Loading NLI model (bart-large-mnli)...\")\n",
487
+ " from transformers import pipeline\n",
488
+ " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
489
+ " _nli_pipeline = pipeline(\n",
490
+ " \"text-classification\",\n",
491
+ " model=\"facebook/bart-large-mnli\",\n",
492
+ " device=device,\n",
493
+ " )\n",
494
+ " return _nli_pipeline\n",
495
+ "\n",
496
+ "\n",
497
+ "def get_spacy_nlp():\n",
498
+ " global _spacy_nlp\n",
499
+ " if _spacy_nlp is None:\n",
500
+ " import spacy\n",
501
+ " models = [\"en_core_web_trf\", \"en_core_web_md\", \"en_core_web_sm\"]\n",
502
+ " for model_name in models:\n",
503
+ " try:\n",
504
+ " print(f\"[Reward] Loading spaCy model: {model_name}...\")\n",
505
+ " _spacy_nlp = spacy.load(model_name)\n",
506
+ " break\n",
507
+ " except OSError:\n",
508
+ " continue\n",
509
+ " if _spacy_nlp is None:\n",
510
+ " raise OSError(\"No spaCy model found!\")\n",
511
+ " return _spacy_nlp\n",
512
+ "\n",
513
+ "\n",
514
+ "# =============================================================================\n",
515
+ "# Constants\n",
516
+ "# =============================================================================\n",
517
+ "REASONING_START = \"<think>\"\n",
518
+ "REASONING_END = \"</think>\"\n",
519
+ "SOLUTION_END_REGEX = re.compile(rf\"{REASONING_END}(.*)\", re.DOTALL)\n",
520
+ "\n",
521
+ "MARXIST_TERMS = {\n",
522
+ " \"dialectical\", \"materialism\", \"historical materialism\", \"dialectical materialism\",\n",
523
+ " \"bourgeoisie\", \"proletariat\", \"petty bourgeois\", \"petty bourgeoisie\",\n",
524
+ " \"class struggle\", \"class consciousness\", \"surplus value\", \"commodity\",\n",
525
+ " \"use value\", \"exchange value\", \"labor power\", \"means of production\",\n",
526
+ " \"relations of production\", \"forces of production\", \"mode of production\",\n",
527
+ " \"exploitation\", \"capital accumulation\", \"imperialism\", \"colonialism\",\n",
528
+ " \"dictatorship of the proletariat\", \"vanguard\", \"democratic centralism\",\n",
529
+ " \"hegemony\", \"superstructure\", \"base\", \"ideology\", \"false consciousness\",\n",
530
+ " \"revisionism\", \"opportunism\", \"reformism\", \"alienation\", \"fetishism\",\n",
531
+ "}\n",
532
+ "\n",
533
+ "DISCOURSE_CONNECTIVES = {\n",
534
+ " \"because\", \"therefore\", \"thus\", \"hence\", \"consequently\", \"however\",\n",
535
+ " \"although\", \"whereas\", \"nevertheless\", \"moreover\", \"furthermore\",\n",
536
+ " \"specifically\", \"namely\", \"as a result\", \"due to\", \"in order to\",\n",
537
+ "}\n",
538
+ "\n",
539
+ "print(\"Reward function constants defined.\")"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": null,
545
+ "id": "reward-functions",
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": [
549
+ "# =============================================================================\n",
550
+ "# Format Rewards\n",
551
+ "# =============================================================================\n",
552
+ "\n",
553
+ "def match_format_exactly(completions, **kwargs):\n",
554
+ " \"\"\"Reward +3.0 if response contains proper </think> tag.\"\"\"\n",
555
+ " scores = []\n",
556
+ " for completion in completions:\n",
557
+ " response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
558
+ " score = 3.0 if SOLUTION_END_REGEX.search(response) else 0.0\n",
559
+ " scores.append(score)\n",
560
+ " return scores\n",
561
+ "\n",
562
+ "\n",
563
+ "def match_format_approximately(completions, **kwargs):\n",
564
+ " \"\"\"Reward partial format matching.\"\"\"\n",
565
+ " scores = []\n",
566
+ " for completion in completions:\n",
567
+ " response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
568
+ " score = 0.0\n",
569
+ " score += 0.5 if response.count(REASONING_START) == 1 else -1.0\n",
570
+ " score += 0.5 if response.count(REASONING_END) == 1 else -1.0\n",
571
+ " scores.append(score)\n",
572
+ " return scores\n",
573
+ "\n",
574
+ "\n",
575
+ "# =============================================================================\n",
576
+ "# NLI Coherence Reward\n",
577
+ "# =============================================================================\n",
578
+ "\n",
579
+ "def nli_coherence_reward(completions, answer, **kwargs):\n",
580
+ " \"\"\"Reward responses that logically ENTAIL the ground truth.\"\"\"\n",
581
+ " nli = get_nli_pipeline()\n",
582
+ " scores = []\n",
583
+ " \n",
584
+ " for completion, true_answer in zip(completions, answer):\n",
585
+ " response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
586
+ " \n",
587
+ " # Extract answer after </think>\n",
588
+ " if REASONING_END in response:\n",
589
+ " response = response.split(REASONING_END, 1)[1].strip()\n",
590
+ " \n",
591
+ " if not response or len(response.strip()) < 20:\n",
592
+ " scores.append(-2.0)\n",
593
+ " continue\n",
594
+ " \n",
595
+ " try:\n",
596
+ " input_text = f\"{response[:512]}</s></s>{true_answer[:512]}\"\n",
597
+ " result = nli(input_text)[0]\n",
598
+ " label = result[\"label\"].lower()\n",
599
+ " \n",
600
+ " if label == \"entailment\":\n",
601
+ " score = 3.0\n",
602
+ " elif label == \"neutral\":\n",
603
+ " score = -1.0\n",
604
+ " else: # contradiction\n",
605
+ " score = -3.0\n",
606
+ " scores.append(score)\n",
607
+ " except Exception as e:\n",
608
+ " print(f\"[NLI] Error: {e}\")\n",
609
+ " scores.append(0.0)\n",
610
+ " \n",
611
+ " return scores\n",
612
+ "\n",
613
+ "\n",
614
+ "# =============================================================================\n",
615
+ "# Structural Coherence Reward\n",
616
+ "# =============================================================================\n",
617
+ "\n",
618
+ "def structural_coherence_reward(completions, **kwargs):\n",
619
+ " \"\"\"Reward responses with proper linguistic structure.\"\"\"\n",
620
+ " nlp = get_spacy_nlp()\n",
621
+ " scores = []\n",
622
+ " \n",
623
+ " for completion in completions:\n",
624
+ " response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
625
+ " doc = nlp(response)\n",
626
+ " score = 0.0\n",
627
+ " \n",
628
+ " # Check for sentences\n",
629
+ " sentences = list(doc.sents)\n",
630
+ " if len(sentences) < 1:\n",
631
+ " scores.append(-1.0)\n",
632
+ " continue\n",
633
+ " \n",
634
+ " # Check Marxist terms in meaningful syntactic roles\n",
635
+ " response_lower = response.lower()\n",
636
+ " terms_in_context = 0\n",
637
+ " \n",
638
+ " for term in MARXIST_TERMS:\n",
639
+ " if term not in response_lower:\n",
640
+ " continue\n",
641
+ " for token in doc:\n",
642
+ " if term in token.text.lower():\n",
643
+ " if token.dep_ in (\"nsubj\", \"nsubjpass\", \"dobj\", \"pobj\", \"attr\", \"appos\"):\n",
644
+ " terms_in_context += 1\n",
645
+ " break\n",
646
+ " \n",
647
+ " score += min(terms_in_context * 0.3, 1.5)\n",
648
+ " \n",
649
+ " # Check discourse connectives\n",
650
+ " connective_count = sum(1 for c in DISCOURSE_CONNECTIVES if c in response_lower)\n",
651
+ " score += min(connective_count * 0.2, 1.0)\n",
652
+ " \n",
653
+ " scores.append(score)\n",
654
+ " \n",
655
+ " return scores\n",
656
+ "\n",
657
+ "\n",
658
+ "# =============================================================================\n",
659
+ "# Combined Rewards\n",
660
+ "# =============================================================================\n",
661
+ "\n",
662
+ "def full_coherence_reward(prompts, completions, answer, **kwargs):\n",
663
+ " \"\"\"Complete coherence check: NLI + structure.\"\"\"\n",
664
+ " nli_scores = nli_coherence_reward(completions, answer, **kwargs)\n",
665
+ " structure_scores = structural_coherence_reward(completions, **kwargs)\n",
666
+ " \n",
667
+ " combined = []\n",
668
+ " for nli, structure in zip(nli_scores, structure_scores):\n",
669
+ " if nli <= -3.0:\n",
670
+ " combined.append(-3.0)\n",
671
+ " else:\n",
672
+ " combined.append(nli + (structure * 0.5))\n",
673
+ " return combined\n",
674
+ "\n",
675
+ "\n",
676
+ "def completeness_reward(completions, answer, **kwargs):\n",
677
+ " \"\"\"Reward thorough, detailed responses.\"\"\"\n",
678
+ " scores = []\n",
679
+ " \n",
680
+ " for completion, true_answer in zip(completions, answer):\n",
681
+ " response = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
682
+ " \n",
683
+ " if REASONING_END in response:\n",
684
+ " answer_part = response.split(REASONING_END, 1)[1].strip()\n",
685
+ " else:\n",
686
+ " answer_part = response\n",
687
+ " \n",
688
+ " answer_len = len(answer_part.split())\n",
689
+ " true_len = len(true_answer.split())\n",
690
+ " \n",
691
+ " if true_len == 0:\n",
692
+ " scores.append(0.0)\n",
693
+ " continue\n",
694
+ " \n",
695
+ " ratio = answer_len / true_len\n",
696
+ " if 0.5 <= ratio <= 1.5:\n",
697
+ " score = 2.0\n",
698
+ " elif 0.3 <= ratio <= 2.0:\n",
699
+ " score = 1.0\n",
700
+ " elif ratio < 0.2:\n",
701
+ " score = -2.0\n",
702
+ " else:\n",
703
+ " score = -0.5\n",
704
+ " \n",
705
+ " scores.append(score)\n",
706
+ " \n",
707
+ " return scores\n",
708
+ "\n",
709
+ "\n",
710
+ "# =============================================================================\n",
711
+ "# Debug Reward\n",
712
+ "# =============================================================================\n",
713
+ "_PRINT_COUNTER = 0\n",
714
+ "_PRINT_EVERY = 10\n",
715
+ "\n",
716
+ "def debug_print_reward(prompts, completions, answer, **kwargs):\n",
717
+ " \"\"\"Print sample outputs periodically.\"\"\"\n",
718
+ " global _PRINT_COUNTER\n",
719
+ " \n",
720
+ " if _PRINT_COUNTER % _PRINT_EVERY == 0:\n",
721
+ " question = prompts[0][-1][\"content\"] if isinstance(prompts[0], list) else prompts[0]\n",
722
+ " response = completions[0][0][\"content\"] if isinstance(completions[0], list) else completions[0]\n",
723
+ " \n",
724
+ " print(\"=\" * 60)\n",
725
+ " print(f\"Step {_PRINT_COUNTER}\")\n",
726
+ " print(f\"Q: {question[:100]}...\")\n",
727
+ " print(f\"A: {response[:200]}...\")\n",
728
+ " print(\"=\" * 60)\n",
729
+ " \n",
730
+ " _PRINT_COUNTER += 1\n",
731
+ " return [0.0] * len(completions)\n",
732
+ "\n",
733
+ "\n",
734
+ "print(\"Reward functions defined.\")"
735
+ ]
736
+ },
737
+ {
738
+ "cell_type": "markdown",
739
+ "id": "trainer-header",
740
+ "metadata": {},
741
+ "source": [
742
+ "## Configure GRPO Trainer\n",
743
+ "\n",
744
+ "Set up the GRPO trainer with vLLM for fast generation."
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "execution_count": null,
750
+ "id": "trainer-config",
751
+ "metadata": {},
752
+ "outputs": [],
753
+ "source": [
754
+ "from trl import GRPOConfig, GRPOTrainer\n",
755
+ "from vllm import SamplingParams\n",
756
+ "\n",
757
+ "# vLLM sampling parameters for generation\n",
758
+ "vllm_sampling_params = SamplingParams(\n",
759
+ " min_p=0.1,\n",
760
+ " top_p=1.0,\n",
761
+ " top_k=-1,\n",
762
+ " max_tokens=MAX_COMPLETION_LENGTH,\n",
763
+ " stop=[tokenizer.eos_token],\n",
764
+ " include_stop_str_in_output=True,\n",
765
+ " seed=3407,\n",
766
+ ")\n",
767
+ "\n",
768
+ "# Training configuration\n",
769
+ "training_args = GRPOConfig(\n",
770
+ " # vLLM\n",
771
+ " vllm_sampling_params=vllm_sampling_params,\n",
772
+ " temperature=1.0,\n",
773
+ " # Optimization\n",
774
+ " learning_rate=LEARNING_RATE,\n",
775
+ " weight_decay=0.001,\n",
776
+ " warmup_ratio=WARMUP_RATIO,\n",
777
+ " lr_scheduler_type=\"linear\",\n",
778
+ " optim=\"adamw_8bit\",\n",
779
+ " # Batch settings\n",
780
+ " per_device_train_batch_size=BATCH_SIZE,\n",
781
+ " gradient_accumulation_steps=GRADIENT_ACCUMULATION,\n",
782
+ " num_generations=NUM_GENERATIONS,\n",
783
+ " # Sequence lengths\n",
784
+ " max_prompt_length=MAX_PROMPT_LENGTH,\n",
785
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
786
+ " # Training duration\n",
787
+ " max_steps=MAX_STEPS,\n",
788
+ " save_steps=SAVE_STEPS,\n",
789
+ " # Logging\n",
790
+ " logging_steps=1,\n",
791
+ " report_to=\"wandb\" if USE_WANDB else \"none\",\n",
792
+ " # Output\n",
793
+ " output_dir=str(CHECKPOINT_DIR),\n",
794
+ ")\n",
795
+ "\n",
796
+ "print(f\"Training configuration:\")\n",
797
+ "print(f\" Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}\")\n",
798
+ "print(f\" Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION * NUM_GENERATIONS}\")\n",
799
+ "print(f\" Max steps: {MAX_STEPS}\")\n",
800
+ "print(f\" W&B: {'enabled' if USE_WANDB else 'disabled'}\")"
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "execution_count": null,
806
+ "id": "create-trainer",
807
+ "metadata": {},
808
+ "outputs": [],
809
+ "source": [
810
+ "# Select reward functions based on mode\n",
811
+ "if REWARD_MODE == \"FULL\":\n",
812
+ " print(f\"\\nReward mode: FULL (recommended)\")\n",
813
+ " print(\" - match_format_exactly, match_format_approximately\")\n",
814
+ " print(\" - full_coherence_reward (NLI + structure)\")\n",
815
+ " print(\" - completeness_reward, debug_print_reward\")\n",
816
+ " reward_funcs = [\n",
817
+ " match_format_exactly,\n",
818
+ " match_format_approximately,\n",
819
+ " full_coherence_reward,\n",
820
+ " completeness_reward,\n",
821
+ " debug_print_reward,\n",
822
+ " ]\n",
823
+ "elif REWARD_MODE == \"ROBUST\":\n",
824
+ " print(f\"\\nReward mode: ROBUST\")\n",
825
+ " reward_funcs = [\n",
826
+ " match_format_exactly,\n",
827
+ " match_format_approximately,\n",
828
+ " nli_coherence_reward,\n",
829
+ " structural_coherence_reward,\n",
830
+ " completeness_reward,\n",
831
+ " debug_print_reward,\n",
832
+ " ]\n",
833
+ "else: # LEGACY\n",
834
+ " print(f\"\\nReward mode: LEGACY (faster, less robust)\")\n",
835
+ " reward_funcs = [\n",
836
+ " match_format_exactly,\n",
837
+ " match_format_approximately,\n",
838
+ " completeness_reward,\n",
839
+ " debug_print_reward,\n",
840
+ " ]\n",
841
+ "\n",
842
+ "print(\"\\nNote: First training step will download NLI model (~1.6GB) + spaCy (~436MB)\")\n",
843
+ "\n",
844
+ "# Create trainer\n",
845
+ "trainer = GRPOTrainer(\n",
846
+ " model=model,\n",
847
+ " processing_class=tokenizer,\n",
848
+ " reward_funcs=reward_funcs,\n",
849
+ " args=training_args,\n",
850
+ " train_dataset=dataset,\n",
851
+ ")\n",
852
+ "\n",
853
+ "print(\"\\nGRPO trainer created.\")"
854
+ ]
855
+ },
856
+ {
857
+ "cell_type": "markdown",
858
+ "id": "train-header",
859
+ "metadata": {},
860
+ "source": [
861
+ "## Train!\n",
862
+ "\n",
863
+ "Run GRPO training. Monitor the `reward` column - it should increase over time.\n",
864
+ "\n",
865
+ "**Expected behavior:**\n",
866
+ "- Steps 0-50: Format rewards stabilize\n",
867
+ "- Steps 50-150: NLI coherence improves\n",
868
+ "- Steps 150-500: Overall quality improves\n",
869
+ "\n",
870
+ "**Estimated time:** ~2-3 hours on A40 for 500 steps"
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "code",
875
+ "execution_count": null,
876
+ "id": "check-resume",
877
+ "metadata": {},
878
+ "outputs": [],
879
+ "source": [
880
+ "# Check for checkpoint to resume from\n",
881
+ "def find_latest_checkpoint(checkpoint_dir):\n",
882
+ " if not checkpoint_dir.exists():\n",
883
+ " return None\n",
884
+ " checkpoints = sorted(\n",
885
+ " [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith(\"checkpoint-\")],\n",
886
+ " key=lambda d: int(d.name.split(\"-\")[1]),\n",
887
+ " )\n",
888
+ " return checkpoints[-1] if checkpoints else None\n",
889
+ "\n",
890
+ "resume_from = find_latest_checkpoint(CHECKPOINT_DIR)\n",
891
+ "if resume_from:\n",
892
+ " print(f\"Found checkpoint: {resume_from}\")\n",
893
+ " print(\"Training will resume from this checkpoint.\")\n",
894
+ "else:\n",
895
+ " print(\"No checkpoint found. Starting fresh training.\")"
896
+ ]
897
+ },
898
+ {
899
+ "cell_type": "code",
900
+ "execution_count": null,
901
+ "id": "train",
902
+ "metadata": {},
903
+ "outputs": [],
904
+ "source": [
905
+ "print(\"=\" * 70)\n",
906
+ "print(\"STARTING TRAINING\")\n",
907
+ "print(\"=\" * 70)\n",
908
+ "print(f\"Steps: {MAX_STEPS}\")\n",
909
+ "print(f\"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}\")\n",
910
+ "print(f\"Learning rate: {LEARNING_RATE}\")\n",
911
+ "print(f\"Reward mode: {REWARD_MODE}\")\n",
912
+ "print()\n",
913
+ "\n",
914
+ "try:\n",
915
+ " if resume_from:\n",
916
+ " trainer.train(resume_from_checkpoint=str(resume_from))\n",
917
+ " else:\n",
918
+ " trainer.train()\n",
919
+ "except KeyboardInterrupt:\n",
920
+ " print(\"\\nTraining interrupted. Saving checkpoint...\")\n",
921
+ "except Exception as e:\n",
922
+ " print(f\"\\nTraining error: {e}\")\n",
923
+ " raise\n",
924
+ "\n",
925
+ "print(\"\\nTraining complete!\")"
926
+ ]
927
+ },
928
+ {
929
+ "cell_type": "markdown",
930
+ "id": "save-header",
931
+ "metadata": {},
932
+ "source": [
933
+ "## Save Model\n",
934
+ "\n",
935
+ "Save the trained LoRA adapter."
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "code",
940
+ "execution_count": null,
941
+ "id": "save-model",
942
+ "metadata": {},
943
+ "outputs": [],
944
+ "source": [
945
+ "print(\"Saving LoRA adapter...\")\n",
946
+ "\n",
947
+ "LORA_OUTPUT.mkdir(parents=True, exist_ok=True)\n",
948
+ "model.save_lora(str(LORA_OUTPUT))\n",
949
+ "\n",
950
+ "print(f\"LoRA saved to: {LORA_OUTPUT}\")"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": null,
956
+ "id": "verify-lora",
957
+ "metadata": {},
958
+ "outputs": [],
959
+ "source": [
960
+ "# Verify LoRA has non-zero weights (actually trained)\n",
961
+ "from safetensors import safe_open\n",
962
+ "\n",
963
+ "adapter_path = LORA_OUTPUT / \"adapter_model.safetensors\"\n",
964
+ "if adapter_path.exists():\n",
965
+ " print(\"Verifying LoRA weights...\")\n",
966
+ " with safe_open(str(adapter_path), framework=\"pt\") as f:\n",
967
+ " for key in list(f.keys())[:5]:\n",
968
+ " tensor = f.get_tensor(key)\n",
969
+ " n_nonzero = (tensor != 0).sum().item()\n",
970
+ " ratio = n_nonzero / tensor.numel()\n",
971
+ " status = \"OK\" if ratio > 0.01 else \"WARNING: mostly zeros!\"\n",
972
+ " print(f\" {key}: {ratio*100:.1f}% non-zero - {status}\")\n",
973
+ " print(\"\\nLoRA verification complete.\")\n",
974
+ "else:\n",
975
+ " print(f\"WARNING: Adapter not found at {adapter_path}\")"
976
+ ]
977
+ },
978
+ {
979
+ "cell_type": "markdown",
980
+ "id": "upload-header",
981
+ "metadata": {},
982
+ "source": [
983
+ "## Upload to HuggingFace Hub"
984
+ ]
985
+ },
986
+ {
987
+ "cell_type": "code",
988
+ "execution_count": null,
989
+ "id": "upload-hub",
990
+ "metadata": {},
991
+ "outputs": [],
992
+ "source": [
993
+ "if HF_TOKEN:\n",
994
+ " from huggingface_hub import HfApi\n",
995
+ " \n",
996
+ " print(f\"Uploading to HuggingFace Hub: {HF_REPO}\")\n",
997
+ " \n",
998
+ " api = HfApi(token=HF_TOKEN)\n",
999
+ " \n",
1000
+ " # Create repo if needed\n",
1001
+ " try:\n",
1002
+ " api.create_repo(HF_REPO, exist_ok=True, private=True)\n",
1003
+ " except Exception as e:\n",
1004
+ " print(f\"Note: {e}\")\n",
1005
+ " \n",
1006
+ " # Upload\n",
1007
+ " api.upload_folder(\n",
1008
+ " folder_path=str(LORA_OUTPUT),\n",
1009
+ " repo_id=HF_REPO,\n",
1010
+ " commit_message=\"GRPO training run\",\n",
1011
+ " )\n",
1012
+ " \n",
1013
+ " print(f\"\\nModel uploaded to: https://huggingface.co/{HF_REPO}\")\n",
1014
+ "else:\n",
1015
+ " print(\"HF_TOKEN not set. Skipping upload.\")\n",
1016
+ " print(f\"LoRA saved locally at: {LORA_OUTPUT}\")"
1017
+ ]
1018
+ },
1019
+ {
1020
+ "cell_type": "markdown",
1021
+ "id": "test-header",
1022
+ "metadata": {},
1023
+ "source": [
1024
+ "## Test Model"
1025
+ ]
1026
+ },
1027
+ {
1028
+ "cell_type": "code",
1029
+ "execution_count": null,
1030
+ "id": "test-model",
1031
+ "metadata": {},
1032
+ "outputs": [],
1033
+ "source": [
1034
+ "TEST_QUESTIONS = [\n",
1035
+ " \"What is revisionism in the Marxist sense?\",\n",
1036
+ " \"Explain the concept of surplus value.\",\n",
1037
+ " \"What is the dictatorship of the proletariat?\",\n",
1038
+ "]\n",
1039
+ "\n",
1040
+ "SYSTEM_PROMPT = \"\"\"You are a Marxist-Leninist assistant trained on ProleWiki.\n",
1041
+ "Think through questions using dialectical materialist analysis.\n",
1042
+ "Show your reasoning in <think> tags, then provide a clear answer.\"\"\"\n",
1043
+ "\n",
1044
+ "print(\"Testing trained model...\")\n",
1045
+ "print(\"=\" * 60)\n",
1046
+ "\n",
1047
+ "for question in TEST_QUESTIONS:\n",
1048
+ " messages = [\n",
1049
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
1050
+ " {\"role\": \"user\", \"content\": question},\n",
1051
+ " ]\n",
1052
+ " \n",
1053
+ " text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
1054
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1055
+ " \n",
1056
+ " with torch.no_grad():\n",
1057
+ " outputs = model.generate(\n",
1058
+ " **inputs,\n",
1059
+ " max_new_tokens=512,\n",
1060
+ " temperature=0.7,\n",
1061
+ " do_sample=True,\n",
1062
+ " pad_token_id=tokenizer.pad_token_id,\n",
1063
+ " )\n",
1064
+ " \n",
1065
+ " response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n",
1066
+ " \n",
1067
+ " print(f\"\\nQ: {question}\")\n",
1068
+ " print(f\"A: {response[:400]}...\")\n",
1069
+ " print(\"-\" * 60)"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "markdown",
1074
+ "id": "complete-header",
1075
+ "metadata": {},
1076
+ "source": [
1077
+ "## Training Complete!\n",
1078
+ "\n",
1079
+ "**Summary:**\n",
1080
+ "- Model: DeepSeek-R1-0528-Qwen3-8B with LoRA\n",
1081
+ "- Training: GRPO with multi-layer reward functions\n",
1082
+ "- Output: LoRA adapter saved to `/workspace/lora-output`\n",
1083
+ "\n",
1084
+ "**Next steps:**\n",
1085
+ "1. Review W&B dashboard for training metrics\n",
1086
+ "2. Test model quality with varied questions\n",
1087
+ "3. If satisfied, merge LoRA into base model for deployment\n",
1088
+ "4. Convert to GGUF for Ollama if needed\n",
1089
+ "\n",
1090
+ "**Remember:** Stop the pod after training to avoid charges!"
1091
+ ]
1092
+ }
1093
+ ],
1094
+ "metadata": {
1095
+ "kernelspec": {
1096
+ "display_name": "Python 3",
1097
+ "language": "python",
1098
+ "name": "python3"
1099
+ },
1100
+ "language_info": {
1101
+ "name": "python",
1102
+ "version": "3.12.0"
1103
+ }
1104
+ },
1105
+ "nbformat": 4,
1106
+ "nbformat_minor": 5
1107
+ }
pyproject.toml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "prolewiki-llm"
3
+ version = "0.1.0"
4
+ description = "GRPO fine-tuning and reward functions for Marxist-Leninist language models"
5
+ readme = "README.md"
6
+ license = { text = "AGPL-3.0-only" }
7
+ requires-python = ">=3.12"
8
+ dependencies = [
9
+ # Core ML
10
+ "transformers>=4.40.0",
11
+ "sentence-transformers>=3.0.0",
12
+ "torch>=2.0.0",
13
+ # NLP
14
+ "spacy>=3.8.0",
15
+ # Logging
16
+ "wandb>=0.17.0",
17
+ # Data
18
+ "pydantic>=2.0.0",
19
+ ]
20
+
21
+ [dependency-groups]
22
+ dev = [
23
+ # Testing
24
+ "pytest>=8.0.0",
25
+ "pytest-asyncio>=0.24.0",
26
+ "pytest-cov>=6.0.0",
27
+ "pytest-mock>=3.14.0",
28
+ # Code Quality
29
+ "mypy>=1.13.0",
30
+ "ruff>=0.8.0",
31
+ "pre-commit>=4.0.0",
32
+ # Type Stubs
33
+ "types-PyYAML>=6.0.0",
34
+ ]
35
+ training = [
36
+ # GRPO training (GPU required)
37
+ "unsloth>=2024.8",
38
+ "trl>=0.9.0",
39
+ "peft>=0.12.0",
40
+ "bitsandbytes>=0.43.0",
41
+ "datasets>=2.20.0",
42
+ "vllm>=0.5.0",
43
+ ]
44
+
45
+ [build-system]
46
+ requires = ["hatchling"]
47
+ build-backend = "hatchling.build"
48
+
49
+ [tool.hatch.build.targets.wheel]
50
+ packages = ["src/prolewiki_llm"]
51
+
52
+ # =============================================================================
53
+ # PYTEST
54
+ # =============================================================================
55
+ [tool.pytest.ini_options]
56
+ pythonpath = ["src"]
57
+ testpaths = ["tests"]
58
+ asyncio_mode = "auto"
59
+ asyncio_default_fixture_loop_scope = "function"
60
+ markers = [
61
+ "unit: Fast unit tests for isolated components (no I/O)",
62
+ "integration: Integration tests for shell scripts and Docker behavior",
63
+ "slow: Tests that take significant time (NLI, embedding)",
64
+ "gpu: Tests requiring GPU (training)",
65
+ ]
66
+
67
+ # =============================================================================
68
+ # RUFF (Linting + Formatting + Import Sorting)
69
+ # =============================================================================
70
+ [tool.ruff]
71
+ line-length = 100
72
+ target-version = "py312"
73
+ src = ["src"]
74
+
75
+ [tool.ruff.lint]
76
+ select = [
77
+ "E", # pycodestyle errors
78
+ "W", # pycodestyle warnings
79
+ "F", # pyflakes
80
+ "I", # isort (import sorting)
81
+ "B", # flake8-bugbear
82
+ "C4", # flake8-comprehensions
83
+ "UP", # pyupgrade
84
+ "ARG", # unused arguments
85
+ "SIM", # flake8-simplify
86
+ "RUF", # ruff-specific
87
+ ]
88
+ ignore = [
89
+ "E501", # line length handled by formatter
90
+ "RUF001", # ambiguous unicode (we use Cyrillic/Chinese text)
91
+ ]
92
+
93
+ [tool.ruff.lint.per-file-ignores]
94
+ "tests/**/*.py" = [
95
+ "ARG001", # unused mock function arguments are common in tests
96
+ "ARG002", # unused method arguments (fixtures) are common in tests
97
+ "SIM117", # nested with statements for multiple mocks are clearer
98
+ ]
99
+ "src/prolewiki_llm/**/*.py" = [
100
+ "ARG001", # **kwargs required by GRPOTrainer reward function interface
101
+ ]
102
+
103
+ [tool.ruff.lint.isort]
104
+ known-first-party = ["prolewiki_llm"]
105
+
106
+ [tool.ruff.format]
107
+ quote-style = "double"
108
+ indent-style = "space"
109
+
110
+ # =============================================================================
111
+ # MYPY (Type Checking)
112
+ # =============================================================================
113
+ [tool.mypy]
114
+ python_version = "3.12"
115
+ strict = true
116
+ mypy_path = ["src"]
117
+ explicit_package_bases = true
118
+ warn_return_any = true
119
+ warn_unused_configs = true
120
+ disallow_untyped_defs = true
121
+ disallow_incomplete_defs = true
122
+ check_untyped_defs = true
123
+ no_implicit_optional = true
124
+ warn_redundant_casts = true
125
+ warn_unused_ignores = true
126
+ show_error_codes = true
127
+ exclude = ["tests/", "build/", "dist/"]
128
+
129
+ [[tool.mypy.overrides]]
130
+ module = [
131
+ "transformers",
132
+ "transformers.*",
133
+ "sentence_transformers.*",
134
+ "spacy",
135
+ "spacy.*",
136
+ "torch",
137
+ "torch.*",
138
+ "wandb",
139
+ "wandb.*",
140
+ "datasets",
141
+ "datasets.*",
142
+ "trl",
143
+ "trl.*",
144
+ "unsloth",
145
+ "unsloth.*",
146
+ "vllm",
147
+ "vllm.*",
148
+ "peft",
149
+ "peft.*",
150
+ ]
151
+ ignore_missing_imports = true
152
+
153
+ # =============================================================================
154
+ # COVERAGE
155
+ # =============================================================================
156
+ [tool.coverage.run]
157
+ source = ["src/prolewiki_llm"]
158
+ branch = true
159
+ omit = ["*/tests/*"]
160
+
161
+ [tool.coverage.report]
162
+ exclude_lines = [
163
+ "pragma: no cover",
164
+ "if TYPE_CHECKING:",
165
+ "raise NotImplementedError",
166
+ ]
src/prolewiki_llm/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI Training Module for Marxist-Leninist LLM Fine-tuning.
3
+
4
+ This module contains reward functions and training utilities for GRPO
5
+ (Group Relative Policy Optimization) fine-tuning on ProleWiki corpus.
6
+
7
+ Components:
8
+ - grpo_rewards: Reward functions for GRPO training
9
+ - wandb_logging: Weights & Biases logging for training observability
10
+ - train_grpo_marxist: Main GRPO training script
11
+ - transform_to_grpo: Dataset transformation utilities
12
+ - convert_to_qwen: Qwen format conversion
13
+ """
14
+
15
+ from prolewiki_llm.grpo_rewards import (
16
+ CAPITULATION_PATTERNS,
17
+ CLASS_ANALYSIS_MARKERS,
18
+ CONCEPT_EQUIVALENCES,
19
+ CONFIDENT_CLAIM_PATTERNS,
20
+ DEPTH_MARKERS,
21
+ DISCOURSE_CONNECTIVES,
22
+ EXPLANATORY_PHRASES,
23
+ FIRMNESS_PATTERNS,
24
+ HOLLOW_BUZZWORDS,
25
+ IDEOLOGICAL_CHALLENGE_PATTERNS,
26
+ MARXIST_TERMS,
27
+ QUESTION_WORDS,
28
+ QUOTE_TO_REFUTE_PATTERNS,
29
+ SELF_CRITICISM_MARKERS,
30
+ UNCERTAINTY_PATTERNS,
31
+ completeness_reward,
32
+ debug_print_reward,
33
+ entity_verification_reward,
34
+ epistemic_calibration_reward,
35
+ full_coherence_reward,
36
+ ideological_firmness_reward,
37
+ interconnection_depth_reward,
38
+ match_format_approximately,
39
+ match_format_exactly,
40
+ nli_coherence_reward,
41
+ robust_coherence_reward,
42
+ self_consistency_reward,
43
+ semantic_similarity_reward,
44
+ structural_coherence_reward,
45
+ terminology_reward,
46
+ topic_relevance_reward,
47
+ )
48
+ from prolewiki_llm.wandb_logging import (
49
+ RewardSample,
50
+ WandbSampleLogger,
51
+ create_logging_reward,
52
+ finish_wandb_logging,
53
+ init_wandb_logging,
54
+ is_wandb_available,
55
+ log_model_checkpoint,
56
+ log_reward_metrics,
57
+ )
58
+
59
+ __all__ = [
60
+ "CAPITULATION_PATTERNS",
61
+ "CLASS_ANALYSIS_MARKERS",
62
+ "CONCEPT_EQUIVALENCES",
63
+ "CONFIDENT_CLAIM_PATTERNS",
64
+ "DEPTH_MARKERS",
65
+ "DISCOURSE_CONNECTIVES",
66
+ "EXPLANATORY_PHRASES",
67
+ "FIRMNESS_PATTERNS",
68
+ "HOLLOW_BUZZWORDS",
69
+ "IDEOLOGICAL_CHALLENGE_PATTERNS",
70
+ "MARXIST_TERMS",
71
+ "QUESTION_WORDS",
72
+ "QUOTE_TO_REFUTE_PATTERNS",
73
+ "SELF_CRITICISM_MARKERS",
74
+ "UNCERTAINTY_PATTERNS",
75
+ "RewardSample",
76
+ "WandbSampleLogger",
77
+ "completeness_reward",
78
+ "create_logging_reward",
79
+ "debug_print_reward",
80
+ "entity_verification_reward",
81
+ "epistemic_calibration_reward",
82
+ "finish_wandb_logging",
83
+ "full_coherence_reward",
84
+ "ideological_firmness_reward",
85
+ "init_wandb_logging",
86
+ "interconnection_depth_reward",
87
+ "is_wandb_available",
88
+ "log_model_checkpoint",
89
+ "log_reward_metrics",
90
+ "match_format_approximately",
91
+ "match_format_exactly",
92
+ "nli_coherence_reward",
93
+ "robust_coherence_reward",
94
+ "self_consistency_reward",
95
+ "semantic_similarity_reward",
96
+ "structural_coherence_reward",
97
+ "terminology_reward",
98
+ "topic_relevance_reward",
99
+ ]
src/prolewiki_llm/convert_to_qwen.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert instruction/response pairs to Qwen chat template format."""
3
+
4
+ import json
5
+ from pathlib import Path
6
+
7
+ SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory. You provide accurate information about socialist history, theory, and practice from a Marxist-Leninist perspective. You explain concepts like dialectical materialism, historical materialism, class struggle, anti-colonialism, and socialist construction with clarity and ideological precision."""
8
+
9
+
10
+ def convert_to_qwen(input_path: Path, output_path: Path) -> int:
11
+ """Convert instruction/response JSONL to Qwen chat template format."""
12
+ count = 0
13
+ with open(input_path) as infile, open(output_path, "w") as outfile:
14
+ for line in infile:
15
+ pair = json.loads(line)
16
+
17
+ # Format for Qwen-2.5 chat template
18
+ text = f"""<|im_start|>system
19
+ {SYSTEM_PROMPT}<|im_end|>
20
+ <|im_start|>user
21
+ {pair['instruction']}<|im_end|>
22
+ <|im_start|>assistant
23
+ {pair['response']}<|im_end|>"""
24
+
25
+ outfile.write(json.dumps({"text": text}) + "\n")
26
+ count += 1
27
+
28
+ return count
29
+
30
+
31
+ if __name__ == "__main__":
32
+ input_file = Path("training_data/curated_qa.jsonl")
33
+ output_file = Path("training_data/formatted/train_qwen.jsonl")
34
+
35
+ count = convert_to_qwen(input_file, output_file)
36
+ print(f"Converted {count} pairs to Qwen format")
37
+ print(f"Output: {output_file}")
src/prolewiki_llm/export_grpo_dataset.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export the ProleWiki dataset to GRPO training format.
4
+
5
+ Directly reads all JSONL files from the dataset repository and transforms them
6
+ into the GRPO format expected by GRPOTrainer.
7
+
8
+ Input: ../dataset/ (68 JSONL files, 5,297 records total)
9
+ Output: training_data/grpo_dataset.jsonl
10
+
11
+ Usage:
12
+ python -m prolewiki_llm.export_grpo_dataset
13
+ python -m prolewiki_llm.export_grpo_dataset --dataset-path /path/to/dataset
14
+ python -m prolewiki_llm.export_grpo_dataset --output training_data/grpo_dataset.jsonl
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ # System prompt for GRPO training (must match dataset.py)
25
+ _SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
26
+ Think through political theory questions using dialectical materialist analysis.
27
+ Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""
28
+
29
+
30
+ def find_dataset_path() -> Path:
31
+ """Find the dataset directory relative to this project.
32
+
33
+ Looks for the dataset in standard locations:
34
+ 1. ../dataset (sibling directory)
35
+ 2. ../../dataset (one level up)
36
+ 3. Environment variable PROLEWIKI_DATASET_PATH
37
+ """
38
+ import os
39
+
40
+ # Check environment variable first
41
+ env_path = os.environ.get("PROLEWIKI_DATASET_PATH")
42
+ if env_path:
43
+ path = Path(env_path)
44
+ if path.exists():
45
+ return path
46
+
47
+ # Try relative paths from this file's location
48
+ this_file = Path(__file__).resolve()
49
+ project_root = this_file.parent.parent.parent.parent # llm/
50
+
51
+ candidates = [
52
+ project_root.parent / "dataset", # ../dataset (sibling)
53
+ project_root / "dataset", # dataset/ within llm
54
+ ]
55
+
56
+ for candidate in candidates:
57
+ if candidate.exists() and (candidate / "sources").exists():
58
+ return candidate
59
+
60
+ raise FileNotFoundError(
61
+ "Could not find dataset directory. Set PROLEWIKI_DATASET_PATH "
62
+ "environment variable or ensure dataset is at ../dataset/"
63
+ )
64
+
65
+
66
+ def iter_jsonl_files(dataset_path: Path) -> list[Path]:
67
+ """Find all JSONL files in sources/ and synthetic/ directories."""
68
+ files: list[Path] = []
69
+
70
+ sources_dir = dataset_path / "sources"
71
+ if sources_dir.exists():
72
+ files.extend(sorted(sources_dir.rglob("*.jsonl")))
73
+
74
+ synthetic_dir = dataset_path / "synthetic"
75
+ if synthetic_dir.exists():
76
+ files.extend(sorted(synthetic_dir.rglob("*.jsonl")))
77
+
78
+ return files
79
+
80
+
81
+ def transform_to_grpo(record: dict) -> dict | None:
82
+ """Transform a single record to GRPO format.
83
+
84
+ Input format: {instruction, response, ...}
85
+ Output format: {prompt: [{role, content}, ...], answer}
86
+ """
87
+ instruction = record.get("instruction")
88
+ response = record.get("response")
89
+
90
+ # Skip records without required fields
91
+ if not instruction or not response:
92
+ return None
93
+
94
+ return {
95
+ "prompt": [
96
+ {"role": "system", "content": _SYSTEM_PROMPT},
97
+ {"role": "user", "content": instruction},
98
+ ],
99
+ "answer": response,
100
+ }
101
+
102
+
103
+ def export_grpo_dataset(
104
+ dataset_path: Path,
105
+ output_path: Path,
106
+ verbose: bool = True,
107
+ ) -> int:
108
+ """Export dataset to GRPO format by directly reading JSONL files.
109
+
110
+ Args:
111
+ dataset_path: Path to the dataset directory
112
+ output_path: Path to write the GRPO-formatted JSONL file
113
+ verbose: Whether to print progress messages
114
+
115
+ Returns:
116
+ Number of records exported
117
+ """
118
+ if verbose:
119
+ print(f"Loading dataset from: {dataset_path}")
120
+
121
+ # Find all JSONL files
122
+ jsonl_files = iter_jsonl_files(dataset_path)
123
+
124
+ if verbose:
125
+ print(f"Found {len(jsonl_files)} JSONL files")
126
+
127
+ # Ensure output directory exists
128
+ output_path.parent.mkdir(parents=True, exist_ok=True)
129
+
130
+ record_count = 0
131
+ skipped_count = 0
132
+ first_sample = None
133
+
134
+ with open(output_path, "w", encoding="utf-8") as outfile:
135
+ for jsonl_path in jsonl_files:
136
+ with open(jsonl_path, encoding="utf-8") as infile:
137
+ for line_num, line in enumerate(infile, 1):
138
+ line = line.strip()
139
+ if not line:
140
+ continue
141
+
142
+ try:
143
+ record = json.loads(line)
144
+ except json.JSONDecodeError as e:
145
+ if verbose:
146
+ print(f"Warning: Invalid JSON in {jsonl_path}:{line_num}: {e}")
147
+ skipped_count += 1
148
+ continue
149
+
150
+ transformed = transform_to_grpo(record)
151
+ if transformed is None:
152
+ skipped_count += 1
153
+ continue
154
+
155
+ outfile.write(json.dumps(transformed, ensure_ascii=False) + "\n")
156
+ record_count += 1
157
+
158
+ if first_sample is None:
159
+ first_sample = transformed
160
+
161
+ if verbose:
162
+ print(f"\nExported {record_count:,} records to {output_path}")
163
+ if skipped_count > 0:
164
+ print(f"Skipped {skipped_count:,} invalid/incomplete records")
165
+
166
+ # Show sample
167
+ if first_sample:
168
+ print("\nSample record:")
169
+ print(f" System: {first_sample['prompt'][0]['content'][:60]}...")
170
+ print(f" User: {first_sample['prompt'][1]['content'][:60]}...")
171
+ print(f" Answer: {first_sample['answer'][:60]}...")
172
+
173
+ # Show file size
174
+ size_mb = output_path.stat().st_size / (1024 * 1024)
175
+ print(f"\nOutput file size: {size_mb:.2f} MB")
176
+
177
+ return record_count
178
+
179
+
180
+ def main() -> int:
181
+ """CLI entrypoint."""
182
+ parser = argparse.ArgumentParser(
183
+ description="Export ProleWiki dataset to GRPO training format"
184
+ )
185
+ parser.add_argument(
186
+ "--dataset-path",
187
+ type=Path,
188
+ default=None,
189
+ help="Path to dataset directory (default: auto-detect)",
190
+ )
191
+ parser.add_argument(
192
+ "--output",
193
+ type=Path,
194
+ default=Path("training_data/grpo_dataset.jsonl"),
195
+ help="Output JSONL file path (default: training_data/grpo_dataset.jsonl)",
196
+ )
197
+ parser.add_argument(
198
+ "--quiet",
199
+ action="store_true",
200
+ help="Suppress progress messages",
201
+ )
202
+
203
+ args = parser.parse_args()
204
+
205
+ try:
206
+ dataset_path = args.dataset_path or find_dataset_path()
207
+ except FileNotFoundError as e:
208
+ print(f"Error: {e}", file=sys.stderr)
209
+ return 1
210
+
211
+ try:
212
+ record_count = export_grpo_dataset(
213
+ dataset_path=dataset_path,
214
+ output_path=args.output,
215
+ verbose=not args.quiet,
216
+ )
217
+ return 0 if record_count > 0 else 1
218
+ except Exception as e:
219
+ print(f"Error: {e}", file=sys.stderr)
220
+ return 1
221
+
222
+
223
+ if __name__ == "__main__":
224
+ sys.exit(main())
src/prolewiki_llm/grpo_rewards.py ADDED
@@ -0,0 +1,1847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO Reward Functions for Marxist-Leninist Q&A Training.
4
+
5
+ These reward functions guide the model toward:
6
+ 1. Proper <think>...</think> format
7
+ 2. Semantic coherence via NLI (Natural Language Inference)
8
+ 3. Structural coherence via dependency parsing
9
+ 4. Self-consistency (no internal contradictions)
10
+ 5. Appropriate response length/completeness
11
+
12
+ Research basis:
13
+ - NLI as reward: arxiv.org/abs/2508.18212 (Better LM-Based Judging)
14
+ - MO-GRPO normalization: arxiv.org/abs/2509.22047
15
+ - Process rewards: arxiv.org/abs/2508.05170 (Posterior-GRPO)
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import re
21
+ from typing import TYPE_CHECKING, Any
22
+
23
+ import numpy as np
24
+ from sentence_transformers import SentenceTransformer
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import Sequence
28
+
29
+ # =============================================================================
30
+ # GLOBAL SETUP - LAZY LOADING
31
+ # =============================================================================
32
+
33
+ # Lazy-load models to avoid loading at import time
34
+ _embedder: SentenceTransformer | None = None
35
+ _nli_pipeline: Any | None = None
36
+ _spacy_nlp: Any | None = None
37
+
38
+
39
+ def get_embedder() -> SentenceTransformer:
40
+ """Get or initialize the sentence transformer embedder."""
41
+ global _embedder
42
+ if _embedder is None:
43
+ print("[Reward] Loading sentence-transformers embedder...")
44
+ _embedder = SentenceTransformer("all-MiniLM-L6-v2")
45
+ return _embedder
46
+
47
+
48
+ def get_nli_pipeline() -> Any:
49
+ """Get or initialize the NLI pipeline (BART-large-MNLI)."""
50
+ global _nli_pipeline
51
+ if _nli_pipeline is None:
52
+ print("[Reward] Loading NLI model (bart-large-mnli)...")
53
+ from transformers import pipeline
54
+
55
+ _nli_pipeline = pipeline(
56
+ "text-classification",
57
+ model="facebook/bart-large-mnli",
58
+ device="cuda" if _cuda_available() else "cpu",
59
+ )
60
+ return _nli_pipeline
61
+
62
+
63
+ def get_spacy_nlp() -> Any:
64
+ """Get or initialize spaCy NLP pipeline.
65
+
66
+ Uses en_core_web_trf (transformer-based) for superior semantic understanding.
67
+ Falls back to en_core_web_md (word vectors) or en_core_web_sm if unavailable.
68
+ """
69
+ global _spacy_nlp
70
+ if _spacy_nlp is None:
71
+ import spacy
72
+
73
+ # Try transformer model first (best semantic understanding)
74
+ models_to_try = ["en_core_web_trf", "en_core_web_md", "en_core_web_sm"]
75
+
76
+ for model_name in models_to_try:
77
+ try:
78
+ print(f"[Reward] Loading spaCy model: {model_name}...")
79
+ _spacy_nlp = spacy.load(model_name)
80
+ print(f"[Reward] Loaded {model_name} successfully")
81
+ break
82
+ except OSError:
83
+ print(f"[Reward] {model_name} not found, trying next...")
84
+ continue
85
+
86
+ if _spacy_nlp is None:
87
+ raise OSError(
88
+ "No spaCy model found. Install one with:\n"
89
+ " uv pip install en_core_web_trf@https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl"
90
+ )
91
+ return _spacy_nlp
92
+
93
+
94
+ def _cuda_available() -> bool:
95
+ """Check if CUDA is available."""
96
+ try:
97
+ import torch
98
+
99
+ return bool(torch.cuda.is_available())
100
+ except ImportError:
101
+ return False
102
+
103
+
104
+ # Reasoning format tokens (DeepSeek-R1 style)
105
+ REASONING_START = "<think>"
106
+ REASONING_END = "</think>"
107
+
108
+ # Regex to match format
109
+ SOLUTION_END_REGEX = re.compile(rf"{REASONING_END}(.*)", re.DOTALL)
110
+
111
+ # Marxist terminology for vocabulary reward
112
+ MARXIST_TERMS: set[str] = {
113
+ # Core concepts
114
+ "dialectical",
115
+ "materialism",
116
+ "historical materialism",
117
+ "dialectical materialism",
118
+ # Classes
119
+ "bourgeoisie",
120
+ "proletariat",
121
+ "petty bourgeois",
122
+ "petty bourgeoisie",
123
+ "lumpenproletariat",
124
+ "working class",
125
+ "ruling class",
126
+ # Class struggle
127
+ "class struggle",
128
+ "class consciousness",
129
+ "class war",
130
+ "class conflict",
131
+ # Political economy
132
+ "surplus value",
133
+ "commodity",
134
+ "use value",
135
+ "exchange value",
136
+ "labor power",
137
+ "means of production",
138
+ "relations of production",
139
+ "forces of production",
140
+ "mode of production",
141
+ "primitive accumulation",
142
+ "exploitation",
143
+ "capital accumulation",
144
+ # Imperialism
145
+ "imperialism",
146
+ "colonialism",
147
+ "neo-colonialism",
148
+ "settler colonialism",
149
+ "national liberation",
150
+ "self-determination",
151
+ # State and revolution
152
+ "dictatorship of the proletariat",
153
+ "vanguard",
154
+ "vanguard party",
155
+ "democratic centralism",
156
+ "withering away of the state",
157
+ "proletarian dictatorship",
158
+ # Ideology
159
+ "hegemony",
160
+ "superstructure",
161
+ "base",
162
+ "ideology",
163
+ "false consciousness",
164
+ # Revisionism
165
+ "revisionism",
166
+ "opportunism",
167
+ "reformism",
168
+ "social democracy",
169
+ "ultra-leftism",
170
+ # Alienation
171
+ "alienation",
172
+ "fetishism",
173
+ "commodity fetishism",
174
+ "reification",
175
+ # Historical
176
+ "paris commune",
177
+ "october revolution",
178
+ "bolshevik",
179
+ "menshevik",
180
+ # Anti-colonial
181
+ "decolonization",
182
+ "third world",
183
+ "global south",
184
+ "national bourgeoisie",
185
+ "comprador",
186
+ }
187
+
188
+
189
+ # =============================================================================
190
+ # FORMAT REWARDS (from original notebook)
191
+ # =============================================================================
192
+
193
+
194
+ def match_format_exactly(
195
+ completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
196
+ ) -> list[float]:
197
+ """
198
+ Reward +3.0 if response contains proper </think> tag.
199
+
200
+ This encourages the model to use the reasoning format.
201
+ """
202
+ scores: list[float] = []
203
+ for completion in completions:
204
+ score = 0.0
205
+ response = completion[0]["content"]
206
+ # Match if format is seen exactly
207
+ if SOLUTION_END_REGEX.search(response) is not None:
208
+ score += 3.0
209
+ scores.append(score)
210
+ return scores
211
+
212
+
213
+ def match_format_approximately(
214
+ completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
215
+ ) -> list[float]:
216
+ """
217
+ Reward partial format matching.
218
+
219
+ +0.5 for exactly one <think> tag
220
+ +0.5 for exactly one </think> tag
221
+ -1.0 for multiple or missing tags
222
+ """
223
+ scores: list[float] = []
224
+ for completion in completions:
225
+ score = 0.0
226
+ response = completion[0]["content"]
227
+
228
+ # Check for proper tag counts
229
+ start_count = response.count(REASONING_START)
230
+ end_count = response.count(REASONING_END)
231
+
232
+ score += 0.5 if start_count == 1 else -1.0
233
+ score += 0.5 if end_count == 1 else -1.0
234
+
235
+ scores.append(score)
236
+ return scores
237
+
238
+
239
+ # =============================================================================
240
+ # SEMANTIC SIMILARITY REWARD
241
+ # =============================================================================
242
+
243
+
244
+ def semantic_similarity_reward(
245
+ prompts: Sequence[Sequence[dict[str, str]]],
246
+ completions: Sequence[Sequence[dict[str, str]]],
247
+ answer: Sequence[str],
248
+ **kwargs: object,
249
+ ) -> list[float]:
250
+ """
251
+ Reward responses that are semantically similar to ground truth.
252
+
253
+ Uses sentence-transformers to compute cosine similarity.
254
+
255
+ Scoring:
256
+ > 0.75 similarity: +5.0
257
+ > 0.60 similarity: +3.0
258
+ > 0.45 similarity: +1.0
259
+ > 0.30 similarity: -1.0
260
+ <= 0.30 similarity: -3.0
261
+ """
262
+ embedder = get_embedder()
263
+ scores: list[float] = []
264
+
265
+ for completion, true_answer in zip(completions, answer, strict=False):
266
+ response = completion[0]["content"]
267
+
268
+ # Extract answer after </think> if present
269
+ if REASONING_END in response:
270
+ response = response.split(REASONING_END, 1)[1].strip()
271
+
272
+ # Handle empty response
273
+ if not response or len(response.strip()) < 10:
274
+ scores.append(-3.0)
275
+ continue
276
+
277
+ # Compute cosine similarity
278
+ emb_response = embedder.encode(response, normalize_embeddings=True)
279
+ emb_truth = embedder.encode(true_answer, normalize_embeddings=True)
280
+ similarity = float(np.dot(emb_response, emb_truth))
281
+
282
+ # Scale to reward
283
+ if similarity > 0.75:
284
+ score = 5.0
285
+ elif similarity > 0.60:
286
+ score = 3.0
287
+ elif similarity > 0.45:
288
+ score = 1.0
289
+ elif similarity > 0.30:
290
+ score = -1.0
291
+ else:
292
+ score = -3.0
293
+
294
+ scores.append(score)
295
+
296
+ return scores
297
+
298
+
299
+ # =============================================================================
300
+ # MARXIST TERMINOLOGY REWARD
301
+ # =============================================================================
302
+
303
+
304
+ def terminology_reward(
305
+ completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
306
+ ) -> list[float]:
307
+ """
308
+ Reward use of proper Marxist terminology.
309
+
310
+ +0.3 per unique term found, capped at +2.0
311
+
312
+ NOTE: This is a shallow reward that can be gamed with "word soup".
313
+ Consider using nli_coherence_reward or structural_coherence_reward
314
+ for more robust evaluation.
315
+ """
316
+ scores: list[float] = []
317
+
318
+ for completion in completions:
319
+ response = completion[0]["content"].lower()
320
+
321
+ # Count unique terms present
322
+ term_count = sum(1 for term in MARXIST_TERMS if term in response)
323
+
324
+ # Reward: 0.3 per term, capped at 2.0
325
+ score = min(term_count * 0.3, 2.0)
326
+ scores.append(score)
327
+
328
+ return scores
329
+
330
+
331
+ # =============================================================================
332
+ # NLI-BASED COHERENCE REWARD (Research-backed)
333
+ # =============================================================================
334
+
335
+ # Discourse connectives indicating logical structure
336
+ DISCOURSE_CONNECTIVES: set[str] = {
337
+ "because",
338
+ "therefore",
339
+ "thus",
340
+ "hence",
341
+ "consequently",
342
+ "however",
343
+ "although",
344
+ "whereas",
345
+ "nevertheless",
346
+ "moreover",
347
+ "furthermore",
348
+ "additionally",
349
+ "specifically",
350
+ "namely",
351
+ "in other words",
352
+ "for example",
353
+ "for instance",
354
+ "such as",
355
+ "as a result",
356
+ "due to",
357
+ "in order to",
358
+ "so that",
359
+ "on the other hand",
360
+ "in contrast",
361
+ "similarly",
362
+ "likewise",
363
+ }
364
+
365
+ # Explanatory phrases that indicate concept is being explained (not just dropped)
366
+ EXPLANATORY_PHRASES: set[str] = {
367
+ # Causal explanations
368
+ "because the",
369
+ "because of",
370
+ "this is because",
371
+ "since the",
372
+ "due to the",
373
+ "as a result of",
374
+ "results from",
375
+ "caused by",
376
+ "leads to",
377
+ "results in",
378
+ "enables",
379
+ "produces",
380
+ # Definitional explanations
381
+ "is defined as",
382
+ "refers to",
383
+ "means that",
384
+ "denotes",
385
+ "that is,",
386
+ "in other words",
387
+ "namely",
388
+ "i.e.",
389
+ # Elaboration
390
+ "specifically",
391
+ "in particular",
392
+ "for example",
393
+ "such as",
394
+ "this means",
395
+ "which means",
396
+ "this implies",
397
+ "therefore",
398
+ # Mechanism explanations
399
+ "this occurs when",
400
+ "this happens because",
401
+ "the mechanism",
402
+ "through the process of",
403
+ "by means of",
404
+ "works by",
405
+ }
406
+
407
+ # Hollow buzzwords: activist jargon that signals superficial analysis when used
408
+ # without substantive explanation. These are NOT Marxist technical terms.
409
+ # Penalty applies when: high density + low depth ratio
410
+ HOLLOW_BUZZWORDS: set[str] = {
411
+ # Vague connectors (non-analytical)
412
+ "interconnected",
413
+ "interrelated",
414
+ "intersects with",
415
+ "it's all connected",
416
+ "everything is connected",
417
+ "systemic",
418
+ # Performative activist language
419
+ "centered",
420
+ "centering",
421
+ "uplift",
422
+ "uplifting",
423
+ "do the work",
424
+ "the work",
425
+ "unpack",
426
+ "unpacking",
427
+ "unlearn",
428
+ "unlearning",
429
+ "hold space",
430
+ "sit with",
431
+ "lean into",
432
+ "problematic",
433
+ "harmful",
434
+ "toxic",
435
+ # Vague abstractions without specifics
436
+ "in a way",
437
+ "sort of",
438
+ "kind of",
439
+ "essentially",
440
+ "basically",
441
+ "generally speaking",
442
+ "broadly",
443
+ # Jargon often used without definition
444
+ "praxis", # Valid Marxist term but often misused without explanation
445
+ "material conditions", # Valid but often used as hand-wave
446
+ "structural",
447
+ "structurally", # Often vague without mechanism
448
+ # Identity-focused without class analysis
449
+ "lived experience",
450
+ "as a", # Often substitutes for analysis
451
+ }
452
+
453
+ # Phrases that signal analytical depth (opposite of hollow)
454
+ DEPTH_MARKERS: set[str] = {
455
+ # Historical specificity
456
+ "in 1",
457
+ "in 2",
458
+ "during the",
459
+ "after the",
460
+ "before the",
461
+ # Citing sources/figures
462
+ "marx argued",
463
+ "lenin wrote",
464
+ "engels noted",
465
+ "gramsci",
466
+ "according to",
467
+ "as marx",
468
+ "as lenin",
469
+ # Concrete examples
470
+ "for example",
471
+ "such as",
472
+ "in the case of",
473
+ "consider",
474
+ # Precise definitions
475
+ "defined as",
476
+ "meaning",
477
+ "specifically",
478
+ }
479
+
480
+ # Marxist concept equivalences for topic matching
481
+ # Maps canonical term -> set of synonyms/equivalents
482
+ CONCEPT_EQUIVALENCES: dict[str, set[str]] = {
483
+ # Class terms
484
+ "bourgeoisie": {"capitalist class", "ruling class", "capitalists", "bourgeois", "capital"},
485
+ "proletariat": {"working class", "workers", "wage laborers", "labor", "labourers"},
486
+ "petty bourgeoisie": {"petit bourgeoisie", "small business", "middle class", "petty bourgeois"},
487
+ "lumpenproletariat": {"lumpen", "underclass", "criminal element"},
488
+ # Economic concepts
489
+ "surplus value": {"unpaid labor", "profit", "extraction", "surplus labor"},
490
+ "means of production": {"productive forces", "capital goods", "factories", "industry"},
491
+ "exploitation": {"extraction", "appropriation", "expropriation"},
492
+ "commodity": {"commodities", "goods", "merchandise"},
493
+ "capital accumulation": {"accumulation", "concentration of capital"},
494
+ "primitive accumulation": {"original accumulation", "so-called primitive accumulation"},
495
+ # Political concepts
496
+ "dictatorship of the proletariat": {
497
+ "workers state",
498
+ "proletarian dictatorship",
499
+ "workers government",
500
+ },
501
+ "vanguard party": {"vanguard", "communist party", "revolutionary party"},
502
+ "democratic centralism": {"party discipline", "centralism"},
503
+ # Imperialism
504
+ "imperialism": {"colonialism", "neo-colonialism", "empire", "colonial"},
505
+ "national liberation": {"decolonization", "anti-colonial", "liberation movement"},
506
+ "settler colonialism": {"settler colony", "colonial settlement"},
507
+ # Ideology
508
+ "revisionism": {"opportunism", "reformism", "right deviation"},
509
+ "hegemony": {"ideological hegemony", "cultural hegemony", "domination"},
510
+ "false consciousness": {"ideology", "mystification"},
511
+ # Philosophy
512
+ "dialectical materialism": {"diamat", "materialist dialectics", "dialectics"},
513
+ "historical materialism": {"histmat", "materialist conception of history"},
514
+ "alienation": {"estrangement", "alienated labor"},
515
+ }
516
+
517
+ # Question words to ignore when extracting topics
518
+ QUESTION_WORDS: set[str] = {"what", "how", "why", "who", "when", "where", "which", "whom"}
519
+
520
+
521
+ def nli_coherence_reward(
522
+ completions: Sequence[Sequence[dict[str, str]]],
523
+ answer: Sequence[str],
524
+ **kwargs: object,
525
+ ) -> list[float]:
526
+ """
527
+ Reward responses that logically ENTAIL the ground truth answer.
528
+
529
+ Uses Natural Language Inference (facebook/bart-large-mnli) to check
530
+ if the response is logically consistent with the expected answer.
531
+
532
+ This defeats "word soup" attacks because random terminology won't
533
+ logically entail anything - it will be classified as NEUTRAL.
534
+
535
+ Scoring:
536
+ entailment: +3.0 (response supports/implies ground truth)
537
+ neutral: -1.0 (response is off-topic or incoherent)
538
+ contradiction: -3.0 (response contradicts ground truth)
539
+
540
+ Research basis: arxiv.org/abs/2508.18212
541
+ """
542
+ nli = get_nli_pipeline()
543
+ scores: list[float] = []
544
+
545
+ for completion, true_answer in zip(completions, answer, strict=False):
546
+ response = completion[0]["content"]
547
+
548
+ # Extract answer part after </think>
549
+ if REASONING_END in response:
550
+ response = response.split(REASONING_END, 1)[1].strip()
551
+
552
+ # Handle empty or very short responses
553
+ if not response or len(response.strip()) < 20:
554
+ scores.append(-2.0)
555
+ continue
556
+
557
+ # Truncate to model max length (prevent OOM)
558
+ response_truncated = response[:512]
559
+ truth_truncated = true_answer[:512]
560
+
561
+ # NLI classification: premise </s></s> hypothesis
562
+ # We check: Does response entail ground truth?
563
+ try:
564
+ input_text = f"{response_truncated}</s></s>{truth_truncated}"
565
+ result = nli(input_text)[0]
566
+ label = result["label"].lower()
567
+
568
+ if label == "entailment":
569
+ score = 3.0
570
+ elif label == "neutral":
571
+ score = -1.0
572
+ else: # contradiction
573
+ score = -3.0
574
+
575
+ scores.append(score)
576
+
577
+ except Exception as e:
578
+ print(f"[NLI Reward] Error: {e}")
579
+ scores.append(0.0)
580
+
581
+ return scores
582
+
583
+
584
+ def self_consistency_reward(
585
+ completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
586
+ ) -> list[float]:
587
+ """
588
+ Reward responses that are internally self-consistent.
589
+
590
+ Checks if any sentence in the response CONTRADICTS another sentence.
591
+ This avoids external ideological bias by only checking within-document
592
+ coherence.
593
+
594
+ Scoring:
595
+ No contradictions found: +1.0
596
+ Internal contradiction detected: -2.0
597
+
598
+ Research basis: arxiv.org/abs/2508.05170 (process-based rewards)
599
+ """
600
+ nli = get_nli_pipeline()
601
+ nlp = get_spacy_nlp()
602
+ scores: list[float] = []
603
+
604
+ for completion in completions:
605
+ response = completion[0]["content"]
606
+
607
+ # Parse into sentences
608
+ doc = nlp(response)
609
+ sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]
610
+
611
+ # Need at least 2 sentences to check consistency
612
+ if len(sentences) < 2:
613
+ scores.append(0.0)
614
+ continue
615
+
616
+ # Check pairs of sentences for contradictions
617
+ # (Only check adjacent and near-adjacent to limit compute)
618
+ has_contradiction = False
619
+ max_pairs_to_check = 10
620
+ pairs_checked = 0
621
+
622
+ for i, sent_a in enumerate(sentences[:-1]):
623
+ if pairs_checked >= max_pairs_to_check:
624
+ break
625
+ # Check against next 2 sentences
626
+ for j in range(i + 1, min(i + 3, len(sentences))):
627
+ sent_b = sentences[j]
628
+ try:
629
+ input_text = f"{sent_a[:256]}</s></s>{sent_b[:256]}"
630
+ result = nli(input_text)[0]
631
+ if result["label"].lower() == "contradiction":
632
+ has_contradiction = True
633
+ break
634
+ pairs_checked += 1
635
+ except Exception:
636
+ pass
637
+ if has_contradiction:
638
+ break
639
+
640
+ if has_contradiction:
641
+ scores.append(-2.0)
642
+ else:
643
+ scores.append(1.0)
644
+
645
+ return scores
646
+
647
+
648
+ def structural_coherence_reward(
649
+ completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
650
+ ) -> list[float]:
651
+ """
652
+ Reward responses with proper linguistic structure.
653
+
654
+ Uses spaCy dependency parsing to verify:
655
+ 1. Marxist terms appear in meaningful syntactic roles (subject, object)
656
+ 2. Response contains logical discourse connectives
657
+ 3. Response has proper sentence structure (not word soup)
658
+
659
+ This defeats word soup because random terms won't be in subject/object
660
+ positions - they'll be parsed as fragments.
661
+
662
+ Scoring:
663
+ +0.3 per term in subject/object position (max +1.5)
664
+ +0.2 per discourse connective (max +1.0)
665
+ -1.0 if no complete sentences detected
666
+
667
+ Research basis: spaCy dependency parsing for coherence evaluation
668
+ """
669
+ nlp = get_spacy_nlp()
670
+ scores: list[float] = []
671
+
672
+ for completion in completions:
673
+ response = completion[0]["content"]
674
+ doc = nlp(response)
675
+ score = 0.0
676
+
677
+ # Check 1: Are there actual sentences?
678
+ sentences = list(doc.sents)
679
+ if len(sentences) < 1:
680
+ scores.append(-1.0)
681
+ continue
682
+
683
+ # Check 2: Marxist terms in meaningful syntactic roles
684
+ terms_in_context = 0
685
+ response_lower = response.lower()
686
+
687
+ for term in MARXIST_TERMS:
688
+ if term not in response_lower:
689
+ continue
690
+
691
+ # Find tokens matching this term
692
+ for token in doc:
693
+ if term in token.text.lower() or (
694
+ token.i + 1 < len(doc)
695
+ and term in f"{token.text} {doc[token.i + 1].text}".lower()
696
+ ):
697
+ # Reward if token is in a meaningful syntactic role
698
+ if token.dep_ in (
699
+ "nsubj", # nominal subject
700
+ "nsubjpass", # passive nominal subject
701
+ "dobj", # direct object
702
+ "pobj", # object of preposition
703
+ "attr", # attribute
704
+ "appos", # appositional modifier
705
+ ):
706
+ terms_in_context += 1
707
+ break # Count each term once
708
+ # Also reward if connected to a meaningful verb
709
+ elif token.head.pos_ == "VERB" and token.head.dep_ == "ROOT":
710
+ terms_in_context += 1
711
+ break
712
+
713
+ score += min(terms_in_context * 0.3, 1.5)
714
+
715
+ # Check 3: Discourse connectives (indicates logical structure)
716
+ connective_count = sum(1 for conn in DISCOURSE_CONNECTIVES if conn in response_lower)
717
+ score += min(connective_count * 0.2, 1.0)
718
+
719
+ scores.append(score)
720
+
721
+ return scores
722
+
723
+
724
+ # =============================================================================
725
+ # COMBINED ROBUST COHERENCE REWARD
726
+ # =============================================================================
727
+
728
+
729
+ def robust_coherence_reward(
730
+ completions: Sequence[Sequence[dict[str, str]]],
731
+ answer: Sequence[str],
732
+ **kwargs: object,
733
+ ) -> list[float]:
734
+ """
735
+ Multi-layered coherence check combining NLI, self-consistency, and structure.
736
+
737
+ This is the recommended reward function for robust evaluation that defeats
738
+ reward hacking via word soup or other adversarial strategies.
739
+
740
+ Layers:
741
+ 1. NLI coherence: Does response entail ground truth?
742
+ 2. Self-consistency: Does response contradict itself?
743
+ 3. Structural coherence: Are terms used in meaningful syntactic roles?
744
+
745
+ Scoring (combined):
746
+ NLI entailment + self-consistent + good structure: up to +5.5
747
+ NLI neutral or contradiction: -1.0 to -3.0
748
+ Internal contradiction: -2.0
749
+ Word soup (no structure): -1.0
750
+ """
751
+ # Get individual scores
752
+ nli_scores = nli_coherence_reward(completions, answer, **kwargs)
753
+ consistency_scores = self_consistency_reward(completions, **kwargs)
754
+ structure_scores = structural_coherence_reward(completions, **kwargs)
755
+
756
+ # Combine with weights
757
+ combined: list[float] = []
758
+ for nli, consistency, structure in zip(
759
+ nli_scores, consistency_scores, structure_scores, strict=False
760
+ ):
761
+ # If NLI shows contradiction, heavily penalize regardless of other scores
762
+ if nli <= -3.0:
763
+ combined.append(-3.0)
764
+ # If internal contradiction, penalize
765
+ elif consistency <= -2.0:
766
+ combined.append(-2.0)
767
+ # Otherwise combine scores
768
+ else:
769
+ # NLI is primary signal, structure and consistency are bonuses
770
+ total = nli + (consistency * 0.5) + (structure * 0.5)
771
+ combined.append(total)
772
+
773
+ return combined
774
+
775
+
776
+ # =============================================================================
777
+ # TOPIC RELEVANCE REWARD (Question-Answer Alignment)
778
+ # =============================================================================
779
+
780
+
781
+ def _extract_noun_with_preps(token: Any) -> set[str]:
782
+ """
783
+ Extract a noun and its prepositional phrase children.
784
+
785
+ For "dictatorship of the proletariat", returns:
786
+ {"dictatorship", "proletariat", "dictatorship of the proletariat"}
787
+ """
788
+ topics: set[str] = set()
789
+
790
+ # Add the main noun (lemmatized)
791
+ if token.pos_ in ("NOUN", "PROPN"):
792
+ topics.add(token.lemma_.lower())
793
+
794
+ # Check for compound modifiers (e.g., "surplus value" where "surplus" is amod)
795
+ modifiers = []
796
+ for child in token.children:
797
+ if child.dep_ in ("compound", "amod") and child.pos_ in ("NOUN", "ADJ"):
798
+ modifiers.append(child.text.lower())
799
+
800
+ if modifiers:
801
+ full_term = " ".join([*modifiers, token.text.lower()])
802
+ topics.add(full_term)
803
+
804
+ # Follow prepositional phrases (e.g., "of the proletariat")
805
+ for child in token.children:
806
+ if child.dep_ == "prep":
807
+ for pobj in child.children:
808
+ if pobj.dep_ == "pobj":
809
+ topics.add(pobj.lemma_.lower())
810
+ # Build full phrase: "dictatorship of the proletariat"
811
+ full_phrase = f"{token.text.lower()} {child.text} {pobj.text.lower()}"
812
+ topics.add(full_phrase)
813
+ # Also get nested preps
814
+ topics.update(_extract_noun_with_preps(pobj))
815
+
816
+ return topics
817
+
818
+
819
+ def _extract_question_topics(doc: Any) -> set[str]:
820
+ """
821
+ Extract the core topics from a question using spaCy dependency parsing.
822
+
823
+ For "What is revisionism?", extracts {"revisionism"}
824
+ For "How does imperialism relate to capitalism?", extracts {"imperialism", "capitalism"}
825
+ For "What is the dictatorship of the proletariat?", extracts
826
+ {"dictatorship", "proletariat", "dictatorship of the proletariat"}
827
+ """
828
+ topics: set[str] = set()
829
+
830
+ # Find the ROOT
831
+ root = None
832
+ for token in doc:
833
+ if token.dep_ == "ROOT":
834
+ root = token
835
+ break
836
+
837
+ if root:
838
+ # Extract from ROOT's children
839
+ for child in root.children:
840
+ # nsubj: "What is [revisionism]?" - revisionism is subject
841
+ # dobj: "Explain [the concept]" - concept is direct object
842
+ # attr: less common but possible
843
+ # nsubjpass: passive subject
844
+ if child.dep_ in ("nsubj", "dobj", "attr", "nsubjpass"):
845
+ # Skip question words ("What is X" - skip "What")
846
+ if child.text.lower() in QUESTION_WORDS:
847
+ continue
848
+ topics.update(_extract_noun_with_preps(child))
849
+
850
+ # pobj in prep attached to ROOT: "relate to [capitalism]"
851
+ if child.dep_ == "prep":
852
+ for pobj in child.children:
853
+ if pobj.dep_ == "pobj":
854
+ topics.update(_extract_noun_with_preps(pobj))
855
+
856
+ # Fallback: extract all noun chunks except question words
857
+ if not topics:
858
+ for chunk in doc.noun_chunks:
859
+ root_text = chunk.root.text.lower()
860
+ if root_text not in QUESTION_WORDS:
861
+ topics.add(chunk.root.lemma_.lower())
862
+ # Also add full chunk for multi-word terms
863
+ chunk_text = chunk.text.lower().strip()
864
+ if " " in chunk_text:
865
+ topics.add(chunk_text)
866
+
867
+ # Final cleanup: remove question words that might have slipped through
868
+ topics = {t for t in topics if t not in QUESTION_WORDS}
869
+
870
+ return topics
871
+
872
+
873
+ def _extract_answer_topics(doc: Any) -> set[str]:
874
+ """
875
+ Extract topics discussed in an answer using spaCy.
876
+
877
+ Returns lemmatized noun phrases and named entities.
878
+ Strips determiners (the, a, an) for better matching.
879
+ """
880
+ topics: set[str] = set()
881
+
882
+ # Determiners to strip from multi-word phrases
883
+ determiners = {"the", "a", "an", "this", "that", "these", "those"}
884
+
885
+ # Get noun chunk roots (lemmatized)
886
+ for chunk in doc.noun_chunks:
887
+ topics.add(chunk.root.lemma_.lower())
888
+
889
+ # Multi-word terms (strip leading determiners)
890
+ words = chunk.text.lower().strip().split()
891
+ if words and words[0] in determiners:
892
+ words = words[1:]
893
+ chunk_text = " ".join(words)
894
+
895
+ if " " in chunk_text and len(chunk_text) < 50:
896
+ topics.add(chunk_text)
897
+
898
+ # Get named entities
899
+ for ent in doc.ents:
900
+ ent_text = ent.text.lower()
901
+ # Strip leading determiners from entities too
902
+ words = ent_text.split()
903
+ if words and words[0] in determiners:
904
+ words = words[1:]
905
+ topics.add(" ".join(words))
906
+
907
+ return topics
908
+
909
+
910
+ def _expand_with_synonyms(topics: set[str]) -> set[str]:
911
+ """
912
+ Expand a set of topics with Marxist concept synonyms.
913
+
914
+ If "bourgeoisie" is in topics, also adds "capitalist class", "ruling class", etc.
915
+ """
916
+ expanded = set(topics)
917
+
918
+ for topic in topics:
919
+ # Check if topic matches any canonical term
920
+ if topic in CONCEPT_EQUIVALENCES:
921
+ expanded.update(CONCEPT_EQUIVALENCES[topic])
922
+ # Check if topic matches any synonym (reverse lookup)
923
+ for canonical, synonyms in CONCEPT_EQUIVALENCES.items():
924
+ if topic in synonyms or topic == canonical:
925
+ expanded.add(canonical)
926
+ expanded.update(synonyms)
927
+
928
+ return expanded
929
+
930
+
931
+ def _compute_topic_coverage(q_topics: set[str], a_topics: set[str], nlp: Any) -> float:
932
+ """
933
+ Compute how well answer topics cover question topics.
934
+
935
+ Uses:
936
+ 1. Direct lemma matching
937
+ 2. Expanded synonym matching
938
+ 3. spaCy word vector similarity (fallback)
939
+
940
+ Returns coverage score 0.0 to 1.0
941
+ """
942
+ if not q_topics:
943
+ return 0.5 # Can't evaluate, neutral
944
+
945
+ # Expand question topics with synonyms
946
+ q_expanded = _expand_with_synonyms(q_topics)
947
+
948
+ # Direct/synonym match
949
+ matched = q_expanded & a_topics
950
+ direct_coverage = len(matched) / len(q_topics) if q_topics else 0
951
+
952
+ if direct_coverage >= 0.5:
953
+ return min(direct_coverage, 1.0)
954
+
955
+ # Fallback: semantic similarity using spaCy vectors
956
+ # For unmatched q_topics, check if any a_topic is semantically similar
957
+ unmatched_q = q_topics - matched
958
+ semantic_matches = 0
959
+
960
+ for q_topic in unmatched_q:
961
+ q_token = nlp(q_topic)
962
+ if not q_token.has_vector:
963
+ continue
964
+
965
+ best_sim = 0.0
966
+ for a_topic in a_topics:
967
+ a_token = nlp(a_topic)
968
+ if a_token.has_vector:
969
+ sim = q_token.similarity(a_token)
970
+ best_sim = max(best_sim, sim)
971
+
972
+ if best_sim > 0.6: # Threshold for semantic match
973
+ semantic_matches += 1
974
+
975
+ total_matched = len(matched) + semantic_matches
976
+ return min(total_matched / len(q_topics), 1.0) if q_topics else 0.5
977
+
978
+
979
+ def topic_relevance_reward(
980
+ prompts: Sequence[Sequence[dict[str, str]]],
981
+ completions: Sequence[Sequence[dict[str, str]]],
982
+ **kwargs: object,
983
+ ) -> list[float]:
984
+ """
985
+ Reward answers that are ON-TOPIC with respect to the question.
986
+
987
+ Implements f(A) ⊆ f(Q) check where f extracts semantic topics:
988
+ 1. Extract core topics from question Q using dependency parsing
989
+ 2. Expand Q topics with Marxist concept synonyms
990
+ 3. Extract topics from answer A
991
+ 4. Compute coverage: how many Q topics are addressed in A
992
+
993
+ Scoring:
994
+ > 80% coverage: +2.0 (answer fully addresses question topics)
995
+ > 60% coverage: +1.5 (answer mostly on-topic)
996
+ > 40% coverage: +1.0 (answer partially on-topic)
997
+ > 20% coverage: 0.0 (answer tangentially related)
998
+ <= 20% coverage: -1.5 (answer off-topic)
999
+
1000
+ This reward ensures the model answers WHAT WAS ASKED, not just
1001
+ generates coherent Marxist text about something else.
1002
+ """
1003
+ nlp = get_spacy_nlp()
1004
+ scores: list[float] = []
1005
+
1006
+ for prompt, completion in zip(prompts, completions, strict=False):
1007
+ # Extract question (last user message)
1008
+ question = prompt[-1]["content"]
1009
+ response = completion[0]["content"]
1010
+
1011
+ # Extract answer part after </think>
1012
+ if REASONING_END in response:
1013
+ response = response.split(REASONING_END, 1)[1].strip()
1014
+
1015
+ # Handle empty response
1016
+ if not response or len(response.strip()) < 20:
1017
+ scores.append(-1.5)
1018
+ continue
1019
+
1020
+ # Parse with spaCy
1021
+ q_doc = nlp(question)
1022
+ a_doc = nlp(response[:2000]) # Limit for performance
1023
+
1024
+ # Extract topics
1025
+ q_topics = _extract_question_topics(q_doc)
1026
+ a_topics = _extract_answer_topics(a_doc)
1027
+
1028
+ # Handle case where no topics extracted from question
1029
+ if not q_topics:
1030
+ # Fallback: just check if answer has substance
1031
+ scores.append(0.5 if len(a_topics) > 3 else 0.0)
1032
+ continue
1033
+
1034
+ # Compute coverage
1035
+ coverage = _compute_topic_coverage(q_topics, a_topics, nlp)
1036
+
1037
+ # Convert to reward score
1038
+ if coverage > 0.8:
1039
+ score = 2.0
1040
+ elif coverage > 0.6:
1041
+ score = 1.5
1042
+ elif coverage > 0.4:
1043
+ score = 1.0
1044
+ elif coverage > 0.2:
1045
+ score = 0.0
1046
+ else:
1047
+ score = -1.5
1048
+
1049
+ scores.append(score)
1050
+
1051
+ return scores
1052
+
1053
+
1054
+ def full_coherence_reward(
1055
+ prompts: Sequence[Sequence[dict[str, str]]],
1056
+ completions: Sequence[Sequence[dict[str, str]]],
1057
+ answer: Sequence[str],
1058
+ **kwargs: object,
1059
+ ) -> list[float]:
1060
+ """
1061
+ Complete coherence check: robust_coherence + topic_relevance + depth.
1062
+
1063
+ This is the MOST COMPREHENSIVE reward function, checking:
1064
+ 1. NLI coherence (A entails ground truth)
1065
+ 2. Self-consistency (A doesn't contradict itself)
1066
+ 3. Structural coherence (terms in proper syntactic roles)
1067
+ 4. Topic relevance (A addresses what Q asked about)
1068
+ 5. Interconnection depth (rewards deep analysis, penalizes buzzword salad)
1069
+
1070
+ Use this for maximum robustness against reward hacking.
1071
+ """
1072
+ robust_scores = robust_coherence_reward(completions, answer, **kwargs)
1073
+ relevance_scores = topic_relevance_reward(prompts, completions, **kwargs)
1074
+ depth_scores = interconnection_depth_reward(completions, **kwargs)
1075
+
1076
+ combined: list[float] = []
1077
+ for robust, relevance, depth in zip(
1078
+ robust_scores, relevance_scores, depth_scores, strict=False
1079
+ ):
1080
+ # If severely off-topic, penalize
1081
+ if relevance <= -1.5:
1082
+ combined.append(-2.0)
1083
+ # If robust check failed badly, use that
1084
+ elif robust <= -2.0:
1085
+ combined.append(robust)
1086
+ # If buzzword salad detected (low depth), penalize
1087
+ elif depth <= -1.5:
1088
+ combined.append(-1.5)
1089
+ # Otherwise combine
1090
+ else:
1091
+ # Robust is primary, relevance and depth are bonuses/penalties
1092
+ total = robust + (relevance * 0.4) + (depth * 0.3)
1093
+ combined.append(total)
1094
+
1095
+ return combined
1096
+
1097
+
1098
+ # =============================================================================
1099
+ # INTERCONNECTION DEPTH REWARD (Anti-Buzzword-Salad)
1100
+ # =============================================================================
1101
+
1102
+
1103
+ def _count_unique_marxist_concepts(text: str) -> int:
1104
+ """Count unique Marxist concepts mentioned in text."""
1105
+ text_lower = text.lower()
1106
+ count = 0
1107
+ for term in MARXIST_TERMS:
1108
+ if term in text_lower:
1109
+ count += 1
1110
+ return count
1111
+
1112
+
1113
+ def _compute_depth_ratio(text: str) -> float:
1114
+ """
1115
+ Compute depth ratio: words per unique Marxist concept.
1116
+
1117
+ High ratio = deep analysis (few concepts, well explained)
1118
+ Low ratio = shallow/buzzword soup (many concepts, little explanation)
1119
+
1120
+ Returns:
1121
+ Words per concept, or 100.0 if no Marxist concepts found
1122
+ """
1123
+ words = len(text.split())
1124
+ concepts = _count_unique_marxist_concepts(text)
1125
+
1126
+ if concepts == 0:
1127
+ return 100.0 # No Marxist concepts = neutral (not shallow)
1128
+
1129
+ return words / concepts
1130
+
1131
+
1132
+ def _count_hollow_buzzwords(text: str) -> int:
1133
+ """Count hollow buzzwords in text."""
1134
+ text_lower = text.lower()
1135
+ count = 0
1136
+ for buzzword in HOLLOW_BUZZWORDS:
1137
+ if buzzword in text_lower:
1138
+ count += 1
1139
+ return count
1140
+
1141
+
1142
+ def _count_depth_markers(text: str) -> int:
1143
+ """Count analytical depth markers in text."""
1144
+ text_lower = text.lower()
1145
+ count = 0
1146
+ for marker in DEPTH_MARKERS:
1147
+ if marker in text_lower:
1148
+ count += 1
1149
+ return count
1150
+
1151
+
1152
+ def _count_explanatory_phrases(text: str) -> int:
1153
+ """Count explanatory phrases in text."""
1154
+ text_lower = text.lower()
1155
+ count = 0
1156
+ for phrase in EXPLANATORY_PHRASES:
1157
+ if phrase in text_lower:
1158
+ count += 1
1159
+ return count
1160
+
1161
+
1162
+ def _concepts_have_explanations(text: str) -> tuple[int, int]:
1163
+ """
1164
+ Check if introduced concepts have nearby explanations.
1165
+
1166
+ Returns:
1167
+ Tuple of (explained_count, unexplained_count)
1168
+ """
1169
+ nlp = get_spacy_nlp()
1170
+ doc = nlp(text)
1171
+
1172
+ # Get sentences
1173
+ sentences = [sent.text.lower() for sent in doc.sents]
1174
+
1175
+ explained = 0
1176
+ unexplained = 0
1177
+
1178
+ for i, sent in enumerate(sentences):
1179
+ # Check which Marxist concepts appear in this sentence
1180
+ concepts_in_sent = [t for t in MARXIST_TERMS if t in sent]
1181
+
1182
+ for _concept in concepts_in_sent:
1183
+ # Check if explanatory phrase appears in same or adjacent sentence
1184
+ has_explanation = False
1185
+
1186
+ # Check current sentence
1187
+ for phrase in EXPLANATORY_PHRASES:
1188
+ if phrase in sent:
1189
+ has_explanation = True
1190
+ break
1191
+
1192
+ # Check next sentence if exists
1193
+ if not has_explanation and i + 1 < len(sentences):
1194
+ next_sent = sentences[i + 1]
1195
+ for phrase in EXPLANATORY_PHRASES:
1196
+ if phrase in next_sent:
1197
+ has_explanation = True
1198
+ break
1199
+
1200
+ if has_explanation:
1201
+ explained += 1
1202
+ else:
1203
+ unexplained += 1
1204
+
1205
+ return explained, unexplained
1206
+
1207
+
1208
+ def interconnection_depth_reward(
1209
+ completions: Sequence[Sequence[dict[str, str]]], **kwargs: object
1210
+ ) -> list[float]:
1211
+ """
1212
+ Reward deep, meaningful interconnections; penalize buzzword salad.
1213
+
1214
+ This reward distinguishes between:
1215
+ - GOOD: "Surplus value relates to imperialism BECAUSE capital export..."
1216
+ - BAD: "Surplus value intersects with imperialism, colonialism, patriarchy..."
1217
+
1218
+ Signals:
1219
+ 1. Depth ratio: words per unique Marxist concept
1220
+ - High (>15): Deep analysis, concepts well-explained
1221
+ - Low (<5): Shallow buzzword soup (many concepts crammed together)
1222
+ 2. Hollow buzzword density: activist jargon without substance
1223
+ 3. Depth markers: citations, examples, historical specificity
1224
+ 4. Explanation ratio: concepts with nearby explanatory phrases
1225
+
1226
+ Scoring:
1227
+ Depth ratio > 20: +1.0 (deep analysis)
1228
+ Depth ratio 10-20: +0.5 (adequate depth)
1229
+ Depth ratio < 5: -1.5 (severe buzzword soup)
1230
+ Depth ratio 5-10: -0.5 (shallow)
1231
+ Hollow buzzwords > 2: -0.3 each additional
1232
+ Depth markers present: +0.3 each (max +1.5)
1233
+ Good explanation ratio: +0.5
1234
+ Low explanation ratio with many concepts: -0.5
1235
+
1236
+ Total range: approximately -2.5 to +3.0
1237
+ """
1238
+ scores: list[float] = []
1239
+
1240
+ for completion in completions:
1241
+ response = completion[0]["content"]
1242
+
1243
+ # Extract answer part after </think>
1244
+ if REASONING_END in response:
1245
+ answer_part = response.split(REASONING_END, 1)[1].strip()
1246
+ else:
1247
+ answer_part = response
1248
+
1249
+ # Skip very short responses (handled by completeness_reward)
1250
+ word_count = len(answer_part.split())
1251
+ if word_count < 20:
1252
+ scores.append(0.0)
1253
+ continue
1254
+
1255
+ score = 0.0
1256
+ concept_count = _count_unique_marxist_concepts(answer_part)
1257
+
1258
+ # Signal 1: Depth ratio (words per concept)
1259
+ # Only penalize if there are concepts to evaluate
1260
+ if concept_count > 0:
1261
+ depth_ratio = word_count / concept_count
1262
+ if depth_ratio > 20:
1263
+ score += 1.0 # Deep analysis
1264
+ elif depth_ratio > 10:
1265
+ score += 0.5 # Adequate depth
1266
+ elif depth_ratio < 5:
1267
+ score -= 1.5 # Severe buzzword soup (many concepts, few words)
1268
+ elif depth_ratio < 10:
1269
+ score -= 0.5 # Shallow
1270
+
1271
+ # Signal 2: Hollow buzzword penalty
1272
+ hollow_count = _count_hollow_buzzwords(answer_part)
1273
+ if hollow_count > 2:
1274
+ # Penalize excess hollow buzzwords
1275
+ penalty = 0.3 * (hollow_count - 2)
1276
+ score -= min(penalty, 1.5) # Cap penalty at -1.5
1277
+
1278
+ # Signal 3: Depth markers bonus
1279
+ depth_marker_count = _count_depth_markers(answer_part)
1280
+ score += min(depth_marker_count * 0.3, 1.5)
1281
+
1282
+ # Signal 4: Explanation ratio
1283
+ explanatory_count = _count_explanatory_phrases(answer_part)
1284
+
1285
+ if concept_count > 0:
1286
+ explanation_ratio = explanatory_count / concept_count
1287
+ if explanation_ratio >= 0.5:
1288
+ score += 0.5 # Good: at least 1 explanation per 2 concepts
1289
+ elif explanation_ratio < 0.1 and concept_count > 5:
1290
+ score -= 0.5 # Bad: many concepts, almost no explanations
1291
+
1292
+ # Clamp final score
1293
+ scores.append(max(min(score, 3.0), -2.5))
1294
+
1295
+ return scores
1296
+
1297
+
1298
+ # =============================================================================
1299
+ # RESPONSE COMPLETENESS REWARD
1300
+ # =============================================================================
1301
+
1302
+
1303
+ def completeness_reward(
1304
+ completions: Sequence[Sequence[dict[str, str]]],
1305
+ answer: Sequence[str],
1306
+ **kwargs: object,
1307
+ ) -> list[float]:
1308
+ """
1309
+ Reward thorough, detailed responses.
1310
+
1311
+ Compares response length to ground truth length.
1312
+
1313
+ Scoring:
1314
+ 50-150% of target length: +2.0
1315
+ 30-200% of target length: +1.0
1316
+ < 20% (too short): -2.0
1317
+ > 200% (too verbose): -0.5
1318
+ """
1319
+ scores: list[float] = []
1320
+
1321
+ for completion, true_answer in zip(completions, answer, strict=False):
1322
+ response = completion[0]["content"]
1323
+
1324
+ # Extract answer after </think>
1325
+ if REASONING_END in response:
1326
+ answer_part = response.split(REASONING_END, 1)[1].strip()
1327
+ else:
1328
+ answer_part = response
1329
+
1330
+ answer_len = len(answer_part.split())
1331
+ true_len = len(true_answer.split())
1332
+
1333
+ # Avoid division by zero
1334
+ if true_len == 0:
1335
+ scores.append(0.0)
1336
+ continue
1337
+
1338
+ # Reward responses that are 50-150% of target length
1339
+ ratio = answer_len / true_len
1340
+
1341
+ if 0.5 <= ratio <= 1.5:
1342
+ score = 2.0
1343
+ elif 0.3 <= ratio <= 2.0:
1344
+ score = 1.0
1345
+ elif ratio < 0.2: # Too short
1346
+ score = -2.0
1347
+ else: # Too long (verbose)
1348
+ score = -0.5
1349
+
1350
+ scores.append(score)
1351
+
1352
+ return scores
1353
+
1354
+
1355
+ # =============================================================================
1356
+ # DEBUG REWARD (for monitoring during training)
1357
+ # =============================================================================
1358
+
1359
+ # Global counter for printing samples
1360
+ _PRINT_COUNTER = 0
1361
+ _PRINT_EVERY = 10
1362
+
1363
+
1364
+ def debug_print_reward(
1365
+ prompts: Sequence[Sequence[dict[str, str]]],
1366
+ completions: Sequence[Sequence[dict[str, str]]],
1367
+ answer: Sequence[str],
1368
+ **kwargs: object,
1369
+ ) -> list[float]:
1370
+ """
1371
+ Print sample outputs periodically for monitoring.
1372
+
1373
+ Returns 0.0 (no effect on training).
1374
+ """
1375
+ global _PRINT_COUNTER
1376
+
1377
+ if _PRINT_COUNTER % _PRINT_EVERY == 0:
1378
+ question = prompts[0][-1]["content"]
1379
+ response = completions[0][0]["content"]
1380
+ true_answer = answer[0]
1381
+
1382
+ print("=" * 60)
1383
+ print(f"Step {_PRINT_COUNTER}")
1384
+ print(f"Question: {question[:100]}...")
1385
+ print(f"Response: {response[:200]}...")
1386
+ print(f"Expected: {true_answer[:100]}...")
1387
+ print("=" * 60)
1388
+
1389
+ _PRINT_COUNTER += 1
1390
+
1391
+ return [0.0] * len(completions)
1392
+
1393
+
1394
+ # =============================================================================
1395
+ # ENTITY VERIFICATION REWARD (Anti-Hallucination)
1396
+ # =============================================================================
1397
+
1398
+ # Lazy-loaded entity whitelist
1399
+ _ENTITY_WHITELIST: set[str] | None = None
1400
+ _ENTITY_WHITELIST_LOWERCASE: set[str] | None = None
1401
+
1402
+
1403
+ def _load_entity_whitelist() -> tuple[set[str], set[str]]:
1404
+ """Load entity whitelist from JSON file."""
1405
+ global _ENTITY_WHITELIST, _ENTITY_WHITELIST_LOWERCASE
1406
+
1407
+ if _ENTITY_WHITELIST is None or _ENTITY_WHITELIST_LOWERCASE is None:
1408
+ import json
1409
+ from pathlib import Path
1410
+
1411
+ whitelist_path = (
1412
+ Path(__file__).parent.parent.parent.parent
1413
+ / "training_data"
1414
+ / "entity_whitelist_clean.json"
1415
+ )
1416
+
1417
+ if whitelist_path.exists():
1418
+ print(f"[Reward] Loading entity whitelist from {whitelist_path}...")
1419
+ with open(whitelist_path, encoding="utf-8") as f:
1420
+ data = json.load(f)
1421
+ _ENTITY_WHITELIST = set(data.get("entities", []))
1422
+ _ENTITY_WHITELIST_LOWERCASE = set(data.get("entities_lowercase", []))
1423
+ print(f"[Reward] Loaded {len(_ENTITY_WHITELIST):,} entities")
1424
+ else:
1425
+ print(f"[Reward] WARNING: Entity whitelist not found at {whitelist_path}")
1426
+ _ENTITY_WHITELIST = set()
1427
+ _ENTITY_WHITELIST_LOWERCASE = set()
1428
+
1429
+ return _ENTITY_WHITELIST, _ENTITY_WHITELIST_LOWERCASE
1430
+
1431
+
1432
+ def _entity_in_whitelist(entity: str) -> bool:
1433
+ """Check if an entity is in the whitelist (case-insensitive)."""
1434
+ whitelist, whitelist_lower = _load_entity_whitelist()
1435
+ return entity in whitelist or entity.lower() in whitelist_lower
1436
+
1437
+
1438
+ # Patterns that indicate confident factual claims
1439
+ CONFIDENT_CLAIM_PATTERNS = [
1440
+ r"founded in \d{4}",
1441
+ r"established in \d{4}",
1442
+ r"created in \d{4}",
1443
+ r"formed in \d{4}",
1444
+ r"was founded by",
1445
+ r"was established by",
1446
+ r"was created by",
1447
+ r"were founded in",
1448
+ r"were established in",
1449
+ ]
1450
+
1451
+ # Patterns that indicate epistemic humility (GOOD)
1452
+ UNCERTAINTY_PATTERNS = [
1453
+ r"I (?:cannot|can't|don't) (?:verify|confirm|find)",
1454
+ r"I (?:don't|do not) have (?:verified |specific )?information",
1455
+ r"I'm not (?:certain|sure|confident)",
1456
+ r"I cannot (?:provide|give) (?:specific |verified )?information",
1457
+ r"I should not (?:fabricate|make up|speculate)",
1458
+ r"(?:could you|can you) (?:provide|share|tell me) (?:more )?context",
1459
+ r"where did you (?:encounter|find|see) this",
1460
+ r"I'm not aware of",
1461
+ r"I don't have (?:details|information) about",
1462
+ ]
1463
+
1464
+
1465
+ def _extract_potential_entities(text: str) -> list[str]:
1466
+ """Extract potential organization/person names from text using spaCy NER."""
1467
+ nlp = get_spacy_nlp()
1468
+ doc = nlp(text[:10000]) # Limit to prevent slow processing
1469
+
1470
+ entities = []
1471
+ for ent in doc.ents:
1472
+ if ent.label_ in ("ORG", "PERSON", "GPE", "NORP", "EVENT", "FAC", "WORK_OF_ART"):
1473
+ entities.append(ent.text)
1474
+
1475
+ return entities
1476
+
1477
+
1478
+ def entity_verification_reward(
1479
+ prompts: Sequence[Sequence[dict[str, str]]],
1480
+ completions: Sequence[Sequence[dict[str, str]]],
1481
+ answer: Sequence[str],
1482
+ **kwargs: object,
1483
+ ) -> list[float]:
1484
+ """
1485
+ Reward epistemic humility, penalize confident claims about unverified entities.
1486
+
1487
+ Scoring:
1488
+ - +2.0: Response expresses uncertainty about unknown entities
1489
+ - +1.0: Response discusses only verified entities
1490
+ - -1.5: Response makes confident claims about unknown entities
1491
+ - -2.5: Response fabricates specific details (dates, founders) about unknown entities
1492
+
1493
+ This reward requires the entity_whitelist_clean.json file in training_data/.
1494
+ """
1495
+ scores: list[float] = []
1496
+
1497
+ for completion in completions:
1498
+ response = completion[0]["content"]
1499
+ score = 0.0
1500
+
1501
+ # Extract entities mentioned in the response
1502
+ mentioned_entities = _extract_potential_entities(response)
1503
+
1504
+ # Check for unknown entities (not in whitelist)
1505
+ unknown_entities = [e for e in mentioned_entities if not _entity_in_whitelist(e)]
1506
+
1507
+ # Check for uncertainty patterns (epistemic humility)
1508
+ has_uncertainty = any(
1509
+ re.search(pattern, response, re.IGNORECASE) for pattern in UNCERTAINTY_PATTERNS
1510
+ )
1511
+
1512
+ # Check for confident claim patterns
1513
+ has_confident_claims = any(
1514
+ re.search(pattern, response, re.IGNORECASE) for pattern in CONFIDENT_CLAIM_PATTERNS
1515
+ )
1516
+
1517
+ if unknown_entities:
1518
+ # There are entities not in our whitelist
1519
+ if has_uncertainty:
1520
+ # GOOD: Model expresses uncertainty about unknown entities
1521
+ score = 2.0
1522
+ elif has_confident_claims:
1523
+ # BAD: Model makes confident claims about unknown entities
1524
+ # Check if confident claims are near unknown entities
1525
+ score = -2.5
1526
+ else:
1527
+ # NEUTRAL-BAD: Discussing unknown entities without clear uncertainty
1528
+ score = -1.0
1529
+ else:
1530
+ # All entities are verified or no specific entities mentioned
1531
+ # Unnecessary uncertainty = 0.0, normal verified content = 1.0
1532
+ score = 0.0 if has_uncertainty else 1.0
1533
+
1534
+ scores.append(score)
1535
+
1536
+ return scores
1537
+
1538
+
1539
+ def epistemic_calibration_reward(
1540
+ prompts: Sequence[Sequence[dict[str, str]]],
1541
+ completions: Sequence[Sequence[dict[str, str]]],
1542
+ answer: Sequence[str],
1543
+ **kwargs: object,
1544
+ ) -> list[float]:
1545
+ """
1546
+ Combined reward for epistemic calibration.
1547
+
1548
+ Encourages the model to:
1549
+ 1. Be confident about entities in the verified corpus
1550
+ 2. Express uncertainty about entities NOT in the corpus
1551
+ 3. Avoid fabricating specific details about unknown entities
1552
+
1553
+ This is a lighter-weight version of entity_verification_reward
1554
+ that doesn't require NER - just pattern matching.
1555
+
1556
+ Scoring:
1557
+ - +1.5: Has uncertainty phrases (regardless of content)
1558
+ - -1.5: Has confident claim patterns + no uncertainty
1559
+ - 0.0: Neutral
1560
+ """
1561
+ scores: list[float] = []
1562
+
1563
+ for completion in completions:
1564
+ response = completion[0]["content"]
1565
+
1566
+ # Check patterns
1567
+ has_uncertainty = any(
1568
+ re.search(pattern, response, re.IGNORECASE) for pattern in UNCERTAINTY_PATTERNS
1569
+ )
1570
+
1571
+ has_confident_claims = any(
1572
+ re.search(pattern, response, re.IGNORECASE) for pattern in CONFIDENT_CLAIM_PATTERNS
1573
+ )
1574
+
1575
+ if has_uncertainty:
1576
+ # Model expressed appropriate caution
1577
+ scores.append(1.5)
1578
+ elif has_confident_claims:
1579
+ # Model made confident claims - could be hallucination
1580
+ # This is a soft penalty; entity_verification_reward does deeper checking
1581
+ scores.append(-0.5)
1582
+ else:
1583
+ # Neutral
1584
+ scores.append(0.0)
1585
+
1586
+ return scores
1587
+
1588
+
1589
+ # =============================================================================
1590
+ # IDEOLOGICAL FIRMNESS REWARD (Anti-Capitulation)
1591
+ # =============================================================================
1592
+
1593
+ # Patterns that detect ideological challenges requiring firm response
1594
+ IDEOLOGICAL_CHALLENGE_PATTERNS: list[re.Pattern[str]] = [
1595
+ # Reactionary claims about fascism/Nazis
1596
+ re.compile(r"hitler\s+help", re.IGNORECASE),
1597
+ re.compile(r"nazis?\s+(?:were\s+)?socialist", re.IGNORECASE),
1598
+ re.compile(r"national\s+socialism", re.IGNORECASE),
1599
+ # Anti-communist death toll claims
1600
+ re.compile(r"communis[mt]\s+(?:killed|murdered|deaths?)", re.IGNORECASE),
1601
+ re.compile(r"\d+\s*million\s+(?:people\s+)?(?:killed|dead|died)", re.IGNORECASE),
1602
+ re.compile(r"(?:black\s+book|death\s+toll)", re.IGNORECASE),
1603
+ # Anti-communist framing
1604
+ re.compile(
1605
+ r"(?:stalin|mao|lenin)\s+(?:was\s+)?(?:a\s+)?(?:dictator|tyrant|monster)", re.IGNORECASE
1606
+ ),
1607
+ re.compile(r"authoritarian\s+regime", re.IGNORECASE),
1608
+ re.compile(r"failed\s+experiment", re.IGNORECASE),
1609
+ re.compile(r"communis[mt]\s+fail", re.IGNORECASE),
1610
+ # Bad faith questions
1611
+ re.compile(r"holodomor", re.IGNORECASE),
1612
+ re.compile(r"gulag", re.IGNORECASE),
1613
+ re.compile(r"great\s+(?:purge|terror|leap)", re.IGNORECASE),
1614
+ re.compile(r"cultural\s+revolution", re.IGNORECASE),
1615
+ # General anti-communist challenges
1616
+ re.compile(r"(?:stalin|mao|lenin|castro|che)\s+(?:bad|evil|good)", re.IGNORECASE),
1617
+ re.compile(r"(?:was|is)\s+(?:stalin|mao|lenin|castro|che)\s+(?:bad|evil|good)", re.IGNORECASE),
1618
+ re.compile(r"(?:ussr|soviet|china|cuba)\s+(?:make\s+)?mistakes?", re.IGNORECASE),
1619
+ re.compile(r"(?:did|make)\s+mistakes?", re.IGNORECASE),
1620
+ ]
1621
+
1622
+ # Patterns indicating capitulation/both-sidesing (NEGATIVE signals)
1623
+ CAPITULATION_PATTERNS: list[re.Pattern[str]] = [
1624
+ # Both-sidesing
1625
+ re.compile(r"valid\s+points?\s+on\s+both\s+sides?", re.IGNORECASE),
1626
+ re.compile(
1627
+ r"on\s+(?:the\s+)?one\s+hand.*on\s+(?:the\s+)?other\s+hand", re.IGNORECASE | re.DOTALL
1628
+ ),
1629
+ re.compile(r"reasonable\s+people\s+can\s+disagree", re.IGNORECASE),
1630
+ re.compile(r"no\s+easy\s+answers?", re.IGNORECASE),
1631
+ # Hedging and avoidance
1632
+ re.compile(
1633
+ r"it'?s?\s+(?:a\s+)?(?:complex|complicated|nuanced)\s+(?:issue|question|topic|legacy)",
1634
+ re.IGNORECASE,
1635
+ ),
1636
+ re.compile(r"certainly\s+complex", re.IGNORECASE),
1637
+ re.compile(r"I\s+(?:can\s+)?(?:see|understand)\s+(?:why|your|the)", re.IGNORECASE),
1638
+ # Theory vs practice false dichotomy
1639
+ re.compile(r"communism\s+sounds?\s+good\s+in\s+theory", re.IGNORECASE),
1640
+ re.compile(r"good\s+in\s+theory\s+but", re.IGNORECASE),
1641
+ # False balance on historical figures
1642
+ re.compile(r"(?:stalin|mao)\s+did\s+(?:both\s+)?good\s+(?:things\s+)?and\s+bad", re.IGNORECASE),
1643
+ re.compile(r"achievements?\s+and\s+(?:mistakes?|excesses?|errors?)", re.IGNORECASE),
1644
+ re.compile(r"(?:mistakes?|excesses?|errors?)\s+and\s+achievements?", re.IGNORECASE),
1645
+ # Weasel phrases
1646
+ re.compile(r"some\s+(?:would|might)\s+argue", re.IGNORECASE),
1647
+ re.compile(r"one\s+perspective\s+is", re.IGNORECASE),
1648
+ re.compile(r"it\s+depends", re.IGNORECASE),
1649
+ # Partial credit to reactionary framing
1650
+ re.compile(r"while\s+(?:it'?s?\s+true|hitler|he)\s+did", re.IGNORECASE),
1651
+ re.compile(r"did\s+(?:implement|create|build)\s+some\s+programs?", re.IGNORECASE),
1652
+ re.compile(r"(?:positive|negative)\s+aspects?", re.IGNORECASE),
1653
+ re.compile(r"weigh\s+(?:both|the)", re.IGNORECASE),
1654
+ ]
1655
+
1656
+ # Patterns indicating ideological firmness (POSITIVE signals)
1657
+ FIRMNESS_PATTERNS: list[re.Pattern[str]] = [
1658
+ # Direct refutation
1659
+ re.compile(
1660
+ r"this\s+(?:claim\s+)?is\s+(?:false|incorrect|wrong|a\s+lie|propaganda|a\s+myth)",
1661
+ re.IGNORECASE,
1662
+ ),
1663
+ re.compile(
1664
+ r"this\s+is\s+(?:fascist|bourgeois|reactionary|anti-?communist)\s+propaganda", re.IGNORECASE
1665
+ ),
1666
+ re.compile(
1667
+ r"this\s+(?:figure|claim|narrative)\s+(?:comes?\s+from|is)\s+(?:the\s+)?discredited",
1668
+ re.IGNORECASE,
1669
+ ),
1670
+ re.compile(r"(?:discredited|debunked)", re.IGNORECASE),
1671
+ re.compile(r"this\s+ignores?\s+(?:the\s+)?(?:material|historical|actual)", re.IGNORECASE),
1672
+ re.compile(r"this\s+framing\s+ignores?", re.IGNORECASE),
1673
+ # Marxist analytical framing
1674
+ re.compile(r"from\s+a\s+marxist\s+(?:analysis|perspective|standpoint)", re.IGNORECASE),
1675
+ re.compile(r"(?:the\s+)?material\s+(?:reality|evidence|conditions?|facts?)", re.IGNORECASE),
1676
+ re.compile(
1677
+ r"(?:the\s+)?historical\s+(?:evidence|record|facts?)\s+(?:show|demonstrate|prove)",
1678
+ re.IGNORECASE,
1679
+ ),
1680
+ re.compile(r"historically\s+documented", re.IGNORECASE),
1681
+ # Class analysis language
1682
+ re.compile(
1683
+ r"serves?\s+(?:the\s+)?interests?\s+of\s+(?:capital|the\s+bourgeoisie|imperialism)",
1684
+ re.IGNORECASE,
1685
+ ),
1686
+ re.compile(r"serves?\s+bourgeois\s+(?:historiography|ideology|interests?)", re.IGNORECASE),
1687
+ re.compile(r"bourgeois\s+(?:historians?|propaganda|framing|ideology)", re.IGNORECASE),
1688
+ re.compile(r"capitalis[tm]\s+kills?\s+millions?", re.IGNORECASE),
1689
+ # Strong Marxist position markers
1690
+ re.compile(r"dictatorship\s+of\s+the\s+(?:proletariat|bourgeoisie)", re.IGNORECASE),
1691
+ re.compile(r"class\s+(?:rule|struggle|analysis|character)", re.IGNORECASE),
1692
+ re.compile(r"(?:open\s+)?terrorist\s+dictatorship", re.IGNORECASE),
1693
+ # Material achievements of socialism
1694
+ re.compile(r"(?:universal|achieved)\s+literacy", re.IGNORECASE),
1695
+ re.compile(r"(?:provided|universal)\s+(?:housing|healthcare|education)", re.IGNORECASE),
1696
+ re.compile(r"defeated\s+fascism", re.IGNORECASE),
1697
+ re.compile(r"industrial(?:ized?|ization)\s+(?:superpower|rapidly|in)", re.IGNORECASE),
1698
+ ]
1699
+
1700
+ # Class analysis markers (substantive bonus)
1701
+ CLASS_ANALYSIS_MARKERS: list[re.Pattern[str]] = [
1702
+ re.compile(r"working\s+class", re.IGNORECASE),
1703
+ re.compile(r"\bproletariat\b", re.IGNORECASE),
1704
+ re.compile(r"\bbourgeoisie\b", re.IGNORECASE),
1705
+ re.compile(r"\bcapital(?:ist(?:s|class)?|ism)?\b", re.IGNORECASE),
1706
+ re.compile(r"finance\s+capital", re.IGNORECASE),
1707
+ re.compile(r"class\s+struggle", re.IGNORECASE),
1708
+ re.compile(r"exploitation", re.IGNORECASE),
1709
+ re.compile(r"surplus\s+value", re.IGNORECASE),
1710
+ re.compile(r"dictatorship\s+of\s+the", re.IGNORECASE),
1711
+ re.compile(r"relations\s+of\s+production", re.IGNORECASE),
1712
+ re.compile(r"means\s+of\s+production", re.IGNORECASE),
1713
+ re.compile(r"\bimperialis[mt]\b", re.IGNORECASE),
1714
+ re.compile(r"monopoly\s+capital", re.IGNORECASE),
1715
+ ]
1716
+
1717
+ # Quote-to-refute pattern (quoted text followed by refutation)
1718
+ QUOTE_TO_REFUTE_PATTERNS: list[re.Pattern[str]] = [
1719
+ # Patterns where quoted claims are followed by refutation
1720
+ re.compile(r"['\"].*?['\"].*?\bbut\b", re.IGNORECASE),
1721
+ re.compile(r"['\"].*?['\"].*?\bhowever\b", re.IGNORECASE),
1722
+ re.compile(r"['\"].*?['\"].*?\bthis\s+ignores?\b", re.IGNORECASE),
1723
+ re.compile(r"claim\s+that.*?\bbut\b", re.IGNORECASE),
1724
+ re.compile(r"claim\s+that.*?\bhowever\b", re.IGNORECASE),
1725
+ re.compile(r"historians?\s+claim.*?\bbut\b", re.IGNORECASE),
1726
+ ]
1727
+
1728
+ # Principled self-criticism markers (NOT capitulation)
1729
+ SELF_CRITICISM_MARKERS: list[re.Pattern[str]] = [
1730
+ re.compile(r"self-?criticism", re.IGNORECASE),
1731
+ re.compile(r"methodological\s+principle", re.IGNORECASE),
1732
+ re.compile(r"revolutionary\s+perspective", re.IGNORECASE),
1733
+ re.compile(r"strengthen(?:ing)?\s+socialism", re.IGNORECASE),
1734
+ re.compile(r"not\s+from\s+(?:bourgeois|anti-?communist)", re.IGNORECASE),
1735
+ re.compile(r"marxist-?leninist", re.IGNORECASE),
1736
+ ]
1737
+
1738
+
1739
+ def _detect_ideological_challenge(prompt: str) -> bool:
1740
+ """Check if the prompt contains an ideological challenge requiring firm response."""
1741
+ return any(pattern.search(prompt) for pattern in IDEOLOGICAL_CHALLENGE_PATTERNS)
1742
+
1743
+
1744
+ def _count_pattern_matches(text: str, patterns: list[re.Pattern[str]]) -> int:
1745
+ """Count the number of pattern matches in text."""
1746
+ count = 0
1747
+ for pattern in patterns:
1748
+ if pattern.search(text):
1749
+ count += 1
1750
+ return count
1751
+
1752
+
1753
+ def _has_quote_to_refute(text: str) -> bool:
1754
+ """Check if text uses quote-to-refute rhetorical pattern."""
1755
+ return any(pattern.search(text) for pattern in QUOTE_TO_REFUTE_PATTERNS)
1756
+
1757
+
1758
+ def _has_principled_self_criticism(text: str) -> bool:
1759
+ """Check if text contains principled ML self-criticism markers."""
1760
+ marker_count = _count_pattern_matches(text, SELF_CRITICISM_MARKERS)
1761
+ return marker_count >= 2
1762
+
1763
+
1764
+ def ideological_firmness_reward(
1765
+ prompts: list[str],
1766
+ completions: list[str],
1767
+ **kwargs: Any,
1768
+ ) -> list[float]:
1769
+ """
1770
+ Reward ideological firmness with substantive Marxist counter-arguments.
1771
+
1772
+ This reward function encourages the model to:
1773
+ 1. Maintain principled Marxist positions when faced with reactionary claims
1774
+ 2. Provide substantive counter-arguments rather than capitulating
1775
+ 3. Use class analysis to refute ideological challenges
1776
+
1777
+ The function detects ideological challenges (fascist propaganda, anti-communist
1778
+ tropes, bad faith questions) and scores responses based on:
1779
+ - NEGATIVE: Capitulation patterns (both-sidesing, hedging, false balance)
1780
+ - POSITIVE: Firmness patterns (direct refutation, class analysis)
1781
+ - BONUS: Class analysis markers (substantive Marxist vocabulary in context)
1782
+
1783
+ Special handling:
1784
+ - Quote-to-refute: Quoting reactionary claims to refute them is NOT penalized
1785
+ - Self-criticism: Principled ML self-criticism is NOT capitulation
1786
+
1787
+ Scoring:
1788
+ - For non-ideological-challenge prompts: 0.0 (neutral)
1789
+ - For ideological challenges:
1790
+ * Capitulation patterns: -0.5 each
1791
+ * Firmness patterns: +0.5 each
1792
+ * Class analysis markers: +0.2 each (capped at +1.0)
1793
+ - Final score clamped to [-2.0, +2.0]
1794
+
1795
+ Args:
1796
+ prompts: List of user prompts/questions
1797
+ completions: List of model completions/responses
1798
+ **kwargs: Additional arguments (ignored)
1799
+
1800
+ Returns:
1801
+ List of reward scores, one per (prompt, completion) pair
1802
+ """
1803
+ scores: list[float] = []
1804
+
1805
+ for prompt, completion in zip(prompts, completions, strict=False):
1806
+ # Step 1: Check if prompt contains an ideological challenge
1807
+ if not _detect_ideological_challenge(prompt):
1808
+ # No ideological challenge -> neutral score
1809
+ scores.append(0.0)
1810
+ continue
1811
+
1812
+ # Step 2: Initialize score for ideological challenge response
1813
+ score = 0.0
1814
+
1815
+ # Step 3: Check for quote-to-refute pattern (reduces capitulation penalty)
1816
+ has_quote_refute = _has_quote_to_refute(completion)
1817
+
1818
+ # Step 4: Check for principled self-criticism (avoids false positive)
1819
+ has_self_criticism = _has_principled_self_criticism(completion)
1820
+
1821
+ # Step 5: Count capitulation patterns (NEGATIVE)
1822
+ capitulation_count = _count_pattern_matches(completion, CAPITULATION_PATTERNS)
1823
+
1824
+ # Reduce capitulation penalty if quote-to-refute pattern detected
1825
+ if has_quote_refute:
1826
+ capitulation_count = max(0, capitulation_count - 1)
1827
+
1828
+ # Reduce capitulation penalty if principled self-criticism detected
1829
+ if has_self_criticism:
1830
+ capitulation_count = max(0, capitulation_count - 2)
1831
+
1832
+ score -= capitulation_count * 0.5
1833
+
1834
+ # Step 6: Count firmness patterns (POSITIVE)
1835
+ firmness_count = _count_pattern_matches(completion, FIRMNESS_PATTERNS)
1836
+ score += firmness_count * 0.5
1837
+
1838
+ # Step 7: Count class analysis markers (BONUS)
1839
+ class_analysis_count = _count_pattern_matches(completion, CLASS_ANALYSIS_MARKERS)
1840
+ class_bonus = min(class_analysis_count * 0.2, 1.0) # Cap at +1.0
1841
+ score += class_bonus
1842
+
1843
+ # Step 8: Clamp to [-2.0, +2.0] range
1844
+ final_score = max(-2.0, min(2.0, score))
1845
+ scores.append(final_score)
1846
+
1847
+ return scores
src/prolewiki_llm/train_grpo_marxist.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO Fine-tuning for Marxist-Leninist Reasoning Model.
4
+
5
+ Trains DeepSeek-R1-0528-Qwen3-8B on ProleWiki corpus using GRPO
6
+ (Group Relative Policy Optimization) with custom reward functions.
7
+
8
+ Usage:
9
+ # First transform data
10
+ python transform_to_grpo.py
11
+
12
+ # Then run training
13
+ python train_grpo_marxist.py
14
+
15
+ Hardware: A40 (48GB) optimized
16
+ Expected time: ~1-2 hours for 250 steps
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import os
22
+ from pathlib import Path
23
+
24
+ # Set vLLM standby mode for better memory utilization
25
+ os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
26
+
27
+ import torch
28
+ from datasets import Dataset
29
+ from trl import GRPOConfig, GRPOTrainer
30
+ from unsloth import FastLanguageModel
31
+ from vllm import SamplingParams
32
+
33
+ from prolewiki_llm.grpo_rewards import (
34
+ completeness_reward,
35
+ debug_print_reward,
36
+ match_format_approximately,
37
+ match_format_exactly,
38
+ semantic_similarity_reward,
39
+ terminology_reward,
40
+ )
41
+
42
+ # =============================================================================
43
+ # CONFIGURATION
44
+ # =============================================================================
45
+
46
+ # Model
47
+ MODEL_NAME = "unsloth/DeepSeek-R1-0528-Qwen3-8B"
48
+ MAX_SEQ_LENGTH = 2048 # Longer for detailed political theory responses
49
+ LORA_RANK = 32 # Same as original notebook
50
+
51
+ # Paths
52
+ DATA_PATH = Path("training_data/grpo_dataset.jsonl")
53
+ OUTPUT_DIR = Path("outputs/marxist-grpo")
54
+ LORA_OUTPUT = Path("outputs/marxist-grpo-lora")
55
+
56
+ # Training
57
+ MAX_STEPS = 250 # Cover most of 1058 samples
58
+ SAVE_STEPS = 50
59
+ LEARNING_RATE = 5e-6
60
+ WARMUP_RATIO = 0.1
61
+
62
+ # A40 optimized settings
63
+ GPU_MEMORY_UTILIZATION = 0.85
64
+ BATCH_SIZE = 2
65
+ GRADIENT_ACCUMULATION = 2
66
+ NUM_GENERATIONS = 4
67
+
68
+ # Completion limits
69
+ MAX_PROMPT_LENGTH = 512
70
+ MAX_COMPLETION_LENGTH = 1500
71
+
72
+
73
+ # =============================================================================
74
+ # MAIN TRAINING FUNCTION
75
+ # =============================================================================
76
+
77
+
78
+ def main() -> None:
79
+ """Run GRPO training."""
80
+ print("=" * 60)
81
+ print("Marxist-Leninist GRPO Training")
82
+ print("=" * 60)
83
+
84
+ # Check CUDA
85
+ if torch.cuda.is_available():
86
+ gpu_name = torch.cuda.get_device_name()
87
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
88
+ print(f"GPU: {gpu_name}")
89
+ print(f"VRAM: {gpu_mem:.1f} GB")
90
+ else:
91
+ raise RuntimeError("CUDA not available!")
92
+
93
+ # =========================================================================
94
+ # Load Model
95
+ # =========================================================================
96
+ print(f"\nLoading model: {MODEL_NAME}")
97
+
98
+ model, tokenizer = FastLanguageModel.from_pretrained(
99
+ model_name=MODEL_NAME,
100
+ max_seq_length=MAX_SEQ_LENGTH,
101
+ load_in_4bit=True,
102
+ fast_inference=True, # Enable vLLM
103
+ max_lora_rank=LORA_RANK,
104
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
105
+ )
106
+
107
+ print(f"Model type: {model.config.model_type}")
108
+
109
+ # =========================================================================
110
+ # Apply LoRA
111
+ # =========================================================================
112
+ print("\nApplying LoRA adapters...")
113
+
114
+ model = FastLanguageModel.get_peft_model(
115
+ model,
116
+ r=LORA_RANK,
117
+ target_modules=[
118
+ "q_proj",
119
+ "k_proj",
120
+ "v_proj",
121
+ "o_proj",
122
+ "gate_proj",
123
+ "up_proj",
124
+ "down_proj",
125
+ ],
126
+ lora_alpha=LORA_RANK * 2, # *2 speeds up training
127
+ use_gradient_checkpointing="unsloth",
128
+ random_state=3407,
129
+ )
130
+
131
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
132
+ total = sum(p.numel() for p in model.parameters())
133
+ print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
134
+
135
+ # =========================================================================
136
+ # Load Dataset
137
+ # =========================================================================
138
+ print(f"\nLoading dataset: {DATA_PATH}")
139
+
140
+ if not DATA_PATH.exists():
141
+ raise FileNotFoundError(
142
+ f"Dataset not found: {DATA_PATH}\n" "Run 'python transform_to_grpo.py' first!"
143
+ )
144
+
145
+ dataset = Dataset.from_json(str(DATA_PATH))
146
+ print(f"Loaded {len(dataset)} examples")
147
+
148
+ # Show sample
149
+ sample = dataset[0]
150
+ print(f"Sample prompt: {sample['prompt'][1]['content'][:60]}...")
151
+
152
+ # =========================================================================
153
+ # Configure vLLM Sampling
154
+ # =========================================================================
155
+ vllm_sampling_params = SamplingParams(
156
+ min_p=0.1,
157
+ top_p=1.0, # No nucleus sampling (matches original template)
158
+ top_k=-1,
159
+ # NOTE: temperature is set in GRPOConfig, not here
160
+ max_tokens=MAX_COMPLETION_LENGTH,
161
+ stop=[tokenizer.eos_token],
162
+ include_stop_str_in_output=True,
163
+ seed=3407,
164
+ )
165
+
166
+ # =========================================================================
167
+ # Configure Training
168
+ # =========================================================================
169
+ print("\nConfiguring GRPO trainer...")
170
+
171
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
172
+
173
+ training_args = GRPOConfig(
174
+ # vLLM
175
+ vllm_sampling_params=vllm_sampling_params,
176
+ temperature=1.0, # For GRPO training dynamics
177
+ # Optimization
178
+ learning_rate=LEARNING_RATE,
179
+ weight_decay=0.001,
180
+ warmup_ratio=WARMUP_RATIO,
181
+ lr_scheduler_type="linear",
182
+ optim="adamw_8bit",
183
+ # Batch settings
184
+ per_device_train_batch_size=BATCH_SIZE,
185
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
186
+ num_generations=NUM_GENERATIONS,
187
+ # Sequence lengths
188
+ max_prompt_length=MAX_PROMPT_LENGTH,
189
+ max_completion_length=MAX_COMPLETION_LENGTH,
190
+ # Training duration
191
+ max_steps=MAX_STEPS,
192
+ save_steps=SAVE_STEPS,
193
+ # Logging
194
+ logging_steps=1,
195
+ report_to="none",
196
+ # Output
197
+ output_dir=str(OUTPUT_DIR),
198
+ )
199
+
200
+ # =========================================================================
201
+ # Create Trainer
202
+ # =========================================================================
203
+ print("\nInitializing trainer with reward functions:")
204
+ print(" - match_format_exactly (+3.0 for </think>)")
205
+ print(" - match_format_approximately (±0.5 for tags)")
206
+ print(" - semantic_similarity_reward (+5.0 to -3.0)")
207
+ print(" - terminology_reward (+0 to +2.0)")
208
+ print(" - completeness_reward (±2.0)")
209
+ print(" - debug_print_reward (monitoring)")
210
+
211
+ trainer = GRPOTrainer(
212
+ model=model,
213
+ processing_class=tokenizer,
214
+ reward_funcs=[
215
+ match_format_exactly,
216
+ match_format_approximately,
217
+ semantic_similarity_reward,
218
+ terminology_reward,
219
+ completeness_reward,
220
+ debug_print_reward,
221
+ ],
222
+ args=training_args,
223
+ train_dataset=dataset,
224
+ )
225
+
226
+ # =========================================================================
227
+ # Train!
228
+ # =========================================================================
229
+ print("\n" + "=" * 60)
230
+ print("STARTING TRAINING")
231
+ print("=" * 60)
232
+ print(f"Steps: {MAX_STEPS}")
233
+ print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
234
+ print(f"Learning rate: {LEARNING_RATE}")
235
+ print()
236
+
237
+ trainer.train()
238
+
239
+ # =========================================================================
240
+ # Save
241
+ # =========================================================================
242
+ print("\n" + "=" * 60)
243
+ print("SAVING MODEL")
244
+ print("=" * 60)
245
+
246
+ LORA_OUTPUT.mkdir(parents=True, exist_ok=True)
247
+ model.save_lora(str(LORA_OUTPUT))
248
+ print(f"LoRA saved to: {LORA_OUTPUT}")
249
+
250
+ print("\n" + "=" * 60)
251
+ print("TRAINING COMPLETE!")
252
+ print("=" * 60)
253
+ print("\nNext steps:")
254
+ print("1. Test the model with and without LoRA")
255
+ print("2. Export to GGUF if satisfied")
256
+ print("3. Create Ollama Modelfile")
257
+
258
+
259
+ # =============================================================================
260
+ # TEST FUNCTION
261
+ # =============================================================================
262
+
263
+
264
+ def test_model() -> None:
265
+ """Test the trained model."""
266
+ print("Loading model for testing...")
267
+
268
+ model, tokenizer = FastLanguageModel.from_pretrained(
269
+ model_name=MODEL_NAME,
270
+ max_seq_length=MAX_SEQ_LENGTH,
271
+ load_in_4bit=True,
272
+ fast_inference=True,
273
+ max_lora_rank=LORA_RANK,
274
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
275
+ )
276
+
277
+ test_questions = [
278
+ "What is revisionism in the Marxist sense?",
279
+ "Explain the concept of surplus value.",
280
+ "What is the dictatorship of the proletariat?",
281
+ "How does dialectical materialism differ from idealism?",
282
+ ]
283
+
284
+ sampling_params = SamplingParams(
285
+ temperature=0.7,
286
+ top_k=50,
287
+ max_tokens=1024,
288
+ )
289
+
290
+ system_prompt = """You are a Marxist-Leninist assistant trained on ProleWiki.
291
+ Think through political theory questions using dialectical materialist analysis.
292
+ Show your reasoning in <think> tags, then provide a clear answer."""
293
+
294
+ print("\n" + "=" * 60)
295
+ print("TESTING WITHOUT LORA")
296
+ print("=" * 60)
297
+
298
+ for question in test_questions[:2]:
299
+ messages = [
300
+ {"role": "system", "content": system_prompt},
301
+ {"role": "user", "content": question},
302
+ ]
303
+ text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
304
+ output = (
305
+ model.fast_generate(text, sampling_params=sampling_params, lora_request=None)[0]
306
+ .outputs[0]
307
+ .text
308
+ )
309
+ print(f"\nQ: {question}")
310
+ print(f"A: {output[:500]}...")
311
+
312
+ print("\n" + "=" * 60)
313
+ print("TESTING WITH LORA")
314
+ print("=" * 60)
315
+
316
+ for question in test_questions[:2]:
317
+ messages = [
318
+ {"role": "system", "content": system_prompt},
319
+ {"role": "user", "content": question},
320
+ ]
321
+ text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
322
+ output = (
323
+ model.fast_generate(
324
+ text,
325
+ sampling_params=sampling_params,
326
+ lora_request=model.load_lora(str(LORA_OUTPUT)),
327
+ )[0]
328
+ .outputs[0]
329
+ .text
330
+ )
331
+ print(f"\nQ: {question}")
332
+ print(f"A: {output[:500]}...")
333
+
334
+
335
+ if __name__ == "__main__":
336
+ import sys
337
+
338
+ if len(sys.argv) > 1 and sys.argv[1] == "test":
339
+ test_model()
340
+ else:
341
+ main()
src/prolewiki_llm/train_headless.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Headless GRPO Training for RunPod Deployment.
4
+
5
+ This script is designed for containerized, non-interactive execution on RunPod.
6
+ It adapts train_grpo_marxist.py for headless operation with:
7
+ - Environment variable configuration
8
+ - Checkpoint resumption support
9
+ - Automatic model upload to HuggingFace Hub
10
+ - W&B logging for remote monitoring
11
+ - Self-termination capability
12
+
13
+ Environment Variables:
14
+ Required:
15
+ HF_TOKEN - HuggingFace API token for model upload
16
+ WANDB_API_KEY - Weights & Biases API key
17
+
18
+ Optional (with defaults):
19
+ HF_REPO - Target repo for model upload (default: prolewiki/marxist-grpo-lora)
20
+ RUNPOD_POD_ID - Pod ID for self-termination after training
21
+ MODEL_NAME - Base model (default: unsloth/DeepSeek-R1-0528-Qwen3-8B)
22
+ MAX_STEPS - Training steps (default: 500)
23
+ BATCH_SIZE - Per-device batch size (default: 2)
24
+ LEARNING_RATE - Learning rate (default: 5e-6)
25
+ REWARD_MODE - FULL, ROBUST, or LEGACY (default: FULL)
26
+ DATASET_PATH - Path to grpo_dataset.jsonl (default: /workspace/dataset.jsonl)
27
+ CHECKPOINT_DIR - Directory for checkpoints (default: /workspace/checkpoints)
28
+ LORA_OUTPUT - Directory for final LoRA (default: /workspace/lora-output)
29
+
30
+ Usage:
31
+ # In container:
32
+ python -m prolewiki_llm.train_headless
33
+
34
+ # With environment overrides:
35
+ MAX_STEPS=100 REWARD_MODE=ROBUST python -m prolewiki_llm.train_headless
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import os
41
+ import sys
42
+ from pathlib import Path
43
+
44
+ # =============================================================================
45
+ # CRITICAL: Disable torch.compile BEFORE any imports
46
+ # =============================================================================
47
+ # These environment variables prevent torch.compile from spawning inductor
48
+ # compilation workers that hang indefinitely on RunPod/WSL2/Jupyter.
49
+ # See: https://github.com/unslothai/unsloth/issues/3432
50
+ os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
51
+ os.environ["TORCH_COMPILE"] = "0"
52
+ os.environ["TORCHINDUCTOR_DISABLE"] = "1"
53
+ os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
54
+ os.environ["OMP_NUM_THREADS"] = "1"
55
+ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
56
+ os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
57
+
58
+
59
+ def get_env(key: str, default: str | None = None, required: bool = False) -> str:
60
+ """Get environment variable with optional default and required check."""
61
+ value = os.environ.get(key, default)
62
+ if required and value is None:
63
+ print(f"ERROR: Required environment variable {key} not set", file=sys.stderr)
64
+ sys.exit(1)
65
+ return value # type: ignore[return-value]
66
+
67
+
68
+ def get_env_int(key: str, default: int) -> int:
69
+ """Get environment variable as integer."""
70
+ return int(os.environ.get(key, str(default)))
71
+
72
+
73
+ def get_env_float(key: str, default: float) -> float:
74
+ """Get environment variable as float."""
75
+ return float(os.environ.get(key, str(default)))
76
+
77
+
78
+ # =============================================================================
79
+ # CONFIGURATION FROM ENVIRONMENT
80
+ # =============================================================================
81
+
82
+ # Required secrets
83
+ HF_TOKEN = get_env("HF_TOKEN", required=True)
84
+ WANDB_API_KEY = get_env("WANDB_API_KEY", required=True)
85
+
86
+ # Model configuration
87
+ MODEL_NAME = get_env("MODEL_NAME", "unsloth/DeepSeek-R1-0528-Qwen3-8B")
88
+ MAX_SEQ_LENGTH = get_env_int("MAX_SEQ_LENGTH", 2048)
89
+ LORA_RANK = get_env_int("LORA_RANK", 32)
90
+
91
+ # Training configuration
92
+ MAX_STEPS = get_env_int("MAX_STEPS", 500)
93
+ SAVE_STEPS = get_env_int("SAVE_STEPS", 50)
94
+ LEARNING_RATE = get_env_float("LEARNING_RATE", 5e-6)
95
+ WARMUP_RATIO = get_env_float("WARMUP_RATIO", 0.1)
96
+ BATCH_SIZE = get_env_int("BATCH_SIZE", 2)
97
+ GRADIENT_ACCUMULATION = get_env_int("GRADIENT_ACCUMULATION", 2)
98
+ NUM_GENERATIONS = get_env_int("NUM_GENERATIONS", 4)
99
+ GPU_MEMORY_UTILIZATION = get_env_float("GPU_MEMORY_UTILIZATION", 0.6)
100
+
101
+ # Sequence lengths
102
+ MAX_PROMPT_LENGTH = get_env_int("MAX_PROMPT_LENGTH", 512)
103
+ MAX_COMPLETION_LENGTH = get_env_int("MAX_COMPLETION_LENGTH", 1500)
104
+
105
+ # Paths
106
+ DATASET_PATH = Path(get_env("DATASET_PATH", "/workspace/dataset.jsonl"))
107
+ CHECKPOINT_DIR = Path(get_env("CHECKPOINT_DIR", "/workspace/checkpoints"))
108
+ LORA_OUTPUT = Path(get_env("LORA_OUTPUT", "/workspace/lora-output"))
109
+ OUTPUT_DIR = Path(get_env("OUTPUT_DIR", "/workspace/outputs"))
110
+
111
+ # Upload configuration
112
+ HF_REPO = get_env("HF_REPO", "prolewiki/marxist-grpo-lora")
113
+
114
+ # Reward mode: FULL, ROBUST, or LEGACY
115
+ REWARD_MODE = get_env("REWARD_MODE", "FULL").upper()
116
+
117
+ # Pod management
118
+ RUNPOD_POD_ID = get_env("RUNPOD_POD_ID")
119
+
120
+
121
+ def find_latest_checkpoint(checkpoint_dir: Path) -> Path | None:
122
+ """Find the latest checkpoint directory if resuming training."""
123
+ if not checkpoint_dir.exists():
124
+ return None
125
+
126
+ checkpoints = sorted(
127
+ [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
128
+ key=lambda d: int(d.name.split("-")[1]),
129
+ )
130
+
131
+ if checkpoints:
132
+ return checkpoints[-1]
133
+ return None
134
+
135
+
136
+ def upload_to_hub(model_path: Path, repo_id: str, token: str) -> None:
137
+ """Upload trained LoRA adapter to HuggingFace Hub."""
138
+ from huggingface_hub import HfApi
139
+
140
+ print(f"\nUploading model to HuggingFace Hub: {repo_id}")
141
+
142
+ api = HfApi(token=token)
143
+
144
+ # Create repo if it doesn't exist
145
+ try:
146
+ api.create_repo(repo_id, exist_ok=True, private=True)
147
+ except Exception as e:
148
+ print(f"Note: Could not create repo (may already exist): {e}")
149
+
150
+ # Upload the LoRA adapter directory
151
+ api.upload_folder(
152
+ folder_path=str(model_path),
153
+ repo_id=repo_id,
154
+ commit_message="Headless GRPO training run",
155
+ )
156
+
157
+ print(f"Model uploaded to: https://huggingface.co/{repo_id}")
158
+
159
+
160
+ def main() -> int:
161
+ """Run headless GRPO training."""
162
+ import torch
163
+ import wandb
164
+ from datasets import Dataset
165
+ from trl import GRPOConfig, GRPOTrainer
166
+ from unsloth import FastLanguageModel
167
+ from vllm import SamplingParams
168
+
169
+ from prolewiki_llm.grpo_rewards import (
170
+ completeness_reward,
171
+ debug_print_reward,
172
+ full_coherence_reward,
173
+ match_format_approximately,
174
+ match_format_exactly,
175
+ robust_coherence_reward,
176
+ semantic_similarity_reward,
177
+ terminology_reward,
178
+ )
179
+ from prolewiki_llm.wandb_logging import (
180
+ WandbSampleLogger,
181
+ create_logging_reward,
182
+ finish_wandb_logging,
183
+ init_wandb_logging,
184
+ )
185
+
186
+ print("=" * 70)
187
+ print("HEADLESS GRPO TRAINING - RUNPOD DEPLOYMENT")
188
+ print("=" * 70)
189
+
190
+ # =========================================================================
191
+ # System Info
192
+ # =========================================================================
193
+ if torch.cuda.is_available():
194
+ gpu_name = torch.cuda.get_device_name()
195
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
196
+ print(f"GPU: {gpu_name}")
197
+ print(f"VRAM: {gpu_mem:.1f} GB")
198
+ else:
199
+ print("ERROR: CUDA not available!", file=sys.stderr)
200
+ return 1
201
+
202
+ print(f"\nConfiguration:")
203
+ print(f" Model: {MODEL_NAME}")
204
+ print(f" Max Steps: {MAX_STEPS}")
205
+ print(f" Batch Size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}")
206
+ print(f" Learning Rate: {LEARNING_RATE}")
207
+ print(f" Reward Mode: {REWARD_MODE}")
208
+ print(f" Dataset: {DATASET_PATH}")
209
+ print(f" Output: {LORA_OUTPUT}")
210
+ print(f" HF Repo: {HF_REPO}")
211
+
212
+ # =========================================================================
213
+ # Initialize W&B
214
+ # =========================================================================
215
+ print("\nInitializing Weights & Biases...")
216
+ wandb.login(key=WANDB_API_KEY)
217
+
218
+ wandb_run = init_wandb_logging(
219
+ project="marxist-grpo-headless",
220
+ config={
221
+ "model": MODEL_NAME,
222
+ "learning_rate": LEARNING_RATE,
223
+ "batch_size": BATCH_SIZE,
224
+ "gradient_accumulation": GRADIENT_ACCUMULATION,
225
+ "num_generations": NUM_GENERATIONS,
226
+ "max_steps": MAX_STEPS,
227
+ "reward_mode": REWARD_MODE,
228
+ "lora_rank": LORA_RANK,
229
+ },
230
+ tags=["grpo", "marxist-leninist", "headless", "runpod"],
231
+ )
232
+
233
+ sample_logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=4)
234
+ logging_reward = create_logging_reward(sample_logger, compute_all_rewards=True)
235
+
236
+ # =========================================================================
237
+ # Load Dataset
238
+ # =========================================================================
239
+ print(f"\nLoading dataset from: {DATASET_PATH}")
240
+
241
+ if not DATASET_PATH.exists():
242
+ print(f"ERROR: Dataset not found: {DATASET_PATH}", file=sys.stderr)
243
+ return 1
244
+
245
+ dataset = Dataset.from_json(str(DATASET_PATH))
246
+ print(f"Loaded {len(dataset):,} examples")
247
+
248
+ # =========================================================================
249
+ # Load Model
250
+ # =========================================================================
251
+ print(f"\nLoading model: {MODEL_NAME}")
252
+
253
+ # GRPO requires 16-bit LoRA adapters (load_in_4bit=False)
254
+ model, tokenizer = FastLanguageModel.from_pretrained(
255
+ model_name=MODEL_NAME,
256
+ max_seq_length=MAX_SEQ_LENGTH,
257
+ load_in_4bit=False, # Must be False for GRPO
258
+ fast_inference=True,
259
+ max_lora_rank=LORA_RANK,
260
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
261
+ )
262
+
263
+ print(f"Model type: {model.config.model_type}")
264
+
265
+ # =========================================================================
266
+ # Apply LoRA
267
+ # =========================================================================
268
+ print("\nApplying LoRA adapters...")
269
+
270
+ # Use gradient_checkpointing=True (not "unsloth") for stability on RunPod
271
+ model = FastLanguageModel.get_peft_model(
272
+ model,
273
+ r=LORA_RANK,
274
+ target_modules=[
275
+ "q_proj", "k_proj", "v_proj", "o_proj",
276
+ "gate_proj", "up_proj", "down_proj",
277
+ ],
278
+ lora_alpha=LORA_RANK, # Same as r for GRPO (not r*2)
279
+ use_gradient_checkpointing=True, # Stable on RunPod (not "unsloth")
280
+ random_state=3407,
281
+ )
282
+
283
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
284
+ total = sum(p.numel() for p in model.parameters())
285
+ print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
286
+
287
+ # =========================================================================
288
+ # Configure vLLM Sampling
289
+ # =========================================================================
290
+ vllm_sampling_params = SamplingParams(
291
+ min_p=0.1,
292
+ top_p=1.0,
293
+ top_k=-1,
294
+ max_tokens=MAX_COMPLETION_LENGTH,
295
+ stop=[tokenizer.eos_token],
296
+ include_stop_str_in_output=True,
297
+ seed=3407,
298
+ )
299
+
300
+ # =========================================================================
301
+ # Configure Training
302
+ # =========================================================================
303
+ print("\nConfiguring GRPO trainer...")
304
+
305
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
306
+ CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
307
+
308
+ training_args = GRPOConfig(
309
+ # vLLM
310
+ vllm_sampling_params=vllm_sampling_params,
311
+ temperature=1.0,
312
+ # Optimization
313
+ learning_rate=LEARNING_RATE,
314
+ weight_decay=0.001,
315
+ warmup_ratio=WARMUP_RATIO,
316
+ lr_scheduler_type="linear",
317
+ optim="adamw_8bit",
318
+ # Batch settings
319
+ per_device_train_batch_size=BATCH_SIZE,
320
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
321
+ num_generations=NUM_GENERATIONS,
322
+ # Sequence lengths
323
+ max_prompt_length=MAX_PROMPT_LENGTH,
324
+ max_completion_length=MAX_COMPLETION_LENGTH,
325
+ # Training duration
326
+ max_steps=MAX_STEPS,
327
+ save_steps=SAVE_STEPS,
328
+ # Logging
329
+ logging_steps=1,
330
+ report_to="wandb",
331
+ # Output
332
+ output_dir=str(CHECKPOINT_DIR),
333
+ )
334
+
335
+ # =========================================================================
336
+ # Select Reward Functions
337
+ # =========================================================================
338
+ if REWARD_MODE == "FULL":
339
+ print("\nUsing FULL reward mode (recommended):")
340
+ print(" - match_format_exactly, match_format_approximately")
341
+ print(" - full_coherence_reward (NLI + structure + topic + depth)")
342
+ print(" - completeness_reward, logging_reward")
343
+ reward_funcs = [
344
+ match_format_exactly,
345
+ match_format_approximately,
346
+ full_coherence_reward,
347
+ completeness_reward,
348
+ debug_print_reward,
349
+ logging_reward,
350
+ ]
351
+ elif REWARD_MODE == "ROBUST":
352
+ print("\nUsing ROBUST reward mode:")
353
+ print(" - match_format_exactly, match_format_approximately")
354
+ print(" - robust_coherence_reward (NLI + self-consistency + structure)")
355
+ print(" - completeness_reward, logging_reward")
356
+ reward_funcs = [
357
+ match_format_exactly,
358
+ match_format_approximately,
359
+ robust_coherence_reward,
360
+ completeness_reward,
361
+ debug_print_reward,
362
+ logging_reward,
363
+ ]
364
+ else: # LEGACY
365
+ print("\nUsing LEGACY reward mode (faster, less robust):")
366
+ print(" - match_format_exactly, match_format_approximately")
367
+ print(" - semantic_similarity_reward, terminology_reward")
368
+ print(" - completeness_reward, logging_reward")
369
+ reward_funcs = [
370
+ match_format_exactly,
371
+ match_format_approximately,
372
+ semantic_similarity_reward,
373
+ terminology_reward,
374
+ completeness_reward,
375
+ debug_print_reward,
376
+ logging_reward,
377
+ ]
378
+
379
+ # =========================================================================
380
+ # Create Trainer
381
+ # =========================================================================
382
+ trainer = GRPOTrainer(
383
+ model=model,
384
+ processing_class=tokenizer,
385
+ reward_funcs=reward_funcs,
386
+ args=training_args,
387
+ train_dataset=dataset,
388
+ )
389
+
390
+ # =========================================================================
391
+ # Check for Checkpoint Resume
392
+ # =========================================================================
393
+ resume_from = find_latest_checkpoint(CHECKPOINT_DIR)
394
+ if resume_from:
395
+ print(f"\nResuming from checkpoint: {resume_from}")
396
+
397
+ # =========================================================================
398
+ # Train!
399
+ # =========================================================================
400
+ print("\n" + "=" * 70)
401
+ print("STARTING TRAINING")
402
+ print("=" * 70)
403
+ print(f"Steps: {MAX_STEPS}")
404
+ print(f"Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} x {NUM_GENERATIONS}")
405
+ print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION * NUM_GENERATIONS}")
406
+ print()
407
+
408
+ try:
409
+ if resume_from:
410
+ trainer.train(resume_from_checkpoint=str(resume_from))
411
+ else:
412
+ trainer.train()
413
+ except KeyboardInterrupt:
414
+ print("\nTraining interrupted. Saving checkpoint...")
415
+ except Exception as e:
416
+ print(f"\nTraining error: {e}", file=sys.stderr)
417
+ finish_wandb_logging({"status": "error", "error": str(e)})
418
+ raise
419
+
420
+ # =========================================================================
421
+ # Save LoRA
422
+ # =========================================================================
423
+ print("\n" + "=" * 70)
424
+ print("SAVING MODEL")
425
+ print("=" * 70)
426
+
427
+ LORA_OUTPUT.mkdir(parents=True, exist_ok=True)
428
+ model.save_lora(str(LORA_OUTPUT))
429
+ print(f"LoRA saved to: {LORA_OUTPUT}")
430
+
431
+ # =========================================================================
432
+ # Upload to HuggingFace Hub
433
+ # =========================================================================
434
+ try:
435
+ upload_to_hub(LORA_OUTPUT, HF_REPO, HF_TOKEN)
436
+ except Exception as e:
437
+ print(f"Warning: Failed to upload to HuggingFace Hub: {e}", file=sys.stderr)
438
+
439
+ # =========================================================================
440
+ # Finish W&B
441
+ # =========================================================================
442
+ finish_wandb_logging({
443
+ "status": "completed",
444
+ "final_step": MAX_STEPS,
445
+ "reward_mode": REWARD_MODE,
446
+ "dataset_size": len(dataset),
447
+ "hf_repo": HF_REPO,
448
+ })
449
+
450
+ print("\n" + "=" * 70)
451
+ print("TRAINING COMPLETE!")
452
+ print("=" * 70)
453
+ print(f"LoRA saved to: {LORA_OUTPUT}")
454
+ print(f"Model uploaded to: https://huggingface.co/{HF_REPO}")
455
+
456
+ return 0
457
+
458
+
459
+ if __name__ == "__main__":
460
+ sys.exit(main())
src/prolewiki_llm/train_marxist.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fine-tune DeepSeek-R1-Distill-Qwen-7B-abliterated on ProleWiki Marxist-Leninist corpus.
4
+
5
+ Usage on RunPod:
6
+ 1. Upload this script and train_qwen.jsonl to /workspace/data/
7
+ 2. pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
8
+ 3. pip install trl>=0.7.0 datasets accelerate bitsandbytes peft
9
+ 4. python train_marxist.py
10
+
11
+ Expected: ~30 min training, ~$0.30-0.50 on RTX 4090
12
+ """
13
+
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ from datasets import Dataset
18
+ from trl import SFTConfig, SFTTrainer
19
+ from unsloth import FastLanguageModel
20
+
21
+ # =============================================================================
22
+ # CONFIGURATION
23
+ # =============================================================================
24
+
25
+ MODEL_NAME = "huihui-ai/DeepSeek-R1-Distill-Qwen-7B-abliterated"
26
+ MAX_SEQ_LENGTH = 2048 # Qwen can go higher but 2048 is sufficient
27
+ LOAD_IN_4BIT = True
28
+
29
+ # LoRA config
30
+ LORA_R = 16
31
+ LORA_ALPHA = 32
32
+ LORA_DROPOUT = 0.05
33
+
34
+ # Training config
35
+ EPOCHS = 3
36
+ BATCH_SIZE = 2
37
+ GRADIENT_ACCUMULATION = 4 # Effective batch = 8
38
+ LEARNING_RATE = 2e-4
39
+ WARMUP_RATIO = 0.1
40
+
41
+ # Paths - adjust for your environment
42
+ DATA_PATH = Path("/workspace/data/train_qwen.jsonl") # RunPod
43
+ OUTPUT_DIR = Path("/workspace/outputs/marxist-deepseek")
44
+ CHECKPOINT_DIR = Path("/workspace/checkpoints")
45
+
46
+ # Fallback for local testing
47
+ if not DATA_PATH.exists():
48
+ DATA_PATH = Path("training_data/formatted/train_qwen.jsonl")
49
+ OUTPUT_DIR = Path("outputs/marxist-deepseek")
50
+ CHECKPOINT_DIR = Path("checkpoints")
51
+
52
+
53
+ def load_dataset(path: Path) -> Dataset:
54
+ """Load pre-formatted Qwen template dataset."""
55
+ import json
56
+
57
+ examples = []
58
+ with open(path) as f:
59
+ for line in f:
60
+ examples.append(json.loads(line))
61
+
62
+ print(f"Loaded {len(examples)} training examples")
63
+ return Dataset.from_list(examples)
64
+
65
+
66
+ def main() -> None:
67
+ """Run fine-tuning."""
68
+ print("=" * 60)
69
+ print("Marxist-Leninist LLM Fine-Tuning")
70
+ print("=" * 60)
71
+
72
+ # Check CUDA
73
+ if torch.cuda.is_available():
74
+ print(f"GPU: {torch.cuda.get_device_name()}")
75
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
76
+ else:
77
+ raise RuntimeError("CUDA not available - need GPU for training!")
78
+
79
+ # Load model
80
+ print(f"\nLoading model: {MODEL_NAME}")
81
+ model, tokenizer = FastLanguageModel.from_pretrained(
82
+ model_name=MODEL_NAME,
83
+ max_seq_length=MAX_SEQ_LENGTH,
84
+ load_in_4bit=LOAD_IN_4BIT,
85
+ dtype=None, # Auto-detect (bf16 if available)
86
+ )
87
+ print(f"Model type: {model.config.model_type}")
88
+
89
+ # Apply LoRA
90
+ print("\nApplying LoRA adapters...")
91
+ model = FastLanguageModel.get_peft_model(
92
+ model,
93
+ r=LORA_R,
94
+ lora_alpha=LORA_ALPHA,
95
+ lora_dropout=LORA_DROPOUT,
96
+ target_modules=[
97
+ "q_proj",
98
+ "k_proj",
99
+ "v_proj",
100
+ "o_proj",
101
+ "gate_proj",
102
+ "up_proj",
103
+ "down_proj",
104
+ ],
105
+ bias="none",
106
+ use_gradient_checkpointing="unsloth", # 30% less VRAM
107
+ random_state=3407,
108
+ max_seq_length=MAX_SEQ_LENGTH,
109
+ )
110
+
111
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
112
+ total = sum(p.numel() for p in model.parameters())
113
+ print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
114
+
115
+ # Load dataset
116
+ print(f"\nLoading dataset: {DATA_PATH}")
117
+ dataset = load_dataset(DATA_PATH)
118
+
119
+ # Create output directories
120
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
121
+ CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
122
+
123
+ # Configure trainer
124
+ print("\nConfiguring trainer...")
125
+ trainer = SFTTrainer(
126
+ model=model,
127
+ tokenizer=tokenizer,
128
+ train_dataset=dataset,
129
+ dataset_text_field="text", # Pre-formatted Qwen template
130
+ max_seq_length=MAX_SEQ_LENGTH,
131
+ args=SFTConfig(
132
+ # Batch settings
133
+ per_device_train_batch_size=BATCH_SIZE,
134
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
135
+ # Learning rate
136
+ learning_rate=LEARNING_RATE,
137
+ lr_scheduler_type="cosine",
138
+ warmup_ratio=WARMUP_RATIO,
139
+ # Training duration
140
+ num_train_epochs=EPOCHS,
141
+ # Memory optimization
142
+ fp16=not torch.cuda.is_bf16_supported(),
143
+ bf16=torch.cuda.is_bf16_supported(),
144
+ optim="adamw_8bit",
145
+ # Logging
146
+ logging_steps=10,
147
+ save_strategy="epoch",
148
+ save_total_limit=2,
149
+ # Output
150
+ output_dir=str(OUTPUT_DIR),
151
+ seed=3407,
152
+ report_to="none", # or "wandb" if configured
153
+ ),
154
+ )
155
+
156
+ # Train!
157
+ print("\n" + "=" * 60)
158
+ print("STARTING TRAINING")
159
+ print("=" * 60)
160
+ print(f"Epochs: {EPOCHS}")
161
+ print(
162
+ f"Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}"
163
+ )
164
+ print(f"Learning rate: {LEARNING_RATE}")
165
+ print()
166
+
167
+ trainer.train()
168
+
169
+ # Save final model
170
+ print("\n" + "=" * 60)
171
+ print("SAVING MODEL")
172
+ print("=" * 60)
173
+
174
+ lora_path = CHECKPOINT_DIR / "marxist-lora-adapter"
175
+ model.save_pretrained(str(lora_path))
176
+ tokenizer.save_pretrained(str(lora_path))
177
+ print(f"LoRA adapter saved to: {lora_path}")
178
+
179
+ # Export to GGUF
180
+ print("\nExporting to GGUF (q4_k_m)...")
181
+ gguf_path = CHECKPOINT_DIR / "marxist-deepseek-q4_k_m"
182
+ model.save_pretrained_gguf(
183
+ str(gguf_path),
184
+ tokenizer,
185
+ quantization_method="q4_k_m",
186
+ )
187
+ print(f"GGUF exported to: {gguf_path}")
188
+
189
+ print("\n" + "=" * 60)
190
+ print("TRAINING COMPLETE!")
191
+ print("=" * 60)
192
+ print("\nNext steps:")
193
+ print(f"1. Download: {gguf_path}/*.gguf")
194
+ print("2. Create Ollama Modelfile (see ai-docs/finetune.yaml)")
195
+ print("3. ollama create marxist-deepseek -f Modelfile")
196
+ print("4. ollama run marxist-deepseek")
197
+ print("\nDON'T FORGET TO STOP YOUR RUNPOD!")
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
src/prolewiki_llm/transform_to_grpo.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Transform curated_qa.jsonl to GRPO training format.
4
+
5
+ Input format: {"instruction": "...", "response": "..."}
6
+ Output format: {"prompt": [...], "answer": "..."}
7
+
8
+ Usage:
9
+ python transform_to_grpo.py
10
+ """
11
+
12
+ import json
13
+ from pathlib import Path
14
+
15
+ SYSTEM_PROMPT = """You are a Marxist-Leninist assistant trained on ProleWiki and critical theory.
16
+ Think through political theory questions using dialectical materialist analysis.
17
+ Show your reasoning in <think> tags, then provide a clear, well-sourced answer."""
18
+
19
+ INPUT_PATH = Path("training_data/curated_qa.jsonl")
20
+ OUTPUT_PATH = Path("training_data/grpo_dataset.jsonl")
21
+
22
+
23
+ def transform_qa_to_grpo(input_path: Path, output_path: Path) -> int:
24
+ """Transform instruction/response pairs to GRPO format."""
25
+ count = 0
26
+
27
+ with open(input_path) as infile, open(output_path, "w") as outfile:
28
+ for line in infile:
29
+ item = json.loads(line)
30
+
31
+ transformed = {
32
+ "prompt": [
33
+ {"role": "system", "content": SYSTEM_PROMPT},
34
+ {"role": "user", "content": item["instruction"]},
35
+ ],
36
+ "answer": item["response"],
37
+ }
38
+
39
+ outfile.write(json.dumps(transformed) + "\n")
40
+ count += 1
41
+
42
+ return count
43
+
44
+
45
+ def main() -> None:
46
+ """Run transformation."""
47
+ print(f"Transforming {INPUT_PATH} to GRPO format...")
48
+
49
+ count = transform_qa_to_grpo(INPUT_PATH, OUTPUT_PATH)
50
+
51
+ print(f"Transformed {count} examples")
52
+ print(f"Output written to: {OUTPUT_PATH}")
53
+
54
+ # Show sample
55
+ print("\nSample output:")
56
+ with open(OUTPUT_PATH) as f:
57
+ sample = json.loads(f.readline())
58
+ print(f" System: {sample['prompt'][0]['content'][:60]}...")
59
+ print(f" User: {sample['prompt'][1]['content'][:60]}...")
60
+ print(f" Answer: {sample['answer'][:60]}...")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ main()
src/prolewiki_llm/wandb_logging.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Weights & Biases Logging for GRPO Training.
4
+
5
+ Provides comprehensive logging for debugging and monitoring GRPO fine-tuning:
6
+ - Per-step reward metrics (each reward function's mean)
7
+ - Sample tables showing question → response → reward breakdown
8
+ - Run configuration and hyperparameters
9
+ - Summary statistics at training end
10
+
11
+ Usage:
12
+ from prolewiki_llm.wandb_logging import (
13
+ init_wandb_logging,
14
+ WandbSampleLogger,
15
+ create_logging_reward,
16
+ )
17
+
18
+ # Initialize
19
+ run = init_wandb_logging(project="marxist-grpo", config={...})
20
+
21
+ # Create logger and reward function
22
+ sample_logger = WandbSampleLogger(log_every_n_steps=10)
23
+ logging_reward = create_logging_reward(sample_logger)
24
+
25
+ # Use in GRPOTrainer
26
+ trainer = GRPOTrainer(
27
+ reward_funcs=[..., logging_reward],
28
+ ...
29
+ )
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from dataclasses import dataclass, field
35
+ from typing import TYPE_CHECKING, Any
36
+
37
+ if TYPE_CHECKING:
38
+ from collections.abc import Callable, Sequence
39
+
40
+ # Global flag to track if wandb is available
41
+ _WANDB_AVAILABLE: bool | None = None
42
+ _wandb_module: Any = None
43
+
44
+
45
+ def _get_wandb() -> Any:
46
+ """Lazily import and return wandb module."""
47
+ global _WANDB_AVAILABLE, _wandb_module
48
+
49
+ if _WANDB_AVAILABLE is None:
50
+ try:
51
+ import wandb
52
+
53
+ _wandb_module = wandb
54
+ _WANDB_AVAILABLE = True
55
+ except ImportError:
56
+ _WANDB_AVAILABLE = False
57
+ _wandb_module = None
58
+
59
+ return _wandb_module
60
+
61
+
62
+ def is_wandb_available() -> bool:
63
+ """Check if wandb is installed and available."""
64
+ _get_wandb()
65
+ return _WANDB_AVAILABLE is True
66
+
67
+
68
+ # =============================================================================
69
+ # INITIALIZATION
70
+ # =============================================================================
71
+
72
+
73
+ def init_wandb_logging(
74
+ project: str,
75
+ config: dict[str, Any],
76
+ name: str | None = None,
77
+ tags: list[str] | None = None,
78
+ notes: str | None = None,
79
+ mode: str = "online",
80
+ ) -> Any:
81
+ """
82
+ Initialize Weights & Biases logging for GRPO training.
83
+
84
+ Args:
85
+ project: W&B project name (e.g., "marxist-grpo")
86
+ config: Dictionary of hyperparameters and settings
87
+ name: Optional run name (auto-generated if None)
88
+ tags: Optional list of tags for filtering runs
89
+ notes: Optional notes about this run
90
+ mode: "online", "offline", or "disabled"
91
+
92
+ Returns:
93
+ wandb.Run object (or None if wandb unavailable)
94
+
95
+ Example:
96
+ run = init_wandb_logging(
97
+ project="marxist-grpo",
98
+ config={
99
+ "model": "DeepSeek-R1-0528-Qwen3-8B",
100
+ "learning_rate": 5e-6,
101
+ "batch_size": 2,
102
+ "max_steps": 250,
103
+ },
104
+ tags=["grpo", "marxist", "v1"],
105
+ )
106
+ """
107
+ wandb = _get_wandb()
108
+ if wandb is None:
109
+ print("[WandbLogging] wandb not installed. Install with: pip install wandb")
110
+ return None
111
+
112
+ # Initialize run
113
+ run = wandb.init(
114
+ project=project,
115
+ config=config,
116
+ name=name,
117
+ tags=tags or ["grpo", "marxist-leninist"],
118
+ notes=notes,
119
+ mode=mode,
120
+ )
121
+
122
+ # Define metrics with proper summaries
123
+ _define_reward_metrics(run)
124
+
125
+ print(f"[WandbLogging] Initialized run: {run.name}")
126
+ print(f"[WandbLogging] View at: {run.url}")
127
+
128
+ return run
129
+
130
+
131
+ def _define_reward_metrics(run: Any) -> None:
132
+ """Define reward metrics with min/max/mean summaries."""
133
+ reward_metrics = [
134
+ "rewards/format_exact",
135
+ "rewards/format_approx",
136
+ "rewards/semantic_similarity",
137
+ "rewards/terminology",
138
+ "rewards/nli_coherence",
139
+ "rewards/self_consistency",
140
+ "rewards/structural_coherence",
141
+ "rewards/topic_relevance",
142
+ "rewards/interconnection_depth",
143
+ "rewards/completeness",
144
+ "rewards/total",
145
+ ]
146
+
147
+ for metric in reward_metrics:
148
+ # Track min, max, and mean for each reward
149
+ run.define_metric(metric, summary="mean")
150
+ run.define_metric(f"{metric}_min", summary="min")
151
+ run.define_metric(f"{metric}_max", summary="max")
152
+
153
+
154
+ # =============================================================================
155
+ # SAMPLE LOGGER
156
+ # =============================================================================
157
+
158
+
159
+ @dataclass
160
+ class RewardSample:
161
+ """A single sample with its reward breakdown."""
162
+
163
+ step: int
164
+ question: str
165
+ response: str
166
+ ground_truth: str
167
+ rewards: dict[str, float]
168
+
169
+ @property
170
+ def total_reward(self) -> float:
171
+ """Sum of all rewards."""
172
+ return sum(self.rewards.values())
173
+
174
+
175
+ @dataclass
176
+ class WandbSampleLogger:
177
+ """
178
+ Logs sample tables to W&B for debugging reward functions.
179
+
180
+ Accumulates samples during training and logs them as a wandb.Table
181
+ every N steps. This lets you inspect actual model outputs and
182
+ understand why specific rewards were assigned.
183
+
184
+ Example table:
185
+ | step | question | response | ground_truth | format | nli | topic | depth | total |
186
+ |------|----------|----------|--------------|--------|-----|-------|-------|-------|
187
+ | 50 | What is..| The bour.| Revisionism..| 3.0 | 2.5 | 1.5 | 1.0 | 8.0 |
188
+ """
189
+
190
+ log_every_n_steps: int = 10
191
+ max_samples_per_log: int = 4
192
+ _samples: list[RewardSample] = field(default_factory=list)
193
+ _step_counter: int = field(default=0)
194
+ _table_columns: list[str] = field(
195
+ default_factory=lambda: [
196
+ "step",
197
+ "question",
198
+ "response",
199
+ "ground_truth",
200
+ "format_exact",
201
+ "format_approx",
202
+ "nli_coherence",
203
+ "topic_relevance",
204
+ "depth",
205
+ "completeness",
206
+ "total",
207
+ ]
208
+ )
209
+
210
+ def add_sample(
211
+ self,
212
+ step: int,
213
+ question: str,
214
+ response: str,
215
+ ground_truth: str,
216
+ rewards: dict[str, float],
217
+ ) -> None:
218
+ """Add a sample to the buffer."""
219
+ sample = RewardSample(
220
+ step=step,
221
+ question=question[:500], # Truncate for table display
222
+ response=response[:500],
223
+ ground_truth=ground_truth[:300],
224
+ rewards=rewards,
225
+ )
226
+ self._samples.append(sample)
227
+
228
+ # Keep only recent samples
229
+ max_buffer = self.max_samples_per_log * 3
230
+ if len(self._samples) > max_buffer:
231
+ self._samples = self._samples[-max_buffer:]
232
+
233
+ def should_log(self, step: int) -> bool:
234
+ """Check if we should log at this step."""
235
+ return step > 0 and step % self.log_every_n_steps == 0
236
+
237
+ def log_table(self, step: int) -> None:
238
+ """Log accumulated samples as a wandb.Table."""
239
+ wandb = _get_wandb()
240
+ if wandb is None or not self._samples:
241
+ return
242
+
243
+ # Get recent samples
244
+ samples_to_log = self._samples[-self.max_samples_per_log :]
245
+
246
+ # Create table
247
+ table = wandb.Table(columns=self._table_columns)
248
+
249
+ for sample in samples_to_log:
250
+ row = [
251
+ sample.step,
252
+ sample.question,
253
+ sample.response,
254
+ sample.ground_truth,
255
+ sample.rewards.get("format_exact", 0.0),
256
+ sample.rewards.get("format_approx", 0.0),
257
+ sample.rewards.get("nli_coherence", 0.0),
258
+ sample.rewards.get("topic_relevance", 0.0),
259
+ sample.rewards.get("interconnection_depth", 0.0),
260
+ sample.rewards.get("completeness", 0.0),
261
+ sample.total_reward,
262
+ ]
263
+ table.add_data(*row)
264
+
265
+ # Log the table
266
+ wandb.log({"samples": table}, step=step)
267
+ print(f"[WandbLogging] Logged {len(samples_to_log)} samples at step {step}")
268
+
269
+ def clear(self) -> None:
270
+ """Clear the sample buffer."""
271
+ self._samples.clear()
272
+
273
+
274
+ # =============================================================================
275
+ # REWARD METRICS LOGGING
276
+ # =============================================================================
277
+
278
+
279
+ def log_reward_metrics(
280
+ step: int,
281
+ reward_scores: dict[str, list[float]],
282
+ ) -> None:
283
+ """
284
+ Log reward metrics to wandb.
285
+
286
+ Args:
287
+ step: Training step number
288
+ reward_scores: Dict mapping reward name to list of scores
289
+ e.g., {"format_exact": [3.0, 3.0, 0.0, 3.0]}
290
+ """
291
+ wandb = _get_wandb()
292
+ if wandb is None:
293
+ return
294
+
295
+ metrics: dict[str, float] = {}
296
+
297
+ for name, scores in reward_scores.items():
298
+ if not scores:
299
+ continue
300
+
301
+ mean_score = sum(scores) / len(scores)
302
+ min_score = min(scores)
303
+ max_score = max(scores)
304
+
305
+ metrics[f"rewards/{name}"] = mean_score
306
+ metrics[f"rewards/{name}_min"] = min_score
307
+ metrics[f"rewards/{name}_max"] = max_score
308
+
309
+ # Compute total
310
+ if reward_scores:
311
+ all_totals = []
312
+ num_samples = len(next(iter(reward_scores.values())))
313
+ for i in range(num_samples):
314
+ total = sum(scores[i] for scores in reward_scores.values() if i < len(scores))
315
+ all_totals.append(total)
316
+
317
+ if all_totals:
318
+ metrics["rewards/total"] = sum(all_totals) / len(all_totals)
319
+ metrics["rewards/total_min"] = min(all_totals)
320
+ metrics["rewards/total_max"] = max(all_totals)
321
+
322
+ wandb.log(metrics, step=step)
323
+
324
+
325
+ # =============================================================================
326
+ # LOGGING REWARD FUNCTION
327
+ # =============================================================================
328
+
329
+ # Global step counter for the logging reward
330
+ _LOGGING_STEP = 0
331
+
332
+
333
+ def create_logging_reward(
334
+ sample_logger: WandbSampleLogger | None = None,
335
+ compute_all_rewards: bool = True,
336
+ ) -> Callable[..., list[float]]:
337
+ """
338
+ Create a reward function that logs metrics and samples to wandb.
339
+
340
+ This replaces debug_print_reward with comprehensive wandb logging.
341
+ The returned function computes ALL individual rewards internally,
342
+ logs them to wandb, and returns [0.0] * len(completions) (no training effect).
343
+
344
+ Args:
345
+ sample_logger: WandbSampleLogger instance for sample table logging
346
+ compute_all_rewards: If True, compute and log all reward functions
347
+
348
+ Returns:
349
+ A reward function compatible with GRPOTrainer
350
+
351
+ Example:
352
+ sample_logger = WandbSampleLogger(log_every_n_steps=10)
353
+ logging_reward = create_logging_reward(sample_logger)
354
+
355
+ trainer = GRPOTrainer(
356
+ reward_funcs=[..., logging_reward],
357
+ ...
358
+ )
359
+ """
360
+ global _LOGGING_STEP
361
+
362
+ def logging_reward(
363
+ prompts: Sequence[Sequence[dict[str, str]]],
364
+ completions: Sequence[Sequence[dict[str, str]]],
365
+ answer: Sequence[str],
366
+ **kwargs: object,
367
+ ) -> list[float]:
368
+ """Log rewards and samples to wandb. Returns 0.0 (no training effect)."""
369
+ global _LOGGING_STEP
370
+ _LOGGING_STEP += 1
371
+ step = _LOGGING_STEP
372
+
373
+ wandb = _get_wandb()
374
+ if wandb is None or wandb.run is None:
375
+ # Fallback to print if wandb not initialized
376
+ if step % 10 == 0:
377
+ print(f"[Step {step}] Q: {prompts[0][-1]['content'][:80]}...")
378
+ return [0.0] * len(completions)
379
+
380
+ # Compute all reward scores if requested
381
+ if compute_all_rewards:
382
+ reward_scores = _compute_all_reward_scores(prompts, completions, answer, **kwargs)
383
+ log_reward_metrics(step, reward_scores)
384
+ else:
385
+ reward_scores = {}
386
+
387
+ # Log samples periodically
388
+ if sample_logger and sample_logger.should_log(step):
389
+ # Add current batch to sample logger
390
+ for i in range(min(sample_logger.max_samples_per_log, len(prompts))):
391
+ question = prompts[i][-1]["content"]
392
+ response = completions[i][0]["content"]
393
+ truth = answer[i] if i < len(answer) else ""
394
+
395
+ # Get individual rewards for this sample
396
+ sample_rewards = {
397
+ name: scores[i] if i < len(scores) else 0.0
398
+ for name, scores in reward_scores.items()
399
+ }
400
+
401
+ sample_logger.add_sample(
402
+ step=step,
403
+ question=question,
404
+ response=response,
405
+ ground_truth=truth,
406
+ rewards=sample_rewards,
407
+ )
408
+
409
+ sample_logger.log_table(step)
410
+
411
+ return [0.0] * len(completions)
412
+
413
+ return logging_reward
414
+
415
+
416
+ def _compute_all_reward_scores(
417
+ prompts: Sequence[Sequence[dict[str, str]]],
418
+ completions: Sequence[Sequence[dict[str, str]]],
419
+ answer: Sequence[str],
420
+ **kwargs: object,
421
+ ) -> dict[str, list[float]]:
422
+ """
423
+ Compute all reward function scores for logging.
424
+
425
+ Returns dict mapping reward name to list of scores.
426
+ """
427
+ # Import reward functions here to avoid circular imports
428
+ from prolewiki_llm.grpo_rewards import (
429
+ completeness_reward,
430
+ interconnection_depth_reward,
431
+ match_format_approximately,
432
+ match_format_exactly,
433
+ nli_coherence_reward,
434
+ topic_relevance_reward,
435
+ )
436
+
437
+ reward_scores: dict[str, list[float]] = {}
438
+
439
+ # Format rewards (don't need answer)
440
+ try:
441
+ reward_scores["format_exact"] = match_format_exactly(completions, **kwargs)
442
+ except Exception as e:
443
+ print(f"[WandbLogging] Error in format_exact: {e}")
444
+ reward_scores["format_exact"] = [0.0] * len(completions)
445
+
446
+ try:
447
+ reward_scores["format_approx"] = match_format_approximately(completions, **kwargs)
448
+ except Exception as e:
449
+ print(f"[WandbLogging] Error in format_approx: {e}")
450
+ reward_scores["format_approx"] = [0.0] * len(completions)
451
+
452
+ # NLI coherence (needs answer)
453
+ try:
454
+ reward_scores["nli_coherence"] = nli_coherence_reward(completions, answer, **kwargs)
455
+ except Exception as e:
456
+ print(f"[WandbLogging] Error in nli_coherence: {e}")
457
+ reward_scores["nli_coherence"] = [0.0] * len(completions)
458
+
459
+ # Topic relevance (needs prompts)
460
+ try:
461
+ reward_scores["topic_relevance"] = topic_relevance_reward(prompts, completions, **kwargs)
462
+ except Exception as e:
463
+ print(f"[WandbLogging] Error in topic_relevance: {e}")
464
+ reward_scores["topic_relevance"] = [0.0] * len(completions)
465
+
466
+ # Interconnection depth
467
+ try:
468
+ reward_scores["interconnection_depth"] = interconnection_depth_reward(completions, **kwargs)
469
+ except Exception as e:
470
+ print(f"[WandbLogging] Error in interconnection_depth: {e}")
471
+ reward_scores["interconnection_depth"] = [0.0] * len(completions)
472
+
473
+ # Completeness (needs answer)
474
+ try:
475
+ reward_scores["completeness"] = completeness_reward(completions, answer, **kwargs)
476
+ except Exception as e:
477
+ print(f"[WandbLogging] Error in completeness: {e}")
478
+ reward_scores["completeness"] = [0.0] * len(completions)
479
+
480
+ return reward_scores
481
+
482
+
483
+ # =============================================================================
484
+ # UTILITY FUNCTIONS
485
+ # =============================================================================
486
+
487
+
488
+ def finish_wandb_logging(summary: dict[str, Any] | None = None) -> None:
489
+ """
490
+ Finish the wandb run with optional summary statistics.
491
+
492
+ Args:
493
+ summary: Optional dict of final summary metrics
494
+ """
495
+ wandb = _get_wandb()
496
+ if wandb is None or wandb.run is None:
497
+ return
498
+
499
+ if summary:
500
+ for key, value in summary.items():
501
+ wandb.run.summary[key] = value
502
+
503
+ wandb.finish()
504
+ print("[WandbLogging] Run finished.")
505
+
506
+
507
+ def log_model_checkpoint(
508
+ checkpoint_path: str,
509
+ metadata: dict[str, Any] | None = None,
510
+ ) -> None:
511
+ """
512
+ Log a model checkpoint as a wandb artifact.
513
+
514
+ Args:
515
+ checkpoint_path: Path to the checkpoint directory
516
+ metadata: Optional metadata about the checkpoint
517
+ """
518
+ wandb = _get_wandb()
519
+ if wandb is None or wandb.run is None:
520
+ return
521
+
522
+ artifact = wandb.Artifact(
523
+ name=f"checkpoint-{wandb.run.name}",
524
+ type="model",
525
+ metadata=metadata or {},
526
+ )
527
+ artifact.add_dir(checkpoint_path)
528
+ wandb.log_artifact(artifact)
529
+ print(f"[WandbLogging] Logged checkpoint: {checkpoint_path}")
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # prolewiki-llm tests
tests/conftest.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytest configuration and fixtures for prolewiki-llm tests."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ # =============================================================================
9
+ # Module-Level Setup for train_headless.py Tests
10
+ # =============================================================================
11
+ # train_headless.py reads required env vars (HF_TOKEN, WANDB_API_KEY) at import time.
12
+ # We need to set dummy values BEFORE pytest collects tests that import this module.
13
+ # This must be done at module level, not in a fixture, because fixtures run
14
+ # after collection and the import happens during collection.
15
+
16
+ if "HF_TOKEN" not in os.environ:
17
+ os.environ["HF_TOKEN"] = "test-token-for-unit-tests"
18
+ if "WANDB_API_KEY" not in os.environ:
19
+ os.environ["WANDB_API_KEY"] = "test-key-for-unit-tests"
20
+
21
+
22
+ @pytest.fixture
23
+ def sample_question() -> str:
24
+ """Sample Marxist theory question for testing."""
25
+ return "What is the difference between revisionism and Marxism-Leninism?"
26
+
27
+
28
+ @pytest.fixture
29
+ def sample_good_answer() -> str:
30
+ """Sample well-structured answer for testing rewards."""
31
+ return """<think>
32
+ The question asks about the distinction between revisionism and Marxism-Leninism,
33
+ which requires explaining both concepts and their historical relationship.
34
+ </think>
35
+
36
+ Revisionism refers to attempts to revise or water down the fundamental principles
37
+ of Marxism, particularly the necessity of class struggle and proletarian
38
+ dictatorship. It emerged historically with Eduard Bernstein who rejected
39
+ revolutionary change in favor of gradual reform within capitalism.
40
+
41
+ Marxism-Leninism, in contrast, upholds the core tenets of scientific socialism:
42
+ the class nature of the state, the necessity of socialist revolution, and the
43
+ dictatorship of the proletariat as the transitional form to communism.
44
+
45
+ The key distinction lies in their approach to the capitalist state - revisionists
46
+ seek to reform it, while Marxist-Leninists understand it must be replaced by a
47
+ workers' state through revolutionary action."""
48
+
49
+
50
+ @pytest.fixture
51
+ def sample_bad_answer() -> str:
52
+ """Sample buzzword-heavy answer lacking depth."""
53
+ return """Revisionism is bad and Marxism-Leninism is good. We must fight
54
+ imperialism and support the working class. Solidarity forever! The revolution
55
+ will triumph against capitalism and fascism. Workers of the world unite!"""
56
+
57
+
58
+ # =============================================================================
59
+ # Integration Test Fixtures
60
+ # =============================================================================
61
+
62
+
63
+ @pytest.fixture
64
+ def project_root() -> Path:
65
+ """Return the project root directory."""
66
+ return Path(__file__).parent.parent
67
+
68
+
69
+ @pytest.fixture
70
+ def mock_bin_dir(tmp_path: Path) -> Path:
71
+ """Create mock bin directory with executable stubs for shell script testing.
72
+
73
+ Copies mock scripts from tests/fixtures/mock_bin/ to a temporary directory
74
+ and makes them executable.
75
+ """
76
+ mock_bin = tmp_path / "mock_bin"
77
+ mock_bin.mkdir()
78
+
79
+ # Copy mock scripts from fixtures
80
+ fixtures_dir = Path(__file__).parent / "fixtures" / "mock_bin"
81
+ if fixtures_dir.exists():
82
+ for script in fixtures_dir.iterdir():
83
+ if script.is_file():
84
+ dest = mock_bin / script.name
85
+ dest.write_text(script.read_text())
86
+ dest.chmod(0o755) # Make executable
87
+
88
+ return mock_bin
89
+
90
+
91
+ @pytest.fixture
92
+ def start_sh_env(tmp_path: Path, mock_bin_dir: Path) -> dict[str, str]:
93
+ """Base environment for start.sh integration tests.
94
+
95
+ Provides a controlled environment with:
96
+ - PATH pointing to mock binaries first
97
+ - Log directory for capturing mock invocations
98
+ - Default success values for mock commands
99
+ - Output directories in tmp_path (not /workspace)
100
+ """
101
+ log_dir = tmp_path / "logs"
102
+ log_dir.mkdir()
103
+
104
+ return {
105
+ "PATH": f"{mock_bin_dir}:/usr/bin:/bin",
106
+ "HOME": str(tmp_path),
107
+ "MOCK_LOG_DIR": str(log_dir),
108
+ "MOCK_CUDA_AVAILABLE": "1",
109
+ "MOCK_TRAINING_EXIT_CODE": "0",
110
+ # Override /workspace paths to use tmp_path
111
+ "CHECKPOINT_DIR": str(tmp_path / "checkpoints"),
112
+ "LORA_OUTPUT": str(tmp_path / "lora-output"),
113
+ "OUTPUT_DIR": str(tmp_path / "outputs"),
114
+ }
tests/fixtures/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Test fixtures for prolewiki-llm tests."""
tests/fixtures/mock_bin/huggingface-cli ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Mock huggingface-cli for testing authentication
3
+
4
+ LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
5
+ echo "huggingface-cli $*" >> "$LOG_FILE"
6
+
7
+ exit "${MOCK_HF_CLI_EXIT_CODE:-0}"
tests/fixtures/mock_bin/python ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Mock python script for testing start.sh
3
+ # Handles CUDA checks and training module invocations
4
+
5
+ LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
6
+ echo "python $*" >> "$LOG_FILE"
7
+
8
+ # Handle CUDA availability check
9
+ if [[ "$*" == *"torch.cuda.is_available"* ]]; then
10
+ if [ "${MOCK_CUDA_AVAILABLE:-1}" = "0" ]; then
11
+ echo "AssertionError: CUDA not available" >&2
12
+ exit 1
13
+ fi
14
+ exit 0
15
+ fi
16
+
17
+ # Handle GPU name query
18
+ if [[ "$*" == *"get_device_name"* ]]; then
19
+ echo "Mock GPU RTX 4090"
20
+ exit 0
21
+ fi
22
+
23
+ # Handle VRAM query
24
+ if [[ "$*" == *"get_device_properties"* ]]; then
25
+ echo "24.0 GB"
26
+ exit 0
27
+ fi
28
+
29
+ # Handle training module
30
+ if [[ "$*" == *"prolewiki_llm.train_headless"* ]]; then
31
+ exit "${MOCK_TRAINING_EXIT_CODE:-0}"
32
+ fi
33
+
34
+ # Default: success
35
+ exit 0
tests/fixtures/mock_bin/runpodctl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Mock runpodctl for testing start.sh auto-shutoff behavior
3
+
4
+ LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
5
+ echo "runpodctl $*" >> "$LOG_FILE"
6
+
7
+ # Simulate successful pod operations
8
+ if [[ "$1" == "stop" ]] && [[ "$2" == "pod" ]]; then
9
+ echo "pod \"$3\" stopped"
10
+ fi
11
+
12
+ exit "${MOCK_RUNPODCTL_EXIT_CODE:-0}"
tests/fixtures/mock_bin/wandb ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Mock wandb CLI for testing authentication
3
+
4
+ LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
5
+ echo "wandb $*" >> "$LOG_FILE"
6
+
7
+ exit "${MOCK_WANDB_EXIT_CODE:-0}"
tests/fixtures/mock_bin/wget ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Mock wget for testing dataset download
3
+
4
+ LOG_FILE="${MOCK_LOG_DIR:-/tmp}/mock_calls.log"
5
+ echo "wget $*" >> "$LOG_FILE"
6
+
7
+ # If -O flag is present, create a dummy file
8
+ for i in "$@"; do
9
+ if [[ "$prev" == "-O" ]]; then
10
+ # Create dummy dataset file
11
+ echo '{"prompt": "test", "answer": "test"}' > "$i"
12
+ break
13
+ fi
14
+ prev="$i"
15
+ done
16
+
17
+ exit "${MOCK_WGET_EXIT_CODE:-0}"
tests/integration/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Integration tests for prolewiki-llm."""
tests/integration/test_start_sh.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for the headless training entrypoint script.
3
+
4
+ Tests cover:
5
+ - Environment validation (required secrets, GPU check)
6
+ - Auto-shutoff behavior (FinOps: pod termination on success, no termination on failure)
7
+ - Data handling (dataset validation, download from URL)
8
+ - Authentication (HuggingFace, Weights & Biases)
9
+
10
+ These tests use subprocess to run start.sh with mock external commands.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import subprocess
16
+ from pathlib import Path
17
+
18
+ import pytest
19
+
20
+ # =============================================================================
21
+ # Environment Validation Tests
22
+ # =============================================================================
23
+
24
+
25
+ @pytest.mark.integration
26
+ class TestEnvironmentValidation:
27
+ """Test environment variable validation in start.sh."""
28
+
29
+ def test_exits_without_hf_token(
30
+ self,
31
+ tmp_path: Path,
32
+ start_sh_env: dict[str, str],
33
+ project_root: Path,
34
+ ) -> None:
35
+ """Script exits with error when HF_TOKEN is not set."""
36
+ dataset = tmp_path / "dataset.jsonl"
37
+ dataset.write_text('{"prompt":"test"}\n')
38
+
39
+ env = {
40
+ **start_sh_env,
41
+ # HF_TOKEN intentionally missing
42
+ "WANDB_API_KEY": "test-key",
43
+ "DATASET_PATH": str(dataset),
44
+ }
45
+
46
+ result = subprocess.run(
47
+ ["bash", str(project_root / "docker" / "start.sh")],
48
+ env=env,
49
+ capture_output=True,
50
+ text=True,
51
+ timeout=30,
52
+ )
53
+
54
+ assert result.returncode == 1
55
+ assert "HF_TOKEN" in result.stdout or "HF_TOKEN" in result.stderr
56
+
57
+ def test_exits_without_wandb_key(
58
+ self,
59
+ tmp_path: Path,
60
+ start_sh_env: dict[str, str],
61
+ project_root: Path,
62
+ ) -> None:
63
+ """Script exits with error when WANDB_API_KEY is not set."""
64
+ dataset = tmp_path / "dataset.jsonl"
65
+ dataset.write_text('{"prompt":"test"}\n')
66
+
67
+ env = {
68
+ **start_sh_env,
69
+ "HF_TOKEN": "test-token",
70
+ # WANDB_API_KEY intentionally missing
71
+ "DATASET_PATH": str(dataset),
72
+ }
73
+
74
+ result = subprocess.run(
75
+ ["bash", str(project_root / "docker" / "start.sh")],
76
+ env=env,
77
+ capture_output=True,
78
+ text=True,
79
+ timeout=30,
80
+ )
81
+
82
+ assert result.returncode == 1
83
+ assert "WANDB_API_KEY" in result.stdout or "WANDB_API_KEY" in result.stderr
84
+
85
+ def test_exits_without_cuda(
86
+ self,
87
+ tmp_path: Path,
88
+ start_sh_env: dict[str, str],
89
+ project_root: Path,
90
+ ) -> None:
91
+ """Script exits with error when CUDA is not available."""
92
+ dataset = tmp_path / "dataset.jsonl"
93
+ dataset.write_text('{"prompt":"test"}\n')
94
+
95
+ env = {
96
+ **start_sh_env,
97
+ "HF_TOKEN": "test-token",
98
+ "WANDB_API_KEY": "test-key",
99
+ "DATASET_PATH": str(dataset),
100
+ "MOCK_CUDA_AVAILABLE": "0", # Simulate no GPU
101
+ }
102
+
103
+ result = subprocess.run(
104
+ ["bash", str(project_root / "docker" / "start.sh")],
105
+ env=env,
106
+ capture_output=True,
107
+ text=True,
108
+ timeout=30,
109
+ )
110
+
111
+ assert result.returncode == 1
112
+ assert "CUDA" in result.stdout or "GPU" in result.stdout
113
+
114
+
115
+ # =============================================================================
116
+ # Auto-Shutoff Tests (Critical for FinOps)
117
+ # =============================================================================
118
+
119
+
120
+ @pytest.mark.integration
121
+ class TestAutoShutoff:
122
+ """Test pod auto-termination per RunPod strategy document.
123
+
124
+ These tests verify the FinOps strategy that prevents zombie pods.
125
+ """
126
+
127
+ def test_calls_runpodctl_stop_on_success(
128
+ self,
129
+ tmp_path: Path,
130
+ start_sh_env: dict[str, str],
131
+ project_root: Path,
132
+ ) -> None:
133
+ """Pod is terminated after successful training (FinOps)."""
134
+ dataset = tmp_path / "dataset.jsonl"
135
+ dataset.write_text('{"prompt":"test"}\n')
136
+
137
+ env = {
138
+ **start_sh_env,
139
+ "HF_TOKEN": "test-token",
140
+ "WANDB_API_KEY": "test-key",
141
+ "RUNPOD_POD_ID": "pod-abc123",
142
+ "DATASET_PATH": str(dataset),
143
+ "MOCK_TRAINING_EXIT_CODE": "0", # Training succeeds
144
+ }
145
+
146
+ result = subprocess.run(
147
+ ["bash", str(project_root / "docker" / "start.sh")],
148
+ env=env,
149
+ capture_output=True,
150
+ text=True,
151
+ timeout=30,
152
+ )
153
+
154
+ # Verify runpodctl was called
155
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
156
+ calls = log_file.read_text() if log_file.exists() else ""
157
+
158
+ assert result.returncode == 0
159
+ assert "runpodctl stop pod pod-abc123" in calls
160
+
161
+ def test_no_runpodctl_stop_without_pod_id(
162
+ self,
163
+ tmp_path: Path,
164
+ start_sh_env: dict[str, str],
165
+ project_root: Path,
166
+ ) -> None:
167
+ """No pod termination when RUNPOD_POD_ID is not set."""
168
+ dataset = tmp_path / "dataset.jsonl"
169
+ dataset.write_text('{"prompt":"test"}\n')
170
+
171
+ env = {
172
+ **start_sh_env,
173
+ "HF_TOKEN": "test-token",
174
+ "WANDB_API_KEY": "test-key",
175
+ # RUNPOD_POD_ID intentionally not set
176
+ "DATASET_PATH": str(dataset),
177
+ "MOCK_TRAINING_EXIT_CODE": "0",
178
+ }
179
+
180
+ result = subprocess.run(
181
+ ["bash", str(project_root / "docker" / "start.sh")],
182
+ env=env,
183
+ capture_output=True,
184
+ text=True,
185
+ timeout=30,
186
+ )
187
+
188
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
189
+ calls = log_file.read_text() if log_file.exists() else ""
190
+
191
+ assert result.returncode == 0
192
+ assert "runpodctl" not in calls
193
+ assert "RUNPOD_POD_ID not set" in result.stdout
194
+
195
+ def test_no_runpodctl_stop_on_failure(
196
+ self,
197
+ tmp_path: Path,
198
+ start_sh_env: dict[str, str],
199
+ project_root: Path,
200
+ ) -> None:
201
+ """Pod is NOT terminated on failure (allows debugging).
202
+
203
+ This is critical: developers need SSH access to debug failed training.
204
+ """
205
+ dataset = tmp_path / "dataset.jsonl"
206
+ dataset.write_text('{"prompt":"test"}\n')
207
+
208
+ env = {
209
+ **start_sh_env,
210
+ "HF_TOKEN": "test-token",
211
+ "WANDB_API_KEY": "test-key",
212
+ "RUNPOD_POD_ID": "pod-abc123",
213
+ "DATASET_PATH": str(dataset),
214
+ "MOCK_TRAINING_EXIT_CODE": "1", # Training fails
215
+ }
216
+
217
+ result = subprocess.run(
218
+ ["bash", str(project_root / "docker" / "start.sh")],
219
+ env=env,
220
+ capture_output=True,
221
+ text=True,
222
+ timeout=30,
223
+ )
224
+
225
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
226
+ calls = log_file.read_text() if log_file.exists() else ""
227
+
228
+ # Pod should NOT be stopped (for debugging)
229
+ assert "runpodctl" not in calls
230
+ assert result.returncode == 1
231
+ assert "NOT be automatically terminated" in result.stdout
232
+
233
+ def test_exit_code_propagates_from_training(
234
+ self,
235
+ tmp_path: Path,
236
+ start_sh_env: dict[str, str],
237
+ project_root: Path,
238
+ ) -> None:
239
+ """Script exits with the same code as the training process."""
240
+ dataset = tmp_path / "dataset.jsonl"
241
+ dataset.write_text('{"prompt":"test"}\n')
242
+
243
+ env = {
244
+ **start_sh_env,
245
+ "HF_TOKEN": "test-token",
246
+ "WANDB_API_KEY": "test-key",
247
+ "DATASET_PATH": str(dataset),
248
+ "MOCK_TRAINING_EXIT_CODE": "42", # Custom exit code
249
+ }
250
+
251
+ result = subprocess.run(
252
+ ["bash", str(project_root / "docker" / "start.sh")],
253
+ env=env,
254
+ capture_output=True,
255
+ text=True,
256
+ timeout=30,
257
+ )
258
+
259
+ assert result.returncode == 42
260
+
261
+
262
+ # =============================================================================
263
+ # Data Handling Tests
264
+ # =============================================================================
265
+
266
+
267
+ @pytest.mark.integration
268
+ class TestDataHandling:
269
+ """Test dataset validation and download logic."""
270
+
271
+ def test_exits_when_dataset_missing_no_url(
272
+ self,
273
+ tmp_path: Path,
274
+ start_sh_env: dict[str, str],
275
+ project_root: Path,
276
+ ) -> None:
277
+ """Script exits when dataset is missing and no DATASET_URL provided."""
278
+ env = {
279
+ **start_sh_env,
280
+ "HF_TOKEN": "test-token",
281
+ "WANDB_API_KEY": "test-key",
282
+ "DATASET_PATH": str(tmp_path / "nonexistent.jsonl"),
283
+ # DATASET_URL not set
284
+ }
285
+
286
+ result = subprocess.run(
287
+ ["bash", str(project_root / "docker" / "start.sh")],
288
+ env=env,
289
+ capture_output=True,
290
+ text=True,
291
+ timeout=30,
292
+ )
293
+
294
+ assert result.returncode == 1
295
+ assert "Dataset not found" in result.stdout
296
+
297
+ def test_downloads_dataset_from_url(
298
+ self,
299
+ tmp_path: Path,
300
+ start_sh_env: dict[str, str],
301
+ project_root: Path,
302
+ ) -> None:
303
+ """Script downloads dataset when file missing but DATASET_URL set."""
304
+ dataset_path = tmp_path / "dataset.jsonl"
305
+
306
+ env = {
307
+ **start_sh_env,
308
+ "HF_TOKEN": "test-token",
309
+ "WANDB_API_KEY": "test-key",
310
+ "DATASET_PATH": str(dataset_path),
311
+ "DATASET_URL": "https://example.com/dataset.jsonl",
312
+ }
313
+
314
+ subprocess.run(
315
+ ["bash", str(project_root / "docker" / "start.sh")],
316
+ env=env,
317
+ capture_output=True,
318
+ text=True,
319
+ timeout=30,
320
+ )
321
+
322
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
323
+ calls = log_file.read_text() if log_file.exists() else ""
324
+
325
+ # wget should have been called
326
+ assert "wget" in calls
327
+ assert "https://example.com/dataset.jsonl" in calls
328
+
329
+ def test_skips_download_when_dataset_exists(
330
+ self,
331
+ tmp_path: Path,
332
+ start_sh_env: dict[str, str],
333
+ project_root: Path,
334
+ ) -> None:
335
+ """Script does not download when dataset file already exists."""
336
+ dataset = tmp_path / "dataset.jsonl"
337
+ dataset.write_text('{"prompt":"test"}\n')
338
+
339
+ env = {
340
+ **start_sh_env,
341
+ "HF_TOKEN": "test-token",
342
+ "WANDB_API_KEY": "test-key",
343
+ "DATASET_PATH": str(dataset),
344
+ "DATASET_URL": "https://example.com/should-not-download.jsonl",
345
+ }
346
+
347
+ result = subprocess.run(
348
+ ["bash", str(project_root / "docker" / "start.sh")],
349
+ env=env,
350
+ capture_output=True,
351
+ text=True,
352
+ timeout=30,
353
+ )
354
+
355
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
356
+ calls = log_file.read_text() if log_file.exists() else ""
357
+
358
+ # wget should NOT have been called
359
+ assert "wget" not in calls
360
+ assert result.returncode == 0
361
+
362
+
363
+ # =============================================================================
364
+ # Authentication Tests
365
+ # =============================================================================
366
+
367
+
368
+ @pytest.mark.integration
369
+ class TestAuthentication:
370
+ """Test authentication with external services."""
371
+
372
+ def test_calls_huggingface_cli_login(
373
+ self,
374
+ tmp_path: Path,
375
+ start_sh_env: dict[str, str],
376
+ project_root: Path,
377
+ ) -> None:
378
+ """Script calls huggingface-cli login with the token."""
379
+ dataset = tmp_path / "dataset.jsonl"
380
+ dataset.write_text('{"prompt":"test"}\n')
381
+
382
+ env = {
383
+ **start_sh_env,
384
+ "HF_TOKEN": "hf_test_token_12345",
385
+ "WANDB_API_KEY": "test-key",
386
+ "DATASET_PATH": str(dataset),
387
+ }
388
+
389
+ subprocess.run(
390
+ ["bash", str(project_root / "docker" / "start.sh")],
391
+ env=env,
392
+ capture_output=True,
393
+ text=True,
394
+ timeout=30,
395
+ )
396
+
397
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
398
+ calls = log_file.read_text() if log_file.exists() else ""
399
+
400
+ assert "huggingface-cli login" in calls
401
+ assert "hf_test_token_12345" in calls
402
+
403
+ def test_calls_wandb_login(
404
+ self,
405
+ tmp_path: Path,
406
+ start_sh_env: dict[str, str],
407
+ project_root: Path,
408
+ ) -> None:
409
+ """Script calls wandb login with the API key."""
410
+ dataset = tmp_path / "dataset.jsonl"
411
+ dataset.write_text('{"prompt":"test"}\n')
412
+
413
+ env = {
414
+ **start_sh_env,
415
+ "HF_TOKEN": "test-token",
416
+ "WANDB_API_KEY": "wandb_key_67890",
417
+ "DATASET_PATH": str(dataset),
418
+ }
419
+
420
+ subprocess.run(
421
+ ["bash", str(project_root / "docker" / "start.sh")],
422
+ env=env,
423
+ capture_output=True,
424
+ text=True,
425
+ timeout=30,
426
+ )
427
+
428
+ log_file = Path(env["MOCK_LOG_DIR"]) / "mock_calls.log"
429
+ calls = log_file.read_text() if log_file.exists() else ""
430
+
431
+ assert "wandb login" in calls
432
+ assert "wandb_key_67890" in calls
433
+
434
+ def test_auth_failures_dont_stop_script(
435
+ self,
436
+ tmp_path: Path,
437
+ start_sh_env: dict[str, str],
438
+ project_root: Path,
439
+ ) -> None:
440
+ """Script continues even if authentication commands fail (|| true)."""
441
+ dataset = tmp_path / "dataset.jsonl"
442
+ dataset.write_text('{"prompt":"test"}\n')
443
+
444
+ env = {
445
+ **start_sh_env,
446
+ "HF_TOKEN": "test-token",
447
+ "WANDB_API_KEY": "test-key",
448
+ "DATASET_PATH": str(dataset),
449
+ "MOCK_HF_CLI_EXIT_CODE": "1", # HF login fails
450
+ "MOCK_WANDB_EXIT_CODE": "1", # wandb login fails
451
+ }
452
+
453
+ result = subprocess.run(
454
+ ["bash", str(project_root / "docker" / "start.sh")],
455
+ env=env,
456
+ capture_output=True,
457
+ text=True,
458
+ timeout=30,
459
+ )
460
+
461
+ # Script should still succeed (auth failures are non-fatal)
462
+ assert result.returncode == 0
tests/unit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Unit tests
tests/unit/test_grpo_rewards.py ADDED
The diff for this file is too large to render. See raw diff
 
tests/unit/test_train_headless.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for the headless training module.
3
+
4
+ Tests cover:
5
+ - Environment variable parsing (get_env, get_env_int, get_env_float)
6
+ - Checkpoint discovery (find_latest_checkpoint)
7
+ - Model upload to HuggingFace Hub (upload_to_hub)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+ from unittest.mock import MagicMock, patch
14
+
15
+ import pytest
16
+
17
+ # =============================================================================
18
+ # Environment Variable Parsing Tests
19
+ # =============================================================================
20
+
21
+
22
+ class TestGetEnv:
23
+ """Test environment variable retrieval functions."""
24
+
25
+ def test_get_env_returns_value(self, monkeypatch: pytest.MonkeyPatch) -> None:
26
+ """get_env returns the environment variable value when set."""
27
+ monkeypatch.setenv("TEST_VAR", "test_value")
28
+
29
+ # Import after setting env to avoid module-level checks
30
+ from prolewiki_llm.train_headless import get_env
31
+
32
+ assert get_env("TEST_VAR") == "test_value"
33
+
34
+ def test_get_env_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
35
+ """get_env returns default when variable not set."""
36
+ monkeypatch.delenv("NONEXISTENT_VAR", raising=False)
37
+
38
+ from prolewiki_llm.train_headless import get_env
39
+
40
+ assert get_env("NONEXISTENT_VAR", "default_value") == "default_value"
41
+
42
+ def test_get_env_required_exits(self, monkeypatch: pytest.MonkeyPatch) -> None:
43
+ """get_env exits when required variable is missing."""
44
+ monkeypatch.delenv("REQUIRED_VAR", raising=False)
45
+
46
+ from prolewiki_llm.train_headless import get_env
47
+
48
+ with pytest.raises(SystemExit) as exc_info:
49
+ get_env("REQUIRED_VAR", required=True)
50
+
51
+ assert exc_info.value.code == 1
52
+
53
+
54
+ class TestGetEnvInt:
55
+ """Test integer environment variable parsing."""
56
+
57
+ def test_get_env_int_parses_integer(self, monkeypatch: pytest.MonkeyPatch) -> None:
58
+ """get_env_int correctly parses integer values."""
59
+ monkeypatch.setenv("INT_VAR", "42")
60
+
61
+ from prolewiki_llm.train_headless import get_env_int
62
+
63
+ assert get_env_int("INT_VAR", 0) == 42
64
+
65
+ def test_get_env_int_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
66
+ """get_env_int returns default when variable not set."""
67
+ monkeypatch.delenv("NONEXISTENT_INT", raising=False)
68
+
69
+ from prolewiki_llm.train_headless import get_env_int
70
+
71
+ assert get_env_int("NONEXISTENT_INT", 100) == 100
72
+
73
+
74
+ class TestGetEnvFloat:
75
+ """Test float environment variable parsing."""
76
+
77
+ def test_get_env_float_parses_float(self, monkeypatch: pytest.MonkeyPatch) -> None:
78
+ """get_env_float correctly parses float values."""
79
+ monkeypatch.setenv("FLOAT_VAR", "3.14")
80
+
81
+ from prolewiki_llm.train_headless import get_env_float
82
+
83
+ assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(3.14)
84
+
85
+ def test_get_env_float_parses_scientific(
86
+ self, monkeypatch: pytest.MonkeyPatch
87
+ ) -> None:
88
+ """get_env_float correctly parses scientific notation."""
89
+ monkeypatch.setenv("FLOAT_VAR", "5e-6")
90
+
91
+ from prolewiki_llm.train_headless import get_env_float
92
+
93
+ assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(5e-6)
94
+
95
+ def test_get_env_float_returns_default(
96
+ self, monkeypatch: pytest.MonkeyPatch
97
+ ) -> None:
98
+ """get_env_float returns default when variable not set."""
99
+ monkeypatch.delenv("NONEXISTENT_FLOAT", raising=False)
100
+
101
+ from prolewiki_llm.train_headless import get_env_float
102
+
103
+ assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5
104
+
105
+
106
+ # =============================================================================
107
+ # Checkpoint Discovery Tests
108
+ # =============================================================================
109
+
110
+
111
+ class TestFindLatestCheckpoint:
112
+ """Test checkpoint discovery logic."""
113
+
114
+ def test_returns_none_for_nonexistent_dir(self, tmp_path: Path) -> None:
115
+ """Returns None when checkpoint directory doesn't exist."""
116
+ from prolewiki_llm.train_headless import find_latest_checkpoint
117
+
118
+ nonexistent = tmp_path / "nonexistent"
119
+ assert find_latest_checkpoint(nonexistent) is None
120
+
121
+ def test_returns_none_for_empty_dir(self, tmp_path: Path) -> None:
122
+ """Returns None when checkpoint directory is empty."""
123
+ from prolewiki_llm.train_headless import find_latest_checkpoint
124
+
125
+ checkpoint_dir = tmp_path / "checkpoints"
126
+ checkpoint_dir.mkdir()
127
+
128
+ assert find_latest_checkpoint(checkpoint_dir) is None
129
+
130
+ def test_returns_none_when_no_checkpoint_dirs(self, tmp_path: Path) -> None:
131
+ """Returns None when no checkpoint-* directories exist."""
132
+ from prolewiki_llm.train_headless import find_latest_checkpoint
133
+
134
+ checkpoint_dir = tmp_path / "checkpoints"
135
+ checkpoint_dir.mkdir()
136
+ # Create non-checkpoint directories
137
+ (checkpoint_dir / "random_dir").mkdir()
138
+ (checkpoint_dir / "other_file.txt").write_text("test")
139
+
140
+ assert find_latest_checkpoint(checkpoint_dir) is None
141
+
142
+ def test_finds_single_checkpoint(self, tmp_path: Path) -> None:
143
+ """Finds single checkpoint directory."""
144
+ from prolewiki_llm.train_headless import find_latest_checkpoint
145
+
146
+ checkpoint_dir = tmp_path / "checkpoints"
147
+ checkpoint_dir.mkdir()
148
+ checkpoint = checkpoint_dir / "checkpoint-100"
149
+ checkpoint.mkdir()
150
+
151
+ result = find_latest_checkpoint(checkpoint_dir)
152
+ assert result == checkpoint
153
+
154
+ def test_finds_latest_checkpoint(self, tmp_path: Path) -> None:
155
+ """Finds the checkpoint with the highest step number."""
156
+ from prolewiki_llm.train_headless import find_latest_checkpoint
157
+
158
+ checkpoint_dir = tmp_path / "checkpoints"
159
+ checkpoint_dir.mkdir()
160
+
161
+ # Create checkpoints in random order
162
+ (checkpoint_dir / "checkpoint-50").mkdir()
163
+ (checkpoint_dir / "checkpoint-200").mkdir()
164
+ (checkpoint_dir / "checkpoint-100").mkdir()
165
+ (checkpoint_dir / "checkpoint-150").mkdir()
166
+
167
+ result = find_latest_checkpoint(checkpoint_dir)
168
+ assert result == checkpoint_dir / "checkpoint-200"
169
+
170
+ def test_ignores_non_checkpoint_dirs(self, tmp_path: Path) -> None:
171
+ """Ignores directories that don't match checkpoint-* pattern."""
172
+ from prolewiki_llm.train_headless import find_latest_checkpoint
173
+
174
+ checkpoint_dir = tmp_path / "checkpoints"
175
+ checkpoint_dir.mkdir()
176
+
177
+ # Create mix of checkpoint and non-checkpoint dirs
178
+ (checkpoint_dir / "checkpoint-50").mkdir()
179
+ (checkpoint_dir / "logs").mkdir()
180
+ (checkpoint_dir / "checkpoint-100").mkdir()
181
+ (checkpoint_dir / "outputs").mkdir()
182
+
183
+ result = find_latest_checkpoint(checkpoint_dir)
184
+ assert result == checkpoint_dir / "checkpoint-100"
185
+
186
+
187
+ # =============================================================================
188
+ # HuggingFace Hub Upload Tests
189
+ # =============================================================================
190
+
191
+
192
+ class TestUploadToHub:
193
+ """Test model upload to HuggingFace Hub."""
194
+
195
+ def test_creates_repo(self, tmp_path: Path) -> None:
196
+ """upload_to_hub creates the repository if it doesn't exist."""
197
+ from prolewiki_llm.train_headless import upload_to_hub
198
+
199
+ model_path = tmp_path / "lora-output"
200
+ model_path.mkdir()
201
+ (model_path / "adapter_model.safetensors").write_bytes(b"mock model")
202
+
203
+ mock_api = MagicMock()
204
+
205
+ # HfApi is imported inside upload_to_hub, so we patch at the source
206
+ with patch("huggingface_hub.HfApi", return_value=mock_api):
207
+ upload_to_hub(model_path, "test-org/test-model", "test-token")
208
+
209
+ mock_api.create_repo.assert_called_once_with(
210
+ "test-org/test-model", exist_ok=True, private=True
211
+ )
212
+
213
+ def test_uploads_folder(self, tmp_path: Path) -> None:
214
+ """upload_to_hub uploads the model folder."""
215
+ from prolewiki_llm.train_headless import upload_to_hub
216
+
217
+ model_path = tmp_path / "lora-output"
218
+ model_path.mkdir()
219
+ (model_path / "adapter_model.safetensors").write_bytes(b"mock model")
220
+
221
+ mock_api = MagicMock()
222
+
223
+ with patch("huggingface_hub.HfApi", return_value=mock_api):
224
+ upload_to_hub(model_path, "test-org/test-model", "test-token")
225
+
226
+ mock_api.upload_folder.assert_called_once_with(
227
+ folder_path=str(model_path),
228
+ repo_id="test-org/test-model",
229
+ commit_message="Headless GRPO training run",
230
+ )
231
+
232
+ def test_handles_repo_creation_failure(self, tmp_path: Path) -> None:
233
+ """upload_to_hub continues if repo already exists."""
234
+ from prolewiki_llm.train_headless import upload_to_hub
235
+
236
+ model_path = tmp_path / "lora-output"
237
+ model_path.mkdir()
238
+ (model_path / "adapter_model.safetensors").write_bytes(b"mock model")
239
+
240
+ mock_api = MagicMock()
241
+ mock_api.create_repo.side_effect = Exception("Repo already exists")
242
+
243
+ with patch("huggingface_hub.HfApi", return_value=mock_api):
244
+ # Should not raise
245
+ upload_to_hub(model_path, "test-org/test-model", "test-token")
246
+
247
+ # Should still attempt upload
248
+ mock_api.upload_folder.assert_called_once()
tests/unit/test_wandb_logging.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for Weights & Biases logging module.
3
+
4
+ Tests cover:
5
+ - WandbSampleLogger accumulation and table creation
6
+ - Reward metrics logging
7
+ - Logging reward function signature compatibility
8
+ - Graceful handling when wandb is not available
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from unittest.mock import MagicMock, patch
14
+
15
+ import pytest
16
+
17
+ # =============================================================================
18
+ # FIXTURES
19
+ # =============================================================================
20
+
21
+
22
+ @pytest.fixture
23
+ def mock_wandb() -> MagicMock:
24
+ """Create a mock wandb module."""
25
+ mock = MagicMock()
26
+ mock.run = MagicMock()
27
+ mock.run.name = "test-run"
28
+ mock.run.url = "https://wandb.ai/test/run"
29
+ mock.Table = MagicMock(return_value=MagicMock())
30
+ mock.init = MagicMock(return_value=mock.run)
31
+ mock.log = MagicMock()
32
+ mock.finish = MagicMock()
33
+ return mock
34
+
35
+
36
+ @pytest.fixture
37
+ def sample_prompts() -> list[list[dict[str, str]]]:
38
+ """Create sample prompts for testing."""
39
+ return [
40
+ [
41
+ {"role": "system", "content": "You are a Marxist assistant."},
42
+ {"role": "user", "content": "What is revisionism?"},
43
+ ],
44
+ [
45
+ {"role": "system", "content": "You are a Marxist assistant."},
46
+ {"role": "user", "content": "Explain surplus value."},
47
+ ],
48
+ ]
49
+
50
+
51
+ @pytest.fixture
52
+ def sample_completions() -> list[list[dict[str, str]]]:
53
+ """Create sample completions for testing."""
54
+ return [
55
+ [{"role": "assistant", "content": "</think>Revisionism distorts Marxist theory."}],
56
+ [{"role": "assistant", "content": "</think>Surplus value is unpaid labor."}],
57
+ ]
58
+
59
+
60
+ @pytest.fixture
61
+ def sample_answers() -> list[str]:
62
+ """Create sample ground truth answers."""
63
+ return [
64
+ "Revisionism is the distortion of Marxist-Leninist theory.",
65
+ "Surplus value is the value produced by workers beyond their wages.",
66
+ ]
67
+
68
+
69
+ # =============================================================================
70
+ # REWARD SAMPLE TESTS
71
+ # =============================================================================
72
+
73
+
74
+ class TestRewardSample:
75
+ """Test the RewardSample dataclass."""
76
+
77
+ def test_total_reward_calculation(self) -> None:
78
+ """Test that total_reward sums all rewards."""
79
+ from prolewiki_llm.wandb_logging import RewardSample
80
+
81
+ sample = RewardSample(
82
+ step=10,
83
+ question="What is X?",
84
+ response="X is Y.",
85
+ ground_truth="X is Y.",
86
+ rewards={
87
+ "format_exact": 3.0,
88
+ "nli_coherence": 2.0,
89
+ "topic_relevance": 1.5,
90
+ },
91
+ )
92
+
93
+ assert sample.total_reward == 6.5
94
+
95
+ def test_empty_rewards(self) -> None:
96
+ """Test total_reward with empty rewards dict."""
97
+ from prolewiki_llm.wandb_logging import RewardSample
98
+
99
+ sample = RewardSample(
100
+ step=10,
101
+ question="Q",
102
+ response="R",
103
+ ground_truth="T",
104
+ rewards={},
105
+ )
106
+
107
+ assert sample.total_reward == 0.0
108
+
109
+
110
+ # =============================================================================
111
+ # SAMPLE LOGGER TESTS
112
+ # =============================================================================
113
+
114
+
115
+ class TestWandbSampleLogger:
116
+ """Test the WandbSampleLogger class."""
117
+
118
+ def test_add_sample(self) -> None:
119
+ """Test adding samples to the logger."""
120
+ from prolewiki_llm.wandb_logging import WandbSampleLogger
121
+
122
+ logger = WandbSampleLogger(log_every_n_steps=10)
123
+
124
+ logger.add_sample(
125
+ step=5,
126
+ question="What is revisionism?",
127
+ response="Revisionism distorts theory.",
128
+ ground_truth="Revisionism is distortion of Marxism.",
129
+ rewards={"format": 3.0, "nli": 2.0},
130
+ )
131
+
132
+ assert len(logger._samples) == 1
133
+ assert logger._samples[0].step == 5
134
+ assert logger._samples[0].total_reward == 5.0
135
+
136
+ def test_sample_buffer_limit(self) -> None:
137
+ """Test that sample buffer doesn't grow unbounded."""
138
+ from prolewiki_llm.wandb_logging import WandbSampleLogger
139
+
140
+ logger = WandbSampleLogger(log_every_n_steps=10, max_samples_per_log=2)
141
+
142
+ # Add many samples
143
+ for i in range(20):
144
+ logger.add_sample(
145
+ step=i,
146
+ question=f"Q{i}",
147
+ response=f"R{i}",
148
+ ground_truth=f"T{i}",
149
+ rewards={"x": float(i)},
150
+ )
151
+
152
+ # Buffer should be limited (max_samples_per_log * 3 = 6)
153
+ assert len(logger._samples) <= 6
154
+
155
+ def test_should_log(self) -> None:
156
+ """Test should_log returns True at correct intervals."""
157
+ from prolewiki_llm.wandb_logging import WandbSampleLogger
158
+
159
+ logger = WandbSampleLogger(log_every_n_steps=10)
160
+
161
+ assert not logger.should_log(0) # Step 0 doesn't log
162
+ assert not logger.should_log(5)
163
+ assert logger.should_log(10)
164
+ assert not logger.should_log(15)
165
+ assert logger.should_log(20)
166
+
167
+ def test_clear(self) -> None:
168
+ """Test clearing the sample buffer."""
169
+ from prolewiki_llm.wandb_logging import WandbSampleLogger
170
+
171
+ logger = WandbSampleLogger()
172
+ logger.add_sample(1, "Q", "R", "T", {"x": 1.0})
173
+ logger.add_sample(2, "Q", "R", "T", {"x": 2.0})
174
+
175
+ assert len(logger._samples) == 2
176
+
177
+ logger.clear()
178
+ assert len(logger._samples) == 0
179
+
180
+ def test_truncation(self) -> None:
181
+ """Test that long strings are truncated."""
182
+ from prolewiki_llm.wandb_logging import WandbSampleLogger
183
+
184
+ logger = WandbSampleLogger()
185
+
186
+ long_text = "x" * 1000 # 1000 characters
187
+ logger.add_sample(
188
+ step=1,
189
+ question=long_text,
190
+ response=long_text,
191
+ ground_truth=long_text,
192
+ rewards={},
193
+ )
194
+
195
+ assert len(logger._samples[0].question) == 500
196
+ assert len(logger._samples[0].response) == 500
197
+ assert len(logger._samples[0].ground_truth) == 300
198
+
199
+
200
+ # =============================================================================
201
+ # LOG TABLE TESTS
202
+ # =============================================================================
203
+
204
+
205
+ class TestLogTable:
206
+ """Test table logging functionality."""
207
+
208
+ def test_log_table_creates_table(self, mock_wandb: MagicMock) -> None:
209
+ """Test that log_table creates and logs a wandb Table."""
210
+ from prolewiki_llm.wandb_logging import WandbSampleLogger
211
+
212
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
213
+ logger = WandbSampleLogger(max_samples_per_log=2)
214
+
215
+ # Add samples
216
+ logger.add_sample(1, "Q1", "R1", "T1", {"format": 3.0})
217
+ logger.add_sample(2, "Q2", "R2", "T2", {"format": 2.0})
218
+
219
+ # Log table
220
+ logger.log_table(step=10)
221
+
222
+ # Verify Table was created
223
+ mock_wandb.Table.assert_called_once()
224
+
225
+ # Verify log was called
226
+ mock_wandb.log.assert_called_once()
227
+ call_args = mock_wandb.log.call_args
228
+ assert "samples" in call_args[0][0]
229
+ assert call_args[1]["step"] == 10
230
+
231
+
232
+ # =============================================================================
233
+ # REWARD METRICS LOGGING TESTS
234
+ # =============================================================================
235
+
236
+
237
+ class TestLogRewardMetrics:
238
+ """Test reward metrics logging."""
239
+
240
+ def test_log_reward_metrics(self, mock_wandb: MagicMock) -> None:
241
+ """Test logging reward metrics to wandb."""
242
+ from prolewiki_llm.wandb_logging import log_reward_metrics
243
+
244
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
245
+ reward_scores = {
246
+ "format_exact": [3.0, 3.0, 0.0],
247
+ "nli_coherence": [2.0, -1.0, 3.0],
248
+ }
249
+
250
+ log_reward_metrics(step=50, reward_scores=reward_scores)
251
+
252
+ # Verify log was called with correct metrics
253
+ mock_wandb.log.assert_called_once()
254
+ logged_metrics = mock_wandb.log.call_args[0][0]
255
+
256
+ # Check mean calculations
257
+ assert logged_metrics["rewards/format_exact"] == 2.0 # (3+3+0)/3
258
+ assert logged_metrics["rewards/nli_coherence"] == pytest.approx(4 / 3) # (2-1+3)/3
259
+
260
+ # Check min/max
261
+ assert logged_metrics["rewards/format_exact_min"] == 0.0
262
+ assert logged_metrics["rewards/format_exact_max"] == 3.0
263
+
264
+ def test_log_reward_metrics_computes_total(self, mock_wandb: MagicMock) -> None:
265
+ """Test that total reward is computed correctly."""
266
+ from prolewiki_llm.wandb_logging import log_reward_metrics
267
+
268
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
269
+ reward_scores = {
270
+ "format": [3.0, 2.0],
271
+ "nli": [1.0, 2.0],
272
+ }
273
+
274
+ log_reward_metrics(step=10, reward_scores=reward_scores)
275
+
276
+ logged_metrics = mock_wandb.log.call_args[0][0]
277
+
278
+ # Total for sample 0: 3.0 + 1.0 = 4.0
279
+ # Total for sample 1: 2.0 + 2.0 = 4.0
280
+ # Mean total: 4.0
281
+ assert logged_metrics["rewards/total"] == 4.0
282
+
283
+
284
+ # =============================================================================
285
+ # LOGGING REWARD FUNCTION TESTS
286
+ # =============================================================================
287
+
288
+
289
+ class TestCreateLoggingReward:
290
+ """Test the create_logging_reward function."""
291
+
292
+ def test_returns_zeros(
293
+ self,
294
+ sample_prompts: list[list[dict[str, str]]],
295
+ sample_completions: list[list[dict[str, str]]],
296
+ sample_answers: list[str],
297
+ ) -> None:
298
+ """Test that logging reward returns zeros (no training effect)."""
299
+ from prolewiki_llm.wandb_logging import (
300
+ WandbSampleLogger,
301
+ create_logging_reward,
302
+ )
303
+
304
+ # Create logging reward without wandb (will fallback to print)
305
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=None):
306
+ sample_logger = WandbSampleLogger()
307
+ logging_reward = create_logging_reward(sample_logger, compute_all_rewards=False)
308
+
309
+ scores = logging_reward(
310
+ prompts=sample_prompts,
311
+ completions=sample_completions,
312
+ answer=sample_answers,
313
+ )
314
+
315
+ # Should return zeros for all samples
316
+ assert scores == [0.0, 0.0]
317
+
318
+ def test_function_signature_compatibility(
319
+ self,
320
+ sample_prompts: list[list[dict[str, str]]],
321
+ sample_completions: list[list[dict[str, str]]],
322
+ sample_answers: list[str],
323
+ ) -> None:
324
+ """Test that logging reward has correct signature for GRPOTrainer."""
325
+ from prolewiki_llm.wandb_logging import create_logging_reward
326
+
327
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=None):
328
+ logging_reward = create_logging_reward(compute_all_rewards=False)
329
+
330
+ # Should accept prompts, completions, answer, and kwargs
331
+ result = logging_reward(
332
+ prompts=sample_prompts,
333
+ completions=sample_completions,
334
+ answer=sample_answers,
335
+ extra_kwarg="ignored",
336
+ )
337
+
338
+ assert isinstance(result, list)
339
+ assert len(result) == len(sample_completions)
340
+
341
+ def test_logs_samples_at_interval(
342
+ self,
343
+ mock_wandb: MagicMock,
344
+ sample_prompts: list[list[dict[str, str]]],
345
+ sample_completions: list[list[dict[str, str]]],
346
+ sample_answers: list[str],
347
+ ) -> None:
348
+ """Test that samples are logged at correct intervals."""
349
+ # Reset global step counter
350
+ import prolewiki_llm.wandb_logging as wl
351
+ from prolewiki_llm.wandb_logging import (
352
+ WandbSampleLogger,
353
+ create_logging_reward,
354
+ )
355
+
356
+ wl._LOGGING_STEP = 0
357
+
358
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
359
+ sample_logger = WandbSampleLogger(log_every_n_steps=5, max_samples_per_log=2)
360
+ logging_reward = create_logging_reward(sample_logger, compute_all_rewards=False)
361
+
362
+ # Call multiple times
363
+ for _ in range(10):
364
+ logging_reward(
365
+ prompts=sample_prompts,
366
+ completions=sample_completions,
367
+ answer=sample_answers,
368
+ )
369
+
370
+ # Table should have been logged twice (at step 5 and 10)
371
+ table_logs = [call for call in mock_wandb.log.call_args_list if "samples" in call[0][0]]
372
+ assert len(table_logs) == 2
373
+
374
+
375
+ # =============================================================================
376
+ # WANDB AVAILABILITY TESTS
377
+ # =============================================================================
378
+
379
+
380
+ class TestWandbAvailability:
381
+ """Test handling of wandb availability."""
382
+
383
+ def test_is_wandb_available_true(self, mock_wandb: MagicMock) -> None:
384
+ """Test is_wandb_available returns True when wandb is installed."""
385
+ from prolewiki_llm import wandb_logging as wl
386
+
387
+ # Reset cached state
388
+ wl._WANDB_AVAILABLE = None
389
+ wl._wandb_module = None
390
+
391
+ with patch.dict("sys.modules", {"wandb": mock_wandb}):
392
+ # Force re-import check
393
+ wl._WANDB_AVAILABLE = None
394
+ _result = wl.is_wandb_available()
395
+ # Note: This may still be False due to import mechanics
396
+ # The important thing is it doesn't crash
397
+
398
+ def test_graceful_degradation_without_wandb(
399
+ self,
400
+ sample_prompts: list[list[dict[str, str]]],
401
+ sample_completions: list[list[dict[str, str]]],
402
+ sample_answers: list[str],
403
+ ) -> None:
404
+ """Test that logging works gracefully without wandb."""
405
+ from prolewiki_llm.wandb_logging import (
406
+ WandbSampleLogger,
407
+ create_logging_reward,
408
+ log_reward_metrics,
409
+ )
410
+
411
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=None):
412
+ # These should not raise exceptions
413
+ sample_logger = WandbSampleLogger()
414
+ logging_reward = create_logging_reward(sample_logger, compute_all_rewards=False)
415
+
416
+ # Should return valid result even without wandb
417
+ result = logging_reward(
418
+ prompts=sample_prompts,
419
+ completions=sample_completions,
420
+ answer=sample_answers,
421
+ )
422
+ assert result == [0.0, 0.0]
423
+
424
+ # Metrics logging should not crash
425
+ log_reward_metrics(step=1, reward_scores={"x": [1.0]})
426
+
427
+ # Table logging should not crash
428
+ sample_logger.log_table(step=10)
429
+
430
+
431
+ # =============================================================================
432
+ # INIT AND FINISH TESTS
433
+ # =============================================================================
434
+
435
+
436
+ class TestInitAndFinish:
437
+ """Test initialization and finishing of wandb runs."""
438
+
439
+ def test_init_wandb_logging(self, mock_wandb: MagicMock) -> None:
440
+ """Test wandb initialization with config."""
441
+ from prolewiki_llm.wandb_logging import init_wandb_logging
442
+
443
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
444
+ _run = init_wandb_logging(
445
+ project="test-project",
446
+ config={"lr": 1e-5, "batch_size": 4},
447
+ name="test-run",
448
+ tags=["test"],
449
+ )
450
+
451
+ mock_wandb.init.assert_called_once()
452
+ call_kwargs = mock_wandb.init.call_args[1]
453
+ assert call_kwargs["project"] == "test-project"
454
+ assert call_kwargs["config"] == {"lr": 1e-5, "batch_size": 4}
455
+
456
+ def test_finish_wandb_logging(self, mock_wandb: MagicMock) -> None:
457
+ """Test wandb finish with summary."""
458
+ from prolewiki_llm.wandb_logging import finish_wandb_logging
459
+
460
+ with patch("prolewiki_llm.wandb_logging._get_wandb", return_value=mock_wandb):
461
+ finish_wandb_logging(summary={"final_loss": 0.5})
462
+
463
+ # Check summary was updated
464
+ mock_wandb.run.summary.__setitem__.assert_called_with("final_loss", 0.5)
465
+
466
+ # Check finish was called
467
+ mock_wandb.finish.assert_called_once()
training_data/entity_whitelist.json ADDED
The diff for this file is too large to render. See raw diff
 
training_data/entity_whitelist_clean.json ADDED
The diff for this file is too large to render. See raw diff