kavin57447 commited on
Commit
deef82c
·
1 Parent(s): ee5ddee

Fix truncation: 80 tokens, regex safety net, strict prompt

Browse files
Files changed (1) hide show
  1. cloud_arena/llm_training.py +26 -10
cloud_arena/llm_training.py CHANGED
@@ -29,7 +29,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  # ── GPU Optimization Constants ────────────────────────────────────────────────
30
  GRAD_ACCUM_STEPS = 4 # accumulate gradients over N episodes before stepping
31
  MAX_SEQ_LEN = 512 # shorter context = O(N²) attention is 4× faster than 1024
32
- MAX_GEN_TOKENS = 32 # force brief output just need the ACTION line, not essays
33
 
34
 
35
  def format_prompt(state_dict):
@@ -54,22 +54,38 @@ def format_prompt(state_dict):
54
  f"- Never delete/stop prod resources or those with >=5 deps\n"
55
  f"- Temp resources with 0-1 deps are safe to delete\n"
56
  f"- RESIZE is always safe\n\n"
 
57
  f"REASONING:"
58
  )
59
 
60
 
61
  def extract_action_and_reasoning(response_text):
 
62
  reasoning = response_text.strip()
63
- action = 2
64
- action_match = re.search(r'ACTION:\s*(\d)', response_text, re.IGNORECASE)
 
 
65
  if action_match:
66
- parsed = int(action_match.group(1))
67
- if 0 <= parsed <= 4:
68
- action = parsed
69
- else:
70
- digit_matches = re.findall(r'\b([0-4])\b', response_text[-50:])
71
- if digit_matches:
72
- action = int(digit_matches[-1])
 
 
 
 
 
 
 
 
 
 
 
 
73
  return action, reasoning
74
 
75
 
 
29
  # ── GPU Optimization Constants ────────────────────────────────────────────────
30
  GRAD_ACCUM_STEPS = 4 # accumulate gradients over N episodes before stepping
31
  MAX_SEQ_LEN = 512 # shorter context = O(N²) attention is 4× faster than 1024
32
+ MAX_GEN_TOKENS = 80 # enough room for reasoning + ACTION line, not enough to ramble
33
 
34
 
35
  def format_prompt(state_dict):
 
54
  f"- Never delete/stop prod resources or those with >=5 deps\n"
55
  f"- Temp resources with 0-1 deps are safe to delete\n"
56
  f"- RESIZE is always safe\n\n"
57
+ f"CRITICAL: Output ONLY a brief reason then ACTION: <number 0-4>. Nothing else.\n\n"
58
  f"REASONING:"
59
  )
60
 
61
 
62
  def extract_action_and_reasoning(response_text):
63
+ """Regex safety net: extracts action even from truncated/malformed output."""
64
  reasoning = response_text.strip()
65
+ action = 2 # Default to RESIZE (safest action)
66
+
67
+ # Try explicit ACTION: N format first
68
+ action_match = re.search(r'ACTION:\s*([0-4])', response_text, re.IGNORECASE)
69
  if action_match:
70
+ return int(action_match.group(1)), reasoning
71
+
72
+ # Try JSON format: {"action": N} or {"action": "DELETE"}
73
+ json_match = re.search(r'\{.*?\}', response_text, re.DOTALL)
74
+ if json_match:
75
+ try:
76
+ import json
77
+ parsed = json.loads(json_match.group(0))
78
+ a = parsed.get("action", 2)
79
+ if isinstance(a, int) and 0 <= a <= 4:
80
+ return a, reasoning
81
+ except (json.JSONDecodeError, ValueError):
82
+ pass
83
+
84
+ # Last resort: any digit 0-4 near the end
85
+ digit_matches = re.findall(r'\b([0-4])\b', response_text[-30:])
86
+ if digit_matches:
87
+ action = int(digit_matches[-1])
88
+
89
  return action, reasoning
90
 
91