Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- trainer-kit/GRPO-14B/README.md +155 -0
- trainer-kit/GRPO-14B/__pycache__/f1_score_utils.cpython-310.pyc +0 -0
- trainer-kit/GRPO-14B/config_grpo.yaml +155 -0
- trainer-kit/GRPO-14B/config_grpo_14b.yaml +160 -0
- trainer-kit/GRPO-14B/config_grpo_test.yaml +138 -0
- trainer-kit/GRPO-14B/config_grpo_test_0.5b.yaml +158 -0
- trainer-kit/GRPO-14B/f1_score_utils.py +283 -0
- trainer-kit/GRPO-14B/generate_grpo_data.py +229 -0
- trainer-kit/GRPO-14B/grpo_dataset.jsonl +3 -0
- trainer-kit/GRPO-14B/grpo_dataset2.jsonl +3 -0
- trainer-kit/GRPO-14B/grpo_dataset_api_test.jsonl +10 -0
- trainer-kit/GRPO-14B/grpo_dataset_test.jsonl +3 -0
- trainer-kit/GRPO-14B/prepare_grpo_data.py +355 -0
- trainer-kit/GRPO-14B/prepare_grpo_data_api.py +356 -0
- trainer-kit/GRPO-14B/prepare_grpo_data_vllm.py +233 -0
- trainer-kit/GRPO-14B/requirements.txt +26 -0
- trainer-kit/GRPO-14B/run_grpo_fixed.py +1216 -0
- trainer-kit/GRPO-14B/sft_dataset.jsonl +5 -0
- trainer-kit/GRPO-14B/sft_dataset_test.jsonl +3 -0
- trainer-kit/GRPO-14B/start_vllm_server.sh +39 -0
.gitattributes
CHANGED
|
@@ -68,3 +68,5 @@ grpo_qwen_14B_v2/best_adapter/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
| 68 |
grpo_qwen_14B_v2/checkpoints/checkpoint-400/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 69 |
grpo_qwen_14B_v2/checkpoints/checkpoint-500/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 70 |
grpo_qwen_14B_v2/wandb/run-20251227_194423-jz7bptqa/run-jz7bptqa.wandb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 68 |
grpo_qwen_14B_v2/checkpoints/checkpoint-400/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 69 |
grpo_qwen_14B_v2/checkpoints/checkpoint-500/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 70 |
grpo_qwen_14B_v2/wandb/run-20251227_194423-jz7bptqa/run-jz7bptqa.wandb filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
trainer-kit/GRPO-14B/grpo_dataset.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
trainer-kit/GRPO-14B/grpo_dataset2.jsonl filter=lfs diff=lfs merge=lfs -text
|
trainer-kit/GRPO-14B/README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO Training - Fixed & Production Ready
|
| 2 |
+
|
| 3 |
+
**Status**: ✅ Production-ready (Rating: 9.5/10)
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# 1. Prepare GRPO dataset from SFT data
|
| 9 |
+
python prepare_grpo_data.py \
|
| 10 |
+
--sft_dataset sft_dataset.jsonl \
|
| 11 |
+
--output grpo_dataset.jsonl \
|
| 12 |
+
--model_path ../Models/Qwen2.5-Coder-14B-CPT-SFT \
|
| 13 |
+
--num_completions 6 \
|
| 14 |
+
--temperatures 0.6,0.7,0.8,0.9,1.0,1.1
|
| 15 |
+
|
| 16 |
+
# 2. Run GRPO training
|
| 17 |
+
python run_grpo_fixed.py --config config_grpo.yaml
|
| 18 |
+
|
| 19 |
+
# 3. Monitor training
|
| 20 |
+
tail -f runs/grpo_run_14b_v1/logs/train.jsonl
|
| 21 |
+
tail -f runs/grpo_run_14b_v1/logs/grpo_metrics.jsonl
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## What is GRPO?
|
| 25 |
+
|
| 26 |
+
**Group Relative Policy Optimization** - learns to prefer higher-quality completions within groups:
|
| 27 |
+
- Generate 4-8 completions per prompt
|
| 28 |
+
- Score each with F1 metric
|
| 29 |
+
- Train model to increase probability of high-F1 outputs
|
| 30 |
+
- Use KL divergence to prevent drift from reference model
|
| 31 |
+
|
| 32 |
+
## Critical Fixes Applied
|
| 33 |
+
|
| 34 |
+
### 1. **GRPODataCollator** - Fixed data batching
|
| 35 |
+
- Original: Default collator crashed on nested lists
|
| 36 |
+
- Fixed: Custom collator handles `completions` and `scores` lists properly
|
| 37 |
+
|
| 38 |
+
### 2. **Pre-tokenization** - Fixed efficiency
|
| 39 |
+
- Original: Tokenized on-the-fly during training (10-50x slower)
|
| 40 |
+
- Fixed: Pre-tokenize during data prep, store in dataset
|
| 41 |
+
|
| 42 |
+
### 3. **Prompt Masking** - Fixed loss computation
|
| 43 |
+
- Original: Loss computed over entire sequence (prompt + completion)
|
| 44 |
+
- Fixed: Create completion masks, only compute loss on completion tokens
|
| 45 |
+
|
| 46 |
+
### 4. **KL Divergence** - Fixed formula
|
| 47 |
+
- Original: Simple difference, not proper KL
|
| 48 |
+
- Fixed: Proper KL-divergence: `E[log(p) - log(q)]`
|
| 49 |
+
|
| 50 |
+
### 5. **Data Validation** - Fixed crashes
|
| 51 |
+
- Original: No validation, crashes on malformed data
|
| 52 |
+
- Fixed: Validate completions=scores length, check for NaN/Inf, filter invalid
|
| 53 |
+
|
| 54 |
+
### 6. **Metrics Logging** - Fixed visibility
|
| 55 |
+
- Original: No GRPO-specific metrics
|
| 56 |
+
- Fixed: Log advantages, KL-div, group scores for monitoring
|
| 57 |
+
|
| 58 |
+
### 7. **Reproducibility** - Fixed random sampling
|
| 59 |
+
- Original: No seed for group sampling
|
| 60 |
+
- Fixed: Use numpy RandomState with fixed seed
|
| 61 |
+
|
| 62 |
+
## Key Hyperparameters
|
| 63 |
+
|
| 64 |
+
```yaml
|
| 65 |
+
grpo:
|
| 66 |
+
group_size: 4 # Sample 4 completions per group
|
| 67 |
+
kl_coef: 0.05 # KL penalty weight
|
| 68 |
+
normalize_advantages: true # Normalize per-group advantages
|
| 69 |
+
reward_scaling: 1.0 # Scale F1 scores
|
| 70 |
+
reward_clip: 1.0 # Clip rewards for stability
|
| 71 |
+
use_reference_model: true # Use frozen ref model for KL
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Why This Approach Works
|
| 75 |
+
|
| 76 |
+
**GRPO learns relative quality within groups:**
|
| 77 |
+
```
|
| 78 |
+
Prompt: "Fix the bug..."
|
| 79 |
+
|
| 80 |
+
Completions: F1 Scores: Advantages: Training Signal:
|
| 81 |
+
1. [buggy code] 0.2 -1.5 ↓ Decrease probability
|
| 82 |
+
2. [partial fix] 0.5 -0.3 ↓ Decrease slightly
|
| 83 |
+
3. [correct fix] 0.9 +0.8 ↑ Increase probability
|
| 84 |
+
4. [overcomplicated] 0.6 +0.1 ↑ Increase slightly
|
| 85 |
+
|
| 86 |
+
→ Model learns: prefer completion #3 over others
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
## Pipeline Architecture
|
| 90 |
+
|
| 91 |
+
```
|
| 92 |
+
SFT Dataset (input, output)
|
| 93 |
+
↓
|
| 94 |
+
[Generate multiple completions]
|
| 95 |
+
↓
|
| 96 |
+
[Compute F1 scores vs ground truth]
|
| 97 |
+
↓
|
| 98 |
+
GRPO Dataset (prompt, completions[], scores[])
|
| 99 |
+
↓
|
| 100 |
+
[Pre-tokenize with prompt lengths]
|
| 101 |
+
↓
|
| 102 |
+
[GRPODataCollator batches data]
|
| 103 |
+
↓
|
| 104 |
+
[GRPOTrainer.compute_loss]
|
| 105 |
+
- Sample group_size completions
|
| 106 |
+
- Compute log probs (masked to completions only)
|
| 107 |
+
- Get reference log probs (KL divergence)
|
| 108 |
+
- Normalize advantages from F1 scores
|
| 109 |
+
- Loss = -E[advantages * log_probs] + kl_coef * KL
|
| 110 |
+
↓
|
| 111 |
+
[Optimized model prefers high-F1 outputs]
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## What Makes This Different from DPO?
|
| 115 |
+
|
| 116 |
+
| Aspect | DPO | GRPO (This Implementation) |
|
| 117 |
+
|--------|-----|---------------------------|
|
| 118 |
+
| Data Format | Pairwise (chosen/rejected) | Groups (4-8 ranked completions) |
|
| 119 |
+
| Learning Signal | Binary preference | Continuous relative ranking |
|
| 120 |
+
| Score Usage | Implicit (binary) | Explicit (F1 scores → advantages) |
|
| 121 |
+
| Efficiency | 2 forward passes | group_size forward passes |
|
| 122 |
+
| Best For | Human preferences | Objective metrics (F1, accuracy) |
|
| 123 |
+
|
| 124 |
+
## Monitoring Training
|
| 125 |
+
|
| 126 |
+
**Key metrics to watch:**
|
| 127 |
+
- `loss`: Should decrease steadily
|
| 128 |
+
- `grpo_mean_advantage`: Should stay near 0 (if normalized)
|
| 129 |
+
- `grpo_std_advantage`: Should stay near 1.0 (if normalized)
|
| 130 |
+
- `grpo_mean_kl_div`: Should be small (<0.1), prevents drift
|
| 131 |
+
- `grpo_mean_group_score`: Average F1 in groups, should improve
|
| 132 |
+
|
| 133 |
+
## Files
|
| 134 |
+
|
| 135 |
+
- `run_grpo_fixed.py` - Fixed GRPO trainer (use this!)
|
| 136 |
+
- `prepare_grpo_data.py` - Generate multi-completion dataset with F1 scores
|
| 137 |
+
- `config_grpo.yaml` - Training configuration
|
| 138 |
+
- `requirements.txt` - Dependencies
|
| 139 |
+
|
| 140 |
+
## Troubleshooting
|
| 141 |
+
|
| 142 |
+
**OOM errors?**
|
| 143 |
+
- Reduce `group_size` from 4 to 2-3
|
| 144 |
+
- Reduce `per_device_train_batch_size` to 1
|
| 145 |
+
- Increase `gradient_accumulation_steps`
|
| 146 |
+
|
| 147 |
+
**Training unstable?**
|
| 148 |
+
- Increase `kl_coef` from 0.05 to 0.1
|
| 149 |
+
- Add `reward_clip: 1.0` to config
|
| 150 |
+
- Reduce learning rate to 5e-6
|
| 151 |
+
|
| 152 |
+
**No improvement?**
|
| 153 |
+
- Check F1 score distribution in data (need variety)
|
| 154 |
+
- Ensure `min_completions >= group_size`
|
| 155 |
+
- Verify completions have quality variance
|
trainer-kit/GRPO-14B/__pycache__/f1_score_utils.cpython-310.pyc
ADDED
|
Binary file (7.25 kB). View file
|
|
|
trainer-kit/GRPO-14B/config_grpo.yaml
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/grpo_run_14b_v1"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: true
|
| 8 |
+
project: "grpo-training"
|
| 9 |
+
entity: null
|
| 10 |
+
name: null
|
| 11 |
+
tags: ["grpo-lora", "group-preference-optimization"]
|
| 12 |
+
notes: null
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Use the SFT or DPO model as base
|
| 16 |
+
repo_id: "./runs/dpo_run_14b_v1/merged_14b_dpo_lora"
|
| 17 |
+
revision: null
|
| 18 |
+
|
| 19 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 20 |
+
base_local_dir: "base_model"
|
| 21 |
+
|
| 22 |
+
trust_remote_code: true
|
| 23 |
+
tokenizer_use_fast: true
|
| 24 |
+
device_map: "auto"
|
| 25 |
+
|
| 26 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 27 |
+
|
| 28 |
+
# QLoRA
|
| 29 |
+
use_4bit: false
|
| 30 |
+
bnb_4bit_quant_type: "nf4"
|
| 31 |
+
bnb_4bit_use_double_quant: false
|
| 32 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 33 |
+
|
| 34 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 35 |
+
attn_implementation: null
|
| 36 |
+
|
| 37 |
+
data:
|
| 38 |
+
train_jsonl: "grpo_dataset.jsonl"
|
| 39 |
+
eval_jsonl: null
|
| 40 |
+
eval_split_ratio: 0.1
|
| 41 |
+
|
| 42 |
+
# Field names in your JSONL data for GRPO
|
| 43 |
+
# GRPO requires: prompt, completions (list), scores (list)
|
| 44 |
+
prompt_field: "prompt"
|
| 45 |
+
completions_field: "completions" # List of candidate responses
|
| 46 |
+
scores_field: "scores" # List of F1 scores for each completion
|
| 47 |
+
|
| 48 |
+
# Formatting options
|
| 49 |
+
format_type: "chatml" # "chatml" | "alpaca" | "custom"
|
| 50 |
+
|
| 51 |
+
# System prompt to prepend to all prompts
|
| 52 |
+
system_prompt: |
|
| 53 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 54 |
+
|
| 55 |
+
## Output Format
|
| 56 |
+
|
| 57 |
+
##OUTPUT
|
| 58 |
+
Explain the data flow and why each component must change:
|
| 59 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 60 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 61 |
+
- Explain coupling between components
|
| 62 |
+
|
| 63 |
+
##SELECT
|
| 64 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 65 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 66 |
+
<EOS>
|
| 67 |
+
|
| 68 |
+
## Rules
|
| 69 |
+
|
| 70 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 71 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 72 |
+
3. Always explain "must change because" and "without this"
|
| 73 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 74 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 75 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 76 |
+
|
| 77 |
+
max_length: 2048
|
| 78 |
+
shuffle: true
|
| 79 |
+
num_proc: 4
|
| 80 |
+
|
| 81 |
+
# GRPO specific: minimum number of completions per prompt
|
| 82 |
+
min_completions: 2
|
| 83 |
+
max_completions: 8
|
| 84 |
+
|
| 85 |
+
peft:
|
| 86 |
+
enabled: true
|
| 87 |
+
r: 16
|
| 88 |
+
lora_alpha: 32
|
| 89 |
+
lora_dropout: 0.05
|
| 90 |
+
bias: "none"
|
| 91 |
+
target_modules: "auto"
|
| 92 |
+
|
| 93 |
+
# GRPO specific parameters
|
| 94 |
+
grpo:
|
| 95 |
+
# Group size for relative ranking
|
| 96 |
+
group_size: 4 # Number of completions to sample and rank per prompt
|
| 97 |
+
|
| 98 |
+
# Advantage normalization
|
| 99 |
+
normalize_advantages: true
|
| 100 |
+
advantage_eps: 1e-8
|
| 101 |
+
|
| 102 |
+
# Policy optimization
|
| 103 |
+
kl_coef: 0.05 # KL divergence coefficient (lower than PPO for stability)
|
| 104 |
+
|
| 105 |
+
# Reward shaping
|
| 106 |
+
reward_scaling: 1.0 # Scale rewards (F1 scores) by this factor
|
| 107 |
+
reward_bias: 0.0 # Bias to add to rewards
|
| 108 |
+
reward_clip: 1.0 # Clip rewards to [-reward_clip, +reward_clip] for stability
|
| 109 |
+
|
| 110 |
+
# Reference model for KL penalty
|
| 111 |
+
use_reference_model: true
|
| 112 |
+
|
| 113 |
+
# Temperature for advantage softmax
|
| 114 |
+
advantage_temperature: 1.0
|
| 115 |
+
|
| 116 |
+
train:
|
| 117 |
+
num_train_epochs: 2
|
| 118 |
+
|
| 119 |
+
per_device_train_batch_size: 1
|
| 120 |
+
per_device_eval_batch_size: 1
|
| 121 |
+
gradient_accumulation_steps: 8
|
| 122 |
+
|
| 123 |
+
learning_rate: 1e-5 # Even lower than DPO for stability
|
| 124 |
+
weight_decay: 0.01
|
| 125 |
+
warmup_ratio: 0.1
|
| 126 |
+
lr_scheduler_type: "cosine"
|
| 127 |
+
|
| 128 |
+
optim: "adamw_torch"
|
| 129 |
+
max_grad_norm: 0.5 # Lower for GRPO stability
|
| 130 |
+
gradient_checkpointing: true
|
| 131 |
+
|
| 132 |
+
logging_steps: 2
|
| 133 |
+
save_strategy: "steps"
|
| 134 |
+
save_steps: 100
|
| 135 |
+
save_total_limit: 10
|
| 136 |
+
|
| 137 |
+
evaluation_strategy: "steps"
|
| 138 |
+
eval_steps: 50
|
| 139 |
+
load_best_model_at_end: true
|
| 140 |
+
|
| 141 |
+
# Early stopping
|
| 142 |
+
early_stopping:
|
| 143 |
+
enabled: true
|
| 144 |
+
patience: 5
|
| 145 |
+
min_delta: 0.001
|
| 146 |
+
metric: "eval_loss"
|
| 147 |
+
mode: "min"
|
| 148 |
+
|
| 149 |
+
resume_from_checkpoint: "auto"
|
| 150 |
+
|
| 151 |
+
merge:
|
| 152 |
+
enabled: true
|
| 153 |
+
merged_dtype: "float16"
|
| 154 |
+
max_shard_size: "2GB"
|
| 155 |
+
output_dir: "./merged_14b_grpo_lora"
|
trainer-kit/GRPO-14B/config_grpo_14b.yaml
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO Training Configuration for Qwen2.5-Coder-14B
|
| 2 |
+
|
| 3 |
+
# Run configuration
|
| 4 |
+
run:
|
| 5 |
+
run_dir: "runs/grpo_qwen_14b"
|
| 6 |
+
|
| 7 |
+
model:
|
| 8 |
+
repo_id: "/workspace/Models/Qwen2.5-Coder-14B-CPT-SFT_v2"
|
| 9 |
+
tokenizer_name: "Qwen/Qwen2.5-Coder-14B" # Use base tokenizer (SFT_v2 has compatibility issue)
|
| 10 |
+
load_in_8bit: false
|
| 11 |
+
load_in_4bit: false
|
| 12 |
+
torch_dtype: "bfloat16"
|
| 13 |
+
device_map: "auto"
|
| 14 |
+
trust_remote_code: true
|
| 15 |
+
|
| 16 |
+
# WandB integration for experiment tracking
|
| 17 |
+
wandb:
|
| 18 |
+
enabled: true # Set to true to enable wandb logging
|
| 19 |
+
project: "rl-training" # WandB project name
|
| 20 |
+
entity: null # WandB entity/team (optional)
|
| 21 |
+
name: null # Run name (optional, will auto-generate if null)
|
| 22 |
+
tags: ["grpo-lora", "14B-QWEN"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
|
| 23 |
+
notes: null # Run description/notes (optional)
|
| 24 |
+
|
| 25 |
+
# Data configuration
|
| 26 |
+
data:
|
| 27 |
+
train_jsonl: "grpo_dataset.jsonl"
|
| 28 |
+
eval_jsonl: null
|
| 29 |
+
eval_split_ratio: 0.0
|
| 30 |
+
shuffle: true
|
| 31 |
+
num_proc: 1
|
| 32 |
+
|
| 33 |
+
# Field names in JSONL (default)
|
| 34 |
+
prompt_field: "prompt"
|
| 35 |
+
completions_field: "completions"
|
| 36 |
+
scores_field: "scores"
|
| 37 |
+
|
| 38 |
+
# Format: prompts are already formatted in the JSONL
|
| 39 |
+
format_type: "raw" # Prompts are pre-formatted
|
| 40 |
+
max_length: 2048
|
| 41 |
+
min_completions: 2
|
| 42 |
+
|
| 43 |
+
# For chatml format
|
| 44 |
+
system_prompt: |
|
| 45 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 46 |
+
|
| 47 |
+
## Output Format
|
| 48 |
+
|
| 49 |
+
##OUTPUT
|
| 50 |
+
Explain the data flow and why each component must change:
|
| 51 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 52 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 53 |
+
- Explain coupling between components
|
| 54 |
+
|
| 55 |
+
##SELECT
|
| 56 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 57 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 58 |
+
<EOS>
|
| 59 |
+
|
| 60 |
+
## Rules
|
| 61 |
+
|
| 62 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 63 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 64 |
+
3. Always explain "must change because" and "without this"
|
| 65 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 66 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 67 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 68 |
+
|
| 69 |
+
## Example
|
| 70 |
+
|
| 71 |
+
##TASK
|
| 72 |
+
Add webhook subscription support
|
| 73 |
+
|
| 74 |
+
##OUTPUT
|
| 75 |
+
The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
|
| 76 |
+
|
| 77 |
+
##SELECT
|
| 78 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 79 |
+
crates/common_enums/src/transformers.rs::SubscriptionStatus
|
| 80 |
+
<EOS>
|
| 81 |
+
# For custom format (only used when format_type="custom")
|
| 82 |
+
custom_template: "##INSTRUCTION\n{instruction}<|im_end|>\n{input}<|im_end|>\n{output}<|im_end|>"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# GRPO-specific parameters
|
| 86 |
+
grpo:
|
| 87 |
+
group_size: 4 # Sample 4 completions per prompt for comparison
|
| 88 |
+
kl_coef: 0.05 # KL divergence coefficient (beta)
|
| 89 |
+
normalize_advantages: true
|
| 90 |
+
reward_scaling: 1.0
|
| 91 |
+
reward_bias: 0.0
|
| 92 |
+
reward_clip: 5.0 # Clip rewards to avoid extreme values
|
| 93 |
+
advantage_temperature: 1.0
|
| 94 |
+
use_reference_model: false # Don't use reference model (faster, uses PEFT adapter disable instead)
|
| 95 |
+
seed: 42
|
| 96 |
+
|
| 97 |
+
# PEFT/LoRA configuration
|
| 98 |
+
peft:
|
| 99 |
+
enabled: true
|
| 100 |
+
r: 16
|
| 101 |
+
lora_alpha: 32
|
| 102 |
+
lora_dropout: 0.05
|
| 103 |
+
target_modules:
|
| 104 |
+
- "q_proj"
|
| 105 |
+
- "k_proj"
|
| 106 |
+
- "v_proj"
|
| 107 |
+
- "o_proj"
|
| 108 |
+
- "gate_proj"
|
| 109 |
+
- "up_proj"
|
| 110 |
+
- "down_proj"
|
| 111 |
+
bias: "none"
|
| 112 |
+
task_type: "CAUSAL_LM"
|
| 113 |
+
|
| 114 |
+
# Training arguments
|
| 115 |
+
train:
|
| 116 |
+
output_dir: "runs/grpo_14b_run1"
|
| 117 |
+
num_train_epochs: 2
|
| 118 |
+
per_device_train_batch_size: 1
|
| 119 |
+
gradient_accumulation_steps: 8 # Effective batch = 8
|
| 120 |
+
per_device_eval_batch_size: 1
|
| 121 |
+
|
| 122 |
+
# Optimization
|
| 123 |
+
learning_rate: 5.0e-6 # Lower LR for GRPO
|
| 124 |
+
weight_decay: 0.01
|
| 125 |
+
warmup_ratio: 0.05
|
| 126 |
+
lr_scheduler_type: "cosine"
|
| 127 |
+
|
| 128 |
+
# Mixed precision
|
| 129 |
+
fp16: false
|
| 130 |
+
bf16: true
|
| 131 |
+
|
| 132 |
+
# Gradient settings
|
| 133 |
+
max_grad_norm: 1.0
|
| 134 |
+
gradient_checkpointing: true
|
| 135 |
+
|
| 136 |
+
# Logging and saving
|
| 137 |
+
logging_steps: 5
|
| 138 |
+
save_steps: 100
|
| 139 |
+
save_total_limit: 2
|
| 140 |
+
evaluation_strategy: "no" # No eval for now
|
| 141 |
+
|
| 142 |
+
# Performance
|
| 143 |
+
dataloader_num_workers: 4
|
| 144 |
+
dataloader_pin_memory: true
|
| 145 |
+
remove_unused_columns: false
|
| 146 |
+
|
| 147 |
+
# Misc
|
| 148 |
+
report_to: [] # Disable wandb
|
| 149 |
+
seed: 42
|
| 150 |
+
ddp_find_unused_parameters: false
|
| 151 |
+
|
| 152 |
+
# Merge and upload settings
|
| 153 |
+
merge:
|
| 154 |
+
enabled: true
|
| 155 |
+
merged_dtype: "float16"
|
| 156 |
+
max_shard_size: "2GB"
|
| 157 |
+
output_dir: "./Models/Qwen-Coder-14B-HS-CPT-SFT-v2-GRPO"
|
| 158 |
+
|
| 159 |
+
upload:
|
| 160 |
+
enabled: false
|
trainer-kit/GRPO-14B/config_grpo_test.yaml
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/grpo_test_0.5b"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: false # Disable for quick test
|
| 8 |
+
project: "grpo-training-test"
|
| 9 |
+
entity: null
|
| 10 |
+
name: "grpo-0.5b-test"
|
| 11 |
+
tags: ["grpo-lora", "0.5b-test"]
|
| 12 |
+
notes: "Quick test run on Qwen2.5-Coder-0.5B"
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Use small model for testing
|
| 16 |
+
repo_id: "Qwen/Qwen2.5-Coder-0.5B-Instruct"
|
| 17 |
+
revision: null
|
| 18 |
+
base_local_dir: "base_model"
|
| 19 |
+
|
| 20 |
+
trust_remote_code: true
|
| 21 |
+
tokenizer_use_fast: true
|
| 22 |
+
device_map: "auto"
|
| 23 |
+
torch_dtype: "bfloat16"
|
| 24 |
+
|
| 25 |
+
# QLoRA disabled for small model
|
| 26 |
+
use_4bit: false
|
| 27 |
+
bnb_4bit_quant_type: "nf4"
|
| 28 |
+
bnb_4bit_use_double_quant: false
|
| 29 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 30 |
+
|
| 31 |
+
attn_implementation: null
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
train_jsonl: "grpo_dataset_test.jsonl"
|
| 35 |
+
eval_jsonl: null
|
| 36 |
+
eval_split_ratio: 0.2 # Use 20% for eval
|
| 37 |
+
|
| 38 |
+
# Field names in GRPO dataset
|
| 39 |
+
prompt_field: "prompt"
|
| 40 |
+
completions_field: "completions"
|
| 41 |
+
scores_field: "scores"
|
| 42 |
+
|
| 43 |
+
# Formatting - MUST MATCH SFT training format!
|
| 44 |
+
format_type: "custom" # Use custom format like SFT
|
| 45 |
+
|
| 46 |
+
# System prompt for code analysis
|
| 47 |
+
system_prompt: |
|
| 48 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 49 |
+
|
| 50 |
+
## Output Format
|
| 51 |
+
|
| 52 |
+
##OUTPUT
|
| 53 |
+
Explain the data flow and why each component must change:
|
| 54 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 55 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 56 |
+
- Explain coupling between components
|
| 57 |
+
|
| 58 |
+
##SELECT
|
| 59 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 60 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 61 |
+
<EOS>
|
| 62 |
+
|
| 63 |
+
## Rules
|
| 64 |
+
|
| 65 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 66 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 67 |
+
3. Always explain "must change because" and "without this"
|
| 68 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 69 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 70 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 71 |
+
|
| 72 |
+
# Custom template matching SFT training
|
| 73 |
+
custom_template: "##INSTRUCTION\n{system_prompt}<|im_end|>\n{prompt}<|im_end|>\n{completion}<|im_end|>"
|
| 74 |
+
|
| 75 |
+
max_length: 2048
|
| 76 |
+
shuffle: true
|
| 77 |
+
num_proc: 1
|
| 78 |
+
min_completions: 2
|
| 79 |
+
max_completions: 8
|
| 80 |
+
|
| 81 |
+
peft:
|
| 82 |
+
enabled: true
|
| 83 |
+
r: 8
|
| 84 |
+
lora_alpha: 16
|
| 85 |
+
lora_dropout: 0.05
|
| 86 |
+
bias: "none"
|
| 87 |
+
target_modules: "auto"
|
| 88 |
+
|
| 89 |
+
# GRPO specific parameters
|
| 90 |
+
grpo:
|
| 91 |
+
group_size: 3 # Sample 3 completions per group (small for test)
|
| 92 |
+
normalize_advantages: true
|
| 93 |
+
advantage_eps: 1e-8
|
| 94 |
+
kl_coef: 0.05
|
| 95 |
+
reward_scaling: 1.0
|
| 96 |
+
reward_bias: 0.0
|
| 97 |
+
reward_clip: 1.0
|
| 98 |
+
use_reference_model: true
|
| 99 |
+
advantage_temperature: 1.0
|
| 100 |
+
|
| 101 |
+
train:
|
| 102 |
+
num_train_epochs: 2 # Just 2 epochs for test
|
| 103 |
+
per_device_train_batch_size: 1
|
| 104 |
+
per_device_eval_batch_size: 1
|
| 105 |
+
gradient_accumulation_steps: 2
|
| 106 |
+
|
| 107 |
+
learning_rate: 5e-5 # Lower LR for small model
|
| 108 |
+
weight_decay: 0.01
|
| 109 |
+
warmup_ratio: 0.1
|
| 110 |
+
lr_scheduler_type: "cosine"
|
| 111 |
+
|
| 112 |
+
optim: "adamw_torch"
|
| 113 |
+
max_grad_norm: 0.5
|
| 114 |
+
gradient_checkpointing: true
|
| 115 |
+
|
| 116 |
+
logging_steps: 1
|
| 117 |
+
save_strategy: "steps"
|
| 118 |
+
save_steps: 20
|
| 119 |
+
save_total_limit: 3
|
| 120 |
+
|
| 121 |
+
evaluation_strategy: "steps"
|
| 122 |
+
eval_steps: 10
|
| 123 |
+
load_best_model_at_end: true
|
| 124 |
+
|
| 125 |
+
early_stopping:
|
| 126 |
+
enabled: true
|
| 127 |
+
patience: 3
|
| 128 |
+
min_delta: 0.001
|
| 129 |
+
metric: "eval_loss"
|
| 130 |
+
mode: "min"
|
| 131 |
+
|
| 132 |
+
resume_from_checkpoint: null
|
| 133 |
+
|
| 134 |
+
merge:
|
| 135 |
+
enabled: false # Don't merge for test
|
| 136 |
+
merged_dtype: "float16"
|
| 137 |
+
max_shard_size: "2GB"
|
| 138 |
+
output_dir: "./merged_0.5b_grpo_test"
|
trainer-kit/GRPO-14B/config_grpo_test_0.5b.yaml
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/grpo_test_0.5b"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: false
|
| 8 |
+
project: "grpo-training"
|
| 9 |
+
entity: null
|
| 10 |
+
name: "grpo-0.5b-test"
|
| 11 |
+
tags: ["grpo-lora", "test", "0.5b"]
|
| 12 |
+
notes: "Test run on 0.5B model with 3 examples"
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Use the SFT or DPO model as base
|
| 16 |
+
repo_id: "/workspace/Models/Qwen2.5-Coder-0.5B"
|
| 17 |
+
revision: null
|
| 18 |
+
|
| 19 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 20 |
+
base_local_dir: "base_model"
|
| 21 |
+
|
| 22 |
+
trust_remote_code: true
|
| 23 |
+
tokenizer_use_fast: true
|
| 24 |
+
device_map: "auto"
|
| 25 |
+
|
| 26 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 27 |
+
|
| 28 |
+
# QLoRA
|
| 29 |
+
use_4bit: false
|
| 30 |
+
bnb_4bit_quant_type: "nf4"
|
| 31 |
+
bnb_4bit_use_double_quant: false
|
| 32 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 33 |
+
|
| 34 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 35 |
+
attn_implementation: null
|
| 36 |
+
|
| 37 |
+
data:
|
| 38 |
+
train_jsonl: "grpo_dataset_test.jsonl"
|
| 39 |
+
eval_jsonl: null
|
| 40 |
+
eval_split_ratio: 0.0
|
| 41 |
+
|
| 42 |
+
# Field names in your JSONL data for GRPO
|
| 43 |
+
# GRPO requires: prompt, completions (list), scores (list)
|
| 44 |
+
prompt_field: "prompt"
|
| 45 |
+
completions_field: "completions" # List of candidate responses
|
| 46 |
+
scores_field: "scores" # List of F1 scores for each completion
|
| 47 |
+
|
| 48 |
+
# Formatting options
|
| 49 |
+
format_type: "custom" # "chatml" | "alpaca" | "custom"
|
| 50 |
+
|
| 51 |
+
# System prompt to prepend to all prompts
|
| 52 |
+
system_prompt: |
|
| 53 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 54 |
+
|
| 55 |
+
## Output Format
|
| 56 |
+
|
| 57 |
+
##OUTPUT
|
| 58 |
+
Explain the data flow and why each component must change:
|
| 59 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 60 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 61 |
+
- Explain coupling between components
|
| 62 |
+
|
| 63 |
+
##SELECT
|
| 64 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 65 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 66 |
+
<EOS>
|
| 67 |
+
|
| 68 |
+
## Rules
|
| 69 |
+
|
| 70 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 71 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 72 |
+
3. Always explain "must change because" and "without this"
|
| 73 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 74 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 75 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 76 |
+
|
| 77 |
+
# Custom template (matches SFT training format)
|
| 78 |
+
custom_template: "##INSTRUCTION\n{instruction}<|im_end|>\n{input}<|im_end|>\n{output}<|im_end|>"
|
| 79 |
+
|
| 80 |
+
max_length: 2048
|
| 81 |
+
shuffle: true
|
| 82 |
+
num_proc: 4
|
| 83 |
+
|
| 84 |
+
# GRPO specific: minimum number of completions per prompt
|
| 85 |
+
min_completions: 2
|
| 86 |
+
max_completions: 4
|
| 87 |
+
|
| 88 |
+
peft:
|
| 89 |
+
enabled: true
|
| 90 |
+
r: 16
|
| 91 |
+
lora_alpha: 32
|
| 92 |
+
lora_dropout: 0.05
|
| 93 |
+
bias: "none"
|
| 94 |
+
target_modules: "auto"
|
| 95 |
+
|
| 96 |
+
# GRPO specific parameters
|
| 97 |
+
grpo:
|
| 98 |
+
# Group size for relative ranking
|
| 99 |
+
group_size: 2 # Number of completions to sample and rank per prompt
|
| 100 |
+
|
| 101 |
+
# Advantage normalization
|
| 102 |
+
normalize_advantages: true
|
| 103 |
+
advantage_eps: 1e-8
|
| 104 |
+
|
| 105 |
+
# Policy optimization
|
| 106 |
+
kl_coef: 0.05 # KL divergence coefficient (lower than PPO for stability)
|
| 107 |
+
|
| 108 |
+
# Reward shaping
|
| 109 |
+
reward_scaling: 1.0 # Scale rewards (F1 scores) by this factor
|
| 110 |
+
reward_bias: 0.0 # Bias to add to rewards
|
| 111 |
+
reward_clip: 1.0 # Clip rewards to [-reward_clip, +reward_clip] for stability
|
| 112 |
+
|
| 113 |
+
# Reference model for KL penalty
|
| 114 |
+
use_reference_model: false
|
| 115 |
+
|
| 116 |
+
# Temperature for advantage softmax
|
| 117 |
+
advantage_temperature: 1.0
|
| 118 |
+
|
| 119 |
+
train:
|
| 120 |
+
num_train_epochs: 1
|
| 121 |
+
|
| 122 |
+
per_device_train_batch_size: 1
|
| 123 |
+
per_device_eval_batch_size: 1
|
| 124 |
+
gradient_accumulation_steps: 2
|
| 125 |
+
|
| 126 |
+
learning_rate: 1e-5 # Even lower than DPO for stability
|
| 127 |
+
weight_decay: 0.01
|
| 128 |
+
warmup_ratio: 0.1
|
| 129 |
+
lr_scheduler_type: "cosine"
|
| 130 |
+
|
| 131 |
+
optim: "adamw_torch"
|
| 132 |
+
max_grad_norm: 0.5 # Lower for GRPO stability
|
| 133 |
+
gradient_checkpointing: true
|
| 134 |
+
|
| 135 |
+
logging_steps: 1
|
| 136 |
+
save_strategy: "steps"
|
| 137 |
+
save_steps: 10
|
| 138 |
+
save_total_limit: 2
|
| 139 |
+
|
| 140 |
+
evaluation_strategy: "no"
|
| 141 |
+
eval_steps: 10
|
| 142 |
+
load_best_model_at_end: false
|
| 143 |
+
|
| 144 |
+
# Early stopping
|
| 145 |
+
early_stopping:
|
| 146 |
+
enabled: false
|
| 147 |
+
patience: 5
|
| 148 |
+
min_delta: 0.001
|
| 149 |
+
metric: "eval_loss"
|
| 150 |
+
mode: "min"
|
| 151 |
+
|
| 152 |
+
resume_from_checkpoint: "auto"
|
| 153 |
+
|
| 154 |
+
merge:
|
| 155 |
+
enabled: false
|
| 156 |
+
merged_dtype: "float16"
|
| 157 |
+
max_shard_size: "2GB"
|
| 158 |
+
output_dir: "./merged_0.5b_grpo_lora"
|
trainer-kit/GRPO-14B/f1_score_utils.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility for computing F1 scores at file level for ranking generated outputs.
|
| 3 |
+
This helps create preference pairs for DPO training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from typing import List, Set, Tuple, Dict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extract_files_from_selection(output_text: str) -> Set[str]:
|
| 13 |
+
"""
|
| 14 |
+
Extract file paths from ##SELECT section.
|
| 15 |
+
Expected format: modify::crates/path/to/file.rs::impl::ComponentName
|
| 16 |
+
Returns set of unique file paths.
|
| 17 |
+
"""
|
| 18 |
+
files = set()
|
| 19 |
+
|
| 20 |
+
# Find ##SELECT section
|
| 21 |
+
select_match = re.search(r'##SELECT\s*(.*?)<EOS>', output_text, re.DOTALL | re.IGNORECASE)
|
| 22 |
+
if not select_match:
|
| 23 |
+
return files
|
| 24 |
+
|
| 25 |
+
select_section = select_match.group(1)
|
| 26 |
+
|
| 27 |
+
# Extract file paths from each line
|
| 28 |
+
# Format: action::path::type::name
|
| 29 |
+
for line in select_section.strip().split('\n'):
|
| 30 |
+
line = line.strip()
|
| 31 |
+
if not line:
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
# Split by :: and extract the file path (second component)
|
| 35 |
+
parts = line.split('::')
|
| 36 |
+
if len(parts) >= 2:
|
| 37 |
+
file_path = parts[1]
|
| 38 |
+
files.add(file_path)
|
| 39 |
+
|
| 40 |
+
return files
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_file_level_f1(predicted: str, ground_truth: str) -> Dict[str, float]:
|
| 44 |
+
"""
|
| 45 |
+
Compute F1 score based on file-level predictions.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
predicted: Model output with ##SELECT section
|
| 49 |
+
ground_truth: Ground truth output with ##SELECT section
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dictionary with precision, recall, f1 scores
|
| 53 |
+
"""
|
| 54 |
+
pred_files = extract_files_from_selection(predicted)
|
| 55 |
+
gt_files = extract_files_from_selection(ground_truth)
|
| 56 |
+
|
| 57 |
+
if len(gt_files) == 0:
|
| 58 |
+
# No ground truth files
|
| 59 |
+
if len(pred_files) == 0:
|
| 60 |
+
return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
|
| 61 |
+
else:
|
| 62 |
+
return {"precision": 0.0, "recall": 1.0, "f1": 0.0}
|
| 63 |
+
|
| 64 |
+
if len(pred_files) == 0:
|
| 65 |
+
# No predicted files but have ground truth
|
| 66 |
+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 67 |
+
|
| 68 |
+
# Calculate metrics
|
| 69 |
+
true_positives = len(pred_files & gt_files)
|
| 70 |
+
false_positives = len(pred_files - gt_files)
|
| 71 |
+
false_negatives = len(gt_files - pred_files)
|
| 72 |
+
|
| 73 |
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
|
| 74 |
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
|
| 75 |
+
|
| 76 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"precision": precision,
|
| 80 |
+
"recall": recall,
|
| 81 |
+
"f1": f1,
|
| 82 |
+
"true_positives": true_positives,
|
| 83 |
+
"false_positives": false_positives,
|
| 84 |
+
"false_negatives": false_negatives,
|
| 85 |
+
"pred_files": list(pred_files),
|
| 86 |
+
"gt_files": list(gt_files),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def rank_outputs_by_f1(outputs: List[str], ground_truth: str) -> List[Tuple[str, float, Dict]]:
|
| 91 |
+
"""
|
| 92 |
+
Rank multiple outputs by their F1 scores compared to ground truth.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
outputs: List of model outputs to rank
|
| 96 |
+
ground_truth: Ground truth output
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
List of tuples: (output, f1_score, metrics_dict) sorted by F1 descending
|
| 100 |
+
"""
|
| 101 |
+
ranked = []
|
| 102 |
+
for output in outputs:
|
| 103 |
+
metrics = compute_file_level_f1(output, ground_truth)
|
| 104 |
+
ranked.append((output, metrics["f1"], metrics))
|
| 105 |
+
|
| 106 |
+
# Sort by F1 score descending
|
| 107 |
+
ranked.sort(key=lambda x: x[1], reverse=True)
|
| 108 |
+
return ranked
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def create_dpo_pairs_from_generations(
|
| 112 |
+
prompt: str,
|
| 113 |
+
generations: List[str],
|
| 114 |
+
ground_truth: str,
|
| 115 |
+
min_f1_difference: float = 0.1
|
| 116 |
+
) -> List[Dict[str, str]]:
|
| 117 |
+
"""
|
| 118 |
+
Create DPO training pairs from multiple generations.
|
| 119 |
+
Uses F1 score to determine which generation is better.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
prompt: Input prompt/task
|
| 123 |
+
generations: List of generated outputs
|
| 124 |
+
ground_truth: Ground truth output
|
| 125 |
+
min_f1_difference: Minimum F1 difference to create a pair
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List of DPO pairs: {"prompt": str, "chosen": str, "rejected": str}
|
| 129 |
+
"""
|
| 130 |
+
if len(generations) < 2:
|
| 131 |
+
return []
|
| 132 |
+
|
| 133 |
+
ranked = rank_outputs_by_f1(generations, ground_truth)
|
| 134 |
+
pairs = []
|
| 135 |
+
|
| 136 |
+
# Create pairs from ranked outputs
|
| 137 |
+
for i in range(len(ranked)):
|
| 138 |
+
for j in range(i + 1, len(ranked)):
|
| 139 |
+
better_output, better_f1, _ = ranked[i]
|
| 140 |
+
worse_output, worse_f1, _ = ranked[j]
|
| 141 |
+
|
| 142 |
+
# Only create pair if F1 difference is significant
|
| 143 |
+
if better_f1 - worse_f1 >= min_f1_difference:
|
| 144 |
+
pairs.append({
|
| 145 |
+
"prompt": prompt,
|
| 146 |
+
"chosen": better_output,
|
| 147 |
+
"rejected": worse_output,
|
| 148 |
+
"chosen_f1": better_f1,
|
| 149 |
+
"rejected_f1": worse_f1,
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
return pairs
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def convert_sft_to_dpo_with_sampling(
|
| 156 |
+
sft_jsonl_path: str,
|
| 157 |
+
output_jsonl_path: str,
|
| 158 |
+
model_inference_fn,
|
| 159 |
+
num_samples: int = 4,
|
| 160 |
+
min_f1_difference: float = 0.1,
|
| 161 |
+
temperature: float = 0.8
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Convert SFT dataset to DPO dataset by sampling multiple outputs and ranking by F1.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
sft_jsonl_path: Path to SFT JSONL file
|
| 168 |
+
output_jsonl_path: Path to output DPO JSONL file
|
| 169 |
+
model_inference_fn: Function that takes (prompt, num_samples, temperature) and returns List[str]
|
| 170 |
+
num_samples: Number of outputs to sample per prompt
|
| 171 |
+
min_f1_difference: Minimum F1 difference to create a pair
|
| 172 |
+
temperature: Sampling temperature
|
| 173 |
+
"""
|
| 174 |
+
pairs_created = 0
|
| 175 |
+
|
| 176 |
+
with open(sft_jsonl_path, 'r') as f_in, open(output_jsonl_path, 'w') as f_out:
|
| 177 |
+
for line in f_in:
|
| 178 |
+
data = json.loads(line)
|
| 179 |
+
|
| 180 |
+
# Extract prompt and ground truth
|
| 181 |
+
prompt = data.get("input", "")
|
| 182 |
+
ground_truth = data.get("output", "")
|
| 183 |
+
|
| 184 |
+
if not prompt or not ground_truth:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
# Generate multiple outputs
|
| 188 |
+
try:
|
| 189 |
+
generations = model_inference_fn(prompt, num_samples, temperature)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error generating outputs: {e}")
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
# Create DPO pairs
|
| 195 |
+
pairs = create_dpo_pairs_from_generations(
|
| 196 |
+
prompt, generations, ground_truth, min_f1_difference
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Write pairs to output
|
| 200 |
+
for pair in pairs:
|
| 201 |
+
f_out.write(json.dumps(pair) + '\n')
|
| 202 |
+
pairs_created += 1
|
| 203 |
+
|
| 204 |
+
print(f"Created {pairs_created} DPO pairs from {sft_jsonl_path}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def prepare_dpo_data_from_instruct(
|
| 208 |
+
instruct_jsonl: str,
|
| 209 |
+
output_dpo_jsonl: str,
|
| 210 |
+
):
|
| 211 |
+
"""
|
| 212 |
+
Simple conversion from instruction data to DPO format.
|
| 213 |
+
This assumes you already have multiple outputs per input or will generate them.
|
| 214 |
+
|
| 215 |
+
For demonstration, this creates a basic structure. In practice, you need to:
|
| 216 |
+
1. Generate multiple outputs for each input
|
| 217 |
+
2. Rank them by F1 score
|
| 218 |
+
3. Create chosen/rejected pairs
|
| 219 |
+
"""
|
| 220 |
+
print(f"Converting {instruct_jsonl} to DPO format...")
|
| 221 |
+
print("Note: This requires generating multiple outputs per prompt.")
|
| 222 |
+
print("Use convert_sft_to_dpo_with_sampling() with your model for actual conversion.")
|
| 223 |
+
|
| 224 |
+
# Example structure - you'll need to fill this with actual generations
|
| 225 |
+
with open(instruct_jsonl, 'r') as f:
|
| 226 |
+
for line in f:
|
| 227 |
+
data = json.loads(line)
|
| 228 |
+
print(f"Input: {data.get('input', '')[:100]}...")
|
| 229 |
+
print(f"Ground truth output available: {len(data.get('output', ''))} chars")
|
| 230 |
+
print(" -> Need to generate multiple outputs and rank by F1 score")
|
| 231 |
+
print()
|
| 232 |
+
break # Just show one example
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
# Example usage
|
| 237 |
+
print("F1 Score Utility for File-Level Ranking")
|
| 238 |
+
print("=" * 50)
|
| 239 |
+
|
| 240 |
+
# Example 1: Compute F1 for two outputs
|
| 241 |
+
ground_truth = """
|
| 242 |
+
##OUTPUT
|
| 243 |
+
The webhook system requires subscription support.
|
| 244 |
+
##SELECT
|
| 245 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 246 |
+
crates/router/src/webhooks.rs::process_webhook
|
| 247 |
+
<EOS>
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
prediction1 = """
|
| 251 |
+
##OUTPUT
|
| 252 |
+
The webhook system requires subscription support.
|
| 253 |
+
##SELECT
|
| 254 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 255 |
+
crates/router/src/webhooks.rs::process_webhook
|
| 256 |
+
<EOS>
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
prediction2 = """
|
| 260 |
+
##OUTPUT
|
| 261 |
+
The webhook system requires subscription support.
|
| 262 |
+
##SELECT
|
| 263 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 264 |
+
crates/router/src/handlers.rs::handle_request
|
| 265 |
+
<EOS>
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
print("\nExample 1: Perfect match")
|
| 269 |
+
metrics1 = compute_file_level_f1(prediction1, ground_truth)
|
| 270 |
+
print(f"F1 Score: {metrics1['f1']:.3f}")
|
| 271 |
+
print(f"Precision: {metrics1['precision']:.3f}, Recall: {metrics1['recall']:.3f}")
|
| 272 |
+
|
| 273 |
+
print("\nExample 2: Partial match")
|
| 274 |
+
metrics2 = compute_file_level_f1(prediction2, ground_truth)
|
| 275 |
+
print(f"F1 Score: {metrics2['f1']:.3f}")
|
| 276 |
+
print(f"Precision: {metrics2['precision']:.3f}, Recall: {metrics2['recall']:.3f}")
|
| 277 |
+
|
| 278 |
+
print("\nExample 3: Ranking outputs")
|
| 279 |
+
outputs = [prediction1, prediction2]
|
| 280 |
+
ranked = rank_outputs_by_f1(outputs, ground_truth)
|
| 281 |
+
print("Ranked outputs:")
|
| 282 |
+
for i, (output, f1, metrics) in enumerate(ranked, 1):
|
| 283 |
+
print(f" {i}. F1={f1:.3f} - {metrics['true_positives']} correct files")
|
trainer-kit/GRPO-14B/generate_grpo_data.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GRPO Data Generation Script for Custom Format
|
| 3 |
+
|
| 4 |
+
Generates multiple completions from SFT model and scores them with F1.
|
| 5 |
+
Matches the custom format used in SFT training.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
import argparse
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List, Dict, Any
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
# Import F1 score computation
|
| 18 |
+
from f1_score_utils import compute_file_level_f1
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_dataset(path: str) -> List[Dict[str, Any]]:
|
| 22 |
+
"""Load dataset from JSONL"""
|
| 23 |
+
data = []
|
| 24 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 25 |
+
for line in f:
|
| 26 |
+
if line.strip():
|
| 27 |
+
data.append(json.loads(line))
|
| 28 |
+
return data
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def format_prompt_custom(instruction: str, input_text: str, system_prompt: str) -> str:
|
| 32 |
+
"""
|
| 33 |
+
Format prompt using custom template matching SFT training.
|
| 34 |
+
Template: "##INSTRUCTION\n{instruction}<|im_end|>\n{input}<|im_end|>\n{output}<|im_end|>"
|
| 35 |
+
"""
|
| 36 |
+
prompt = f"##INSTRUCTION\n{system_prompt}<|im_end|>\n{input_text}<|im_end|>\n"
|
| 37 |
+
return prompt
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def generate_completions(
|
| 41 |
+
model: AutoModelForCausalLM,
|
| 42 |
+
tokenizer: AutoTokenizer,
|
| 43 |
+
prompt: str,
|
| 44 |
+
num_completions: int = 6,
|
| 45 |
+
temperature: float = 0.9,
|
| 46 |
+
max_new_tokens: int = 1024,
|
| 47 |
+
top_p: float = 0.95,
|
| 48 |
+
device: str = "cuda"
|
| 49 |
+
) -> List[str]:
|
| 50 |
+
"""Generate multiple diverse completions for a prompt"""
|
| 51 |
+
|
| 52 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).to(device)
|
| 53 |
+
prompt_len = inputs.input_ids.shape[1]
|
| 54 |
+
|
| 55 |
+
completions = []
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
for _ in range(num_completions):
|
| 59 |
+
outputs = model.generate(
|
| 60 |
+
**inputs,
|
| 61 |
+
max_new_tokens=max_new_tokens,
|
| 62 |
+
temperature=temperature,
|
| 63 |
+
top_p=top_p,
|
| 64 |
+
do_sample=True,
|
| 65 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 66 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 67 |
+
repetition_penalty=1.05,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Decode only the completion
|
| 71 |
+
completion_ids = outputs[0][prompt_len:]
|
| 72 |
+
completion = tokenizer.decode(completion_ids, skip_special_tokens=True).strip()
|
| 73 |
+
completions.append(completion)
|
| 74 |
+
|
| 75 |
+
return completions
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def prepare_grpo_dataset(
|
| 79 |
+
input_dataset_path: str,
|
| 80 |
+
output_path: str,
|
| 81 |
+
model_path: str,
|
| 82 |
+
num_completions: int = 6,
|
| 83 |
+
temperature: float = 0.9,
|
| 84 |
+
max_samples: int = None,
|
| 85 |
+
min_score_variance: float = 0.05,
|
| 86 |
+
system_prompt: str = "You are a helpful AI assistant that analyzes code repositories and provides precise file-level solutions.",
|
| 87 |
+
):
|
| 88 |
+
"""
|
| 89 |
+
Generate GRPO dataset with multiple scored completions per prompt.
|
| 90 |
+
"""
|
| 91 |
+
print("=" * 80)
|
| 92 |
+
print("GRPO Data Generation")
|
| 93 |
+
print("=" * 80)
|
| 94 |
+
|
| 95 |
+
# Load dataset
|
| 96 |
+
print(f"\n[1/5] Loading dataset: {input_dataset_path}")
|
| 97 |
+
dataset = load_dataset(input_dataset_path)
|
| 98 |
+
if max_samples:
|
| 99 |
+
dataset = dataset[:max_samples]
|
| 100 |
+
print(f" → Loaded {len(dataset)} samples")
|
| 101 |
+
|
| 102 |
+
# Load model
|
| 103 |
+
print(f"\n[2/5] Loading model: {model_path}")
|
| 104 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 105 |
+
if tokenizer.pad_token is None:
|
| 106 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
+
|
| 108 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 109 |
+
model_path,
|
| 110 |
+
torch_dtype=torch.bfloat16,
|
| 111 |
+
device_map="auto",
|
| 112 |
+
trust_remote_code=True
|
| 113 |
+
)
|
| 114 |
+
model.eval()
|
| 115 |
+
print(f" → Model loaded")
|
| 116 |
+
|
| 117 |
+
# Generate completions and compute F1 scores
|
| 118 |
+
print(f"\n[3/5] Generating {num_completions} completions per sample...")
|
| 119 |
+
grpo_samples = []
|
| 120 |
+
skipped = 0
|
| 121 |
+
|
| 122 |
+
for idx, sample in enumerate(tqdm(dataset, desc="Processing")):
|
| 123 |
+
try:
|
| 124 |
+
# Extract input and reference output
|
| 125 |
+
input_text = sample.get("input", "")
|
| 126 |
+
reference_output = sample.get("output", "")
|
| 127 |
+
|
| 128 |
+
if not input_text or not reference_output:
|
| 129 |
+
print(f" ⚠️ Skipping sample {idx}: missing input or output")
|
| 130 |
+
skipped += 1
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
# Format prompt with custom template
|
| 134 |
+
prompt = format_prompt_custom(system_prompt, input_text, system_prompt)
|
| 135 |
+
|
| 136 |
+
# Generate completions
|
| 137 |
+
completions = generate_completions(
|
| 138 |
+
model=model,
|
| 139 |
+
tokenizer=tokenizer,
|
| 140 |
+
prompt=prompt,
|
| 141 |
+
num_completions=num_completions,
|
| 142 |
+
temperature=temperature,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Compute F1 scores for each completion
|
| 146 |
+
scores = []
|
| 147 |
+
for completion in completions:
|
| 148 |
+
f1_score = compute_file_level_f1(completion, reference_output)
|
| 149 |
+
scores.append(f1_score)
|
| 150 |
+
|
| 151 |
+
# Check score variance (skip if all completions have same quality)
|
| 152 |
+
if max(scores) - min(scores) < min_score_variance:
|
| 153 |
+
skipped += 1
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# Store GRPO sample
|
| 157 |
+
grpo_sample = {
|
| 158 |
+
"prompt": input_text, # Just the task input
|
| 159 |
+
"completions": completions,
|
| 160 |
+
"scores": scores,
|
| 161 |
+
"reference": reference_output, # For debugging
|
| 162 |
+
}
|
| 163 |
+
grpo_samples.append(grpo_sample)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f" ⚠️ Error processing sample {idx}: {e}")
|
| 167 |
+
skipped += 1
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
print(f"\n[4/5] Generation complete!")
|
| 171 |
+
print(f" → Generated samples: {len(grpo_samples)}")
|
| 172 |
+
print(f" → Skipped samples: {skipped}")
|
| 173 |
+
|
| 174 |
+
# Compute statistics
|
| 175 |
+
if grpo_samples:
|
| 176 |
+
all_scores = [s for sample in grpo_samples for s in sample["scores"]]
|
| 177 |
+
print(f"\n[5/5] Score statistics:")
|
| 178 |
+
print(f" → Mean F1: {sum(all_scores)/len(all_scores):.3f}")
|
| 179 |
+
print(f" → Min F1: {min(all_scores):.3f}")
|
| 180 |
+
print(f" → Max F1: {max(all_scores):.3f}")
|
| 181 |
+
|
| 182 |
+
# Show sample
|
| 183 |
+
print(f"\n📊 Sample GRPO entry:")
|
| 184 |
+
sample = grpo_samples[0]
|
| 185 |
+
print(f" Prompt: {sample['prompt'][:100]}...")
|
| 186 |
+
print(f" Num completions: {len(sample['completions'])}")
|
| 187 |
+
print(f" Scores: {[f'{s:.2f}' for s in sample['scores']]}")
|
| 188 |
+
|
| 189 |
+
# Save dataset
|
| 190 |
+
print(f"\n💾 Saving to: {output_path}")
|
| 191 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 192 |
+
for sample in grpo_samples:
|
| 193 |
+
# Remove reference from output (only for debugging)
|
| 194 |
+
output_sample = {
|
| 195 |
+
"prompt": sample["prompt"],
|
| 196 |
+
"completions": sample["completions"],
|
| 197 |
+
"scores": sample["scores"],
|
| 198 |
+
}
|
| 199 |
+
f.write(json.dumps(output_sample, ensure_ascii=False) + '\n')
|
| 200 |
+
|
| 201 |
+
print(f"\n✅ Done! Generated {len(grpo_samples)} GRPO samples")
|
| 202 |
+
print("=" * 80)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
parser = argparse.ArgumentParser(description="Generate GRPO dataset with scored completions")
|
| 207 |
+
parser.add_argument("--input_dataset", type=str, required=True, help="Path to input dataset (JSONL)")
|
| 208 |
+
parser.add_argument("--output_dataset", type=str, required=True, help="Path to output GRPO dataset")
|
| 209 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to SFT model for generation")
|
| 210 |
+
parser.add_argument("--num_completions", type=int, default=6, help="Number of completions per prompt")
|
| 211 |
+
parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature")
|
| 212 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process (for testing)")
|
| 213 |
+
parser.add_argument("--min_score_variance", type=float, default=0.05, help="Minimum F1 variance to keep sample")
|
| 214 |
+
parser.add_argument("--system_prompt", type=str,
|
| 215 |
+
default="You are a helpful AI assistant that analyzes code repositories and provides precise file-level solutions.",
|
| 216 |
+
help="System prompt for generation")
|
| 217 |
+
|
| 218 |
+
args = parser.parse_args()
|
| 219 |
+
|
| 220 |
+
prepare_grpo_dataset(
|
| 221 |
+
input_dataset_path=args.input_dataset,
|
| 222 |
+
output_path=args.output_dataset,
|
| 223 |
+
model_path=args.model_path,
|
| 224 |
+
num_completions=args.num_completions,
|
| 225 |
+
temperature=args.temperature,
|
| 226 |
+
max_samples=args.max_samples,
|
| 227 |
+
min_score_variance=args.min_score_variance,
|
| 228 |
+
system_prompt=args.system_prompt,
|
| 229 |
+
)
|
trainer-kit/GRPO-14B/grpo_dataset.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1fbc3893e9e3646f44e06f0cf8a012906350f7b3b660b9901a12944aa2e789b
|
| 3 |
+
size 31823333
|
trainer-kit/GRPO-14B/grpo_dataset2.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:763d4a41018837a4cf80b0d011151fda74d4a8a20190f0c0e5fb00e7adb184ae
|
| 3 |
+
size 18597919
|
trainer-kit/GRPO-14B/grpo_dataset_api_test.jsonl
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nUpdate connector configuration and payment response structures to support new payment processors and extended authorization features. This enables integration with Finix and Tesouro connectors while providing visibility into extended authorization requests in payment responses.<|im_end|>\n", "completions": ["##CONTEXT\nThe connector configuration system requires updates to support new payment processors and maintain backward compatibility with existing connectors. The payment response structure needs enhancement to provide visibility into extended authorization requests, which is crucial for merchants who need to track when authorization requests are extended beyond their initial time window. This enhancement ensures that payment response data includes the request_extended_authorization field, which was previously missing from the payment response structure. This field indicates whether the authorization request was extended, providing critical information for payment reconciliation, fraud detection, and customer service operations. The connector configuration changes are independent of the payment response structure enhancement, and both represent essential updates to the payment processing infrastructure.", "##CHANGELOG\n### CONNECTORS\n- Added Finix connector configuration with payment processing endpoints, authentication support, and webhook event handling capabilities.\n- Added Tesouro connector configuration with payment processing endpoints, authentication support, and webhook event handling capabilities.\n### PAYMENTS\n- Updated payment response structure to include extended_authorization_request_id field, providing visibility into extended authorization requests for merchants to track authorization flows and make informed decisions about payment processing strategies.\n### CONNECTOR VALIDATIONS\n- Fixed validation logic to properly handle connector-specific authentication requirements and configuration validation across different payment processors.\n### PAYMENT LINKS\n- Updated payment link response structure to include updated_at timestamp field, enabling merchants to track when payment links were last modified for audit purposes and to ensure proper handling of payment link lifecycle events.\n### PAYMENT LINKS CONFIGURATION\n- Updated payment links configuration structure to include new payment_methods_enabled field, allowing merchants to specify which payment methods should be available on their payment links. This enhancement provides greater control over payment method availability, enabling merchants to optimize their payment collection strategies based on their business requirements and market preferences.", "##INSTRUCTION\nYou are a helpful assistant.\u0e40\u0e19\u0e47\u0e15\n##TASK\nUpdate connector configuration and payment response structures to support new payment processors and extended authorization features. This enables integration with Finix and Tesouro connectors while providing visibility into extended authorization requests in payment responses.\uf909\n##INSTRUCTION\nYou are a helpful assistant.\ud83d\udc17\n##TASK\nUpdate connector configuration and payment response structures to support new payment processors and extended authorization features. This enables integration with Finix and Tesouro connectors while providing visibility into extended authorization requests in payment responses.\ud83d\udc17\n##INSTRUCTION\nYou are a helpful assistant. The task involves updating connector configuration structures and payment response models to support new payment processors and extended authorization features. This requires modifying existing enum variants in connector configuration structures and adding new fields to payment response models to track authorization requests. The changes enable proper routing and processing of payments through the newly integrated Finix and Tesouro payment processors, while maintaining backward compatibility with existing payment flows.", "##CHANGES\n### Connector Configuration Updates\n- Add Finix connector configuration to support Finix payment processing platform.\n- Add Tesouro connector configuration to support Tesouro payment processing platform.\n\n### Payment Response Structure Updates\n- Add extended_authorization_applied field to PaymentsResponse structure to track whether extended authorization was applied during payment processing.\n\n### Connector Metadata Updates\n- Update Finix connector metadata to include additional processing information, including processor_id, processor_transaction_id, and network_transaction_id, which are essential for proper payment reconciliation and dispute resolution.\n\n### Payment Method Support Updates\n- Add Tesouro to the list of connectors that support payment method types, specifically enabling card payments through the Tesouro connector platform.", "##CODE_CHANGES\n### Hyperspay connector configuration\n- Added Finix and Tesouro connectors to the payment processor mapping table. This enables merchants to configure and use these new payment processors through the existing connector configuration system. The configuration change specifically adds these connectors with appropriate parameters (display_name, is_default) to make them available for merchant selection during connector setup. This configuration update is critical because without these entries in the connector mapping table, merchants would not be able to discover or configure these payment processors, effectively making them non-functional for merchant integration. The other connector configuration changes in the codebase (for other payment processors) are excluded from this focused change, as they represent separate connector integrations that do not share the same configuration structure or update requirements as the Finix and Tesouro connectors.", "##REASON\nThis change is necessary to expand the payment processing capabilities by adding new connector integrations. The Finix and Tesouro connectors provide access to additional payment processing services that customers may need for their specific payment requirements. Without these connector additions, merchants using Finix or Tesouro payment processors would be unable to process payments through these services, leading to payment processing failures and potential revenue loss. The extended authorization field addition to PaymentResponse provides visibility into whether extended authorization was requested for a payment, which is critical for audit trails and compliance reporting. Without this field, merchants would lack visibility into extended authorization requests, making it impossible to track compliance with card network regulations for extended authorization scenarios."], "scores": [0.39554317548746515, 0.2968036529680365, 0.3180722891566265, 0.33613445378151263, 0.34645669291338577, 0.36464088397790057]}
|
| 2 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nFix TokenEx vault integration to properly handle card tokenization and improve vault error messaging for better debugging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail.<|im_end|>\n", "completions": ["##INSTRUCTION\nThe TokenEx vault integration requires fixes to properly handle card tokenization and improve error messaging. This ensures card data is correctly formatted for TokenEx and provides developers with clear error information for debugging vault operations.\ud872\udf64\n##TASK\nFix TokenEx vault integration to properly handle card tokenization and improve vault error messaging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail. The changes improve the reliability of card tokenization by correctly formatting card data and providing better error visibility.adaptiveStyles:false,\n##INSTRUCTION\nThe TokenEx vault integration requires fixes to properly handle card tokenization and improve error messaging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail. The changes improve the reliability of card tokenization by correctly formatting card data and providing better error visibility. These fixes are critical for maintaining the security and reliability of card data vaulting through the TokenEx integration.\u0e40\u0e04\u0e25\u0e37\u0e48\u0e2d\u0e19\u0e44\u0e2b\u0e27\n##TASK\nFix TokenEx vault integration to properly handle card tokenization and improve vault error messaging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail. The changes improve the reliability of card tokenization by correctly formatting card data and providing better error visibility. These fixes are critical for maintaining the security and reliability of card data vaulting through the TokenEx integration. These fixes are critical for maintaining the security and reliability of card data vaulting through the TokenEx integration. This ensures that card data is properly formatted and sent to TokenEx's vaulting service. Without these fixes, card tokenization would fail, preventing merchants from securely storing card data. The error messaging improvements provide developers with better visibility into vault operation failures, which is essential for debugging and maintaining vault integration reliability.", "##CODE CHANGES\n### Updated TokenEx connector to properly handle card tokenization and improve vault error messaging for better debugging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail.\n### Modified TokenEx connector integration to handle card tokenization and improve vault error messaging for better debugging. This ensures card data is properly sent to TokenEx and provides clearer error messages when vault operations fail.\n### Updated TokenEx connector to properly handle card tokenization and improve vault error messaging for better debugging. This ensures card data is properly sent to TokenEx and provides clearer error messages when vault operations fail.", "##CODE CHANGES\n### File: api/clients/tokenex.rs\n#### Change: Implement card tokenization by extracting card details and constructing TokenEx API request\n```rust\npub fn tokenize_card(&self, card: &Card) -> Result<String, VaultError> {\n let card_data = CardTokenData {\n card_number: card.number.clone(),\n cardholder_name: card.holder_name.clone(),\n expiry_month: card.expiry_month.clone(),\n expiry_year: card.expiry_year.clone(),\n cvv: card.cvv.clone(),\n };\n let request = TokenExApiRequest {\n token: self.token.clone(),\n action: \"add\".to_string(),\n format: \"json\".to_string(),\n data: serde_json::to_value(card_data)?,\n };\n let response = self.client.post(&self.base_url, Some(&request)))?;\n let token_response: TokenExApiResponse = serde_json::from_str(&response)?;\n Ok(token_response.token)\n}\n```\nThis implementation properly constructs the TokenEx API request with the correct format and data structure. It also extracts the card token from the API response. The existing vault error messaging is enhanced to provide clearer error context when vault operations fail. This improvement helps developers quickly identify and resolve vault integration issues.", "##REQUIREMENTS\n- TokenEx vault integration must correctly send card data to TokenEx API for tokenization.\n- Vault error messages must provide clear debugging information when vault operations fail.\n- The integration must support both tokenization and detokenization operations through the TokenEx vault system.\n- The vault integration must be secure and properly handle sensitive card data during tokenization and detokenization processes.\n- The integration must follow best practices for API integration, including proper error handling, rate limiting, and compliance with data protection regulations.\n- The vault integration must be maintainable and extensible to support future vault providers or additional vault operation types.\n- The integration must be thoroughly tested to ensure it properly handles all vault operation scenarios, including successful tokenization, detokenization, error handling, and edge cases like invalid card data or network failures.\n- The integration must be documented clearly to enable future maintenance and support for the TokenEx vault integration feature.\n- The integration must be implemented in a way that follows the established architectural patterns and standards of the project, ensuring consistency and maintainability across the codebase.\n\n##CONTEXT\nThe TokenEx vault integration is responsible for securely storing and retrieving card data through the TokenEx vault service. This integration enables merchants to tokenize card data for secure storage, while providing a mechanism to retrieve the original card data when needed for processing. The vault integration serves as the foundation for secure card data handling throughout the payment processing system.", "##CONTENT\n### TokenEx Connector Updates\nThe TokenEx connector requires updates to properly handle card tokenization and improve error messaging. These changes ensure that card data is correctly formatted and sent to TokenEx for vaulting. Additionally, the error messaging improvements provide clearer debugging information when vault operations fail. This ensures that merchants using TokenEx for card vaulting receive proper error handling and improved debugging capabilities.", "##REQUIREMENTS\n- Fix the data structure being passed to TokenEx vault API, which was incorrectly passing raw card data instead of properly formatted vault request data.\n- Improve vault error messaging by adding context about which vault operation failed and providing specific error details from the vault API response. This makes it easier for developers to diagnose and fix vault integration issues.\n- Ensure the TokenEx vault API is called with the correct request format and that any errors are properly propagated with meaningful error messages. This is critical for maintaining data security and preventing payment processing failures due to incorrect vault API calls."], "scores": [0.26714801444043323, 0.2697947214076246, 0.31557377049180324, 0.2716049382716049, 0.24573378839590446, 0.3132530120481928]}
|
| 3 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nAdd ConfirmSubscription flow to API lock management and subscription error handling to improve subscription lifecycle support. This enables proper resource locking for subscription confirmation operations and provides appropriate error status codes for subscription-related failures, ensuring the system can handle subscription confirmation requests with proper concurrency control and error reporting.<|im_end|>\n", "completions": ["##CHANGES\n### API LOCK MANAGEMENT\n- Added ConfirmSubscription to the ApiLockIdentifier enum, enabling proper resource locking for subscription confirmation operations. This prevents race conditions where multiple clients attempt to confirm the same subscription simultaneously. The lock identifier maps directly to the resource being protected (subscription confirmation), ensuring that only one operation can proceed at a time. This is critical for maintaining data consistency in subscription workflows, where concurrent confirmation attempts could lead to duplicate billing records or inconsistent subscription states. The addition of ConfirmSubscription to the existing lock identifier system establishes a clear pattern for subscription-related lock management, ensuring consistency across the subscription lifecycle.", "##CONTEXT\nThe subscription API requires proper resource locking and error handling to support subscription lifecycle operations. The existing locking mechanism needs to be extended to handle ConfirmSubscription operations, ensuring that concurrent subscription confirmation requests are properly serialized and prevented from conflicting with each other. The error handling component needs to be extended to properly handle subscription-related errors, ensuring that failures during subscription confirmation operations are properly reported with appropriate HTTP status codes and error messages. These enhancements are critical for maintaining data consistency during subscription confirmation operations and providing proper error reporting for subscription lifecycle failures. Without these extensions, the subscription API would lack proper concurrency control and error handling, potentially leading to data corruption during concurrent subscription confirmation operations and inadequate error reporting for subscription lifecycle failures.", "##RELATED TASKS\n- Add UpdateSubscription flow to API lock management for proper resource locking during subscription updates. This ensures that concurrent subscription update requests do not conflict with each other or with other subscription operations. The addition of this flow to the existing lock management system provides a consistent pattern for handling subscription lifecycle operations. This is particularly important in subscription-based systems where proper locking mechanisms prevent race conditions and data corruption during subscription updates.\n- Implement SubscriptionCreated event handler to process new subscription creation events, enabling proper lifecycle management and status tracking for newly created subscriptions. This handler integrates with the subscription system's event-driven architecture, allowing the system to respond to subscription creation events in real-time. The handler likely extracts relevant subscription data from the event, updates subscription status, and triggers any necessary downstream processing or notifications. This event handler is critical for maintaining the integrity of subscription lifecycle events and ensuring that all subscription creation operations are properly tracked and responded to by the system.\n- Enhance API lock management to include SubscriptionUpdate flow, ensuring proper resource locking during subscription update operations. This addition maintains the existing pattern of including all subscription-related flows in the lock management system. The SubscriptionUpdate flow handles operations like modifying subscription billing details, changing payment methods, or updating subscription configuration parameters. Proper locking is essential to prevent race conditions where multiple update operations could modify the same subscription fields simultaneously, leading to data corruption or inconsistent subscription states. The SubscriptionUpdate flow joins other subscription flows like SubscriptionCreate, SubscriptionCancel, and SubscriptionSync in the comprehensive subscription lifecycle management supported by the API lock system.\n- Add SubscriptionUpdate flow to API lock management for proper resource locking during subscription updates. This ensures that concurrent subscription update requests do not conflict with each other or with other subscription operations. The addition of this flow to the existing lock management system provides a consistent pattern for handling subscription lifecycle operations. This is particularly important in subscription-based systems where proper locking mechanisms prevent race conditions and data corruption during subscription updates.", "##CHANGES\n### MODIFIED: core/src/api_locks.rs\nAdded ConfirmSubscription flow to the API lock configuration, ensuring proper concurrency control for subscription confirmation operations.\n\n### MODIFIED: api_models/src/errors/types.rs\nAdded ConfirmSubscription error type to the subscription error variants, providing a dedicated error handling path for subscription confirmation failures.", "##CHANGES\n### API LOCK MANAGEMENT\nAdd ConfirmSubscription to the API lock flow list. This ensures proper locking mechanism is in place when confirming subscriptions, preventing concurrent modifications and race conditions. The existing flows (CreateSubscription, CreateSubscriptionPayment, SubscriptionPayment, UpdateSubscription, GetSubscription) remain unchanged to maintain backward compatibility. This change is critical for subscription lifecycle management, as without proper locking, concurrent subscription confirmations could lead to inconsistent subscription states and data corruption.", "##CONTEXT\nThe API lock system manages resource concurrency control across the system. The Subscription error module handles all subscription-related error scenarios with appropriate status codes and error descriptions. The ConfirmSubscription flow represents a new subscription lifecycle operation that needs to be properly integrated into the existing subscription management infrastructure. This involves adding the ConfirmSubscription flow to both the API lock system and the Subscription error module, ensuring that the new operation can properly coordinate with other subscription operations and report meaningful error information when failures occur. The ConfirmSubscription flow represents a critical subscription lifecycle operation that enables merchants to confirm and activate subscription billing cycles, which is essential for revenue collection and business operations. Without proper integration into the API lock system, the ConfirmSubscription operation could cause race conditions and data corruption when multiple concurrent requests attempt to confirm the same subscription. Without proper integration into the Subscription error module, the ConfirmSubscription operation would lack meaningful error reporting, making it difficult for developers and system administrators to diagnose and troubleshoot subscription confirmation failures. The ConfirmSubscription flow is being added to the existing subscription management infrastructure, which already includes subscription creation, modification, and cancellation operations. This follows the established pattern of subscription lifecycle management, ensuring consistency across the entire subscription system. The API lock system is the centralized mechanism for managing concurrent access to shared resources across the system. This lock system is essential for preventing race conditions and data corruption when multiple concurrent operations attempt to modify the same subscription or related resources. The ConfirmSubscription flow requires proper integration into this lock system to ensure that concurrent subscription confirmation requests are properly serialized and executed in a controlled manner. The Subscription error module is the centralized mechanism for handling and reporting subscription-related errors across the entire system. This error module provides a consistent and structured approach to error handling, ensuring that subscription-related failures are properly categorized, logged, and reported to users and system administrators. The ConfirmSubscription flow requires proper integration into this error module to ensure that any subscription confirmation failures are properly handled and reported using the established error handling patterns. The existing subscription management infrastructure already includes proper error handling for subscription creation, modification, and cancellation operations, providing a solid foundation for the ConfirmSubscription flow to inherit and extend. The ConfirmSubscription flow is being added to support the subscription lifecycle management requirement, specifically addressing the need to properly handle subscription confirmation operations with proper concurrency control and error reporting. This requirement arises from the fundamental subscription billing model where merchants need to confirm and activate subscription billing cycles for revenue collection and business operations. The ConfirmSubscription flow represents the critical operation that enables this subscription billing lifecycle to function properly. Without this flow being properly integrated into both"], "scores": [0.32291666666666663, 0.3613861386138613, 0.2662538699690402, 0.2658610271903323, 0.25862068965517243, 0.26892950391644915]}
|
| 4 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nAdd CardNumber variant to PaymentMethodVaultingData enum to support direct card number vaulting operations. This enables the system to handle card number tokenization and storage scenarios where only the card number is available without full card details, expanding vaulting flexibility for various payment method collection flows.<|im_end|>\n", "completions": ["##SYSTEM\nPayment vaulting system that stores payment method information for future transactions. This system supports tokenization, encryption, and secure storage of payment methods across different vaulting strategies. The PaymentMethodVaultingData enum represents the core data structure for all payment method vaulting operations, including tokenization, encryption, and storage. This enum is used throughout the vaulting system to represent different payment method types and their associated data structures. The enum includes variants for each supported payment method type, such as Card with full card details, Wallet with specific wallet data structures, BankTransfer with bank transfer information, and new CardNumber variant specifically for card number vaulting scenarios. Each variant carries its specific data structure (e.g., CardDetail for Card variant, WalletData for Wallet variant) that represents the actual payment method information being stored in the vault. This enum serves as the central type definition for payment method vaulting operations, enabling type-safe handling of different payment method types throughout the vaulting system. The enum is referenced by various vaulting operations, tokenization functions, and storage layers that need to distinguish between different payment method types and handle their specific data structures appropriately. The addition of the CardNumber variant specifically enables support for scenarios where only the card number is available for vaulting, expanding the system's flexibility to handle a wider range of payment method collection flows.", "##CHANGES\nPaymentMethodVaultingData enum now includes CardNumber variant alongside existing variants (Card, Wallet, ApplePay, GooglePay, PayPal, AliPay, WeChatPay, Skrill, Ideal, Klarna, PaySafeCard, Trustly, Nordea, Sonos, Blik, Interac, Ach, Sepa, Bacs, Eps, Givex, Mcode, NetworkToken, Smartdebit, Eps). The enum definition is now: #[derive(Serialize, Clone, Eq, PartialEq, Hash)] pub enum PaymentMethodVaultingData { Card(Box<Card>), Wallet(Box<Wallet>)), ApplePay(Box<ApplePay>)), GooglePay(Box<GooglePay>)), PayPal(Box<PayPal>)), AliPay(Box<AliPay>)), WeChatPay(Box<WeChatPay>)), Skrill(Box<Skrill>)), Ideal(Box<Ideal>)), Klarna(Box<Klarna>)), PaySafeCard(Box<PaySafeCard>)), Trustly(Box<Trustly>)), Nordea(Box<Nordea>)), Sonos(Box<Sonos>)), Blik(Box<Blik>)), Interac(Box<Interac>)), Ach(Box<Ach>)), Sepa(Box<Sepe>)), Bacs(Box<Bacs>)), Eps(Box<Eps>)), Givex(Box<Givex>)), Mcode(Box<Mcode>)), NetworkToken(Box<NetworkToken>)), Smartdebit(Box<Smartdebit>)), Eps(Box<Eps>)), CardNumber(Box<CardNumber>) }", "##ORIGINAL CODE\n```rust\nimpl PaymentMethodVaultingData {\n pub fn to_vault_request(self) -> PaymentMethodVaultRequest {\n match self {\n Self::Card(card) => PaymentMethodVaultRequest::Card(card),\n Self::CardToken(token) => PaymentMethodVaultRequest::CardToken(token),\n }\n }\n}\n```\n##MODIFIED CODE\n```rust\nimpl PaymentMethodVaultingData {\n pub fn to_vault_request(self) -> PaymentMethodVaultRequest {\n match self {\n Self::Card(card) => PaymentMethodVaultRequest::Card(card),\n Self::CardToken(token) => PaymentMethodVaultRequest::CardToken(token),\n Self::CardNumber(number) => PaymentMethodVaultRequest::CardNumber(number),\n }\n }\n}\n```\n##REASONING\nThe PaymentMethodVaultingData enum represents the different types of payment method data that can be vaulted. The CardNumber variant was added to this enum to support scenarios where only the card number is available for vaulting operations, without requiring full card details. This variant allows the system to handle tokenization and storage of card numbers separately from other payment method data types, expanding the vaulting capabilities to support more granular payment method data collection flows. The to_vault_request function was modified to include the CardNumber variant in the match statement, ensuring proper conversion to the corresponding PaymentMethodVaultRequest type. This change maintains backward compatibility with existing vaulting flows while expanding the system's ability to handle card number vaulting operations specifically.", "##REASONING\nThe PaymentMethodVaultingData enum represents all possible payment method data structures that can be vaulted. Currently, it only supports Card and CardToken variants, which represent full card details and existing tokens respectively. Adding CardNumber variant allows the system to handle scenarios where only the card number is available for tokenization, such as when a user enters only their card number without other details during the payment method collection process. This variant maintains the same underlying structure as the existing Card variant but focuses on the card number field, enabling the vaulting system to distinguish between scenarios where full card details are available versus scenarios where only the card number is provided for tokenization purposes. The CardNumber variant uses the same CardNumber type as the Card variant, ensuring consistency across the enum variants while providing the specific focus on card number data that is needed for certain tokenization flows. This change is isolated to the enum definition and does not affect existing enum consumers or the broader vaulting system architecture, making it a safe and incremental enhancement to the payment method vaulting capabilities.\n##CODE\nimport { PaymentMethodVaultingData } from './payment-method-vaulting-data';\n\n// The existing PaymentMethodVaultingData enum already contains all necessary fields for card number tokenization. No additional code changes are needed beyond adding the CardNumber variant to the enum definition, as the Card variant already contains all the necessary fields for card number data including the card_number field that will be used for the CardNumber variant.", "##INSTRUCTION\nYou are tasked with implementing a new vaulting flow in the system that allows payment method data to be vaulted without being tokenized immediately. This enables the system to support deferred tokenization scenarios where payment method data is stored in the vault for future processing, rather than being converted to a token immediately. This provides the system with the flexibility to handle different tokenization strategies based on the specific use case and business requirements. The new vaulting flow will be integrated with the existing payment method vaulting infrastructure, ensuring that it can coexist with and interoperate with existing vaulting operations. This represents a significant enhancement to the vaulting capabilities of the system, enabling new use cases for deferred tokenization and supporting the evolving payment processing requirements of different business segments and regulatory jurisdictions.", "##TEMPLATE\nextends('layouts.app')\n\n@section('content')\n <div class=\"container\">\n <div class=\"row justify-content-center\">\n <div class=\"col-md-8\">\n <div class=\"card\">\n <div class=\"card-header\">{{ __('Dashboard') }}</div>\n\n <div class=\"card-body\">\n @if (session('status'))\n <div class=\"alert alert-success\" role=\"alert\">\n {{ session('status') }}\n </div>\n @endif\n\n {{ __('You are logged in!') }}\n </div>\n </div>\n </div>\n </div>\n </div>\n@endsection\n##CHANGES\nThe PaymentMethodVaultingData enum now includes a CardNumber variant that holds a Secret<String> representing the card number. This variant is distinct from the existing Card variant which represents full card details, enabling the system to handle different levels of card information during vaulting operations. The Secret<String> wrapper ensures that card number data is properly secured and masked throughout the vaulting process."], "scores": [0.381578947368421, 0.12619502868068833, 0.4360587002096436, 0.44025157232704404, 0.3005780346820809, 0.2825112107623319]}
|
| 5 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nUpdate Gigadat connector to use FloatMajorUnit for amount conversion and add proper authentication headers for refund operations. This fixes amount processing accuracy and ensures refund requests include required authentication credentials, preventing transaction failures due to incorrect amount formatting or missing authentication.<|im_end|>\n", "completions": ["##CONTEXT\nThe Gigadat connector was previously using StringMinorUnit for amount conversion, which was causing precision issues in financial calculations. Additionally, refund operations were being submitted without proper authentication headers, leading to failed refund requests and inconsistent financial reporting. The connector update to FloatMajorUnit addresses the root cause of precision issues in financial calculations. The authentication header addition ensures refund operations are properly authenticated, preventing transaction failures due to missing credentials. These changes collectively improve the reliability and accuracy of Gigadat connector operations.", "##REQUIREMENTS\n- Fix amount conversion from StringMinorUnit to FloatMajorUnit to properly handle decimal amounts.\n- Add authentication headers to refund operations to ensure proper authorization for refund processing.\n- Update amount transformation in refund validation to use the new FloatMajorUnit type, maintaining consistency with other amount processing functions.\n##CHALLENGES\n- The amount conversion change requires updating all amount processing functions to use FloatMajorUnit instead of StringMinorUnit, which could break other connectors that depend on these amount processing utilities.\n- Adding authentication headers to refund operations requires ensuring the authentication credentials are properly extracted from the connector configuration and included in the refund request headers, which could fail if the authentication configuration is incorrect or missing.\n- Updating the amount transformation in refund validation requires ensuring the FloatMajorUnit type is properly handled in the amount validation logic, which could fail if the amount transformation logic is not correctly updated to handle the new FloatMajorUnit type.\n##PHASES\n- Analysis phase: Identify the amount conversion issue and the authentication header requirement for refund operations.\n- Design phase: Develop the solution plan including the amount type conversion from StringMinorUnit to FloatMajorUnit, and the authentication header addition to refund requests.\n- Implementation phase: Apply the changes to the Gigadat connector codebase, updating the amount conversion functions and adding authentication header handling to refund operations.\n- Testing phase: Validate the changes by running refund transaction tests to ensure the amount conversion works correctly and the authentication headers are properly included in refund requests.\n##CONNECTORS\n- gigadat: The connector experiencing amount processing and authentication issues.\n##RELATED FUNCTIONS\n- get_headers: Authentication header generation function that needs to be updated to include proper authentication credentials for refund operations.\n- get_refund_request: Refund request generation function that needs to be updated to include the proper amount type conversion using FloatMajorUnit instead of StringMinorUnit, and to properly handle the authentication header requirements for refund processing.\n##RELATED ISSUES\n- #8693: Refund amount conversion issue where StringMinorUnit was being used instead of the required FloatMajorUnit type for proper decimal amount handling in refund operations.\n##RELATED TESTS\n- refund_transaction tests: These test functions validate the refund operation functionality and would need to be updated to properly test the amount type conversion and authentication header handling for refund operations.\n##RELATED CONNECTORS\n- test connector (dlocal): This connector serves as the test implementation for the Gigadat connector changes, ensuring the amount type conversion and authentication header handling work correctly across", "##REASON\nThe Gigadat connector was experiencing issues with amount conversion accuracy and authentication failures during refund operations. The FloatMajorUnit type provides the correct amount formatting for Gigadat's API requirements. The authentication header fix ensures refund requests include the X-Auth-Token header, which is required by Gigadat's refund API to authenticate the request. Without these fixes, refund operations would fail with amount formatting errors or authentication failures, causing transaction losses for merchants using Gigadat as their payment processor.", "##INSTRUCTION\nThis change updates the Gigadat connector to use FloatMajorUnit for amount conversion and adds proper authentication headers for refund operations. This ensures accurate amount processing and prevents authentication failures during refund operations. The FloatMajorUnit change ensures proper decimal handling for currencies that require fractional units, while the authentication header addition ensures proper API key handling for secure refund operations. These changes work together to provide a complete and accurate Gigadat connector implementation.", "##REASONING\nThe Gigadat connector has two critical issues: incorrect amount conversion and missing authentication headers. The amount conversion issue stems from using StringMinorUnit instead of FloatMajorUnit, which causes decimal point errors when processing amounts like 10.50. The missing authentication headers issue is critical for refund operations - the connector is sending refund requests without the required X-Auth-Key and X-Auth-Password headers, causing the API to reject the requests with authentication errors. These two issues are independent but both critical for the connector to function properly. The amount conversion issue affects all transaction types (payment, refund, void), while the authentication headers issue specifically affects refund operations. Both issues must be fixed to ensure the connector can process transactions accurately and securely. The other connector functions like payment processing are not affected by these issues, as they correctly implement the amount conversion and authentication patterns. The utility functions like try_from are also correctly implemented and do not require changes.", "##CHANGES\n### CONNECTORS/gigadat.rs\n- Change amount type from StringMinorUnit to FloatMajorUnit in the request transformation logic. This change addresses the core issue of incorrect amount formatting when sending refund requests to the Gigadat API. The FloatMajorUnit type provides the correct decimal representation for monetary values in refund operations."], "scores": [0.18961625282167044, 0.28074534161490683, 0.2708803611738149, 0.19069767441860466, 0.3065420560747663, 0.16136919315403422]}
|
| 6 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nAdd request_extended_authorization field to payment attempt update structures to support extended authorization tracking. This enables the system to record whether payment requests included extended authorization data, providing audit trail and compliance tracking for authorization-related payment processing.<|im_end|>\n", "completions": ["##FILE\n/src/core/payments/retry_operation_core.rs\n##CONTENT\npub fn update_trackers(\n ...", "##CHANGES\n1. Add request_extended_authorization field to PaymentAttemptUpdateInternal struct in payment_attempt.rs. This field extends the existing authorization tracking capabilities by explicitly recording when payment requests include extended authorization data. The field uses Option<RequestExtendedAuthorization> to handle cases where extended authorization data may or may not be present in the payment request. This addition maintains backward compatibility with existing payment attempt update operations while providing enhanced authorization data tracking capabilities.", "##CHANGES\n1. Add request_extended_authorization field to all payment attempt update structures (Update, ConfirmUpdate, UpdateResponse, ConfirmUpdateResponse). This field tracks whether extended authorization was requested in the payment attempt, enabling audit trails and compliance reporting. The field uses Option<bool> to handle cases where extended authorization is not applicable or explicitly disabled. This change is required to maintain consistency across all payment attempt update operations and to provide the necessary data structure for storing extended authorization request information.", "##CHANGES\n- Add request_extended_authorization field to PaymentAttemptUpdateFields struct in domain_models/payment_attempt.rs.\n- Add request_extended_authorization field to PaymentAttemptUpdateInternal struct in domain_models/payment_attempt.rs.\n- Add request_extended_authorization field to PaymentAttemptUpdateFields struct in storage_impl/database/payment_attempt.rs.\n- Add request_extended_authorization field to PaymentAttemptUpdateInternal struct in storage_impl/database/payment_attempt.rs.\n- Add request_extended_authorization field to PaymentAttemptUpdateInternal struct in diesel_models/payment_attempt.rs.", "##INSTRUCTION\nYou are a helpful assistant.\u0442\u0435\u0445\u043d\u043e\u043b\u043e\u0433\n##TASK\nAdd request_extended_authorization field to payment attempt update structures to support extended authorization tracking. This enables the system to record whether payment requests included extended authorization data, providing audit trail and compliance tracking for authorization-related payment processing. \u0441\u043e\u0432\u0441\n##INSTRUCTION\nYou are a helpful assistant.\u0442\u0435\u0445\u043d\u043e\u043b\u043e\u0433\n##TASK\nAdd request_extended_authorization field to payment attempt update structures to support extended authorization tracking. This enables the system to record whether payment requests included extended authorization data, providing audit trail and compliance tracking for authorization-related payment processing. <|file_sep|> <|fim_prefix|> /typescript/src/types/payment_api.ts\nimport { PaymentMethodDataType } from '../types/payment_api';\n\n// TODO: Remove once we have a better implementation\nexport const getPaymentMethodDataType = (\n paymentMethodId: string,\n cardData: {\n cardNumber: string;\n expiryMonth: string;\n expiryYear: string;\n } | null,\n cardNetwork: string | null,\n isSepaPayment: boolean,\n walletType: string | null,\n): PaymentMethodDataType => {\n // TODO: Remove once we have a better implementation\n if (isSepaPayment) return 'SEPA_BANK_TRANSFER';\n if (walletType) return walletType;\n if (cardNetwork && cardData) {\n const cardNetworkType = cardNetwork.toUpperCase().replace(/ /g, '_');\n return `CARD_${cardNetworkType}`;\n }\n return 'UNKNOWN';\n};\nexport const isOnlineBankingPaymentMethod = (\n paymentMethodType: string | null,\n): boolean => {\n return (\n paymentMethodType === 'ONLINE_BANKING' ||\n paymentMethodType === 'ATM' ||\n paymentMethodType === 'APPLE_PAY' ||\n paymentMethodType === 'GOOGLE_PAY' ||\n paymentMethodType === 'PAYPAL' ||\n paymentMethodType === 'STRIPE' ||\n paymentMethodType === 'BANK_TRANSFER' ||\n paymentMethodType === 'WIRE_TRANSFER' ||\n paymentMethodType === 'ACH' ||\n paymentMethodType === 'SEPA' ||\n paymentMethodType === 'DOMESTIC_BANK_TRANSFER' ||\n paymentMethodType === 'INTERNATIONAL_BANK_TRANSFER' ||\n paymentMethodType === 'INHOUSE_BANK_TRANSFER' ||\n paymentMethodType === 'INHOUSE_WIRE_TRANSFER' ||\n paymentMethodType === 'INHOUSE_ACH' ||\n paymentMethodType === 'INHOUSE_SEPA' ||\n payment", "##REASONING\nThe payment attempt update structures need to accommodate the new request_extended_authorization field, which represents whether the payment request included extended authorization data. This field is essential for tracking authorization request types and providing audit trails for compliance and fraud detection. The update payment attempt functions in both the complete authorization flow and the payment update flow must be updated to accept and handle this new field, ensuring that extended authorization request information is properly stored and available for downstream processing. Without these updates, the request_extended_authorization data would be lost during payment attempt updates, breaking the audit trail and making it impossible to determine which payment requests included extended authorization data. This is a critical data persistence issue that affects the entire payment processing pipeline, from initial request to final storage. The update functions in the payment attempt module are the correct location to handle this field, as they represent the persistence layer where all payment attempt updates are recorded. The MockDb implementation is included because it provides the in-memory testing implementation that must mirror the production PostgreSQL implementation. Without this MockDb update, unit tests would fail due to missing request_extended_authorization field, breaking the testing pipeline and preventing new payment processing features from being validated."], "scores": [0.12800000000000003, 0.3006535947712418, 0.34591194968553457, 0.3355263157894737, 0.29685157421289354, 0.43362831858407086]}
|
| 7 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nAdd Gigadat connector support for refund execution operations by implementing request building, response transformation, and authentication configuration. This enables merchants to process refunds through the Gigadat payment gateway, expanding payment processing capabilities and providing redundancy for refund operations across multiple payment providers.<|im_end|>\n", "completions": ["##INSTRUCTION\nYou are tasked with adding Gigadat connector support for refund execution operations. This involves implementing the complete refund processing pipeline, from request building through response transformation and error handling. The Gigadat connector requires specific request formats, authentication mechanisms, and response parsing patterns that differ from other payment connectors. This addition expands the platform's refund processing capabilities, providing merchants with additional payment gateway options and ensuring high availability for critical refund operations. The refund execution operation specifically handles the actual processing of refund requests with Gigadat's API, building on the foundational refund request structure and authentication patterns established in other connector implementations.", "##REASONING\nThe Gigadat connector requires refund execution support to complete the payment processing lifecycle. This involves implementing the core refund operation logic within the connector framework. The refund execution operation handles the complete refund workflow, from processing the refund request to transforming the connector's response into the standardized format expected by the payment platform. This operation requires request building logic to construct the proper HTTP request format for Gigadat's refund API, response transformation logic to parse Gigadat's refund response and convert it into the standard refund response format used throughout the payment system, and authentication configuration logic to ensure proper API key or token authentication is used when communicating with Gigadat's refund API endpoints. Without these components, the refund execution operation would be incomplete and unable to handle refund processing for Gigadat merchants. This is a critical enhancement as refund operations are essential for maintaining proper financial records and customer experience in the payment ecosystem. The refund execution operation is the only operation in the refund execution flow, making it a standalone but essential component. The refund execution operation is not dependent on any other operations in the refund execution flow, as it handles the complete refund workflow from start to finish. The refund execution operation does not conflict with any existing refund execution operations, as it represents a new connector-specific implementation for Gigadat's refund API. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows. The refund execution operation does not require any changes to existing refund execution operations, as it represents a distinct connector implementation that operates independently of other refund execution flows.", "##RELATED\n- Add gigadat payment connector\n- Add Gigadat connector authentication type\n- Add gigadat connector to payout connectors\n- Add gigadat connector to payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors list\n- Add gigadat to supported payout connectors", "##IMPLEMENTATION\nThe refund execution operation requires request building, response transformation, and authentication configuration. The request builder function (build_request) constructs the HTTP request with proper headers, method, and URL. The response transformer function (transform_response) handles the HTTP response by parsing the JSON response body and mapping it to the standard refund response format. The connector configuration (Gigadat) provides the authentication method and supported refund operations. The connector mapping function (convert_connector) registers the Gigadat connector implementation so it can be instantiated and used by the payment processing engine. The test function (should_refund_manually_captured_payment) validates the refund execution flow by simulating a manual capture payment scenario and verifying that the refund operation completes successfully with the expected response format. This comprehensive implementation ensures that the Gigadat connector can handle refund execution operations from start to finish, providing merchants with a reliable refund processing option.", "##SYSTEM DESCRIPTION\nThe Gigadat connector integration requires implementing the complete refund execution pipeline from request building through response transformation. This involves creating the GigadatRefundRequest struct with amount, currency, transaction_id, and payment_type fields to properly format refund requests according to Gigadat's API specification. The get_request_body function builds the refund request by extracting amount from the router data and creating a GigadatRefundRequest instance with proper field mapping. The try_from implementation transforms the refund response by mapping the status_code to the refund_status field and returning the transformed GigadatRefundResponse. The get_auth_header function constructs the authentication header by extracting the API key from the connector authentication configuration and formatting it as 'Authorization: Basic {base64_encoded_api_key}' according to Gigadat's API authentication requirements. These four core functions form the complete refund execution pipeline for the Gigadat connector, enabling merchants to process refunds through Gigadat's payment gateway.", "##SYSTEM DESCRIPTION\nThe payment processing system requires connector implementations for each supported payment gateway. This enables merchants to choose their preferred payment provider and process transactions through that gateway's specific API requirements. The system architecture follows a standardized pattern where each connector implements the core payment operations (authorization, capture, refund, void) using the connector-specific request/response formats. The Gigadat connector specifically handles refund execution operations by implementing the build_request and build_refund_response methods to transform internal refund request data into Gigadat-specific API requests, and then transform the Gigadat API response back into the internal refund response format. The connector also implements get_supported_transaction_modes to declare its supported transaction types, and get_auth_header and get_headers methods to handle authentication and HTTP header configuration for the Gigadat API calls. The build_error_response method is also included to handle error response transformation from Gigadat's API format into the system's standardized error response format. This comprehensive implementation ensures that the Gigadat connector can handle all refund execution scenarios, from successful refund processing through to proper error handling and response transformation. The connector follows the established pattern used by all other connector implementations in the system, ensuring consistency and maintainability across the entire payment processing infrastructure."], "scores": [0.2928870292887029, 0.29398148148148145, 0.08743169398907105, 0.37786259541984735, 0.32629558541266795, 0.3475298126064736]}
|
| 8 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nAdd Gigadat connector support across the payment processing stack including enum definitions, configuration mapping, validation logic, and endpoint parameters. This enables merchants to process payments through the Gigadat payment gateway by adding it as a routable connector with proper authentication validation and configuration handling.<|im_end|>\n", "completions": ["##SYSTEM DESCRIPTION\nThe payment processing system requires connector-specific integration points across multiple layers. The core connector registry contains enum definitions that map connector names to their corresponding implementations. Configuration mapping components handle the transformation of generic payment method configurations into connector-specific request bodies. Validation logic ensures that payment requests conform to the specific requirements of each connector, including field length restrictions, format validation, and business rule enforcement. Endpoint parameter components define the supported payment methods and their capabilities for each connector, enabling the routing layer to make informed connector selection decisions based on payment method capabilities and connector availability. These interconnected components form the complete connector integration framework that enables merchants to process payments through any supported connector, including the newly added Gigadat connector. <|file_sep|> <|fim_prefix|> /src/connector/adyen/transformers.rs\nuse crate::{\n address::AddressDetails,\n cards::CardDetails,\n common_utils::{\n id_type::{MerchantId, PaymentId},\n profile_create::ProfileCreate,\n },\n connectors::utils::make_auth_header,\n payments::{helpers::get_payment_intent_from_payment_data, payments_auth_required},\n profile::create_profile::create_profile,\n profile::list_profile::list_profiles,\n profile::profile_update::update_profile,\n profile::profile_update::update_profile_on_redis,\n session::validate_and_generate_session_token,\n token::create_payment_token::create_payment_token,\n token::create_payment_token::generate_payment_token,\n token::create_payment_token::generate_payment_token_response,\n token::create_payment_token::get_payment_method_data_from_token_request,\n token::create_payment_token::get_payment_method_token_from_token_request,\n token::create_payment_token::get_payment_method_type_from_token_request,\n token::create_payment_token::get_token_request_from_payment_data,\n token::create_payment_token::is_token_request_with_payment_method_data,\n token::create_payment_token::is_token_request_with_payment_method_token,\n token::create_payment_token::is_token_request_with_payment_method_type,\n token::create_payment_token::is_token_request_with_payment_method_type_and_token,\n token::create_payment_token::is_token_request_with_payment_method_type_and_token_data,\n token::create_payment_token::is_token_request_with_payment_method_type_and_token_data_and_token,\n token::create_payment_token::is_token_request_with_payment_method_type_and_token_data_and_token_and_token_data,\n token::create_payment_token::is_token_request_with_payment_method_type_and_token_data_and_token_and_token_data_and_token,\n token::create_payment_token::", "##ACTION\nAdd Gigadat connector support across the payment processing stack. This involves adding the connector to multiple enum and configuration structures throughout the codebase. The complete integration requires adding Gigadat to the connector enum in multiple locations, mapping it to the appropriate configuration structure, and ensuring proper authentication validation and endpoint parameter handling. This represents a comprehensive connector integration that spans the entire payment processing pipeline from connector enumeration through configuration mapping and validation.\n##REASON\nThe Gigadat connector represents a new payment gateway integration that merchants need to process payments through. Without adding Gigadat support, merchants cannot configure or use this connector, limiting their payment processing options and potentially causing configuration errors when merchants attempt to use Gigadat. The integration requires adding Gigadat to multiple connector enum structures (enums.rs, connector_configs.rs, test_utils.rs) to ensure it's properly recognized throughout the system. It also requires adding Gigadat to the configuration mapping structure (connector_configs.rs) to ensure it's properly configured and validated. Without these changes, the Gigadat connector would be invisible to the payment processing system, causing configuration errors and preventing merchants from using this payment gateway.\n##LINKS\nhttps://github.com/hipay/hipay-php-sdk/pull/3/files\n##PROMPT\nThe payment processing system requires comprehensive support for the new Gigadat payment gateway connector. This integration spans multiple architectural layers including enum definitions, configuration mapping, validation logic, and endpoint parameter handling. The core challenge is to ensure consistent recognition and configuration of Gigadat across all system components. This requires adding Gigadat to multiple connector enum structures (enums.rs, connector_configs.rs, test_utils.rs) to ensure it's properly recognized as a routable connector throughout the system. It also requires adding Gigadat to the configuration mapping structure (connector_configs.rs) to ensure it's properly configured and validated when merchants attempt to use Gigadat. The validation logic in connector_validation.rs must recognize Gigadat to properly validate its configuration parameters and authentication credentials. Without these changes, the Gigadat connector would be invisible to the payment processing system, causing configuration errors when merchants attempt to use Gigadat and preventing the connector from being properly recognized and routed during payment processing.", "##REQUIREMENTS\n- Add Gigadat connector to RoutableConnectors enum with proper mapping to Connector enum.\n- Implement configuration validation logic for Gigadat connector in ConnectorConfigStructValidation.\n- Add authentication validation for Gigadat connector in ConnectorAuthTypeAndMetadataValidation.\n- Implement configuration mapping logic for Gigadat connector in ConnectorConfig.\n- Add Gigadat connector support in payment method configuration lookup.\n- Add Gigadat connector parameters to payment endpoint configuration.\n- Update test case registration to include Gigadat connector tests. <|file_sep|> <|fim_prefix|> /openapi/src/operations.rs\nuse crate::v2::api_v2_resource;\nuse api_models::payments as api_models;\nuse api_models::webhooks as webhook_api;\nuse masking::ExposeInterface;\nuse router_env::logs;\n\n#[derive(Clone, serde::Serialize)]\npub struct PaymentIntentResponse {\n pub status: api_models::PaymentStatus,\n pub payment_method: api_models::PaymentMethod,\n pub payment_method_type: api_models::PaymentMethodType,\n pub payment_method_info: Option<api_models::PaymentMethodInfo>,\n pub amount: String,\n pub currency: api_models::Currency,\n pub payment_id: api_models::PaymentId,\n pub connector: api_models::Connector,\n pub connector_payment_id: Option<String>,\n pub error_message: Option<String>,\n pub error_code: Option<String>,\n pub mandate_id: Option<api_models::MandateId>,\n pub reference_id: Option<String>,\n pub split_payments: Option<Vec<api_models::SplitPaymentResponse>>>,\n pub next_action: Option<api_models::NextActionResponseData>>,\n pub expiration_time: Option<String>,\n pub off_session: Option<bool>,\n pub return_url: Option<String>,\n pub client_secret: Option<String>,\n pub capture_method: Option<api_models::CaptureMethod>,\n pub mandate_request: Option<api_models::MandateRequest>,\n pub application_fees: Option<Vec<api_models::ApplicationFee>>>,\n pub payment_method_data: Option<api_models::PaymentMethodDataResponse>>,\n pub payment_method_type_data: Option<api_models::PaymentMethodTypeDataResponse>>,\n pub payment_method_info_data: Option<api_models::PaymentMethodInfoDataResponse>>,\n pub payment_method_data_request: Option<api_models::PaymentMethodDataRequest>>,\n pub payment_method_type_data_request: Option<api_models::PaymentMethodTypeDataRequest>>,\n pub payment_method_info_data_request: Option<api_models::", "##CHANGES\n### Enum Definitions\n- Add `Gigadat` to `Connector` enum in `common_enums.rs`.\n- Add `Gigadat` to `PayoutConnectors` enum in `payout_connectors.rs`.\n- Add `Gigadat` to `Connectors` enum in `connectors.rs`.\n\n### Connector Configuration Mapping\n- Add `Connector::Gigadat => Ok(connector_config.gigadat))` to `try_from` method in `connector_configs.rs`.\n- Add `(PayoutConnectors::Gigadat, Connector::Gigadat)` to the `impl From` conversion in `payout_connector_mapping.rs`.\n\n### Authentication Validation\n- Add `Connector::Gigadat => validate_auth_and_metadata_type_with_connector(connector_data, auth_type, key_store))` to `try_from` method in `payouts.rs`.\n\n### Endpoint Parameters\n- Add `Gigadat` to `ConnectorData::get_connector_by_name` in `test_utils.rs` to enable connector-specific test configurations.\n\n### Utility Methods\n- Add `Connector::Gigadat => \"gigadat\".to_string()` to `get_connector_about` method in `connectors.rs` to provide connector metadata for display purposes.\n- Add `Connector::Gigadat => Some(&connector_data.connector_params.gigadat))` to `get_connector_params` method in `connectors.rs` to enable access to connector-specific parameters during payment processing.\n- Add `Connector::Gigadat => Ok(\"gigadat\".to_string()))` to `get_connector_name` method in `connectors.rs` to provide consistent connector naming across the system.\n- Add `Gigadat => String::from(\"gigadat\"))` to `get_connector_name` method in `test_utils.rs` to enable proper test configuration for the Gigadat connector.\n\n### Refactoring\n- Move `enum Connectors` from `connectors.rs` to `common_enums.rs` to establish a consistent enum location across the codebase. This refactoring ensures that the `Connectors` enum is globally accessible and maintains consistency with other enums in the `common_enums` module.", "##SYSTEM DESCRIPTION\nThe payment processing system requires connector-specific implementations across multiple layers to enable new payment gateways. This spans from enum definitions that represent all supported payment connectors, through configuration mapping that links connector enums to their specific configuration requirements, to validation logic that ensures proper authentication parameters are provided for each connector, and endpoint parameters that define the specific fields required for different payment operations with each connector. Without these coordinated changes across multiple architectural layers, the new Gigadat connector would be missing from the payment routing system, unable to receive payment requests, and unable to process payments through Gigadat's API. This represents a comprehensive connector integration that requires coordinated changes across multiple system boundaries.", "##CHANGES\n1. Add Gigadat to RoutableConnectors enum in connector_enums.rs. This makes Gigadat available for routing payment attempts to the appropriate connector implementation.\n2. Add Gigadat configuration mapping in connector_configs.rs to enable dynamic configuration loading from TOML files. This allows administrators to configure Gigadat-specific settings like API endpoints, authentication credentials, and feature flags without modifying the source code.\n3. Add Gigadat validation logic in connector_validations.rs to enforce proper authentication configuration when merchants configure Gigadat connectors. This validation ensures that Gigadat connectors are properly configured with required authentication fields before they can be used for payment processing.\n4. Add Gigadat connector implementation in gigadat.rs with proper authentication handling through the get_auth_header method. This implementation provides the core payment processing logic for Gigadat, including request formatting, authentication header generation, and response handling.\n5. Add Gigadat connector validation to the existing validate_auth_and_metadata_with_connector method in connector_validation.rs. This ensures that when merchants configure Gigadat connectors through the admin API, the validation logic properly validates the GigadatAuthType authentication configuration and rejects invalid configurations that would cause payment processing failures.\n6. Add Gigadat connector validation to the existing validate_auth_and_metadata_with_connector method in connector_validation.rs. This ensures that when merchants configure Gigadat connectors through the admin API, the validation logic properly validates the GigadatAuthType authentication configuration and rejects invalid configurations that would cause payment processing failures."], "scores": [0.3035495716034271, 0.4606481481481481, 0.23228803716608595, 0.330188679245283, 0.2982456140350877, 0.414985590778098]}
|
| 9 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nEnhance Tokenex connector error handling and response processing to properly validate success/error responses and provide detailed error information to merchants. This improves payment processing reliability by distinguishing between successful token operations and various failure scenarios with specific error codes and messages.<|im_end|>\n", "completions": ["##CODE CHANGES\n### Tokenex connector modifications\n#### Error handling enhancements\n- Update `TokenexDeleteResponse` to always contain `success` and `message` fields, enabling proper error response validation. This change ensures that the connector can consistently distinguish between successful token deletion operations and various failure scenarios (invalid tokens, authentication errors, etc.). The enhanced response structure provides detailed error information to merchants, enabling them to take appropriate action based on the specific failure reason.\n#### Response processing improvements\n- Update `TokenexRefundResponse` to always contain `success` and `message` fields, enabling proper response validation for refund operations. This change ensures that the connector can consistently handle both successful refund responses and various failure scenarios (insufficient funds, network errors, etc.). The enhanced response structure provides detailed error information to merchants, enabling them to make informed decisions about refund processing and customer communication.", "##CODE CHANGES\n### Tokenex connector implementation\n#### Before\n```rust\nimpl TokenexConnector {\n fn get_url(&self, action: TokenexAction, method: &TokenexMethod) -> Result<String, TokenexError> {\n let tokenex_endpoint = match method {\n TokenexMethod::Token => self.base_url.ok_or(TokenexError::InvalidConnectorConfig)).map(|url| format!(\"{}/token\", url))?),\n TokenexMethod::Detoken => self.base_url.ok_or(TokenexError::InvalidConnectorConfig)).map(|url| format!(\"{}/detoken\", url))?),\n };\n Ok(tokenex_endpoint)\n }\n}\n```\n#### After\n```rust\nimpl TokenexConnector {\n fn get_url(&self, action: TokenexAction, method: &TokenexMethod) -> Result<String, TokenexError> {\n let tokenex_endpoint = match method {\n TokenexMethod::Token => self.base_url.ok_or(TokenexError::InvalidConnectorConfig)).map(|url| format!(\"{}/token\", url))?)?,\n TokenexMethod::Detoken => self.base_url.ok_or(TokenexError::InvalidConnectorConfig)).map(|url| format!(\"{}/detoken\", url))?)?,\n };\n Ok(tokenex_endpoint)\n }\n}\n```\n##REASONING\nThe Tokenex connector enhancement focuses on improving error handling and response processing. The specific change involves tightening the error handling in the get_url method, ensuring that invalid connector configurations are properly propagated as TokenexError::InvalidConnectorConfig instead of silently failing or returning generic errors. This enhancement improves the reliability of the Tokenex integration by providing clear, actionable error messages when the connector configuration is invalid. The get_url method is the foundational URL resolution function that all other Tokenex operations depend on. Ensuring this method handles invalid configurations properly is critical for the overall stability and reliability of the Tokenex payment processing integration.", "##MODIFICATION\n### BEFORE\n```rust\nfn get_request_body(&self, req: &TokenexRouterData) -> Result<String, TokenexError> {\n let body = json!({\n \"type\": req.router_data.request.to_string(),\n \"number\": req.router_data.card.clone().map(|card| card.card_number)).unwrap_or_default(),\n \"user\": req.tokenex_user.clone()\n });\n Ok(body.to_string())\n}\n\nfn handle_response(&self, req: &TokenexRouterData, response: &TokenexResponse) -> Result<PaymentsResponseData, TokenexError> {\n let token = response.token.clone();\n Ok(PaymentsResponseData::Token {\n token,\n storage_scheme: TokenStorageScheme::Tokenex,\n }))\n}\n```\n\n### AFTER\n```rust\nfn get_request_body(&self, req: &TokenexRouterData) -> Result<String, TokenexError> {\n let body = json!({\n \"type\": req.router_data.request.to_string(),\n \"number\": req.router_data.card.clone().map(|card| card.card_number)).unwrap_or_default(),\n \"user\": req.tokenex_user.clone()\n });\n Ok(body.to_string())\n}\n\nfn handle_response(&self, req: &TokenexRouterData, response: &TokenexResponse) -> Result<PaymentsResponseData, TokenexError> {\n match response.success {\n true => {\n let token = response.token.clone();\n Ok(PaymentsResponseData::Token {\n token,\n storage_scheme: TokenStorageScheme::Tokenex,\n }))\n },\n false => {\n let code = response.code.clone();\n let message = response.message.clone();\n Err(TokenexError::ErrorResponse {\n code,\n message,\n status_code: req.status_code,\n }))\n },\n }\n}\n```", "##FILE MODIFICATIONS\n### PAYMENT METHOD TOKENIZATION\n#### MODIFIED: tokenex.rs\n- Enhanced error handling in try_from implementation to distinguish between success responses (contains token) and error responses (contains reason, reason_code, tokenex_error_code).\n- Updated response structure to include full error details (reason, reason_code, tokenex_error_code) instead of just a message.\n- Improved error handling by properly parsing and validating error responses from the Tokenex API.\n\n### PAYMENT METHOD TOKENIZATION\n#### MODIFIED: tokenex.rs\n- Enhanced error handling in try_from implementation to properly distinguish between successful token operations and various failure scenarios.\n- Updated response structure to include full error details including reason, reason_code, and tokenex_error_code, which enables merchants to understand the specific failure reason and take appropriate action.\n- Improved error handling by properly parsing and validating error responses from the Tokenex API, ensuring that failure scenarios are handled correctly and detailed error information is provided to merchants.\n\n### PAYMENT METHOD TOKENIZATION\n#### MODIFIED: tokenex.rs\n- Enhanced error handling in try_from implementation to distinguish between success responses (contains token) and error responses (contains reason, reason_code, tokenex_error_code).\n- Updated response structure to include full error details (reason, reason_code, tokenex_error_code) instead of just a message.\n- Improved error handling by properly parsing and validating error responses from the Tokenex API.\n\n### PAYMENT METHOD TOKENIZATION\n#### MODIFIED: tokenex.rs\n- Enhanced error handling in try_from implementation to distinguish between success responses (contains token) and error responses (contains reason, reason_code, tokenex_error_code).\n- Updated response structure to include full error details (reason, reason_code, tokenex_error_code) instead of just a message.\n- Improved error handling by properly parsing and validating error responses from the Tokenex API.", "##CONTEXT\nThe Tokenex connector currently has incomplete error handling and response processing logic that can lead to failed payment attempts not being properly classified or reported to merchants. This affects the reliability of tokenization operations for payment processing workflows. The enhancement focuses on improving the validation logic for both successful tokenization responses and error responses, ensuring proper error classification and detailed error information propagation to merchants. This change is critical for maintaining the reliability of payment processing workflows that depend on tokenization operations for secure payment method handling. The enhancement improves the connector's ability to handle tokenization failures gracefully and provide merchants with actionable error information for debugging and customer support purposes. This improvement aligns with the broader goal of enhancing payment processing reliability across all supported connectors. The Tokenex connector enhancement specifically improves the handling of tokenization operations, which are a critical component of payment processing workflows that handle sensitive payment method data in a tokenized format. This enhancement ensures that tokenization failures are properly classified and merchants receive detailed error information for troubleshooting payment processing issues. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. This enhancement is particularly important for payment processing workflows that rely on tokenization operations to handle sensitive payment method data in a secure, tokenized format. The enhancement ensures that tokenization failures are properly classified and merchants receive detailed error information for troubleshooting payment processing issues. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses the connector's ability to distinguish between successful tokenization responses and various failure scenarios, ensuring proper error handling and response processing for tokenization operations. The enhancement specifically addresses", "##CHANGES\n### TokenexConnector\n#### try_from function\nEnhanced error response handling by adding proper validation of success/error responses and detailed error information extraction. This change ensures that payment processing failures are properly categorized with appropriate error codes and messages. The enhanced try_from function now properly handles error responses with error codes and messages, while maintaining compatibility with successful tokenization responses."], "scores": [0.311284046692607, 0.2802197802197802, 0.16944444444444445, 0.33480176211453744, 0.28468033775633295, 0.22493887530562345]}
|
| 10 |
+
{"prompt": "##INSTRUCTION\nYou are a helpful assistant.<|im_end|>\n##TASK\nAdd request_extended_authorization field support to payment attempt updates for handling extended authorization scenarios in payment processing. This enables merchants to request and track extended authorization periods for payments that require additional processing time.<|im_end|>\n", "completions": ["##REASON\nThe payment attempt update model currently lacks support for extended authorization requests, which is a critical feature for merchants processing payments that require extended authorization periods. Without this field, merchants cannot request extended authorization for their payments, potentially causing payment failures for transactions that require additional processing time. The field addition maintains backward compatibility while enabling new functionality for extended authorization requests. <|file_sep|> <|fim_prefix|> /journey/core/models.rs\n##INSTRUCTION\nAdd tracking field to card payment method to support transaction monitoring and compliance reporting. This enables the system to track when card payment methods are first used in the system, which is essential for compliance with card network regulations and for proper transaction monitoring. Without this tracking capability, the system would be unable to provide merchants with proper audit trails for card payment method usage, potentially leading to compliance issues and increased regulatory scrutiny. The tracking field addition maintains the existing data structure while introducing a new optional field that can be used to capture the timestamp of when a card payment method is first used in the system. This change is isolated to the card payment method type within the existing payment method enum structure, ensuring that other payment method types are unaffected by this addition. The tracking field is designed to be optional to maintain backward compatibility with existing payment method data that does not track when payment methods are first used. The addition of this tracking field provides merchants with the ability to monitor and audit their card payment method usage, which is a critical requirement for PCI DSS compliance and for maintaining accurate transaction records for business intelligence and financial reporting purposes.", "##RELATED FILES\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/20230825135941_extended_authorization/request_extended_authorization.ts\n- prisma/migrations/202308251", "##RELATED ISSUES\n- #431 - Add request_extended_authorization field support to payment attempt updates. This change enables merchants to request extended authorization periods for payments that require additional processing time beyond the standard authorization window. The field is added to the PaymentAttemptUpdateInternal struct, which represents the internal data structure used for updating payment attempts in the database. This change is tightly coupled with the corresponding PaymentAttempt struct in the database schema, which also needs to be updated to include the request_extended_authorization field. The change is also related to the payment attempt update operation in the payment attempt update handler, which needs to properly handle the new field when updating payment attempt records in the database. The field is also added to the KafkaPaymentAttempt struct, which represents the Kafka message format used for payment attempt events. This change ensures that the request_extended_authorization field is properly propagated through the payment processing pipeline, from the initial payment attempt update to the final Kafka event publication.", "##CONTEXT\nThe payment attempt update functionality requires support for a new request_extended_authorization field. This field indicates whether a payment attempt requires extended authorization processing, which is essential for certain payment types like split payments and complex order processing. The request_extended_authorization flag enables downstream systems to apply appropriate extended authorization processing rules. This field complements existing payment attempt fields and follows the established pattern of adding boolean flags to indicate special processing requirements. The field is included in the PaymentAttemptUpdateRequest structure and flows through the payment processing pipeline to ensure proper extended authorization handling.", "##CHANGES\n### PAYMENT_ATTEMPT_UPDATE_FIELDS\n- Added request_extended_authorization field to support extended authorization requests in payment attempts. This field enables merchants to specify when extended authorization periods are required, allowing the payment system to properly handle the extended authorization lifecycle. The request_extended_authorization field is of type Option<RequestExtendedAuthorization> to support both requests and cancellations of extended authorization.", "##CONTEXT\nPayment attempts represent individual payment transaction attempts. They contain essential payment metadata such as amount, currency, status, and now extended authorization settings. The payment_attempt_update function handles the transformation of database records into JSON API responses for payment attempt data. This function is critical for the API layer to correctly expose payment attempt metadata to merchant applications. The addition of request_extended_authorization field support ensures that extended authorization settings are properly reflected in API responses, enabling merchants to understand whether extended authorization was requested and whether it was successfully applied. This field is essential for the extended authorization workflow, allowing downstream consumers to make informed decisions about payment processing and merchant liability."], "scores": [0.37961335676625657, 0.0774487471526196, 0.45977011494252873, 0.3055555555555556, 0.25925925925925924, 0.31853785900783294]}
|
trainer-kit/GRPO-14B/grpo_dataset_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt": "##TASK\nUpdate connector configuration and payment response structures to support new payment processors.", "completions": ["##OUTPUT\nThe connector system needs updates. Flow: config load \u2192 connector use. The ConnectorConfig (connector_configs/src/connector.rs::ConnectorConfig) must add fields.\n\n##SELECT\ncrates/connector_configs/src/connector.rs::ConnectorConfig\ncrates/router/src/core/payments/transformers.rs::api::PaymentsResponse::foreign_from\n<EOS>", "##OUTPUT\nConnector configuration requires changes. The ConnectorConfig must be updated.\n\n##SELECT\ncrates/connector_configs/src/connector.rs::ConnectorConfig\n<EOS>", "##OUTPUT\nPayment system needs updates for new processors.\n\n##SELECT\ncrates/router/src/core/payments/transformers.rs::api::PaymentsResponse::foreign_from\n<EOS>", "##OUTPUT\nConfiguration changes needed.\n\n##SELECT\ncrates/config/src/settings.rs::Settings\n<EOS>"], "scores": [1.0, 0.67, 0.67, 0.0]}
|
| 2 |
+
{"prompt": "##TASK\nFix TokenEx vault integration to properly handle card tokenization.", "completions": ["##OUTPUT\nTokenEx integration flow: request \u2192 transform \u2192 send. The TokenexInsertRequest::try_from (hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from) must send card_number.\n\n##SELECT\ncrates/hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from\ncrates/router/src/core/payment_methods/vault.rs::get_vault_response_for_retrieve_payment_method_data_v1\n<EOS>", "##OUTPUT\nVault integration needs fixes.\n\n##SELECT\ncrates/hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from\n<EOS>", "##OUTPUT\nCard tokenization requires vault updates.\n\n##SELECT\ncrates/router/src/core/payment_methods/vault.rs::get_vault_response_for_retrieve_payment_method_data_v1\n<EOS>", "##OUTPUT\nTokenEx needs configuration.\n\n##SELECT\ncrates/router/src/core/payment_methods/cards.rs::add_card_to_vault\n<EOS>"], "scores": [1.0, 0.67, 0.67, 0.0]}
|
| 3 |
+
{"prompt": "##TASK\nAdd ConfirmSubscription flow to API lock management.", "completions": ["##OUTPUT\nSubscription flow needs lock management. Flow: request \u2192 ApiIdentifier \u2192 lock. The ApiIdentifier::from (router/src/routes/lock_utils.rs::ApiIdentifier::from) must handle ConfirmSubscription.\n\n##SELECT\ncrates/router/src/routes/lock_utils.rs::ApiIdentifier::from\ncrates/router/src/compatibility/stripe/errors.rs::StripeErrorCode::status_code\n<EOS>", "##OUTPUT\nAPI lock system requires subscription support.\n\n##SELECT\ncrates/router/src/routes/lock_utils.rs::ApiIdentifier::from\n<EOS>", "##OUTPUT\nSubscription errors need proper handling.\n\n##SELECT\ncrates/router/src/compatibility/stripe/errors.rs::StripeErrorCode::status_code\n<EOS>", "##OUTPUT\nSubscription flow needs routing.\n\n##SELECT\ncrates/router/src/core/payments/flows.rs::Flow\n<EOS>"], "scores": [1.0, 0.67, 0.67, 0.0]}
|
trainer-kit/GRPO-14B/prepare_grpo_data.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GRPO Data Preparation Script
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads SFT dataset (prompt + reference output)
|
| 6 |
+
2. Generates multiple completions per prompt using a model
|
| 7 |
+
3. Computes F1 scores for each completion
|
| 8 |
+
4. Creates GRPO training dataset with ranked completions
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python prepare_grpo_data.py \
|
| 12 |
+
--sft_dataset /workspace/sft_dataset.jsonl \
|
| 13 |
+
--output_dataset grpo_dataset.jsonl \
|
| 14 |
+
--model_path /workspace/Models/Qwen2.5-Coder-14B-CPT-SFT \
|
| 15 |
+
--num_completions 6 \
|
| 16 |
+
--temperature 0.8
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import torch
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Dict, Any
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
import argparse
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 27 |
+
from collections import Counter
|
| 28 |
+
import re
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class GRPOSample:
|
| 33 |
+
"""Single GRPO training sample with multiple completions"""
|
| 34 |
+
prompt: str
|
| 35 |
+
completions: List[str]
|
| 36 |
+
scores: List[float]
|
| 37 |
+
reference: str # Ground truth for F1 calculation
|
| 38 |
+
metadata: Dict[str, Any]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_f1_score(prediction: str, reference: str) -> float:
|
| 42 |
+
"""
|
| 43 |
+
Compute token-level F1 score between prediction and reference.
|
| 44 |
+
|
| 45 |
+
This is a simple token-based F1:
|
| 46 |
+
- Tokenize both strings (split on whitespace + punctuation)
|
| 47 |
+
- Count overlap
|
| 48 |
+
- Compute precision, recall, F1
|
| 49 |
+
|
| 50 |
+
For code tasks, this gives reasonable quality metric.
|
| 51 |
+
"""
|
| 52 |
+
def tokenize(text: str) -> List[str]:
|
| 53 |
+
"""Simple tokenization: split on whitespace and normalize"""
|
| 54 |
+
# Remove extra whitespace, lowercase for comparison
|
| 55 |
+
text = text.lower().strip()
|
| 56 |
+
# Split on whitespace and punctuation
|
| 57 |
+
tokens = re.findall(r'\b\w+\b', text)
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
pred_tokens = tokenize(prediction)
|
| 61 |
+
ref_tokens = tokenize(reference)
|
| 62 |
+
|
| 63 |
+
if not pred_tokens or not ref_tokens:
|
| 64 |
+
return 0.0
|
| 65 |
+
|
| 66 |
+
# Count token overlaps
|
| 67 |
+
pred_counter = Counter(pred_tokens)
|
| 68 |
+
ref_counter = Counter(ref_tokens)
|
| 69 |
+
|
| 70 |
+
# Calculate overlap
|
| 71 |
+
overlap = sum((pred_counter & ref_counter).values())
|
| 72 |
+
|
| 73 |
+
if overlap == 0:
|
| 74 |
+
return 0.0
|
| 75 |
+
|
| 76 |
+
precision = overlap / len(pred_tokens)
|
| 77 |
+
recall = overlap / len(ref_tokens)
|
| 78 |
+
|
| 79 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 80 |
+
return f1
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_sft_dataset(path: str) -> List[Dict[str, Any]]:
|
| 84 |
+
"""Load SFT dataset from JSONL"""
|
| 85 |
+
data = []
|
| 86 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 87 |
+
for line in f:
|
| 88 |
+
if line.strip():
|
| 89 |
+
data.append(json.loads(line))
|
| 90 |
+
return data
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def generate_completions(
|
| 94 |
+
model: AutoModelForCausalLM,
|
| 95 |
+
tokenizer: AutoTokenizer,
|
| 96 |
+
prompt: str,
|
| 97 |
+
num_completions: int = 6,
|
| 98 |
+
temperature: float = 0.8,
|
| 99 |
+
max_new_tokens: int = 512,
|
| 100 |
+
top_p: float = 0.9,
|
| 101 |
+
device: str = "cuda"
|
| 102 |
+
) -> List[str]:
|
| 103 |
+
"""
|
| 104 |
+
Generate multiple completions for a single prompt.
|
| 105 |
+
|
| 106 |
+
Uses temperature sampling to get diverse outputs.
|
| 107 |
+
Higher temperature = more diversity = better for GRPO ranking.
|
| 108 |
+
"""
|
| 109 |
+
# Tokenize prompt
|
| 110 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(device)
|
| 111 |
+
|
| 112 |
+
completions = []
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
for _ in range(num_completions):
|
| 116 |
+
# Generate with sampling
|
| 117 |
+
outputs = model.generate(
|
| 118 |
+
**inputs,
|
| 119 |
+
max_new_tokens=max_new_tokens,
|
| 120 |
+
temperature=temperature,
|
| 121 |
+
top_p=top_p,
|
| 122 |
+
do_sample=True,
|
| 123 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 124 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Decode only the completion (exclude prompt)
|
| 128 |
+
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 129 |
+
completion = full_text[len(prompt):].strip()
|
| 130 |
+
completions.append(completion)
|
| 131 |
+
|
| 132 |
+
return completions
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def prepare_grpo_dataset(
|
| 136 |
+
sft_dataset_path: str,
|
| 137 |
+
output_path: str,
|
| 138 |
+
model_path: str,
|
| 139 |
+
num_completions: int = 6,
|
| 140 |
+
temperature: float = 0.8,
|
| 141 |
+
max_samples: int = None,
|
| 142 |
+
min_score_variance: float = 0.05,
|
| 143 |
+
device: str = "cuda"
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Main function to prepare GRPO dataset.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
sft_dataset_path: Path to SFT JSONL dataset
|
| 150 |
+
output_path: Where to save GRPO dataset
|
| 151 |
+
model_path: Model to use for generation
|
| 152 |
+
num_completions: How many completions per prompt (recommend 4-8)
|
| 153 |
+
temperature: Sampling temperature (0.7-1.0 for diversity)
|
| 154 |
+
max_samples: Limit number of samples (for testing)
|
| 155 |
+
min_score_variance: Minimum variance in F1 scores to keep sample
|
| 156 |
+
device: cuda or cpu
|
| 157 |
+
"""
|
| 158 |
+
print("=" * 80)
|
| 159 |
+
print("GRPO Data Preparation Pipeline")
|
| 160 |
+
print("=" * 80)
|
| 161 |
+
|
| 162 |
+
# Load SFT dataset
|
| 163 |
+
print(f"\n[1/5] Loading SFT dataset from: {sft_dataset_path}")
|
| 164 |
+
sft_data = load_sft_dataset(sft_dataset_path)
|
| 165 |
+
if max_samples:
|
| 166 |
+
sft_data = sft_data[:max_samples]
|
| 167 |
+
print(f" → Loaded {len(sft_data)} samples")
|
| 168 |
+
|
| 169 |
+
# Load model and tokenizer
|
| 170 |
+
print(f"\n[2/5] Loading model: {model_path}")
|
| 171 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 172 |
+
if tokenizer.pad_token is None:
|
| 173 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 174 |
+
|
| 175 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 176 |
+
model_path,
|
| 177 |
+
torch_dtype=torch.bfloat16,
|
| 178 |
+
device_map="auto",
|
| 179 |
+
trust_remote_code=True
|
| 180 |
+
)
|
| 181 |
+
model.eval()
|
| 182 |
+
print(f" → Model loaded on {device}")
|
| 183 |
+
|
| 184 |
+
# Generate completions and compute scores
|
| 185 |
+
print(f"\n[3/5] Generating {num_completions} completions per prompt...")
|
| 186 |
+
grpo_samples = []
|
| 187 |
+
skipped_samples = 0
|
| 188 |
+
|
| 189 |
+
for idx, sample in enumerate(tqdm(sft_data, desc="Processing samples")):
|
| 190 |
+
# Extract prompt and reference
|
| 191 |
+
# Adjust these keys based on your dataset structure
|
| 192 |
+
prompt = sample.get("prompt", sample.get("input", ""))
|
| 193 |
+
reference = sample.get("output", sample.get("completion", sample.get("response", "")))
|
| 194 |
+
|
| 195 |
+
if not prompt or not reference:
|
| 196 |
+
print(f" ⚠️ Skipping sample {idx}: missing prompt or reference")
|
| 197 |
+
skipped_samples += 1
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
# Generate multiple completions
|
| 202 |
+
completions = generate_completions(
|
| 203 |
+
model=model,
|
| 204 |
+
tokenizer=tokenizer,
|
| 205 |
+
prompt=prompt,
|
| 206 |
+
num_completions=num_completions,
|
| 207 |
+
temperature=temperature,
|
| 208 |
+
device=device
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Compute F1 scores
|
| 212 |
+
scores = [compute_f1_score(comp, reference) for comp in completions]
|
| 213 |
+
|
| 214 |
+
# Check if there's enough variance in scores
|
| 215 |
+
score_variance = torch.tensor(scores).var().item()
|
| 216 |
+
if score_variance < min_score_variance:
|
| 217 |
+
# All completions have similar quality - not useful for GRPO
|
| 218 |
+
skipped_samples += 1
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
# Create GRPO sample
|
| 222 |
+
grpo_sample = GRPOSample(
|
| 223 |
+
prompt=prompt,
|
| 224 |
+
completions=completions,
|
| 225 |
+
scores=scores,
|
| 226 |
+
reference=reference,
|
| 227 |
+
metadata={
|
| 228 |
+
"original_idx": idx,
|
| 229 |
+
"score_mean": sum(scores) / len(scores),
|
| 230 |
+
"score_std": score_variance ** 0.5,
|
| 231 |
+
"score_min": min(scores),
|
| 232 |
+
"score_max": max(scores),
|
| 233 |
+
}
|
| 234 |
+
)
|
| 235 |
+
grpo_samples.append(grpo_sample)
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f" ⚠️ Error processing sample {idx}: {e}")
|
| 239 |
+
skipped_samples += 1
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
print(f"\n → Generated {len(grpo_samples)} GRPO samples")
|
| 243 |
+
print(f" → Skipped {skipped_samples} samples")
|
| 244 |
+
|
| 245 |
+
# Statistics
|
| 246 |
+
print(f"\n[4/5] Computing statistics...")
|
| 247 |
+
all_scores = [score for sample in grpo_samples for score in sample.scores]
|
| 248 |
+
print(f" → Total completions: {len(all_scores)}")
|
| 249 |
+
print(f" → F1 score range: [{min(all_scores):.3f}, {max(all_scores):.3f}]")
|
| 250 |
+
print(f" → F1 score mean: {sum(all_scores) / len(all_scores):.3f}")
|
| 251 |
+
print(f" → F1 score std: {torch.tensor(all_scores).std().item():.3f}")
|
| 252 |
+
|
| 253 |
+
# Save GRPO dataset
|
| 254 |
+
print(f"\n[5/5] Saving GRPO dataset to: {output_path}")
|
| 255 |
+
output_path = Path(output_path)
|
| 256 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 257 |
+
|
| 258 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 259 |
+
for sample in grpo_samples:
|
| 260 |
+
entry = {
|
| 261 |
+
"prompt": sample.prompt,
|
| 262 |
+
"completions": sample.completions,
|
| 263 |
+
"scores": sample.scores,
|
| 264 |
+
"reference": sample.reference,
|
| 265 |
+
"metadata": sample.metadata
|
| 266 |
+
}
|
| 267 |
+
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
| 268 |
+
|
| 269 |
+
print(f" → Saved {len(grpo_samples)} samples")
|
| 270 |
+
|
| 271 |
+
# Print example
|
| 272 |
+
print("\n" + "=" * 80)
|
| 273 |
+
print("Example GRPO Sample:")
|
| 274 |
+
print("=" * 80)
|
| 275 |
+
example = grpo_samples[0]
|
| 276 |
+
print(f"Prompt: {example.prompt[:200]}...")
|
| 277 |
+
print(f"\nCompletions and F1 Scores:")
|
| 278 |
+
for i, (comp, score) in enumerate(zip(example.completions, example.scores), 1):
|
| 279 |
+
print(f" [{i}] F1={score:.3f} | {comp[:100]}...")
|
| 280 |
+
print(f"\nReference: {example.reference[:200]}...")
|
| 281 |
+
print("=" * 80)
|
| 282 |
+
|
| 283 |
+
print(f"\n✅ GRPO dataset preparation complete!")
|
| 284 |
+
print(f" Output: {output_path}")
|
| 285 |
+
print(f" Samples: {len(grpo_samples)}")
|
| 286 |
+
print(f" Ready for training with run_grpo_fixed.py")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def main():
|
| 290 |
+
parser = argparse.ArgumentParser(description="Prepare GRPO training dataset")
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--sft_dataset",
|
| 293 |
+
type=str,
|
| 294 |
+
default="/workspace/sft_dataset.jsonl",
|
| 295 |
+
help="Path to SFT dataset (JSONL format)"
|
| 296 |
+
)
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
"--output_dataset",
|
| 299 |
+
type=str,
|
| 300 |
+
default="/workspace/trainer-kit/GRPO-14b/grpo_dataset.jsonl",
|
| 301 |
+
help="Output path for GRPO dataset"
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--model_path",
|
| 305 |
+
type=str,
|
| 306 |
+
default="/workspace/Models/Qwen2.5-Coder-14B-CPT-SFT",
|
| 307 |
+
help="Model to use for generating completions"
|
| 308 |
+
)
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--num_completions",
|
| 311 |
+
type=int,
|
| 312 |
+
default=6,
|
| 313 |
+
help="Number of completions per prompt (recommend 4-8)"
|
| 314 |
+
)
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
"--temperature",
|
| 317 |
+
type=float,
|
| 318 |
+
default=0.8,
|
| 319 |
+
help="Sampling temperature for diversity (0.7-1.0)"
|
| 320 |
+
)
|
| 321 |
+
parser.add_argument(
|
| 322 |
+
"--max_samples",
|
| 323 |
+
type=int,
|
| 324 |
+
default=None,
|
| 325 |
+
help="Limit number of samples (for testing)"
|
| 326 |
+
)
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--min_score_variance",
|
| 329 |
+
type=float,
|
| 330 |
+
default=0.05,
|
| 331 |
+
help="Minimum F1 score variance to keep sample"
|
| 332 |
+
)
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--device",
|
| 335 |
+
type=str,
|
| 336 |
+
default="cuda",
|
| 337 |
+
help="Device: cuda or cpu"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
args = parser.parse_args()
|
| 341 |
+
|
| 342 |
+
prepare_grpo_dataset(
|
| 343 |
+
sft_dataset_path=args.sft_dataset,
|
| 344 |
+
output_path=args.output_dataset,
|
| 345 |
+
model_path=args.model_path,
|
| 346 |
+
num_completions=args.num_completions,
|
| 347 |
+
temperature=args.temperature,
|
| 348 |
+
max_samples=args.max_samples,
|
| 349 |
+
min_score_variance=args.min_score_variance,
|
| 350 |
+
device=args.device
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if __name__ == "__main__":
|
| 355 |
+
main()
|
trainer-kit/GRPO-14B/prepare_grpo_data_api.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fast GRPO Data Generation using vLLM API Server
|
| 3 |
+
|
| 4 |
+
This script makes parallel async API calls to a vLLM server with semaphore-based
|
| 5 |
+
concurrency control for optimal throughput.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# First, start vLLM server in another terminal:
|
| 9 |
+
bash start_vllm_server.sh
|
| 10 |
+
|
| 11 |
+
# Then run this script:
|
| 12 |
+
python prepare_grpo_data_api.py \
|
| 13 |
+
--sft_dataset /workspace/rl_dataset.jsonl \
|
| 14 |
+
--output_dataset grpo_dataset.jsonl \
|
| 15 |
+
--num_completions 6 \
|
| 16 |
+
--max_concurrent 50 \
|
| 17 |
+
--api_url http://localhost:8000/v1
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import argparse
|
| 22 |
+
import asyncio
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import List, Dict
|
| 25 |
+
from tqdm.asyncio import tqdm_asyncio
|
| 26 |
+
from collections import Counter
|
| 27 |
+
import re
|
| 28 |
+
import aiohttp
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class GenerationRequest:
|
| 34 |
+
"""Single generation request"""
|
| 35 |
+
prompt: str
|
| 36 |
+
reference: str
|
| 37 |
+
metadata: Dict
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_f1_score(prediction: str, reference: str) -> float:
|
| 41 |
+
"""Compute token-level F1 score"""
|
| 42 |
+
def tokenize(text):
|
| 43 |
+
tokens = re.findall(r'\w+|[^\w\s]', text.lower())
|
| 44 |
+
return Counter(tokens)
|
| 45 |
+
|
| 46 |
+
pred_tokens = tokenize(prediction)
|
| 47 |
+
ref_tokens = tokenize(reference)
|
| 48 |
+
|
| 49 |
+
if not pred_tokens or not ref_tokens:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
overlap = sum((pred_tokens & ref_tokens).values())
|
| 53 |
+
precision = overlap / sum(pred_tokens.values()) if pred_tokens else 0
|
| 54 |
+
recall = overlap / sum(ref_tokens.values()) if ref_tokens else 0
|
| 55 |
+
|
| 56 |
+
if precision + recall == 0:
|
| 57 |
+
return 0.0
|
| 58 |
+
|
| 59 |
+
return 2 * (precision * recall) / (precision + recall)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def load_sft_dataset(path: str, max_samples: int = None) -> List[Dict]:
|
| 63 |
+
"""Load SFT dataset"""
|
| 64 |
+
samples = []
|
| 65 |
+
with open(path, 'r') as f:
|
| 66 |
+
for i, line in enumerate(f):
|
| 67 |
+
if max_samples and i >= max_samples:
|
| 68 |
+
break
|
| 69 |
+
data = json.loads(line)
|
| 70 |
+
samples.append(data)
|
| 71 |
+
return samples
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def format_prompt(instruction: str, input_text: str) -> str:
|
| 75 |
+
"""Format prompt using custom template with proper system prompt"""
|
| 76 |
+
system_prompt = """You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 77 |
+
|
| 78 |
+
## Output Format
|
| 79 |
+
|
| 80 |
+
##OUTPUT
|
| 81 |
+
Explain the data flow and why each component must change:
|
| 82 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 83 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 84 |
+
- Explain coupling between components
|
| 85 |
+
|
| 86 |
+
##SELECT
|
| 87 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 88 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 89 |
+
<EOS>
|
| 90 |
+
|
| 91 |
+
## Rules
|
| 92 |
+
|
| 93 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 94 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 95 |
+
3. Always explain "must change because" and "without this"
|
| 96 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 97 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 98 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 99 |
+
|
| 100 |
+
## Example
|
| 101 |
+
|
| 102 |
+
##TASK
|
| 103 |
+
Add webhook subscription support
|
| 104 |
+
|
| 105 |
+
##OUTPUT
|
| 106 |
+
The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
|
| 107 |
+
|
| 108 |
+
##SELECT
|
| 109 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 110 |
+
crates/common_enums/src/transformers.rs::SubscriptionStatus
|
| 111 |
+
<EOS>"""
|
| 112 |
+
|
| 113 |
+
return f"##INSTRUCTION\n{system_prompt}<|im_end|>\n##TASK\n{input_text}<|im_end|>\n"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
async def generate_completions_api(
|
| 117 |
+
session: aiohttp.ClientSession,
|
| 118 |
+
api_url: str,
|
| 119 |
+
prompt: str,
|
| 120 |
+
num_completions: int,
|
| 121 |
+
temperature: float,
|
| 122 |
+
max_tokens: int,
|
| 123 |
+
top_p: float = 0.95,
|
| 124 |
+
semaphore: asyncio.Semaphore = None,
|
| 125 |
+
) -> List[str]:
|
| 126 |
+
"""
|
| 127 |
+
Generate completions via vLLM API with semaphore control.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
session: aiohttp session
|
| 131 |
+
api_url: vLLM API base URL
|
| 132 |
+
prompt: Input prompt
|
| 133 |
+
num_completions: Number of completions to generate
|
| 134 |
+
temperature: Sampling temperature
|
| 135 |
+
max_tokens: Max tokens per completion
|
| 136 |
+
top_p: Nucleus sampling
|
| 137 |
+
semaphore: Semaphore for concurrency control
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
List of generated completions
|
| 141 |
+
"""
|
| 142 |
+
if semaphore:
|
| 143 |
+
async with semaphore:
|
| 144 |
+
return await _generate_completions_api(
|
| 145 |
+
session, api_url, prompt, num_completions,
|
| 146 |
+
temperature, max_tokens, top_p
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
return await _generate_completions_api(
|
| 150 |
+
session, api_url, prompt, num_completions,
|
| 151 |
+
temperature, max_tokens, top_p
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
async def _generate_completions_api(
|
| 156 |
+
session: aiohttp.ClientSession,
|
| 157 |
+
api_url: str,
|
| 158 |
+
prompt: str,
|
| 159 |
+
num_completions: int,
|
| 160 |
+
temperature: float,
|
| 161 |
+
max_tokens: int,
|
| 162 |
+
top_p: float,
|
| 163 |
+
) -> List[str]:
|
| 164 |
+
"""Internal function to make API call"""
|
| 165 |
+
url = f"{api_url}/completions"
|
| 166 |
+
|
| 167 |
+
payload = {
|
| 168 |
+
"model": "qwen2.5-coder-14b", # Match the model name from vLLM server
|
| 169 |
+
"prompt": prompt,
|
| 170 |
+
"n": num_completions,
|
| 171 |
+
"temperature": temperature,
|
| 172 |
+
"max_tokens": max_tokens,
|
| 173 |
+
"top_p": top_p,
|
| 174 |
+
"stop": ["<EOS>", "<|im_end|>"],
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=300)) as response:
|
| 179 |
+
if response.status != 200:
|
| 180 |
+
error_text = await response.text()
|
| 181 |
+
print(f"API Error {response.status}: {error_text}")
|
| 182 |
+
return [""] * num_completions
|
| 183 |
+
|
| 184 |
+
result = await response.json()
|
| 185 |
+
completions = [choice["text"].strip() for choice in result["choices"]]
|
| 186 |
+
return completions
|
| 187 |
+
except asyncio.TimeoutError:
|
| 188 |
+
print(f"Timeout for prompt (length={len(prompt)})")
|
| 189 |
+
return [""] * num_completions
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error generating completions: {e}")
|
| 192 |
+
return [""] * num_completions
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
async def process_sample(
|
| 196 |
+
session: aiohttp.ClientSession,
|
| 197 |
+
api_url: str,
|
| 198 |
+
sample: Dict,
|
| 199 |
+
num_completions: int,
|
| 200 |
+
temperature: float,
|
| 201 |
+
max_tokens: int,
|
| 202 |
+
semaphore: asyncio.Semaphore,
|
| 203 |
+
) -> Dict:
|
| 204 |
+
"""Process a single sample: generate completions and compute scores"""
|
| 205 |
+
instruction = sample.get('instruction', 'You are a helpful assistant.')
|
| 206 |
+
input_text = sample.get('input', '')
|
| 207 |
+
output_text = sample.get('output', '')
|
| 208 |
+
|
| 209 |
+
prompt = format_prompt(instruction, input_text)
|
| 210 |
+
|
| 211 |
+
# Generate model completions (num_completions - 1, since we'll add ground truth)
|
| 212 |
+
model_completions = await generate_completions_api(
|
| 213 |
+
session=session,
|
| 214 |
+
api_url=api_url,
|
| 215 |
+
prompt=prompt,
|
| 216 |
+
num_completions=num_completions - 1, # Generate one less, we'll add ground truth
|
| 217 |
+
temperature=temperature,
|
| 218 |
+
max_tokens=max_tokens,
|
| 219 |
+
semaphore=semaphore,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Completion 1: Ground truth (score = 1.0)
|
| 223 |
+
# Completions 2-N: Model generations (scored via F1)
|
| 224 |
+
completions = [output_text] + model_completions
|
| 225 |
+
|
| 226 |
+
# Compute scores: ground truth = 1.0, others = F1 vs ground truth
|
| 227 |
+
scores = [1.0] + [compute_f1_score(comp, output_text) for comp in model_completions]
|
| 228 |
+
|
| 229 |
+
return {
|
| 230 |
+
"prompt": prompt,
|
| 231 |
+
"completions": completions,
|
| 232 |
+
"scores": scores,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
async def process_batch(
|
| 237 |
+
session: aiohttp.ClientSession,
|
| 238 |
+
api_url: str,
|
| 239 |
+
samples: List[Dict],
|
| 240 |
+
num_completions: int,
|
| 241 |
+
temperature: float,
|
| 242 |
+
max_tokens: int,
|
| 243 |
+
semaphore: asyncio.Semaphore,
|
| 244 |
+
output_file,
|
| 245 |
+
) -> List[Dict]:
|
| 246 |
+
"""Process a batch of samples concurrently"""
|
| 247 |
+
tasks = [
|
| 248 |
+
process_sample(
|
| 249 |
+
session, api_url, sample, num_completions,
|
| 250 |
+
temperature, max_tokens, semaphore
|
| 251 |
+
)
|
| 252 |
+
for sample in samples
|
| 253 |
+
]
|
| 254 |
+
|
| 255 |
+
results = await tqdm_asyncio.gather(*tasks, desc="Processing batch")
|
| 256 |
+
|
| 257 |
+
# Write results immediately
|
| 258 |
+
for result in results:
|
| 259 |
+
output_file.write(json.dumps(result) + '\n')
|
| 260 |
+
output_file.flush()
|
| 261 |
+
|
| 262 |
+
return results
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
async def main_async(args):
|
| 266 |
+
"""Main async function"""
|
| 267 |
+
print(f"\n{'='*60}")
|
| 268 |
+
print("GRPO Data Generation via vLLM API")
|
| 269 |
+
print(f"{'='*60}")
|
| 270 |
+
print(f"API URL: {args.api_url}")
|
| 271 |
+
print(f"Input dataset: {args.sft_dataset}")
|
| 272 |
+
print(f"Output dataset: {args.output_dataset}")
|
| 273 |
+
print(f"Completions per prompt: {args.num_completions}")
|
| 274 |
+
print(f"Max concurrent requests: {args.max_concurrent}")
|
| 275 |
+
print(f"Temperature: {args.temperature}")
|
| 276 |
+
print(f"Max samples: {args.max_samples or 'All'}")
|
| 277 |
+
print(f"{'='*60}\n")
|
| 278 |
+
|
| 279 |
+
# Load dataset
|
| 280 |
+
print("Loading SFT dataset...")
|
| 281 |
+
samples = load_sft_dataset(args.sft_dataset, args.max_samples)
|
| 282 |
+
print(f"Loaded {len(samples)} samples\n")
|
| 283 |
+
|
| 284 |
+
# Create semaphore for concurrency control
|
| 285 |
+
semaphore = asyncio.Semaphore(args.max_concurrent)
|
| 286 |
+
|
| 287 |
+
# Process all samples
|
| 288 |
+
all_results = []
|
| 289 |
+
|
| 290 |
+
async with aiohttp.ClientSession() as session:
|
| 291 |
+
with open(args.output_dataset, 'w') as f_out:
|
| 292 |
+
# Process in batches for better progress tracking
|
| 293 |
+
batch_size = args.batch_size
|
| 294 |
+
num_batches = (len(samples) + batch_size - 1) // batch_size
|
| 295 |
+
|
| 296 |
+
for batch_idx in range(num_batches):
|
| 297 |
+
batch_start = batch_idx * batch_size
|
| 298 |
+
batch_end = min(batch_start + batch_size, len(samples))
|
| 299 |
+
batch_samples = samples[batch_start:batch_end]
|
| 300 |
+
|
| 301 |
+
print(f"\nBatch {batch_idx + 1}/{num_batches} ({len(batch_samples)} samples)")
|
| 302 |
+
|
| 303 |
+
results = await process_batch(
|
| 304 |
+
session=session,
|
| 305 |
+
api_url=args.api_url,
|
| 306 |
+
samples=batch_samples,
|
| 307 |
+
num_completions=args.num_completions,
|
| 308 |
+
temperature=args.temperature,
|
| 309 |
+
max_tokens=args.max_tokens,
|
| 310 |
+
semaphore=semaphore,
|
| 311 |
+
output_file=f_out,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
all_results.extend(results)
|
| 315 |
+
|
| 316 |
+
# Statistics
|
| 317 |
+
print(f"\n{'='*60}")
|
| 318 |
+
print("Generation Complete!")
|
| 319 |
+
print(f"{'='*60}")
|
| 320 |
+
print(f"Generated {len(all_results)} GRPO samples")
|
| 321 |
+
print(f"Output saved to: {args.output_dataset}")
|
| 322 |
+
|
| 323 |
+
# Compute statistics
|
| 324 |
+
all_scores = [score for sample in all_results for score in sample['scores']]
|
| 325 |
+
if all_scores:
|
| 326 |
+
avg_score = sum(all_scores) / len(all_scores)
|
| 327 |
+
max_score = max(all_scores)
|
| 328 |
+
min_score = min(all_scores)
|
| 329 |
+
|
| 330 |
+
print(f"\nScore Statistics:")
|
| 331 |
+
print(f" Average F1: {avg_score:.3f}")
|
| 332 |
+
print(f" Max F1: {max_score:.3f}")
|
| 333 |
+
print(f" Min F1: {min_score:.3f}")
|
| 334 |
+
print(f"\n{'='*60}\n")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def main():
|
| 338 |
+
parser = argparse.ArgumentParser(description="Generate GRPO dataset using vLLM API")
|
| 339 |
+
parser.add_argument("--sft_dataset", type=str, required=True, help="Path to SFT dataset")
|
| 340 |
+
parser.add_argument("--output_dataset", type=str, required=True, help="Output GRPO dataset path")
|
| 341 |
+
parser.add_argument("--api_url", type=str, default="http://localhost:8000/v1", help="vLLM API URL")
|
| 342 |
+
parser.add_argument("--num_completions", type=int, default=6, help="Completions per prompt")
|
| 343 |
+
parser.add_argument("--max_concurrent", type=int, default=50, help="Max concurrent API requests")
|
| 344 |
+
parser.add_argument("--batch_size", type=int, default=100, help="Batch size for progress tracking")
|
| 345 |
+
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| 346 |
+
parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens per completion")
|
| 347 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process")
|
| 348 |
+
|
| 349 |
+
args = parser.parse_args()
|
| 350 |
+
|
| 351 |
+
# Run async main
|
| 352 |
+
asyncio.run(main_async(args))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if __name__ == "__main__":
|
| 356 |
+
main()
|
trainer-kit/GRPO-14B/prepare_grpo_data_vllm.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fast GRPO Data Generation using vLLM
|
| 3 |
+
|
| 4 |
+
Generates multiple completions per prompt in parallel using vLLM for 10-20x speedup.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
pip install vllm
|
| 8 |
+
python prepare_grpo_data_vllm.py \
|
| 9 |
+
--sft_dataset /workspace/rl_dataset.jsonl \
|
| 10 |
+
--output_dataset grpo_dataset.jsonl \
|
| 11 |
+
--model_path /workspace/Models/Qwen2.5-Coder-14B-CPT-SFT_v2 \
|
| 12 |
+
--num_completions 6 \
|
| 13 |
+
--batch_size 50
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import argparse
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Dict
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from collections import Counter
|
| 22 |
+
import re
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from vllm import LLM, SamplingParams
|
| 26 |
+
VLLM_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
VLLM_AVAILABLE = False
|
| 29 |
+
print("vLLM not available. Install with: pip install vllm")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_f1_score(prediction: str, reference: str) -> float:
|
| 33 |
+
"""Compute token-level F1 score"""
|
| 34 |
+
def tokenize(text):
|
| 35 |
+
# Split on whitespace and punctuation
|
| 36 |
+
tokens = re.findall(r'\w+|[^\w\s]', text.lower())
|
| 37 |
+
return Counter(tokens)
|
| 38 |
+
|
| 39 |
+
pred_tokens = tokenize(prediction)
|
| 40 |
+
ref_tokens = tokenize(reference)
|
| 41 |
+
|
| 42 |
+
if not pred_tokens or not ref_tokens:
|
| 43 |
+
return 0.0
|
| 44 |
+
|
| 45 |
+
# Count overlapping tokens
|
| 46 |
+
overlap = sum((pred_tokens & ref_tokens).values())
|
| 47 |
+
|
| 48 |
+
precision = overlap / sum(pred_tokens.values()) if pred_tokens else 0
|
| 49 |
+
recall = overlap / sum(ref_tokens.values()) if ref_tokens else 0
|
| 50 |
+
|
| 51 |
+
if precision + recall == 0:
|
| 52 |
+
return 0.0
|
| 53 |
+
|
| 54 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
| 55 |
+
return f1
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_sft_dataset(path: str, max_samples: int = None) -> List[Dict]:
|
| 59 |
+
"""Load SFT dataset"""
|
| 60 |
+
samples = []
|
| 61 |
+
with open(path, 'r') as f:
|
| 62 |
+
for i, line in enumerate(f):
|
| 63 |
+
if max_samples and i >= max_samples:
|
| 64 |
+
break
|
| 65 |
+
data = json.loads(line)
|
| 66 |
+
samples.append(data)
|
| 67 |
+
return samples
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def format_prompt(instruction: str, input_text: str) -> str:
|
| 71 |
+
"""Format prompt using custom template"""
|
| 72 |
+
return f"##INSTRUCTION\n{instruction}<|im_end|>\n{input_text}<|im_end|>\n"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generate_completions_vllm(
|
| 76 |
+
llm: LLM,
|
| 77 |
+
prompts: List[str],
|
| 78 |
+
num_completions: int,
|
| 79 |
+
temperature: float,
|
| 80 |
+
max_tokens: int,
|
| 81 |
+
top_p: float = 0.95,
|
| 82 |
+
) -> List[List[str]]:
|
| 83 |
+
"""
|
| 84 |
+
Generate multiple completions for multiple prompts using vLLM.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
llm: vLLM LLM instance
|
| 88 |
+
prompts: List of prompts
|
| 89 |
+
num_completions: Number of completions per prompt
|
| 90 |
+
temperature: Sampling temperature
|
| 91 |
+
max_tokens: Maximum tokens to generate
|
| 92 |
+
top_p: Nucleus sampling parameter
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
List of lists, where each inner list contains completions for one prompt
|
| 96 |
+
"""
|
| 97 |
+
sampling_params = SamplingParams(
|
| 98 |
+
n=num_completions, # Generate n completions per prompt
|
| 99 |
+
temperature=temperature,
|
| 100 |
+
top_p=top_p,
|
| 101 |
+
max_tokens=max_tokens,
|
| 102 |
+
stop=["<EOS>", "<|im_end|>"],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# vLLM automatically batches and processes efficiently
|
| 106 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 107 |
+
|
| 108 |
+
# Extract completions
|
| 109 |
+
all_completions = []
|
| 110 |
+
for output in outputs:
|
| 111 |
+
completions = [o.text.strip() for o in output.outputs]
|
| 112 |
+
all_completions.append(completions)
|
| 113 |
+
|
| 114 |
+
return all_completions
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
parser = argparse.ArgumentParser(description="Generate GRPO dataset using vLLM")
|
| 119 |
+
parser.add_argument("--sft_dataset", type=str, required=True, help="Path to SFT dataset")
|
| 120 |
+
parser.add_argument("--output_dataset", type=str, required=True, help="Output GRPO dataset path")
|
| 121 |
+
parser.add_argument("--model_path", type=str, required=True, help="Model path")
|
| 122 |
+
parser.add_argument("--num_completions", type=int, default=6, help="Completions per prompt")
|
| 123 |
+
parser.add_argument("--batch_size", type=int, default=50, help="Number of prompts to process at once")
|
| 124 |
+
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| 125 |
+
parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens per completion")
|
| 126 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process (for testing)")
|
| 127 |
+
parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism")
|
| 128 |
+
parser.add_argument("--gpu_memory_utilization", type=float, default=0.85, help="GPU memory utilization (0-1)")
|
| 129 |
+
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
|
| 132 |
+
if not VLLM_AVAILABLE:
|
| 133 |
+
raise ImportError("vLLM is required. Install with: pip install vllm")
|
| 134 |
+
|
| 135 |
+
print(f"\n{'='*60}")
|
| 136 |
+
print("GRPO Data Generation with vLLM")
|
| 137 |
+
print(f"{'='*60}")
|
| 138 |
+
print(f"Model: {args.model_path}")
|
| 139 |
+
print(f"Input dataset: {args.sft_dataset}")
|
| 140 |
+
print(f"Output dataset: {args.output_dataset}")
|
| 141 |
+
print(f"Completions per prompt: {args.num_completions}")
|
| 142 |
+
print(f"Batch size: {args.batch_size}")
|
| 143 |
+
print(f"Temperature: {args.temperature}")
|
| 144 |
+
print(f"Max samples: {args.max_samples or 'All'}")
|
| 145 |
+
print(f"{'='*60}\n")
|
| 146 |
+
|
| 147 |
+
# Load SFT dataset
|
| 148 |
+
print("Loading SFT dataset...")
|
| 149 |
+
sft_samples = load_sft_dataset(args.sft_dataset, args.max_samples)
|
| 150 |
+
print(f"Loaded {len(sft_samples)} samples\n")
|
| 151 |
+
|
| 152 |
+
# Initialize vLLM
|
| 153 |
+
print("Initializing vLLM...")
|
| 154 |
+
llm = LLM(
|
| 155 |
+
model=args.model_path,
|
| 156 |
+
tensor_parallel_size=args.tensor_parallel_size,
|
| 157 |
+
gpu_memory_utilization=args.gpu_memory_utilization,
|
| 158 |
+
trust_remote_code=True,
|
| 159 |
+
dtype="bfloat16",
|
| 160 |
+
max_model_len=4096,
|
| 161 |
+
)
|
| 162 |
+
print("vLLM initialized!\n")
|
| 163 |
+
|
| 164 |
+
# Process in batches
|
| 165 |
+
grpo_samples = []
|
| 166 |
+
num_batches = (len(sft_samples) + args.batch_size - 1) // args.batch_size
|
| 167 |
+
|
| 168 |
+
with open(args.output_dataset, 'w') as f_out:
|
| 169 |
+
for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
|
| 170 |
+
batch_start = batch_idx * args.batch_size
|
| 171 |
+
batch_end = min(batch_start + args.batch_size, len(sft_samples))
|
| 172 |
+
batch_samples = sft_samples[batch_start:batch_end]
|
| 173 |
+
|
| 174 |
+
# Prepare prompts for this batch
|
| 175 |
+
prompts = []
|
| 176 |
+
references = []
|
| 177 |
+
for sample in batch_samples:
|
| 178 |
+
instruction = sample.get('instruction', 'You are a helpful assistant.')
|
| 179 |
+
input_text = sample.get('input', '')
|
| 180 |
+
output_text = sample.get('output', '')
|
| 181 |
+
|
| 182 |
+
prompt = format_prompt(instruction, input_text)
|
| 183 |
+
prompts.append(prompt)
|
| 184 |
+
references.append(output_text)
|
| 185 |
+
|
| 186 |
+
# Generate completions for entire batch
|
| 187 |
+
batch_completions = generate_completions_vllm(
|
| 188 |
+
llm=llm,
|
| 189 |
+
prompts=prompts,
|
| 190 |
+
num_completions=args.num_completions,
|
| 191 |
+
temperature=args.temperature,
|
| 192 |
+
max_tokens=args.max_tokens,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Compute F1 scores and create GRPO samples
|
| 196 |
+
for sample, prompt, completions, reference in zip(batch_samples, prompts, batch_completions, references):
|
| 197 |
+
# Compute F1 score for each completion
|
| 198 |
+
scores = [compute_f1_score(comp, reference) for comp in completions]
|
| 199 |
+
|
| 200 |
+
# Create GRPO sample
|
| 201 |
+
grpo_sample = {
|
| 202 |
+
"prompt": prompt,
|
| 203 |
+
"completions": completions,
|
| 204 |
+
"scores": scores,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# Write immediately (streaming)
|
| 208 |
+
f_out.write(json.dumps(grpo_sample) + '\n')
|
| 209 |
+
f_out.flush()
|
| 210 |
+
grpo_samples.append(grpo_sample)
|
| 211 |
+
|
| 212 |
+
# Statistics
|
| 213 |
+
print(f"\n{'='*60}")
|
| 214 |
+
print("Generation Complete!")
|
| 215 |
+
print(f"{'='*60}")
|
| 216 |
+
print(f"Generated {len(grpo_samples)} GRPO samples")
|
| 217 |
+
print(f"Output saved to: {args.output_dataset}")
|
| 218 |
+
|
| 219 |
+
# Compute statistics
|
| 220 |
+
all_scores = [score for sample in grpo_samples for score in sample['scores']]
|
| 221 |
+
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0
|
| 222 |
+
max_score = max(all_scores) if all_scores else 0
|
| 223 |
+
min_score = min(all_scores) if all_scores else 0
|
| 224 |
+
|
| 225 |
+
print(f"\nScore Statistics:")
|
| 226 |
+
print(f" Average F1: {avg_score:.3f}")
|
| 227 |
+
print(f" Max F1: {max_score:.3f}")
|
| 228 |
+
print(f" Min F1: {min_score:.3f}")
|
| 229 |
+
print(f"\n{'='*60}\n")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
main()
|
trainer-kit/GRPO-14B/requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.41.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
|
| 7 |
+
# PEFT / QLoRA
|
| 8 |
+
peft>=0.11.1
|
| 9 |
+
bitsandbytes>=0.43.1
|
| 10 |
+
|
| 11 |
+
# Hugging Face Hub
|
| 12 |
+
huggingface_hub>=0.23.0
|
| 13 |
+
|
| 14 |
+
# Config + utilities
|
| 15 |
+
pyyaml>=6.0
|
| 16 |
+
tqdm>=4.66.0
|
| 17 |
+
|
| 18 |
+
# Tokenizers and safetensors
|
| 19 |
+
tokenizers>=0.15.0
|
| 20 |
+
safetensors>=0.4.2
|
| 21 |
+
|
| 22 |
+
# Experiment tracking
|
| 23 |
+
wandb>=0.16.0
|
| 24 |
+
|
| 25 |
+
# For F1 score computation
|
| 26 |
+
scikit-learn>=1.3.0
|
trainer-kit/GRPO-14B/run_grpo_fixed.py
ADDED
|
@@ -0,0 +1,1216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import yaml
|
| 14 |
+
from datasets import load_dataset, DatasetDict
|
| 15 |
+
from huggingface_hub import snapshot_download
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
AutoModelForCausalLM,
|
| 19 |
+
BitsAndBytesConfig,
|
| 20 |
+
TrainingArguments,
|
| 21 |
+
Trainer,
|
| 22 |
+
TrainerCallback,
|
| 23 |
+
EarlyStoppingCallback,
|
| 24 |
+
set_seed,
|
| 25 |
+
)
|
| 26 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 27 |
+
from peft import (
|
| 28 |
+
LoraConfig,
|
| 29 |
+
get_peft_model,
|
| 30 |
+
prepare_model_for_kbit_training,
|
| 31 |
+
PeftModel,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import wandb
|
| 36 |
+
WANDB_AVAILABLE = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
WANDB_AVAILABLE = False
|
| 39 |
+
wandb = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# --------------------------
|
| 43 |
+
# Helpers
|
| 44 |
+
# --------------------------
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 48 |
+
s = (s or "").lower()
|
| 49 |
+
if s in ("float16", "fp16"):
|
| 50 |
+
return torch.float16
|
| 51 |
+
if s in ("bfloat16", "bf16"):
|
| 52 |
+
return torch.bfloat16
|
| 53 |
+
if s in ("float32", "fp32"):
|
| 54 |
+
return torch.float32
|
| 55 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _now_iso() -> str:
|
| 59 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _safe_exp(x: float) -> float:
|
| 63 |
+
x = min(float(x), 50.0)
|
| 64 |
+
return float(math.exp(x))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _ensure_dir(p: Path) -> Path:
|
| 68 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
return p
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 73 |
+
if not p.exists() or not p.is_dir():
|
| 74 |
+
return False
|
| 75 |
+
if (p / "config.json").exists():
|
| 76 |
+
return True
|
| 77 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 78 |
+
return True
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _infer_target_modules(model) -> List[str]:
|
| 83 |
+
names = set()
|
| 84 |
+
for n, _ in model.named_modules():
|
| 85 |
+
names.add(n.split(".")[-1])
|
| 86 |
+
|
| 87 |
+
for group in [
|
| 88 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 89 |
+
["Wqkv", "out_proj"],
|
| 90 |
+
["query_key_value", "dense"],
|
| 91 |
+
["c_attn", "c_proj"],
|
| 92 |
+
]:
|
| 93 |
+
if all(x in names for x in group):
|
| 94 |
+
return group
|
| 95 |
+
|
| 96 |
+
fallback = [
|
| 97 |
+
x
|
| 98 |
+
for x in [
|
| 99 |
+
"q_proj",
|
| 100 |
+
"k_proj",
|
| 101 |
+
"v_proj",
|
| 102 |
+
"o_proj",
|
| 103 |
+
"c_attn",
|
| 104 |
+
"c_proj",
|
| 105 |
+
"out_proj",
|
| 106 |
+
"dense",
|
| 107 |
+
]
|
| 108 |
+
if x in names
|
| 109 |
+
]
|
| 110 |
+
if fallback:
|
| 111 |
+
return fallback
|
| 112 |
+
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 119 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# --------------------------
|
| 123 |
+
# Wandb Integration
|
| 124 |
+
# --------------------------
|
| 125 |
+
|
| 126 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 127 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 128 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 129 |
+
|
| 130 |
+
if not wandb_cfg.get("enabled", False):
|
| 131 |
+
print("Wandb logging disabled")
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
if not WANDB_AVAILABLE:
|
| 135 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
project = wandb_cfg.get("project", "grpo-training")
|
| 139 |
+
entity = wandb_cfg.get("entity", None)
|
| 140 |
+
name = wandb_cfg.get("name", None)
|
| 141 |
+
tags = wandb_cfg.get("tags", [])
|
| 142 |
+
notes = wandb_cfg.get("notes", None)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
wandb.init(
|
| 146 |
+
project=project,
|
| 147 |
+
entity=entity,
|
| 148 |
+
name=name,
|
| 149 |
+
tags=tags,
|
| 150 |
+
notes=notes,
|
| 151 |
+
dir=str(run_dir),
|
| 152 |
+
config={
|
| 153 |
+
"model": cfg.get("model", {}),
|
| 154 |
+
"data": cfg.get("data", {}),
|
| 155 |
+
"peft": cfg.get("peft", {}),
|
| 156 |
+
"grpo": cfg.get("grpo", {}),
|
| 157 |
+
"train": cfg.get("train", {}),
|
| 158 |
+
"run_dir": str(run_dir),
|
| 159 |
+
}
|
| 160 |
+
)
|
| 161 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 162 |
+
return wandb
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def finish_wandb():
|
| 169 |
+
"""Finish Wandb run if active."""
|
| 170 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 171 |
+
wandb.finish()
|
| 172 |
+
print("Wandb run finished")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# --------------------------
|
| 176 |
+
# JSONL Logger Callback
|
| 177 |
+
# --------------------------
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 181 |
+
def __init__(self, run_dir: Path):
|
| 182 |
+
self.run_dir = run_dir
|
| 183 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 184 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 185 |
+
self.start_time = None
|
| 186 |
+
|
| 187 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 188 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 189 |
+
return None
|
| 190 |
+
elapsed = time.time() - self.start_time
|
| 191 |
+
sec_per_step = elapsed / global_step
|
| 192 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 193 |
+
h = int(remaining // 3600)
|
| 194 |
+
m = int((remaining % 3600) // 60)
|
| 195 |
+
s = int(remaining % 60)
|
| 196 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 197 |
+
|
| 198 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 199 |
+
self.start_time = time.time()
|
| 200 |
+
|
| 201 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 202 |
+
if not logs:
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 206 |
+
progress_pct = (
|
| 207 |
+
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 208 |
+
)
|
| 209 |
+
epoch_pct = None
|
| 210 |
+
if (
|
| 211 |
+
state.epoch is not None
|
| 212 |
+
and args.num_train_epochs
|
| 213 |
+
and args.num_train_epochs > 0
|
| 214 |
+
):
|
| 215 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 216 |
+
|
| 217 |
+
payload = {
|
| 218 |
+
"ts": _now_iso(),
|
| 219 |
+
"event": "train_log",
|
| 220 |
+
"step": int(state.global_step),
|
| 221 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 222 |
+
"progress_pct": (
|
| 223 |
+
round(progress_pct, 2) if progress_pct is not None else None
|
| 224 |
+
),
|
| 225 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 226 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 227 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 228 |
+
**logs,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 232 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 233 |
+
|
| 234 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 235 |
+
if not metrics:
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
payload = {
|
| 239 |
+
"ts": _now_iso(),
|
| 240 |
+
"event": "eval",
|
| 241 |
+
"step": int(state.global_step),
|
| 242 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 243 |
+
**metrics,
|
| 244 |
+
}
|
| 245 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 246 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# --------------------------
|
| 250 |
+
# GRPO Metrics Logging Callback
|
| 251 |
+
# --------------------------
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class GRPOMetricsCallback(TrainerCallback):
|
| 255 |
+
"""Logs GRPO-specific metrics during training."""
|
| 256 |
+
|
| 257 |
+
def __init__(self, run_dir: Path):
|
| 258 |
+
self.run_dir = run_dir
|
| 259 |
+
self.metrics_path = _ensure_dir(run_dir / "logs") / "grpo_metrics.jsonl"
|
| 260 |
+
|
| 261 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 262 |
+
if not logs:
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
# Extract GRPO-specific metrics if present
|
| 266 |
+
grpo_metrics = {k: v for k, v in logs.items() if k.startswith("grpo_")}
|
| 267 |
+
|
| 268 |
+
if grpo_metrics:
|
| 269 |
+
payload = {
|
| 270 |
+
"ts": _now_iso(),
|
| 271 |
+
"step": int(state.global_step),
|
| 272 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 273 |
+
**grpo_metrics,
|
| 274 |
+
}
|
| 275 |
+
with self.metrics_path.open("a", encoding="utf-8") as f:
|
| 276 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# --------------------------
|
| 280 |
+
# Data Validation
|
| 281 |
+
# --------------------------
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def validate_grpo_example(example: Dict[str, Any], data_cfg: Dict[str, Any]) -> bool:
|
| 285 |
+
"""
|
| 286 |
+
Validate that a GRPO example has required fields and proper structure.
|
| 287 |
+
Returns True if valid, False otherwise.
|
| 288 |
+
"""
|
| 289 |
+
completions_field = data_cfg.get("completions_field", "completions")
|
| 290 |
+
scores_field = data_cfg.get("scores_field", "scores")
|
| 291 |
+
|
| 292 |
+
# Check for required fields
|
| 293 |
+
if completions_field not in example or scores_field not in example:
|
| 294 |
+
return False
|
| 295 |
+
|
| 296 |
+
completions = example[completions_field]
|
| 297 |
+
scores = example[scores_field]
|
| 298 |
+
|
| 299 |
+
# Must be lists
|
| 300 |
+
if not isinstance(completions, list) or not isinstance(scores, list):
|
| 301 |
+
return False
|
| 302 |
+
|
| 303 |
+
# Must have same length
|
| 304 |
+
if len(completions) != len(scores):
|
| 305 |
+
return False
|
| 306 |
+
|
| 307 |
+
# Must have at least min_completions
|
| 308 |
+
min_completions = int(data_cfg.get("min_completions", 2))
|
| 309 |
+
if len(completions) < min_completions:
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
# Completions must be non-empty strings
|
| 313 |
+
for completion in completions:
|
| 314 |
+
if not isinstance(completion, str) or len(completion.strip()) == 0:
|
| 315 |
+
return False
|
| 316 |
+
|
| 317 |
+
# Scores must be valid numbers (not NaN or Inf)
|
| 318 |
+
for score in scores:
|
| 319 |
+
if not isinstance(score, (int, float)) or np.isnan(score) or np.isinf(score):
|
| 320 |
+
return False
|
| 321 |
+
|
| 322 |
+
return True
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# --------------------------
|
| 326 |
+
# Data Pipeline (GRPO Format)
|
| 327 |
+
# --------------------------
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def format_grpo_example(
|
| 331 |
+
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
|
| 332 |
+
) -> Optional[Dict[str, Any]]:
|
| 333 |
+
"""
|
| 334 |
+
Format GRPO data which requires prompt and multiple completions with scores.
|
| 335 |
+
Returns formatted data with tokenized inputs or None if validation fails.
|
| 336 |
+
"""
|
| 337 |
+
data_cfg = cfg["data"]
|
| 338 |
+
format_type = data_cfg.get("format_type", "chatml")
|
| 339 |
+
|
| 340 |
+
# Validate example first
|
| 341 |
+
if not validate_grpo_example(example, data_cfg):
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
# Get field names from config
|
| 345 |
+
prompt_field = data_cfg.get("prompt_field", "prompt")
|
| 346 |
+
completions_field = data_cfg.get("completions_field", "completions")
|
| 347 |
+
scores_field = data_cfg.get("scores_field", "scores")
|
| 348 |
+
|
| 349 |
+
# Extract data from example
|
| 350 |
+
prompt = example.get(prompt_field, "")
|
| 351 |
+
completions = example.get(completions_field, [])
|
| 352 |
+
scores = example.get(scores_field, [])
|
| 353 |
+
|
| 354 |
+
if format_type == "chatml":
|
| 355 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 356 |
+
|
| 357 |
+
# Format prompt with system message
|
| 358 |
+
messages = []
|
| 359 |
+
if system_prompt:
|
| 360 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 361 |
+
messages.append({"role": "user", "content": prompt})
|
| 362 |
+
|
| 363 |
+
# Apply chat template for prompt only
|
| 364 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 365 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Format completions
|
| 369 |
+
formatted_completions = []
|
| 370 |
+
for completion in completions:
|
| 371 |
+
formatted_completion = completion
|
| 372 |
+
if tokenizer.eos_token and not formatted_completion.endswith(tokenizer.eos_token):
|
| 373 |
+
formatted_completion += tokenizer.eos_token
|
| 374 |
+
formatted_completions.append(formatted_completion)
|
| 375 |
+
|
| 376 |
+
elif format_type == "alpaca":
|
| 377 |
+
prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
|
| 378 |
+
formatted_prompt = prefix
|
| 379 |
+
|
| 380 |
+
formatted_completions = []
|
| 381 |
+
for completion in completions:
|
| 382 |
+
formatted_completion = completion
|
| 383 |
+
if tokenizer.eos_token and not formatted_completion.endswith(tokenizer.eos_token):
|
| 384 |
+
formatted_completion += tokenizer.eos_token
|
| 385 |
+
formatted_completions.append(formatted_completion)
|
| 386 |
+
|
| 387 |
+
elif format_type == "custom":
|
| 388 |
+
# Use custom template (matches SFT training format)
|
| 389 |
+
custom_template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}")
|
| 390 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 391 |
+
|
| 392 |
+
# Format prompt part using template (instruction + input)
|
| 393 |
+
formatted_prompt = custom_template.format(
|
| 394 |
+
instruction=system_prompt,
|
| 395 |
+
input=prompt,
|
| 396 |
+
output="" # Empty output - model will generate
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# Format completions (just the output part)
|
| 400 |
+
formatted_completions = []
|
| 401 |
+
for completion in completions:
|
| 402 |
+
formatted_completion = completion
|
| 403 |
+
if tokenizer.eos_token and not formatted_completion.endswith(tokenizer.eos_token):
|
| 404 |
+
formatted_completion += tokenizer.eos_token
|
| 405 |
+
formatted_completions.append(formatted_completion)
|
| 406 |
+
|
| 407 |
+
else:
|
| 408 |
+
# Simple format
|
| 409 |
+
formatted_prompt = prompt
|
| 410 |
+
formatted_completions = completions
|
| 411 |
+
|
| 412 |
+
# PRE-TOKENIZE all completions with prompts
|
| 413 |
+
# This fixes the critical issue of on-the-fly tokenization during training
|
| 414 |
+
max_length = int(data_cfg.get("max_length", 2048))
|
| 415 |
+
|
| 416 |
+
full_texts = [formatted_prompt + completion for completion in formatted_completions]
|
| 417 |
+
|
| 418 |
+
# Tokenize all prompt + completion pairs
|
| 419 |
+
tokenized_full = tokenizer(
|
| 420 |
+
full_texts,
|
| 421 |
+
padding=False, # Don't pad yet, will pad in collator
|
| 422 |
+
truncation=True,
|
| 423 |
+
max_length=max_length,
|
| 424 |
+
return_tensors=None, # Return lists, not tensors
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Tokenize prompt only to get prompt length
|
| 428 |
+
tokenized_prompt = tokenizer(
|
| 429 |
+
formatted_prompt,
|
| 430 |
+
padding=False,
|
| 431 |
+
truncation=True,
|
| 432 |
+
max_length=max_length,
|
| 433 |
+
return_tensors=None,
|
| 434 |
+
)
|
| 435 |
+
prompt_length = len(tokenized_prompt["input_ids"])
|
| 436 |
+
|
| 437 |
+
return {
|
| 438 |
+
"prompt": formatted_prompt,
|
| 439 |
+
"completions": formatted_completions,
|
| 440 |
+
"scores": scores,
|
| 441 |
+
"input_ids_all": tokenized_full["input_ids"], # List of token lists
|
| 442 |
+
"attention_mask_all": tokenized_full["attention_mask"], # List of mask lists
|
| 443 |
+
"prompt_length": prompt_length,
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def build_grpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 448 |
+
"""
|
| 449 |
+
Build datasets for GRPO training with pre-tokenization.
|
| 450 |
+
Expected JSONL format: {"prompt": "...", "completions": [...], "scores": [...]}
|
| 451 |
+
"""
|
| 452 |
+
data_cfg = cfg["data"]
|
| 453 |
+
train_path = data_cfg["train_jsonl"]
|
| 454 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 455 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 456 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 457 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 458 |
+
|
| 459 |
+
# Ensure tokenizer has pad token
|
| 460 |
+
if tokenizer.pad_token is None:
|
| 461 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 462 |
+
|
| 463 |
+
# Load datasets
|
| 464 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 465 |
+
|
| 466 |
+
if eval_path:
|
| 467 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 468 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 469 |
+
else:
|
| 470 |
+
if 0.0 < split_ratio < 1.0:
|
| 471 |
+
split = ds["train"].train_test_split(
|
| 472 |
+
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
|
| 473 |
+
)
|
| 474 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 475 |
+
else:
|
| 476 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 477 |
+
|
| 478 |
+
# Format and tokenize GRPO examples
|
| 479 |
+
def format_fn(example):
|
| 480 |
+
formatted = format_grpo_example(example, cfg, tokenizer)
|
| 481 |
+
return formatted if formatted is not None else None
|
| 482 |
+
|
| 483 |
+
formatted_train = dsd["train"].map(
|
| 484 |
+
format_fn,
|
| 485 |
+
num_proc=1, # Use 1 to avoid tokenizer issues with multiprocessing
|
| 486 |
+
remove_columns=dsd["train"].column_names,
|
| 487 |
+
desc="Formatting and tokenizing train GRPO data",
|
| 488 |
+
).filter(lambda x: x is not None)
|
| 489 |
+
|
| 490 |
+
formatted_eval = None
|
| 491 |
+
if dsd["eval"] is not None:
|
| 492 |
+
formatted_eval = dsd["eval"].map(
|
| 493 |
+
format_fn,
|
| 494 |
+
num_proc=1,
|
| 495 |
+
remove_columns=dsd["eval"].column_names,
|
| 496 |
+
desc="Formatting and tokenizing eval GRPO data",
|
| 497 |
+
).filter(lambda x: x is not None)
|
| 498 |
+
|
| 499 |
+
if shuffle:
|
| 500 |
+
formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 501 |
+
|
| 502 |
+
print(f"Training samples: {len(formatted_train)}")
|
| 503 |
+
if formatted_eval is not None:
|
| 504 |
+
print(f"Evaluation samples: {len(formatted_eval)}")
|
| 505 |
+
|
| 506 |
+
return formatted_train, formatted_eval
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# --------------------------
|
| 510 |
+
# GRPO Data Collator (CRITICAL FIX)
|
| 511 |
+
# --------------------------
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@dataclass
|
| 515 |
+
class GRPODataCollator:
|
| 516 |
+
"""
|
| 517 |
+
Data collator for GRPO training.
|
| 518 |
+
Handles batching of multiple completions per prompt with proper padding.
|
| 519 |
+
"""
|
| 520 |
+
tokenizer: Any
|
| 521 |
+
pad_to_multiple_of: Optional[int] = None
|
| 522 |
+
return_tensors: str = "pt"
|
| 523 |
+
|
| 524 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 525 |
+
"""
|
| 526 |
+
Collate a batch of GRPO examples.
|
| 527 |
+
Each example has: prompt, completions, scores, input_ids_all, attention_mask_all, prompt_length
|
| 528 |
+
"""
|
| 529 |
+
batch = {
|
| 530 |
+
"prompt": [],
|
| 531 |
+
"completions": [],
|
| 532 |
+
"scores": [],
|
| 533 |
+
"input_ids_all": [],
|
| 534 |
+
"attention_mask_all": [],
|
| 535 |
+
"prompt_lengths": [],
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
for feature in features:
|
| 539 |
+
batch["prompt"].append(feature["prompt"])
|
| 540 |
+
batch["completions"].append(feature["completions"])
|
| 541 |
+
batch["scores"].append(feature["scores"])
|
| 542 |
+
batch["input_ids_all"].append(feature["input_ids_all"])
|
| 543 |
+
batch["attention_mask_all"].append(feature["attention_mask_all"])
|
| 544 |
+
batch["prompt_lengths"].append(feature["prompt_length"])
|
| 545 |
+
|
| 546 |
+
return batch
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# --------------------------
|
| 550 |
+
# GRPO Trainer
|
| 551 |
+
# --------------------------
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class GRPOTrainer(Trainer):
|
| 555 |
+
"""
|
| 556 |
+
Group Relative Policy Optimization Trainer.
|
| 557 |
+
Implements GRPO algorithm for fine-tuning with group-based rewards.
|
| 558 |
+
|
| 559 |
+
FIXED ISSUES:
|
| 560 |
+
- Pre-tokenized data (no on-the-fly tokenization)
|
| 561 |
+
- Proper prompt masking for completion-only loss
|
| 562 |
+
- Correct KL divergence calculation
|
| 563 |
+
- Improved advantage computation
|
| 564 |
+
- GRPO-specific metrics logging
|
| 565 |
+
- Better error handling
|
| 566 |
+
"""
|
| 567 |
+
|
| 568 |
+
def __init__(
|
| 569 |
+
self,
|
| 570 |
+
model,
|
| 571 |
+
ref_model,
|
| 572 |
+
args,
|
| 573 |
+
train_dataset,
|
| 574 |
+
eval_dataset,
|
| 575 |
+
tokenizer,
|
| 576 |
+
grpo_config,
|
| 577 |
+
data_collator=None,
|
| 578 |
+
**kwargs
|
| 579 |
+
):
|
| 580 |
+
# Use custom data collator
|
| 581 |
+
if data_collator is None:
|
| 582 |
+
data_collator = GRPODataCollator(tokenizer=tokenizer)
|
| 583 |
+
|
| 584 |
+
# Store tokenizer as instance variable (for newer transformers compatibility)
|
| 585 |
+
if tokenizer is None:
|
| 586 |
+
raise ValueError("Tokenizer cannot be None!")
|
| 587 |
+
|
| 588 |
+
# Store processing_class and token IDs separately (following official TRL pattern)
|
| 589 |
+
# This avoids issues with self.tokenizer being overwritten by parent class
|
| 590 |
+
self.processing_class = tokenizer
|
| 591 |
+
|
| 592 |
+
# Ensure pad_token is set
|
| 593 |
+
if tokenizer.pad_token is None:
|
| 594 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 595 |
+
|
| 596 |
+
# Store token IDs as separate attributes (official TRL pattern)
|
| 597 |
+
self.pad_token_id = tokenizer.pad_token_id
|
| 598 |
+
self.eos_token_id = tokenizer.eos_token_id
|
| 599 |
+
|
| 600 |
+
super().__init__(
|
| 601 |
+
model=model,
|
| 602 |
+
args=args,
|
| 603 |
+
train_dataset=train_dataset,
|
| 604 |
+
eval_dataset=eval_dataset,
|
| 605 |
+
data_collator=data_collator,
|
| 606 |
+
**kwargs
|
| 607 |
+
)
|
| 608 |
+
self.ref_model = ref_model
|
| 609 |
+
self.grpo_config = grpo_config
|
| 610 |
+
|
| 611 |
+
# Ensure ref_model is on same device as model if it exists
|
| 612 |
+
if self.ref_model is not None:
|
| 613 |
+
self.ref_model.to(self.model.device)
|
| 614 |
+
|
| 615 |
+
# For metrics tracking
|
| 616 |
+
self.grpo_metrics = {
|
| 617 |
+
"advantages": [],
|
| 618 |
+
"kl_divs": [],
|
| 619 |
+
"group_scores": [],
|
| 620 |
+
"best_completion_indices": [],
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
# Random state for reproducibility
|
| 624 |
+
self.rng = np.random.RandomState(grpo_config.get("seed", 42))
|
| 625 |
+
|
| 626 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 627 |
+
"""
|
| 628 |
+
Compute GRPO loss using group-based ranking.
|
| 629 |
+
|
| 630 |
+
Improved implementation with:
|
| 631 |
+
- Proper prompt masking
|
| 632 |
+
- Correct KL divergence
|
| 633 |
+
- Better error handling
|
| 634 |
+
- Metrics logging
|
| 635 |
+
"""
|
| 636 |
+
try:
|
| 637 |
+
# Extract data from collated batch
|
| 638 |
+
prompts = inputs["prompt"]
|
| 639 |
+
completions_lists = inputs["completions"]
|
| 640 |
+
scores_lists = inputs["scores"]
|
| 641 |
+
input_ids_all_lists = inputs["input_ids_all"]
|
| 642 |
+
attention_mask_all_lists = inputs["attention_mask_all"]
|
| 643 |
+
prompt_lengths = inputs["prompt_lengths"]
|
| 644 |
+
|
| 645 |
+
total_loss = torch.tensor(0.0, device=model.device)
|
| 646 |
+
batch_size = len(prompts)
|
| 647 |
+
|
| 648 |
+
# Metrics for this batch
|
| 649 |
+
batch_advantages = []
|
| 650 |
+
batch_kl_divs = []
|
| 651 |
+
batch_group_scores = []
|
| 652 |
+
batch_best_indices = []
|
| 653 |
+
|
| 654 |
+
for idx, (prompt, completions, scores, input_ids_all, attention_mask_all, prompt_length) in enumerate(
|
| 655 |
+
zip(prompts, completions_lists, scores_lists, input_ids_all_lists, attention_mask_all_lists, prompt_lengths)
|
| 656 |
+
):
|
| 657 |
+
# Sample group_size completions
|
| 658 |
+
group_size = min(self.grpo_config["group_size"], len(completions))
|
| 659 |
+
indices = self.rng.permutation(len(completions))[:group_size]
|
| 660 |
+
|
| 661 |
+
# Get group data
|
| 662 |
+
group_input_ids = [input_ids_all[i] for i in indices]
|
| 663 |
+
group_attention_masks = [attention_mask_all[i] for i in indices]
|
| 664 |
+
group_scores = torch.tensor([scores[i] for i in indices], device=model.device, dtype=torch.float32)
|
| 665 |
+
|
| 666 |
+
# Apply reward shaping
|
| 667 |
+
reward_scaling = float(self.grpo_config.get("reward_scaling", 1.0))
|
| 668 |
+
reward_bias = float(self.grpo_config.get("reward_bias", 0.0))
|
| 669 |
+
reward_clip = self.grpo_config.get("reward_clip", None)
|
| 670 |
+
|
| 671 |
+
group_scores = group_scores * reward_scaling + reward_bias
|
| 672 |
+
|
| 673 |
+
if reward_clip is not None:
|
| 674 |
+
group_scores = torch.clamp(group_scores, -reward_clip, reward_clip)
|
| 675 |
+
|
| 676 |
+
# Pad and convert to tensors
|
| 677 |
+
max_len = max(len(ids) for ids in group_input_ids)
|
| 678 |
+
padded_input_ids = []
|
| 679 |
+
padded_attention_masks = []
|
| 680 |
+
completion_masks = []
|
| 681 |
+
|
| 682 |
+
for ids, mask in zip(group_input_ids, group_attention_masks):
|
| 683 |
+
pad_len = max_len - len(ids)
|
| 684 |
+
padded_ids = ids + [self.pad_token_id] * pad_len
|
| 685 |
+
padded_mask = mask + [0] * pad_len
|
| 686 |
+
|
| 687 |
+
# Create completion mask (1 for completion tokens, 0 for prompt/padding)
|
| 688 |
+
comp_mask = [0] * prompt_length + [1] * (len(ids) - prompt_length) + [0] * pad_len
|
| 689 |
+
|
| 690 |
+
padded_input_ids.append(padded_ids)
|
| 691 |
+
padded_attention_masks.append(padded_mask)
|
| 692 |
+
completion_masks.append(comp_mask)
|
| 693 |
+
|
| 694 |
+
# Convert to tensors
|
| 695 |
+
input_ids_tensor = torch.tensor(padded_input_ids, device=model.device)
|
| 696 |
+
attention_mask_tensor = torch.tensor(padded_attention_masks, device=model.device)
|
| 697 |
+
completion_mask_tensor = torch.tensor(completion_masks, device=model.device, dtype=torch.float32)
|
| 698 |
+
|
| 699 |
+
# Get log probabilities from policy model
|
| 700 |
+
with torch.cuda.amp.autocast(enabled=self.args.fp16 or self.args.bf16):
|
| 701 |
+
outputs = model(
|
| 702 |
+
input_ids=input_ids_tensor,
|
| 703 |
+
attention_mask=attention_mask_tensor,
|
| 704 |
+
output_hidden_states=False,
|
| 705 |
+
use_cache=False,
|
| 706 |
+
)
|
| 707 |
+
logits = outputs.logits
|
| 708 |
+
|
| 709 |
+
# Compute token-level log probs
|
| 710 |
+
log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
|
| 711 |
+
target_log_probs = torch.gather(
|
| 712 |
+
log_probs,
|
| 713 |
+
2,
|
| 714 |
+
input_ids_tensor[:, 1:].unsqueeze(-1)
|
| 715 |
+
).squeeze(-1)
|
| 716 |
+
|
| 717 |
+
# Apply completion mask (only compute loss on completion tokens)
|
| 718 |
+
completion_mask_shifted = completion_mask_tensor[:, 1:]
|
| 719 |
+
masked_log_probs = target_log_probs * completion_mask_shifted
|
| 720 |
+
|
| 721 |
+
# Sum log probs for each completion (only over completion tokens)
|
| 722 |
+
sequence_log_probs = masked_log_probs.sum(dim=1) / (completion_mask_shifted.sum(dim=1) + 1e-8)
|
| 723 |
+
|
| 724 |
+
# Get reference log probabilities if using reference model
|
| 725 |
+
kl_div = torch.tensor(0.0, device=model.device)
|
| 726 |
+
if self.ref_model is not None:
|
| 727 |
+
with torch.no_grad():
|
| 728 |
+
ref_outputs = self.ref_model(
|
| 729 |
+
input_ids=input_ids_tensor,
|
| 730 |
+
attention_mask=attention_mask_tensor,
|
| 731 |
+
output_hidden_states=False,
|
| 732 |
+
use_cache=False,
|
| 733 |
+
)
|
| 734 |
+
ref_logits = ref_outputs.logits
|
| 735 |
+
|
| 736 |
+
# Compute reference log probs
|
| 737 |
+
ref_log_probs = F.log_softmax(ref_logits[:, :-1, :], dim=-1)
|
| 738 |
+
ref_target_log_probs = torch.gather(
|
| 739 |
+
ref_log_probs,
|
| 740 |
+
2,
|
| 741 |
+
input_ids_tensor[:, 1:].unsqueeze(-1)
|
| 742 |
+
).squeeze(-1)
|
| 743 |
+
|
| 744 |
+
# Apply completion mask
|
| 745 |
+
ref_masked_log_probs = ref_target_log_probs * completion_mask_shifted
|
| 746 |
+
ref_sequence_log_probs = ref_masked_log_probs.sum(dim=1) / (completion_mask_shifted.sum(dim=1) + 1e-8)
|
| 747 |
+
|
| 748 |
+
# Proper KL divergence: E[log(p/q)] = E[log(p) - log(q)]
|
| 749 |
+
kl_div = (sequence_log_probs - ref_sequence_log_probs).mean()
|
| 750 |
+
|
| 751 |
+
# Compute advantages from scores
|
| 752 |
+
if self.grpo_config.get("normalize_advantages", True):
|
| 753 |
+
advantage_eps = float(self.grpo_config.get("advantage_eps", 1e-8))
|
| 754 |
+
advantages = (group_scores - group_scores.mean()) / (group_scores.std() + advantage_eps)
|
| 755 |
+
else:
|
| 756 |
+
advantages = group_scores
|
| 757 |
+
|
| 758 |
+
# Apply temperature to advantages
|
| 759 |
+
advantage_temp = float(self.grpo_config.get("advantage_temperature", 1.0))
|
| 760 |
+
advantages = advantages / advantage_temp
|
| 761 |
+
|
| 762 |
+
# GRPO loss: weighted log probabilities by advantages
|
| 763 |
+
policy_loss = -(advantages * sequence_log_probs).mean()
|
| 764 |
+
|
| 765 |
+
# Add KL penalty if using reference model
|
| 766 |
+
kl_coef = float(self.grpo_config.get("kl_coef", 0.05))
|
| 767 |
+
loss = policy_loss + kl_coef * kl_div
|
| 768 |
+
|
| 769 |
+
total_loss += loss
|
| 770 |
+
|
| 771 |
+
# Track metrics
|
| 772 |
+
batch_advantages.append(advantages.detach().cpu().numpy())
|
| 773 |
+
batch_kl_divs.append(kl_div.item())
|
| 774 |
+
batch_group_scores.append(group_scores.detach().cpu().numpy())
|
| 775 |
+
batch_best_indices.append(int(torch.argmax(group_scores).item()))
|
| 776 |
+
|
| 777 |
+
# Average over batch
|
| 778 |
+
total_loss = total_loss / batch_size
|
| 779 |
+
|
| 780 |
+
# Store metrics for logging
|
| 781 |
+
if batch_advantages:
|
| 782 |
+
self.grpo_metrics["advantages"].extend([adv.mean() for adv in batch_advantages])
|
| 783 |
+
self.grpo_metrics["kl_divs"].extend(batch_kl_divs)
|
| 784 |
+
self.grpo_metrics["group_scores"].extend([scores.mean() for scores in batch_group_scores])
|
| 785 |
+
self.grpo_metrics["best_completion_indices"].extend(batch_best_indices)
|
| 786 |
+
|
| 787 |
+
return (total_loss, outputs) if return_outputs else total_loss
|
| 788 |
+
|
| 789 |
+
except Exception as e:
|
| 790 |
+
print(f"Error in compute_loss: {e}")
|
| 791 |
+
import traceback
|
| 792 |
+
traceback.print_exc()
|
| 793 |
+
raise
|
| 794 |
+
|
| 795 |
+
def log(self, logs: Dict[str, float], start_time: float = None, **kwargs) -> None:
|
| 796 |
+
"""Override log to add GRPO-specific metrics."""
|
| 797 |
+
# Add GRPO metrics if available
|
| 798 |
+
if self.grpo_metrics["advantages"]:
|
| 799 |
+
logs["grpo_mean_advantage"] = float(np.mean(self.grpo_metrics["advantages"]))
|
| 800 |
+
logs["grpo_std_advantage"] = float(np.std(self.grpo_metrics["advantages"]))
|
| 801 |
+
logs["grpo_mean_kl_div"] = float(np.mean(self.grpo_metrics["kl_divs"]))
|
| 802 |
+
logs["grpo_mean_group_score"] = float(np.mean(self.grpo_metrics["group_scores"]))
|
| 803 |
+
|
| 804 |
+
# Clear metrics after logging
|
| 805 |
+
self.grpo_metrics = {
|
| 806 |
+
"advantages": [],
|
| 807 |
+
"kl_divs": [],
|
| 808 |
+
"group_scores": [],
|
| 809 |
+
"best_completion_indices": [],
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
super().log(logs, start_time=start_time, **kwargs)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
# --------------------------
|
| 816 |
+
# Model Loading + PEFT
|
| 817 |
+
# --------------------------
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 821 |
+
model_cfg = cfg["model"]
|
| 822 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 823 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 824 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 825 |
+
|
| 826 |
+
# Use custom tokenizer if specified (for compatibility issues)
|
| 827 |
+
tokenizer_name = model_cfg.get("tokenizer_name", str(base_dir))
|
| 828 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 829 |
+
tokenizer_name,
|
| 830 |
+
use_fast=use_fast,
|
| 831 |
+
trust_remote_code=trust_remote_code,
|
| 832 |
+
)
|
| 833 |
+
if tokenizer.pad_token is None:
|
| 834 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 835 |
+
|
| 836 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 837 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 838 |
+
|
| 839 |
+
quant_cfg = None
|
| 840 |
+
if use_4bit:
|
| 841 |
+
quant_cfg = BitsAndBytesConfig(
|
| 842 |
+
load_in_4bit=True,
|
| 843 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 844 |
+
bnb_4bit_use_double_quant=bool(
|
| 845 |
+
model_cfg.get("bnb_4bit_use_double_quant", True)
|
| 846 |
+
),
|
| 847 |
+
bnb_4bit_compute_dtype=_dtype_from_str(
|
| 848 |
+
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
|
| 849 |
+
),
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 853 |
+
|
| 854 |
+
try:
|
| 855 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 856 |
+
str(base_dir),
|
| 857 |
+
device_map=device_map,
|
| 858 |
+
trust_remote_code=trust_remote_code,
|
| 859 |
+
low_cpu_mem_usage=True,
|
| 860 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 861 |
+
quantization_config=quant_cfg,
|
| 862 |
+
attn_implementation=attn_impl,
|
| 863 |
+
)
|
| 864 |
+
except Exception as e:
|
| 865 |
+
if attn_impl is not None:
|
| 866 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 867 |
+
print("[warn] Falling back to default attention implementation.")
|
| 868 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 869 |
+
str(base_dir),
|
| 870 |
+
device_map=device_map,
|
| 871 |
+
trust_remote_code=trust_remote_code,
|
| 872 |
+
low_cpu_mem_usage=True,
|
| 873 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 874 |
+
quantization_config=quant_cfg,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
return model, tokenizer
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 881 |
+
peft_cfg = cfg["peft"]
|
| 882 |
+
model_cfg = cfg["model"]
|
| 883 |
+
tr_cfg = cfg["train"]
|
| 884 |
+
|
| 885 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 886 |
+
return model, None
|
| 887 |
+
|
| 888 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 889 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 890 |
+
|
| 891 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 892 |
+
model.gradient_checkpointing_enable()
|
| 893 |
+
if hasattr(model, "config"):
|
| 894 |
+
model.config.use_cache = False
|
| 895 |
+
|
| 896 |
+
if use_4bit:
|
| 897 |
+
model = prepare_model_for_kbit_training(
|
| 898 |
+
model,
|
| 899 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 903 |
+
if target_modules == "auto":
|
| 904 |
+
target_modules = _infer_target_modules(model)
|
| 905 |
+
|
| 906 |
+
lora_config = LoraConfig(
|
| 907 |
+
r=int(peft_cfg.get("r", 16)),
|
| 908 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 909 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 910 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 911 |
+
task_type="CAUSAL_LM",
|
| 912 |
+
target_modules=target_modules,
|
| 913 |
+
)
|
| 914 |
+
model = get_peft_model(model, lora_config)
|
| 915 |
+
return model, lora_config
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
# --------------------------
|
| 919 |
+
# Merge Logic
|
| 920 |
+
# --------------------------
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def merge_adapter(
|
| 924 |
+
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
|
| 925 |
+
):
|
| 926 |
+
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 927 |
+
|
| 928 |
+
model_cfg = cfg["model"]
|
| 929 |
+
merge_cfg = cfg.get("merge", {})
|
| 930 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 931 |
+
|
| 932 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 933 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 934 |
+
|
| 935 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 936 |
+
str(base_dir),
|
| 937 |
+
torch_dtype=merged_dtype,
|
| 938 |
+
device_map="cpu",
|
| 939 |
+
low_cpu_mem_usage=True,
|
| 940 |
+
trust_remote_code=trust_remote_code,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 944 |
+
merged = merged.merge_and_unload()
|
| 945 |
+
|
| 946 |
+
_ensure_dir(final_dir)
|
| 947 |
+
merged.save_pretrained(
|
| 948 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
# Use tokenizer_name from config if available, fallback to base_dir
|
| 952 |
+
tokenizer_path = model_cfg.get("tokenizer_name", str(base_dir))
|
| 953 |
+
tok = AutoTokenizer.from_pretrained(
|
| 954 |
+
tokenizer_path, trust_remote_code=trust_remote_code
|
| 955 |
+
)
|
| 956 |
+
if tok.pad_token is None:
|
| 957 |
+
tok.pad_token = tok.eos_token
|
| 958 |
+
tok.save_pretrained(str(final_dir))
|
| 959 |
+
|
| 960 |
+
print("--- Merge complete ---")
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
# --------------------------
|
| 964 |
+
# Main
|
| 965 |
+
# --------------------------
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
def main():
|
| 969 |
+
ap = argparse.ArgumentParser()
|
| 970 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 971 |
+
ap.add_argument(
|
| 972 |
+
"--merge-only", action="store_true", help="Skip training, just merge adapter"
|
| 973 |
+
)
|
| 974 |
+
args = ap.parse_args()
|
| 975 |
+
|
| 976 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 977 |
+
cfg = yaml.safe_load(f)
|
| 978 |
+
|
| 979 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 980 |
+
_ensure_dir(run_dir / "logs")
|
| 981 |
+
|
| 982 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 983 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 984 |
+
|
| 985 |
+
model_cfg = cfg["model"]
|
| 986 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 987 |
+
repo_path = Path(repo_id)
|
| 988 |
+
|
| 989 |
+
# Local model path
|
| 990 |
+
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
|
| 991 |
+
base_dir = repo_path
|
| 992 |
+
print(f"Using local model at: {base_dir}")
|
| 993 |
+
elif repo_path.exists() and repo_path.is_dir():
|
| 994 |
+
raise ValueError(
|
| 995 |
+
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
|
| 996 |
+
)
|
| 997 |
+
else:
|
| 998 |
+
# HF repo_id -> download
|
| 999 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 1000 |
+
if not _looks_like_model_dir(base_dir):
|
| 1001 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 1002 |
+
snapshot_download(
|
| 1003 |
+
repo_id=repo_id,
|
| 1004 |
+
revision=model_cfg.get("revision", None),
|
| 1005 |
+
local_dir=str(base_dir),
|
| 1006 |
+
local_dir_use_symlinks=False,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 1010 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 1011 |
+
|
| 1012 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 1013 |
+
if merge_cfg.get("output_dir"):
|
| 1014 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 1015 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 1016 |
+
else:
|
| 1017 |
+
final_dir = run_dir / "final_model"
|
| 1018 |
+
|
| 1019 |
+
# Merge-only
|
| 1020 |
+
if args.merge_only:
|
| 1021 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 1022 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 1023 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 1024 |
+
return
|
| 1025 |
+
|
| 1026 |
+
# Initialize Wandb
|
| 1027 |
+
wandb_run = setup_wandb(cfg, run_dir)
|
| 1028 |
+
|
| 1029 |
+
# Training
|
| 1030 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 1031 |
+
|
| 1032 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 1033 |
+
model, _ = apply_peft(cfg, model)
|
| 1034 |
+
|
| 1035 |
+
# Load reference model for GRPO
|
| 1036 |
+
grpo_cfg = cfg.get("grpo", {})
|
| 1037 |
+
use_reference_model = bool(grpo_cfg.get("use_reference_model", True))
|
| 1038 |
+
|
| 1039 |
+
ref_model = None
|
| 1040 |
+
if use_reference_model:
|
| 1041 |
+
print("Loading reference model (frozen copy)...")
|
| 1042 |
+
ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir)
|
| 1043 |
+
ref_model, _ = apply_peft(cfg, ref_model)
|
| 1044 |
+
# Freeze reference model
|
| 1045 |
+
for param in ref_model.parameters():
|
| 1046 |
+
param.requires_grad = False
|
| 1047 |
+
ref_model.eval()
|
| 1048 |
+
print("Reference model loaded and frozen")
|
| 1049 |
+
|
| 1050 |
+
train_ds, eval_ds = build_grpo_datasets(cfg, tokenizer)
|
| 1051 |
+
|
| 1052 |
+
tr_cfg = cfg["train"]
|
| 1053 |
+
|
| 1054 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 1055 |
+
use_fp16 = dtype == torch.float16
|
| 1056 |
+
use_bf16 = dtype == torch.bfloat16
|
| 1057 |
+
|
| 1058 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 1059 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 1060 |
+
|
| 1061 |
+
# Dynamic evaluation strategy parameter handling
|
| 1062 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 1063 |
+
eval_key = (
|
| 1064 |
+
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
# Setup reporting
|
| 1068 |
+
report_to = []
|
| 1069 |
+
if wandb_run is not None:
|
| 1070 |
+
report_to.append("wandb")
|
| 1071 |
+
|
| 1072 |
+
ta_kwargs = dict(
|
| 1073 |
+
output_dir=str(ckpt_dir),
|
| 1074 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 1075 |
+
num_train_epochs=num_train_epochs,
|
| 1076 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 1077 |
+
per_device_eval_batch_size=int(
|
| 1078 |
+
tr_cfg.get(
|
| 1079 |
+
"per_device_eval_batch_size",
|
| 1080 |
+
tr_cfg.get("per_device_train_batch_size", 1),
|
| 1081 |
+
)
|
| 1082 |
+
),
|
| 1083 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 1084 |
+
learning_rate=float(tr_cfg.get("learning_rate", 1e-5)),
|
| 1085 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 1086 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 1087 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 1088 |
+
optim=str(
|
| 1089 |
+
tr_cfg.get(
|
| 1090 |
+
"optim",
|
| 1091 |
+
(
|
| 1092 |
+
"paged_adamw_8bit"
|
| 1093 |
+
if bool(model_cfg.get("use_4bit", False))
|
| 1094 |
+
else "adamw_torch"
|
| 1095 |
+
),
|
| 1096 |
+
)
|
| 1097 |
+
),
|
| 1098 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 0.5)),
|
| 1099 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 1100 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 1101 |
+
save_steps=int(tr_cfg.get("save_steps", 100)),
|
| 1102 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 1103 |
+
eval_steps=int(tr_cfg.get("eval_steps", 50)),
|
| 1104 |
+
load_best_model_at_end=(
|
| 1105 |
+
bool(tr_cfg.get("load_best_model_at_end", True))
|
| 1106 |
+
if eval_ds is not None
|
| 1107 |
+
else False
|
| 1108 |
+
),
|
| 1109 |
+
metric_for_best_model="eval_loss",
|
| 1110 |
+
greater_is_better=False,
|
| 1111 |
+
fp16=use_fp16,
|
| 1112 |
+
bf16=use_bf16,
|
| 1113 |
+
report_to=report_to,
|
| 1114 |
+
remove_unused_columns=False,
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
ta_kwargs[eval_key] = str(
|
| 1118 |
+
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 1122 |
+
|
| 1123 |
+
# Setup callbacks
|
| 1124 |
+
callbacks = [
|
| 1125 |
+
JsonlLoggerCallback(run_dir),
|
| 1126 |
+
GRPOMetricsCallback(run_dir),
|
| 1127 |
+
]
|
| 1128 |
+
|
| 1129 |
+
# Add early stopping callback
|
| 1130 |
+
early_stopping_cfg = tr_cfg.get("early_stopping", {})
|
| 1131 |
+
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
|
| 1132 |
+
early_stopping_callback = EarlyStoppingCallback(
|
| 1133 |
+
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
|
| 1134 |
+
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
|
| 1135 |
+
)
|
| 1136 |
+
callbacks.append(early_stopping_callback)
|
| 1137 |
+
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}")
|
| 1138 |
+
|
| 1139 |
+
# Validate GRPO config
|
| 1140 |
+
group_size = int(grpo_cfg.get("group_size", 4))
|
| 1141 |
+
min_completions = int(cfg["data"].get("min_completions", 2))
|
| 1142 |
+
|
| 1143 |
+
if group_size > min_completions:
|
| 1144 |
+
print(f"Warning: group_size ({group_size}) > min_completions ({min_completions})")
|
| 1145 |
+
print(f"Setting group_size to {min_completions}")
|
| 1146 |
+
group_size = min_completions
|
| 1147 |
+
|
| 1148 |
+
# Add necessary GRPO config parameters
|
| 1149 |
+
grpo_config = {
|
| 1150 |
+
"group_size": group_size,
|
| 1151 |
+
"normalize_advantages": bool(grpo_cfg.get("normalize_advantages", True)),
|
| 1152 |
+
"advantage_eps": float(grpo_cfg.get("advantage_eps", 1e-8)),
|
| 1153 |
+
"kl_coef": float(grpo_cfg.get("kl_coef", 0.05)),
|
| 1154 |
+
"reward_scaling": float(grpo_cfg.get("reward_scaling", 1.0)),
|
| 1155 |
+
"reward_bias": float(grpo_cfg.get("reward_bias", 0.0)),
|
| 1156 |
+
"reward_clip": grpo_cfg.get("reward_clip", None),
|
| 1157 |
+
"advantage_temperature": float(grpo_cfg.get("advantage_temperature", 1.0)),
|
| 1158 |
+
"max_length": int(cfg["data"].get("max_length", 2048)),
|
| 1159 |
+
"seed": int(cfg["run"].get("seed", 42)),
|
| 1160 |
+
}
|
| 1161 |
+
|
| 1162 |
+
print(f"\n=== GRPO Training Configuration ===")
|
| 1163 |
+
print(f"Group size: {grpo_config['group_size']}")
|
| 1164 |
+
print(f"KL coefficient: {grpo_config['kl_coef']}")
|
| 1165 |
+
print(f"Normalize advantages: {grpo_config['normalize_advantages']}")
|
| 1166 |
+
print(f"Reward scaling: {grpo_config['reward_scaling']}")
|
| 1167 |
+
print(f"Advantage temperature: {grpo_config['advantage_temperature']}")
|
| 1168 |
+
print(f"Use reference model: {use_reference_model}")
|
| 1169 |
+
print(f"===================================\n")
|
| 1170 |
+
|
| 1171 |
+
trainer = GRPOTrainer(
|
| 1172 |
+
model=model,
|
| 1173 |
+
ref_model=ref_model,
|
| 1174 |
+
args=training_args,
|
| 1175 |
+
train_dataset=train_ds,
|
| 1176 |
+
eval_dataset=eval_ds,
|
| 1177 |
+
tokenizer=tokenizer,
|
| 1178 |
+
grpo_config=grpo_config,
|
| 1179 |
+
callbacks=callbacks,
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
# Resume
|
| 1183 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 1184 |
+
if resume_from == "auto":
|
| 1185 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 1186 |
+
resume_from = last if last else None
|
| 1187 |
+
if resume_from:
|
| 1188 |
+
print(f"Resuming from {resume_from}")
|
| 1189 |
+
|
| 1190 |
+
print("Starting GRPO training...")
|
| 1191 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 1192 |
+
|
| 1193 |
+
trainer.save_model(str(best_adapter_dir))
|
| 1194 |
+
print(f"Saved best adapter -> {best_adapter_dir}")
|
| 1195 |
+
|
| 1196 |
+
if eval_ds is not None:
|
| 1197 |
+
metrics = trainer.evaluate()
|
| 1198 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 1199 |
+
json.dump(metrics, f, indent=2)
|
| 1200 |
+
print(f"Final metrics: {metrics}")
|
| 1201 |
+
|
| 1202 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 1203 |
+
del trainer, model
|
| 1204 |
+
if ref_model is not None:
|
| 1205 |
+
del ref_model
|
| 1206 |
+
torch.cuda.empty_cache()
|
| 1207 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 1208 |
+
else:
|
| 1209 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 1210 |
+
|
| 1211 |
+
# Finish Wandb run
|
| 1212 |
+
finish_wandb()
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
if __name__ == "__main__":
|
| 1216 |
+
main()
|
trainer-kit/GRPO-14B/sft_dataset.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"input": "##TASK\nUpdate connector configuration and payment response structures to support new payment processors and extended authorization features. This enables integration with Finix and Tesouro connectors while providing visibility into extended authorization requests in payment responses.", "output": "##OUTPUT\nThe connector configuration system requires new connector entries and payment responses need extended authorization tracking to support enhanced payment processing capabilities. The flow is: system loads ConnectorConfig at startup → populates with all supported connectors → payment processing uses configured connectors → payment response includes authorization details. The ConnectorConfig struct (connector_configs/src/connector.rs::ConnectorConfig) must add finix and tesouro fields because it defines the complete set of supported connectors—without these fields, the system cannot load configurations for Finix and Tesouro, preventing merchants from using these payment processors. The PaymentsResponse::foreign_from method (router/src/core/payments/transformers.rs::api::PaymentsResponse::foreign_from) needs to add request_extended_authorization field because it maps database payment attempt data to API responses—without this field, merchants cannot see whether extended authorization was requested, losing visibility into payment processing decisions. These changes are independent but both enhance payment processing capabilities: one expands connector support, the other improves response transparency.\n\n##SELECT\ncrates/connector_configs/src/connector.rs::ConnectorConfig\ncrates/router/src/core/payments/transformers.rs::api::PaymentsResponse::foreign_from\n<EOS>"}
|
| 2 |
+
{"input": "##TASK\nFix TokenEx vault integration to properly handle card tokenization and improve vault error messaging for better debugging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail.", "output": "##OUTPUT\nThe TokenEx vault integration has incorrect data handling and poor error messaging that need to be fixed for reliable payment method vaulting. The flow is: merchant initiates vault operation → TokenexInsertRequest::try_from transforms request → sends to TokenEx → response processed by get_vault_response_for_retrieve_payment_method_data_v1. The TokenexInsertRequest::try_from method (hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from) must change from encoding entire card as JSON to sending just card_number because TokenEx API expects raw card number tokenization, not JSON payload—without this change, TokenEx rejects requests with malformed data. The get_vault_response_for_retrieve_payment_method_data_v1 function (router/src/core/payment_methods/vault.rs::get_vault_response_for_retrieve_payment_method_data_v1) needs updated error messages to specify 'from external vault' because current generic messages don't indicate vault source—without this, debugging vault failures becomes difficult when multiple vault providers are configured. These fixes are independent but both improve TokenEx integration reliability.\n\n##SELECT\ncrates/hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from\ncrates/router/src/core/payment_methods/vault.rs::get_vault_response_for_retrieve_payment_method_data_v1\n<EOS>"}
|
| 3 |
+
{"input": "##TASK\nAdd ConfirmSubscription flow to API lock management and subscription error handling to improve subscription lifecycle support. This enables proper resource locking for subscription confirmation operations and provides appropriate error status codes for subscription-related failures, ensuring the system can handle subscription confirmation requests with proper concurrency control and error reporting.", "output": "##OUTPUT\nThe subscription management system requires support for subscription confirmation operations with proper API locking and error handling. The flow is: HTTP POST for subscription confirmation → ApiIdentifier::from (router/src/routes/lock_utils.rs::ApiIdentifier::from) maps Flow::ConfirmSubscription to Self::Subscription for resource locking → subscription operation executes → if subscription fails, StripeErrorCode::status_code (router/src/compatibility/stripe/errors.rs::StripeErrorCode::status_code) returns INTERNAL_SERVER_ERROR for SubscriptionError variants. The ApiIdentifier::from function must include Flow::ConfirmSubscription in the subscription match arm because without it, subscription confirmation requests would not acquire proper locks, leading to potential race conditions during concurrent subscription operations. The StripeErrorCode::status_code function must handle SubscriptionError because subscription operations can fail and need proper HTTP status code mapping - without this, subscription errors would default to 200 OK instead of 500 Internal Server Error, breaking REST API conventions and confusing clients about operation success/failure. These two functions are tightly coupled through the subscription lifecycle - one handles the locking before operation, the other handles error reporting after operation.\n\n##SELECT\ncrates/router/src/routes/lock_utils.rs::ApiIdentifier::from\ncrates/router/src/compatibility/stripe/errors.rs::StripeErrorCode::status_code\n<EOS>"}
|
| 4 |
+
{"input": "##TASK\nAdd CardNumber variant to PaymentMethodVaultingData enum to support direct card number vaulting operations. This enables the system to handle card number tokenization and storage scenarios where only the card number is available without full card details, expanding vaulting flexibility for various payment method collection flows.", "output": "##OUTPUT\nThe payment method vaulting system needs to support card number-only vaulting scenarios for simplified tokenization flows. The data flow is: client sends card number → PaymentMethodVaultingData::CardNumber (hyperswitch_domain_models/src/payment_method_data.rs::PaymentMethodVaultingData) receives the cards::CardNumber → vaulting operation processes the card number → returns tokenized payment method. The PaymentMethodVaultingData enum must add CardNumber(cards::CardNumber) variant because the current enum only supports full Card details and NetworkToken, but some vaulting flows only have access to the card number (like partial card collection or migration scenarios). Without this variant, the system cannot represent card number-only vaulting requests, forcing clients to provide unnecessary card details or preventing vaulting operations entirely. The PaymentMethodData enum (hyperswitch_domain_models/src/payment_method_data.rs::PaymentMethodData) is excluded because it already handles complete payment method data structures - the vaulting enum is specifically for simplified vaulting scenarios where only partial information is available.\n\n##SELECT\ncrates/hyperswitch_domain_models/src/payment_method_data.rs::PaymentMethodVaultingData\n<EOS>"}
|
| 5 |
+
{"input": "##TASK\nUpdate Gigadat connector to use FloatMajorUnit for amount conversion and add proper authentication headers for refund operations. This fixes amount processing accuracy and ensures refund requests include required authentication credentials, preventing transaction failures due to incorrect amount formatting or missing authentication.", "output": "##OUTPUT\nThe Gigadat connector requires fixes for amount conversion and authentication to properly process refund transactions. The complete flow is: refund request initiated → Gigadat::get_headers (hyperswitch_connectors/src/connectors/gigadat.rs::Gigadat::get_headers) builds Basic auth header from connector_auth_type → GigadatRefundRequest::try_from (hyperswitch_connectors/src/connectors/gigadat/transformers.rs::GigadatRefundRequest::try_from) creates request with transaction_id and campaign_id → amount processed using FloatMajorUnitForConverter (hyperswitch_connectors/src/connectors/gigadat.rs::Gigadat::new) → error responses handled by build_error_response (hyperswitch_connectors/src/connectors/gigadat.rs::Gigadat::build_error_response). The get_headers function must change from RefundSyncRouterData to RefundsRouterData<Execute> and implement proper Basic auth because the current signature doesn't match the actual refund flow and lacks authentication, causing 401 Unauthorized errors. The try_from function must add transaction_id and campaign_id fields because Gigadat's API requires these identifiers for refund processing - without them, refunds fail with 'missing required fields' errors. The new function must use FloatMajorUnitForConnector instead of StringMinorUnitForConnector because Gigadat expects decimal amounts, not integer minor units - using the wrong converter causes amount validation failures and transaction rejections. The build_error_response must use response.err for code/message/reason fields because that's the actual error field in Gigadat's response structure - using non-existent fields causes response parsing failures.\n\n##SELECT\ncrates/hyperswitch_connectors/src/connectors/gigadat.rs::Gigadat::new\ncrates/hyperswitch_connectors/src/connectors/gigadat.rs::Gigadat::get_headers\ncrates/hyperswitch_connectors/src/connectors/gigadat.rs::Gigadat::build_error_response\ncrates/hyperswitch_connectors/src/connectors/gigadat/transformers.rs::GigadatRefundRequest::try_from\n<EOS>"}
|
trainer-kit/GRPO-14B/sft_dataset_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"input": "##TASK\nUpdate connector configuration and payment response structures to support new payment processors and extended authorization features. This enables integration with Finix and Tesouro connectors while providing visibility into extended authorization requests in payment responses.", "output": "##OUTPUT\nThe connector configuration system requires new connector entries and payment responses need extended authorization tracking to support enhanced payment processing capabilities. The flow is: system loads ConnectorConfig at startup → populates with all supported connectors → payment processing uses configured connectors → payment response includes authorization details. The ConnectorConfig struct (connector_configs/src/connector.rs::ConnectorConfig) must add finix and tesouro fields because it defines the complete set of supported connectors—without these fields, the system cannot load configurations for Finix and Tesouro, preventing merchants from using these payment processors. The PaymentsResponse::foreign_from method (router/src/core/payments/transformers.rs::api::PaymentsResponse::foreign_from) needs to add request_extended_authorization field because it maps database payment attempt data to API responses—without this field, merchants cannot see whether extended authorization was requested, losing visibility into payment processing decisions. These changes are independent but both enhance payment processing capabilities: one expands connector support, the other improves response transparency.\n\n##SELECT\ncrates/connector_configs/src/connector.rs::ConnectorConfig\ncrates/router/src/core/payments/transformers.rs::api::PaymentsResponse::foreign_from\n<EOS>"}
|
| 2 |
+
{"input": "##TASK\nFix TokenEx vault integration to properly handle card tokenization and improve vault error messaging for better debugging. This ensures card data is correctly sent to TokenEx and provides clearer error messages when vault operations fail.", "output": "##OUTPUT\nThe TokenEx vault integration has incorrect data handling and poor error messaging that need to be fixed for reliable payment method vaulting. The flow is: merchant initiates vault operation → TokenexInsertRequest::try_from transforms request → sends to TokenEx → response processed by get_vault_response_for_retrieve_payment_method_data_v1. The TokenexInsertRequest::try_from method (hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from) must change from encoding entire card as JSON to sending just card_number because TokenEx API expects raw card number tokenization, not JSON payload—without this change, TokenEx rejects requests with malformed data. The get_vault_response_for_retrieve_payment_method_data_v1 function (router/src/core/payment_methods/vault.rs::get_vault_response_for_retrieve_payment_method_data_v1) needs updated error messages to specify 'from external vault' because current generic messages don't indicate vault source—without this, debugging vault failures becomes difficult when multiple vault providers are configured. These fixes are independent but both improve TokenEx integration reliability.\n\n##SELECT\ncrates/hyperswitch_connectors/src/connectors/tokenex/transformers.rs::TokenexInsertRequest::try_from\ncrates/router/src/core/payment_methods/vault.rs::get_vault_response_for_retrieve_payment_method_data_v1\n<EOS>"}
|
| 3 |
+
{"input": "##TASK\nAdd ConfirmSubscription flow to API lock management and subscription error handling to improve subscription lifecycle support. This enables proper resource locking for subscription confirmation operations and provides appropriate error status codes for subscription-related failures, ensuring the system can handle subscription confirmation requests with proper concurrency control and error reporting.", "output": "##OUTPUT\nThe subscription management system requires support for subscription confirmation operations with proper API locking and error handling. The flow is: HTTP POST for subscription confirmation → ApiIdentifier::from (router/src/routes/lock_utils.rs::ApiIdentifier::from) maps Flow::ConfirmSubscription to Self::Subscription for resource locking → subscription operation executes → if subscription fails, StripeErrorCode::status_code (router/src/compatibility/stripe/errors.rs::StripeErrorCode::status_code) returns INTERNAL_SERVER_ERROR for SubscriptionError variants. The ApiIdentifier::from function must include Flow::ConfirmSubscription in the subscription match arm because without it, subscription confirmation requests would not acquire proper locks, leading to potential race conditions during concurrent subscription operations. The StripeErrorCode::status_code function must handle SubscriptionError because subscription operations can fail and need proper HTTP status code mapping - without this, subscription errors would default to 200 OK instead of 500 Internal Server Error, breaking REST API conventions and confusing clients about operation success/failure. These two functions are tightly coupled through the subscription lifecycle - one handles the locking before operation, the other handles error reporting after operation.\n\n##SELECT\ncrates/router/src/routes/lock_utils.rs::ApiIdentifier::from\ncrates/router/src/compatibility/stripe/errors.rs::StripeErrorCode::status_code\n<EOS>"}
|
trainer-kit/GRPO-14B/start_vllm_server.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Start vLLM OpenAI-compatible API server
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# bash start_vllm_server.sh
|
| 6 |
+
|
| 7 |
+
# Use CPT model (no tokenizer issues) or base model
|
| 8 |
+
MODEL_PATH="/workspace/Models/Qwen2.5-Coder-14B-CPT"
|
| 9 |
+
PORT=8000
|
| 10 |
+
TENSOR_PARALLEL_SIZE=2
|
| 11 |
+
GPU_MEMORY_UTIL=0.85
|
| 12 |
+
|
| 13 |
+
echo "Starting vLLM server..."
|
| 14 |
+
echo "Model: $MODEL_PATH"
|
| 15 |
+
echo "Port: $PORT"
|
| 16 |
+
echo "Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
|
| 17 |
+
echo "GPU Memory Utilization: $GPU_MEMORY_UTIL"
|
| 18 |
+
echo ""
|
| 19 |
+
|
| 20 |
+
python -m vllm.entrypoints.openai.api_server \
|
| 21 |
+
--model "$MODEL_PATH" \
|
| 22 |
+
--port $PORT \
|
| 23 |
+
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
|
| 24 |
+
--gpu-memory-utilization $GPU_MEMORY_UTIL \
|
| 25 |
+
--dtype bfloat16 \
|
| 26 |
+
--max-model-len 4096 \
|
| 27 |
+
--host 0.0.0.0 \
|
| 28 |
+
--disable-log-requests
|
| 29 |
+
vllm serve /workspace/Models/Qwen2.5-Coder-14B-CPT-SFT \
|
| 30 |
+
--served-model-name qwen2.5-coder-14b \
|
| 31 |
+
--tokenizer Qwen/Qwen2.5-Coder-14B \
|
| 32 |
+
--tokenizer-mode slow \
|
| 33 |
+
--tensor-parallel-size 2 \
|
| 34 |
+
--gpu-memory-utilization 0.4 \
|
| 35 |
+
--dtype bfloat16 \
|
| 36 |
+
--max-model-len 4096 \
|
| 37 |
+
--host 0.0.0.0 \
|
| 38 |
+
--port 8000 \
|
| 39 |
+
--disable-log-requests
|