Spaces:
Paused
Paused
Commit ·
deef82c
1
Parent(s): ee5ddee
Fix truncation: 80 tokens, regex safety net, strict prompt
Browse files- cloud_arena/llm_training.py +26 -10
cloud_arena/llm_training.py
CHANGED
|
@@ -29,7 +29,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 29 |
# ── GPU Optimization Constants ────────────────────────────────────────────────
|
| 30 |
GRAD_ACCUM_STEPS = 4 # accumulate gradients over N episodes before stepping
|
| 31 |
MAX_SEQ_LEN = 512 # shorter context = O(N²) attention is 4× faster than 1024
|
| 32 |
-
MAX_GEN_TOKENS =
|
| 33 |
|
| 34 |
|
| 35 |
def format_prompt(state_dict):
|
|
@@ -54,22 +54,38 @@ def format_prompt(state_dict):
|
|
| 54 |
f"- Never delete/stop prod resources or those with >=5 deps\n"
|
| 55 |
f"- Temp resources with 0-1 deps are safe to delete\n"
|
| 56 |
f"- RESIZE is always safe\n\n"
|
|
|
|
| 57 |
f"REASONING:"
|
| 58 |
)
|
| 59 |
|
| 60 |
|
| 61 |
def extract_action_and_reasoning(response_text):
|
|
|
|
| 62 |
reasoning = response_text.strip()
|
| 63 |
-
action = 2
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
if action_match:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
return action, reasoning
|
| 74 |
|
| 75 |
|
|
|
|
| 29 |
# ── GPU Optimization Constants ────────────────────────────────────────────────
|
| 30 |
GRAD_ACCUM_STEPS = 4 # accumulate gradients over N episodes before stepping
|
| 31 |
MAX_SEQ_LEN = 512 # shorter context = O(N²) attention is 4× faster than 1024
|
| 32 |
+
MAX_GEN_TOKENS = 80 # enough room for reasoning + ACTION line, not enough to ramble
|
| 33 |
|
| 34 |
|
| 35 |
def format_prompt(state_dict):
|
|
|
|
| 54 |
f"- Never delete/stop prod resources or those with >=5 deps\n"
|
| 55 |
f"- Temp resources with 0-1 deps are safe to delete\n"
|
| 56 |
f"- RESIZE is always safe\n\n"
|
| 57 |
+
f"CRITICAL: Output ONLY a brief reason then ACTION: <number 0-4>. Nothing else.\n\n"
|
| 58 |
f"REASONING:"
|
| 59 |
)
|
| 60 |
|
| 61 |
|
| 62 |
def extract_action_and_reasoning(response_text):
|
| 63 |
+
"""Regex safety net: extracts action even from truncated/malformed output."""
|
| 64 |
reasoning = response_text.strip()
|
| 65 |
+
action = 2 # Default to RESIZE (safest action)
|
| 66 |
+
|
| 67 |
+
# Try explicit ACTION: N format first
|
| 68 |
+
action_match = re.search(r'ACTION:\s*([0-4])', response_text, re.IGNORECASE)
|
| 69 |
if action_match:
|
| 70 |
+
return int(action_match.group(1)), reasoning
|
| 71 |
+
|
| 72 |
+
# Try JSON format: {"action": N} or {"action": "DELETE"}
|
| 73 |
+
json_match = re.search(r'\{.*?\}', response_text, re.DOTALL)
|
| 74 |
+
if json_match:
|
| 75 |
+
try:
|
| 76 |
+
import json
|
| 77 |
+
parsed = json.loads(json_match.group(0))
|
| 78 |
+
a = parsed.get("action", 2)
|
| 79 |
+
if isinstance(a, int) and 0 <= a <= 4:
|
| 80 |
+
return a, reasoning
|
| 81 |
+
except (json.JSONDecodeError, ValueError):
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
# Last resort: any digit 0-4 near the end
|
| 85 |
+
digit_matches = re.findall(r'\b([0-4])\b', response_text[-30:])
|
| 86 |
+
if digit_matches:
|
| 87 |
+
action = int(digit_matches[-1])
|
| 88 |
+
|
| 89 |
return action, reasoning
|
| 90 |
|
| 91 |
|