oxdev commited on
Commit
3c818d7
·
verified ·
1 Parent(s): 75be256

v2: 5K subset for A10G, fix escaping

Browse files
Files changed (1) hide show
  1. train_grpo_v2.py +13 -15
train_grpo_v2.py CHANGED
@@ -246,34 +246,32 @@ def main():
246
  sev_dist = Counter(dataset['severity'])
247
  logger.info(f"Severity distribution: {dict(sev_dist)}")
248
 
249
- # Subsample for 0.5B model on T4 use 10K most relevant samples
250
- # Prioritize: HIGH > MEDIUM > has_code > has_poc
251
- logger.info("Selecting high-quality training subset...")
252
  indices = []
 
253
 
254
- # Priority 1: HIGH severity with code (most valuable)
255
  for i, row in enumerate(dataset):
256
  if row['severity'] in ('high', 'critical') and row['has_code']:
257
  indices.append(i)
 
258
  logger.info(f" HIGH+CRITICAL with code: {len(indices)}")
259
 
260
- # Priority 2: MEDIUM severity with code
261
  for i, row in enumerate(dataset):
262
- if row['severity'] == 'medium' and row['has_code'] and i not in set(indices):
263
- indices.append(i)
264
- logger.info(f" + MEDIUM with code: {len(indices)}")
265
-
266
- # Priority 3: Any with PoC reference
267
- for i, row in enumerate(dataset):
268
- if row['has_poc'] and i not in set(indices):
269
  indices.append(i)
 
270
  logger.info(f" + Has PoC: {len(indices)}")
271
 
272
- # Priority 4: Fill remaining with HIGH/MED without code
273
  for i, row in enumerate(dataset):
274
- if row['severity'] in ('high', 'medium', 'critical') and i not in set(indices):
275
  indices.append(i)
276
- if len(indices) >= 15000:
 
277
  break
278
  logger.info(f" Final subset: {len(indices)} samples")
279
 
 
246
  sev_dist = Counter(dataset['severity'])
247
  logger.info(f"Severity distribution: {dict(sev_dist)}")
248
 
249
+ # Subsample 5K highest-value samples for A10G (fits in ~6hrs)
250
+ # Focus on HIGH+CRITICAL with code most valuable training signal
251
+ logger.info("Selecting high-quality training subset (5K for A10G)...")
252
  indices = []
253
+ idx_set = set()
254
 
255
+ # Priority 1: HIGH+CRITICAL severity with code (most valuable)
256
  for i, row in enumerate(dataset):
257
  if row['severity'] in ('high', 'critical') and row['has_code']:
258
  indices.append(i)
259
+ idx_set.add(i)
260
  logger.info(f" HIGH+CRITICAL with code: {len(indices)}")
261
 
262
+ # Priority 2: Any with PoC reference
263
  for i, row in enumerate(dataset):
264
+ if row['has_poc'] and i not in idx_set:
 
 
 
 
 
 
265
  indices.append(i)
266
+ idx_set.add(i)
267
  logger.info(f" + Has PoC: {len(indices)}")
268
 
269
+ # Priority 3: MEDIUM with code (fill to 5K cap)
270
  for i, row in enumerate(dataset):
271
+ if row['severity'] == 'medium' and row['has_code'] and i not in idx_set:
272
  indices.append(i)
273
+ idx_set.add(i)
274
+ if len(indices) >= 5000:
275
  break
276
  logger.info(f" Final subset: {len(indices)} samples")
277