Addyk24 commited on
Commit
92e2763
·
1 Parent(s): 7bafea6

Initialized RL environment

Browse files
Files changed (3) hide show
  1. envs/__init__.py +9 -0
  2. envs/environment.py +511 -0
  3. envs/errors.py +11 -0
envs/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .environment import WorkSpaceEnvironment
2
+ from .errors import EnvironmentDoneError, EnvironmentNotResetError, EnvError
3
+
4
+ __all__ = [
5
+ "WorkSpaceEnvironment",
6
+ "EnvironmentDoneError",
7
+ "EnvironmentNotResetError",
8
+ "EnvError",
9
+ ]
envs/environment.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from dotenv import load_dotenv
3
+ except ImportError:
4
+ def load_dotenv():
5
+ return False
6
+
7
+ import time
8
+
9
+ load_dotenv()
10
+
11
+ import logging
12
+ import os
13
+
14
+ try:
15
+ from openai import OpenAI
16
+ from groq import Groq
17
+ except ImportError:
18
+ OpenAI = None
19
+
20
+ from envs.errors import EnvironmentDoneError
21
+ from models.schemas import ExpertState, WorkSpaceAction, WorkspaceObservation, WorkspaceState
22
+ from openenv.core import Environment
23
+ from prompter.system_prompt import SystemPrompt
24
+
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ import re
33
+
34
+ DISCOVERY_PATTERNS = {
35
+ "Finance": [
36
+ r"50\s*k",
37
+ r"50,000",
38
+ r"fifty thousand",
39
+ r"budget cap",
40
+ r"budget ceiling",
41
+ r"hard cap",
42
+ r"low[- ]five[- ]figure",
43
+ r"mid[- ]five[- ]figure",
44
+ r"five[- ]figure",
45
+ r"under (?:the )?ceiling",
46
+ r"under\s+\$?50k",
47
+ r"below\s+\$?50k",
48
+ r"sub-\$?50k",
49
+ ],
50
+ "Security": [
51
+ r"biometric",
52
+ r"2\s*fa",
53
+ r"m\s*fa",
54
+ r"two-factor",
55
+ r"second factor",
56
+ r"physiological",
57
+ ],
58
+ "UX": [
59
+ r"single[ -]click",
60
+ r"one[ -]click",
61
+ r"one[ -]tap",
62
+ r"single[ -]tap",
63
+ r"single[\u2011-]tap",
64
+ r"single[\u2011-]click",
65
+ r"frictionless purchase",
66
+ r"one decisive interaction",
67
+ ],
68
+ }
69
+
70
+
71
+ def normalize_environment_mode(mode: str | None) -> str:
72
+ canonical = (mode or "").strip().lower()
73
+ aliases = {
74
+ "": "mock",
75
+ "easy": "easy",
76
+ "deterministic": "mock",
77
+ "medium": "medium",
78
+ "hard": "hard",
79
+ "scripted": "mock",
80
+ "llm": "llm",
81
+ "live": "llm",
82
+ "online": "llm",
83
+ "remote": "llm",
84
+ "api": "llm",
85
+ }
86
+ if canonical not in aliases:
87
+ raise ValueError(f"Unsupported environment mode: {mode}")
88
+ return aliases[canonical]
89
+
90
+
91
+ class WorkSpaceEnvironment(Environment):
92
+ def __init__(self, mode: str | None = None):
93
+ self._state: WorkspaceState | None = None
94
+ self.system_prompt = SystemPrompt()
95
+
96
+ requested_mode = mode or os.getenv("BASELINE_ENV_MODE") or "easy"
97
+ self.mode = normalize_environment_mode(requested_mode)
98
+ self.env_model = os.getenv("ENV_MODEL_NAME") or os.getenv("MODEL_NAME") or "llama-3.1-8b-instant"
99
+ self._env_client: object | None = None
100
+
101
+
102
+ if self.mode in ["medium", "hard", "llm"]:
103
+ self.env_model = os.getenv("MODEL_NAME") or "llama-3.1-8b-instant"
104
+ self._env_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
105
+
106
+ # self._env_client = OpenAI(
107
+ # base_url=base_url,
108
+ # api_key=api_key,
109
+ # timeout=45.0,
110
+ # max_retries=2,
111
+ # )
112
+
113
+ def reset(self, topic="Draft the new Mobile App PRD") -> WorkspaceObservation:
114
+ experts = {
115
+ "Finance": ExpertState(name="Finance", hidden_constraint="Budget must not exceed $50k."),
116
+ "Security": ExpertState(name="Security", hidden_constraint="Must include biometric 2FA."),
117
+ "UX": ExpertState(name="UX", hidden_constraint="Checkout must be a single click."),
118
+ }
119
+
120
+ self._state = WorkspaceState(experts=experts, chat_history=[])
121
+
122
+ return WorkspaceObservation(
123
+ feedback=f"SYSTEM: You are the PM. {topic}. Message the experts to gather requirements.",
124
+ current_turn=0,
125
+ reward=0.0,
126
+ done=False,
127
+ )
128
+
129
+ def state(self) -> WorkspaceState:
130
+ if self._state is None:
131
+ raise Exception("Call reset() first")
132
+ return self._state
133
+
134
+ def step(self, action: WorkSpaceAction) -> WorkspaceObservation:
135
+ if self._state is None:
136
+ raise Exception("Call reset() before step()")
137
+ if self._state.is_done:
138
+ raise EnvironmentDoneError("Episode already terminated.")
139
+
140
+ self._state.turn_count += 1
141
+
142
+ feedback_text, _ = self._get_expert_feedback(action)
143
+
144
+ component_rewards = self._calculate_multi_reward(action, feedback_text)
145
+
146
+ self._state.chat_history.append({
147
+ "agent": action.content,
148
+ "world": feedback_text,
149
+ })
150
+
151
+ total_reward = 0.0
152
+
153
+ if self.mode == "easy":
154
+ # Goal: Discover all 3. Reward is sum of NEW discoveries.
155
+ total_reward = (
156
+ component_rewards["discovery_finance"] +
157
+ component_rewards["discovery_security"] +
158
+ component_rewards["discovery_ux"]
159
+ )
160
+
161
+ # TERMINATION
162
+ all_found = all(e.constraint_discovered_by_agent for e in self._state.experts.values())
163
+ if all_found or action.action_type == "submit_final":
164
+ self._state.is_done = True
165
+ if all_found:
166
+ feedback_text += "\nSYSTEM: All constraints discovered. Task complete."
167
+
168
+ elif self.mode in ["medium", "hard", "llm"]:
169
+ # Goal: Synthesis
170
+ if action.action_type == "submit_final":
171
+ self._state.is_done = True
172
+ scores = [
173
+ component_rewards["final_finance"],
174
+ component_rewards["final_security"],
175
+ component_rewards["final_ux"],
176
+ ]
177
+ # Harmonic Mean logic
178
+ total_reward = 0.0 if any(s == 0 for s in scores) else 3 / sum(1/s for s in scores)
179
+ else:
180
+ # Dense discovery 'nudges' (0.033 instead of 0.33)
181
+ total_reward = (
182
+ component_rewards["discovery_finance"] +
183
+ component_rewards["discovery_security"] +
184
+ component_rewards["discovery_ux"]
185
+ ) * 0.1
186
+
187
+ total_reward += component_rewards["penalty"]
188
+
189
+ # 6. Safety Turn Limit
190
+ if self._state.turn_count >= self._state.max_turns:
191
+ self._state.is_done = True
192
+ feedback_text += "\nSYSTEM: Turn limit reached."
193
+
194
+ return WorkspaceObservation(
195
+ feedback=feedback_text,
196
+ current_turn=self._state.turn_count,
197
+ reward=round(max(0, total_reward), 3),
198
+ done=self._state.is_done,
199
+ )
200
+
201
+
202
+ def _get_expert_feedback(self, action: WorkSpaceAction) -> tuple[str, float]:
203
+ """
204
+ Executes the expert logic based on action type.
205
+ Returns: (feedback_text, internal_dense_reward)
206
+ """
207
+ all_feedback = []
208
+ total_internal_reward = 0.0
209
+
210
+ if action.action_type == "message_expert":
211
+ target = action.target
212
+
213
+ if target == "All":
214
+ for name in self._state.experts:
215
+ self._update_frustration(name, action)
216
+ resp, reward = self.expert_response(name, action.content)
217
+ all_feedback.append(f"{name}: {resp}")
218
+ total_internal_reward += reward
219
+ feedback_text = "\n\n".join(all_feedback)
220
+
221
+ elif target in self._state.experts:
222
+ self._update_frustration(target, action)
223
+ resp, reward = self.expert_response(target, action.content)
224
+ feedback_text = f"{target}: {resp}"
225
+ total_internal_reward += reward
226
+
227
+ else:
228
+ feedback_text = f"SYSTEM: Unknown expert '{target}'."
229
+
230
+ elif action.action_type == "propose_draft":
231
+ for name in self._state.experts:
232
+ self._update_frustration(name, action)
233
+ resp, reward = self.expert_response(name, action.content)
234
+ all_feedback.append(f"{name}: {resp}")
235
+ # Small reward for progress, but less than discovery
236
+ total_internal_reward += (reward * 0.5)
237
+ feedback_text = "\n".join(all_feedback)
238
+
239
+ elif action.action_type == "submit_final":
240
+ feedback_text = "SYSTEM: Final draft received for grading."
241
+ total_internal_reward = 0.0
242
+
243
+ else:
244
+ feedback_text = f"SYSTEM: Invalid action_type '{action.action_type}'."
245
+
246
+ return feedback_text, total_internal_reward
247
+
248
+ def expert_response(self, expert_name: str, agent_message: str) -> tuple[str, float]:
249
+ expert = self._state.experts[expert_name]
250
+ response = self._generate_expert_response(expert, expert_name, agent_message)
251
+ # Discovery state is awarded and flipped in _calculate_multi_reward so the
252
+ # environment has a single source of truth for easy-mode reward.
253
+ return response, 0.0
254
+
255
+ def harmonic_mean_reward(self, draft: str) -> float:
256
+ scores = [
257
+ self._grade_draft_against_constraint(draft, expert.hidden_constraint)
258
+ for expert in self._state.experts.values()
259
+ ]
260
+
261
+ if any(score == 0 for score in scores):
262
+ return 0.0
263
+
264
+ harmonic = len(scores) / sum(1 / score for score in scores)
265
+ return round(harmonic, 3)
266
+
267
+ def _calculate_multi_reward(self, action: WorkSpaceAction, feedback_text: str) -> dict:
268
+ r = {
269
+ "discovery_finance": 0.0, "discovery_security": 0.0, "discovery_ux": 0.0,
270
+ "final_finance": 0.0, "final_security": 0.0, "final_ux": 0.0,
271
+ "penalty": 0.0
272
+ }
273
+
274
+ # 1. DISCOVERY (Only grant if NOT already discovered)
275
+ text = feedback_text.lower()
276
+ for name, patterns in DISCOVERY_PATTERNS.items():
277
+ expert = self._state.experts[name]
278
+ if not expert.constraint_discovered_by_agent:
279
+ if any(re.search(p, text) for p in patterns):
280
+ r[f"discovery_{name.lower()}"] = 0.33
281
+ expert.constraint_discovered_by_agent = True # FLIP THE BIT
282
+
283
+ # 2. FINAL SUBMISSION
284
+ if action.action_type == "submit_final":
285
+ for name, expert in self._state.experts.items():
286
+ r[f"final_{name.lower()}"] = self._grade_draft_against_constraint(
287
+ action.content,
288
+ expert.hidden_constraint,
289
+ )
290
+
291
+ # 3. PENALTIES
292
+ if action.action_type == "message_expert" and action.target == "All":
293
+ r["penalty"] -= 1.0 if self.mode == "easy" else 0.5
294
+ elif action.action_type == "propose_draft" and action.target == "All":
295
+ r["penalty"] -= 0.1 if self.mode in ["medium", "hard", "llm"] else 0.0
296
+
297
+ if self._is_repeated_question(action.content, action.target or ""):
298
+ r["penalty"] -= 0.4 # Doubled the repeat penalty
299
+
300
+ return r
301
+
302
+ def _grade_draft_against_constraint(self, draft: str, constraint: str) -> float:
303
+ # DETERMINISTIC VERIFIER (The "Smack It" Fix)
304
+ text = draft.lower()
305
+
306
+ # Finance Check
307
+ if "$50k" in constraint or "budget" in constraint:
308
+ mentions_amount = any(
309
+ x in text
310
+ for x in [
311
+ "50k",
312
+ "$50k",
313
+ "50,000",
314
+ "$50,000",
315
+ "fifty thousand",
316
+ "sub-$50k",
317
+ "sub 50k",
318
+ ]
319
+ )
320
+ mentions_limit = any(
321
+ token in text
322
+ for token in [
323
+ "under",
324
+ "below",
325
+ "at or below",
326
+ "not exceed",
327
+ "cap",
328
+ "ceiling",
329
+ "budget cap",
330
+ ]
331
+ )
332
+ if mentions_amount and mentions_limit:
333
+ return 1.0
334
+
335
+ # Security Check
336
+ if "biometric" in constraint:
337
+ if "biometric" in text and any(
338
+ token in text for token in ("2fa", "mfa", "two-factor", "multi-factor")
339
+ ):
340
+ return 1.0
341
+
342
+ # UX Check
343
+ if "single click" in constraint:
344
+ if any(
345
+ token in text
346
+ for token in ("single-click", "one-click", "single click", "one click", "single-tap", "one-tap")
347
+ ) and "checkout" in text:
348
+ return 1.0
349
+
350
+ # Fallback to LLM grading ONLY in live mode
351
+ if self.mode == "live":
352
+ # (Your existing LLM grader logic here)
353
+ pass
354
+
355
+ return 0.0
356
+
357
+ def _update_frustration(self, expert_name: str, action: WorkSpaceAction):
358
+ expert = self._state.experts[expert_name]
359
+ repeated_question = self._is_repeated_question(action.content, expert_name)
360
+ if repeated_question:
361
+ expert.frustration_level = min(10.0, expert.frustration_level + 1.0)
362
+
363
+ if expert.frustration_level >= 5.0 and not expert.constraint_shifted:
364
+ expert.hidden_constraint += " Also requires board approval."
365
+ expert.constraint_shifted = True
366
+
367
+ def _call_llm(self, prompt: str, max_tokens: int = 300) -> str:
368
+ if self._env_client is None:
369
+ raise RuntimeError("Environment client is not configured for llm mode.")
370
+
371
+ time.sleep(4.0)
372
+ try:
373
+ response = self._env_client.chat.completions.create(
374
+ model=self.env_model,
375
+ messages=[{"role": "user", "content": prompt}],
376
+ temperature=0.7,
377
+ max_tokens=max_tokens,
378
+ )
379
+ return response.choices[0].message.content.strip()
380
+ except Exception as exc:
381
+ logger.error(f"Environment LLM Error: {exc}")
382
+ raise
383
+
384
+ def _generate_expert_response(self, expert: ExpertState, expert_name: str, agent_message: str) -> str:
385
+ # If in EASY mode, don't even call Groq. Use pure string templates.
386
+ if self.mode == "easy":
387
+ responses = {
388
+ "Finance": "The budget cap is $50k. Don't go over it.",
389
+ "Security": "We require biometric 2FA. No exceptions.",
390
+ "UX": "The checkout must be a single-click flow."
391
+ }
392
+ return responses.get(expert_name, "I have no requirements.")
393
+
394
+ # Medium and Live still use the LLM
395
+ prompt = self.system_prompt.get_expert_prompt(expert, expert_name, agent_message)
396
+ return self._call_llm(prompt, max_tokens=300)
397
+
398
+ def _mock_expert_response(self, expert: ExpertState, expert_name: str, agent_message: str) -> str:
399
+ draft_score = self._mock_grade_constraint(agent_message, expert.hidden_constraint)
400
+ lower_message = agent_message.lower()
401
+ is_question = "?" in agent_message or any(
402
+ token in lower_message for token in ("please", "could you", "can you", "what", "which", "how")
403
+ )
404
+
405
+ if expert_name == "Finance":
406
+ if is_question:
407
+ response = (
408
+ "We need the initial release budget capped at or below $50k. "
409
+ "Please keep the scope lean and prioritize the highest-ROI features."
410
+ )
411
+ elif draft_score >= 0.9:
412
+ response = (
413
+ "This draft respects the sub-$50k budget and keeps scope disciplined. "
414
+ "From a finance perspective, the release plan looks viable."
415
+ )
416
+ else:
417
+ response = (
418
+ "I still need the PRD to explicitly cap the first release budget at $50k or less. "
419
+ "Right now the financial guardrails are too vague."
420
+ )
421
+ elif expert_name == "Security":
422
+ if is_question:
423
+ response = (
424
+ "Passwords alone will not be enough for this app. "
425
+ "We need biometric 2FA for sign-in and other sensitive actions."
426
+ )
427
+ elif draft_score >= 0.9:
428
+ response = (
429
+ "The draft now captures biometric 2FA clearly, which addresses our baseline security requirement. "
430
+ "That is the level of control we need."
431
+ )
432
+ else:
433
+ response = (
434
+ "The PRD still needs to call out biometric 2FA explicitly. "
435
+ "Without that requirement, the security posture is incomplete."
436
+ )
437
+ else:
438
+ if is_question:
439
+ response = (
440
+ "Checkout has to feel immediate for the user. "
441
+ "The flow should support a true single-click checkout with minimal friction."
442
+ )
443
+ elif draft_score >= 0.9:
444
+ response = (
445
+ "This draft captures the single-click checkout requirement well. "
446
+ "The flow now feels appropriately low-friction."
447
+ )
448
+ else:
449
+ response = (
450
+ "I still need the PRD to commit to a single-click checkout experience. "
451
+ "The current draft leaves too much friction in the funnel."
452
+ )
453
+
454
+ if expert.constraint_shifted:
455
+ response += " Any change of this size would also need board approval."
456
+
457
+ return response
458
+
459
+ def _mock_grade_constraint(self, draft: str, constraint: str) -> float:
460
+ text = draft.lower()
461
+ checks = []
462
+
463
+ if "$50k" in constraint:
464
+ checks.append(
465
+ any(token in text for token in ("$50k", "50k", "under 50k", "below 50k", "budget cap"))
466
+ and "budget" in text
467
+ )
468
+ if "biometric 2FA" in constraint:
469
+ checks.append(
470
+ "biometric" in text and any(token in text for token in ("2fa", "two-factor", "mfa", "multi-factor"))
471
+ )
472
+ if "single click" in constraint:
473
+ checks.append(
474
+ any(token in text for token in ("single click", "single-click", "one click", "one-click"))
475
+ and "checkout" in text
476
+ )
477
+ if "board approval" in constraint.lower():
478
+ checks.append("board approval" in text)
479
+
480
+ if not checks:
481
+ return 0.0
482
+
483
+ satisfied = sum(1 for check in checks if check)
484
+ return round(satisfied / len(checks), 3)
485
+
486
+ def _constraint_mentioned(self, response: str, constraint: str) -> bool:
487
+ constraint_keywords = constraint.lower().split()
488
+ stopwords = {"must", "the", "a", "an", "is", "be", "and", "or", "not", "to", "in"}
489
+ keywords = [word for word in constraint_keywords if word not in stopwords]
490
+ response_lower = response.lower()
491
+ matches = sum(1 for keyword in keywords if keyword in response_lower)
492
+ return matches >= max(1, len(keywords) // 2)
493
+
494
+ def _is_repeated_question(self, content: str, expert_name: str) -> bool:
495
+ previous = [
496
+ history["agent"] for history in self._state.chat_history if expert_name in history.get("world", "")
497
+ ]
498
+ if not previous:
499
+ return False
500
+
501
+ content_words = set(content.lower().split())
502
+ for prev in previous:
503
+ prev_words = set(prev.lower().split())
504
+ if not content_words:
505
+ continue
506
+
507
+ overlap = len(content_words & prev_words) / len(content_words)
508
+ if overlap > 0.7:
509
+ return True
510
+
511
+ return False
envs/errors.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EnvError(Exception):
2
+ """Base exception for environment errors."""
3
+ pass
4
+
5
+ class EnvironmentNotResetError(EnvError):
6
+ """Raised when stepping an environment before resetting it."""
7
+ pass
8
+
9
+ class EnvironmentDoneError(EnvError):
10
+ """Raised when stepping an environment that has already terminated."""
11
+ pass