oxdev commited on
Commit
75be256
·
verified ·
1 Parent(s): 93b6a9a

fix: escape syntax in quality_reward

Browse files
Files changed (1) hide show
  1. train_grpo_v2.py +8 -10
train_grpo_v2.py CHANGED
@@ -214,7 +214,7 @@ def quality_reward(prompts, completions, completion_ids=None, **kwargs):
214
  reward += 0.1
215
 
216
  # Penalize generic/unhelpful responses
217
- generic_phrases = ['i cannot', 'i don\\'t', 'no vulnerabilities found', 'the code looks safe']
218
  if any(p in text.lower() for p in generic_phrases):
219
  reward -= 0.5
220
 
@@ -239,42 +239,40 @@ def main():
239
  dataset = load_dataset(DATASET_ID, split="train")
240
  logger.info(f"Dataset: {len(dataset)} samples, columns={dataset.column_names}")
241
 
 
 
 
242
  # Log severity distribution
243
  sev_dist = Counter(dataset['severity'])
244
  logger.info(f"Severity distribution: {dict(sev_dist)}")
245
 
246
- # Subsample for 0.5B model on T4 — use 15K most relevant samples
247
  # Prioritize: HIGH > MEDIUM > has_code > has_poc
248
  logger.info("Selecting high-quality training subset...")
249
  indices = []
250
- idx_set = set()
251
 
252
  # Priority 1: HIGH severity with code (most valuable)
253
  for i, row in enumerate(dataset):
254
  if row['severity'] in ('high', 'critical') and row['has_code']:
255
  indices.append(i)
256
- idx_set.add(i)
257
  logger.info(f" HIGH+CRITICAL with code: {len(indices)}")
258
 
259
  # Priority 2: MEDIUM severity with code
260
  for i, row in enumerate(dataset):
261
- if row['severity'] == 'medium' and row['has_code'] and i not in idx_set:
262
  indices.append(i)
263
- idx_set.add(i)
264
  logger.info(f" + MEDIUM with code: {len(indices)}")
265
 
266
  # Priority 3: Any with PoC reference
267
  for i, row in enumerate(dataset):
268
- if row['has_poc'] and i not in idx_set:
269
  indices.append(i)
270
- idx_set.add(i)
271
  logger.info(f" + Has PoC: {len(indices)}")
272
 
273
  # Priority 4: Fill remaining with HIGH/MED without code
274
  for i, row in enumerate(dataset):
275
- if row['severity'] in ('high', 'medium', 'critical') and i not in idx_set:
276
  indices.append(i)
277
- idx_set.add(i)
278
  if len(indices) >= 15000:
279
  break
280
  logger.info(f" Final subset: {len(indices)} samples")
 
214
  reward += 0.1
215
 
216
  # Penalize generic/unhelpful responses
217
+ generic_phrases = ['i cannot', 'i don\'t', 'no vulnerabilities found', 'the code looks safe']
218
  if any(p in text.lower() for p in generic_phrases):
219
  reward -= 0.5
220
 
 
239
  dataset = load_dataset(DATASET_ID, split="train")
240
  logger.info(f"Dataset: {len(dataset)} samples, columns={dataset.column_names}")
241
 
242
+ # For GRPO we only need 'prompt' column + metadata columns for reward
243
+ # The reward functions access metadata via kwargs passed from the dataset
244
+
245
  # Log severity distribution
246
  sev_dist = Counter(dataset['severity'])
247
  logger.info(f"Severity distribution: {dict(sev_dist)}")
248
 
249
+ # Subsample for 0.5B model on T4 — use 10K most relevant samples
250
  # Prioritize: HIGH > MEDIUM > has_code > has_poc
251
  logger.info("Selecting high-quality training subset...")
252
  indices = []
 
253
 
254
  # Priority 1: HIGH severity with code (most valuable)
255
  for i, row in enumerate(dataset):
256
  if row['severity'] in ('high', 'critical') and row['has_code']:
257
  indices.append(i)
 
258
  logger.info(f" HIGH+CRITICAL with code: {len(indices)}")
259
 
260
  # Priority 2: MEDIUM severity with code
261
  for i, row in enumerate(dataset):
262
+ if row['severity'] == 'medium' and row['has_code'] and i not in set(indices):
263
  indices.append(i)
 
264
  logger.info(f" + MEDIUM with code: {len(indices)}")
265
 
266
  # Priority 3: Any with PoC reference
267
  for i, row in enumerate(dataset):
268
+ if row['has_poc'] and i not in set(indices):
269
  indices.append(i)
 
270
  logger.info(f" + Has PoC: {len(indices)}")
271
 
272
  # Priority 4: Fill remaining with HIGH/MED without code
273
  for i, row in enumerate(dataset):
274
+ if row['severity'] in ('high', 'medium', 'critical') and i not in set(indices):
275
  indices.append(i)
 
276
  if len(indices) >= 15000:
277
  break
278
  logger.info(f" Final subset: {len(indices)} samples")