10doshi12 commited on
Commit
25ec612
·
1 Parent(s): 2853152

fix: runtime issues with docker build on previous changes and openenv complaince

Browse files
Files changed (3) hide show
  1. inference.py +8 -4
  2. rewards.py +36 -17
  3. server/app.py +69 -0
inference.py CHANGED
@@ -336,18 +336,22 @@ def find_root_cause(services: dict, dep_graph: dict) -> Optional[str]:
336
 
337
  def _pick_remediation(service_name: str, fetched_logs: dict) -> dict:
338
  """Pick remediation action based on log keywords for the service."""
339
- logs = fetched_logs.get(service_name, [])
340
- log_text = " ".join(logs).lower()
 
 
 
 
341
  if "oomkilled" in log_text or "exit code 137" in log_text or "memory limit" in log_text:
342
  return {"action_type": "restart_service", "target_service": service_name}
 
 
343
  if "hikaripool" in log_text or "connection pool" in log_text or "timed out after" in log_text:
344
  return {"action_type": "revert_config", "target_service": service_name}
345
  if "connection refused" in log_text or "circuit breaker" in log_text:
346
  return {"action_type": "circuit_break", "target_service": service_name}
347
  if "memory leak" in log_text or "high latency" in log_text:
348
  return {"action_type": "scale_replicas", "target_service": service_name}
349
- if "nullpointerexception" in log_text or "deploy" in log_text or "version" in log_text:
350
- return {"action_type": "rollback_deploy", "target_service": service_name}
351
  return {"action_type": "restart_service", "target_service": service_name}
352
 
353
 
 
336
 
337
  def _pick_remediation(service_name: str, fetched_logs: dict) -> dict:
338
  """Pick remediation action based on log keywords for the service."""
339
+ raw = fetched_logs.get(service_name, [])
340
+ # Accept both str (single log blob) and list of log lines
341
+ if isinstance(raw, str):
342
+ log_text = raw.lower()
343
+ else:
344
+ log_text = " ".join(raw).lower()
345
  if "oomkilled" in log_text or "exit code 137" in log_text or "memory limit" in log_text:
346
  return {"action_type": "restart_service", "target_service": service_name}
347
+ if "nullpointerexception" in log_text or "deploy" in log_text or "version" in log_text:
348
+ return {"action_type": "rollback_deploy", "target_service": service_name}
349
  if "hikaripool" in log_text or "connection pool" in log_text or "timed out after" in log_text:
350
  return {"action_type": "revert_config", "target_service": service_name}
351
  if "connection refused" in log_text or "circuit breaker" in log_text:
352
  return {"action_type": "circuit_break", "target_service": service_name}
353
  if "memory leak" in log_text or "high latency" in log_text:
354
  return {"action_type": "scale_replicas", "target_service": service_name}
 
 
355
  return {"action_type": "restart_service", "target_service": service_name}
356
 
357
 
rewards.py CHANGED
@@ -13,23 +13,42 @@ from __future__ import annotations
13
 
14
  from dataclasses import dataclass, field
15
 
16
- from firewatch_env.models import SystemObservation, FirewatchAction
17
- from firewatch_env.config import (
18
- REWARD_WEIGHT_HEALTH,
19
- REWARD_WEIGHT_SLO,
20
- REWARD_MTTM_BONUS,
21
- REWARD_TIME_COST,
22
- REWARD_WRONG_ACTION_PENALTY,
23
- REWARD_SLO_BREACH_PENALTY,
24
- GRADER_WEIGHT_RECOVERY,
25
- GRADER_WEIGHT_SPEED,
26
- GRADER_WEIGHT_PRECISION,
27
- GRADER_WEIGHT_SLO,
28
- GRADER_WRONG_ACTION_PENALTY_PER_ACTION,
29
- GRADER_SPEED_MTTM_WEIGHT,
30
- GRADER_SPEED_BCM_WEIGHT,
31
- TASKS,
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  # ==========================================================================
 
13
 
14
  from dataclasses import dataclass, field
15
 
16
+ try:
17
+ from .models import SystemObservation, FirewatchAction
18
+ from .config import (
19
+ REWARD_WEIGHT_HEALTH,
20
+ REWARD_WEIGHT_SLO,
21
+ REWARD_MTTM_BONUS,
22
+ REWARD_TIME_COST,
23
+ REWARD_WRONG_ACTION_PENALTY,
24
+ REWARD_SLO_BREACH_PENALTY,
25
+ GRADER_WEIGHT_RECOVERY,
26
+ GRADER_WEIGHT_SPEED,
27
+ GRADER_WEIGHT_PRECISION,
28
+ GRADER_WEIGHT_SLO,
29
+ GRADER_WRONG_ACTION_PENALTY_PER_ACTION,
30
+ GRADER_SPEED_MTTM_WEIGHT,
31
+ GRADER_SPEED_BCM_WEIGHT,
32
+ TASKS,
33
+ )
34
+ except ImportError:
35
+ from models import SystemObservation, FirewatchAction
36
+ from config import (
37
+ REWARD_WEIGHT_HEALTH,
38
+ REWARD_WEIGHT_SLO,
39
+ REWARD_MTTM_BONUS,
40
+ REWARD_TIME_COST,
41
+ REWARD_WRONG_ACTION_PENALTY,
42
+ REWARD_SLO_BREACH_PENALTY,
43
+ GRADER_WEIGHT_RECOVERY,
44
+ GRADER_WEIGHT_SPEED,
45
+ GRADER_WEIGHT_PRECISION,
46
+ GRADER_WEIGHT_SLO,
47
+ GRADER_WRONG_ACTION_PENALTY_PER_ACTION,
48
+ GRADER_SPEED_MTTM_WEIGHT,
49
+ GRADER_SPEED_BCM_WEIGHT,
50
+ TASKS,
51
+ )
52
 
53
 
54
  # ==========================================================================
server/app.py CHANGED
@@ -28,9 +28,13 @@ Usage:
28
  python -m server.app
29
  """
30
 
 
 
31
  from fastapi import Request
32
  from fastapi.exceptions import RequestValidationError
33
  from fastapi.responses import JSONResponse
 
 
34
 
35
  try:
36
  from openenv.core.env_server.http_server import create_app
@@ -67,6 +71,71 @@ app = create_app(
67
  )
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Zero-crash policy (CLAUDE.md): invalid requests must return HTTP 200 with error
71
  # in the response body, never HTTP 422 or 500.
72
  @app.exception_handler(RequestValidationError)
 
28
  python -m server.app
29
  """
30
 
31
+ import json
32
+
33
  from fastapi import Request
34
  from fastapi.exceptions import RequestValidationError
35
  from fastapi.responses import JSONResponse
36
+ from starlette.middleware.base import BaseHTTPMiddleware
37
+ from starlette.responses import Response
38
 
39
  try:
40
  from openenv.core.env_server.http_server import create_app
 
71
  )
72
 
73
 
74
+ class StepInfoMiddleware(BaseHTTPMiddleware):
75
+ """
76
+ Middleware that injects an ``info`` dict into /step responses.
77
+
78
+ The openenv-core framework serializes SystemObservation by promoting
79
+ ``reward`` and ``done`` to the top level and dropping ``metadata``.
80
+ This middleware re-attaches the metadata as ``info`` so downstream
81
+ clients can read ``info["episode_score"]`` without digging into
82
+ ``observation``.
83
+
84
+ Only activates on POST /step responses with JSON content.
85
+ """
86
+
87
+ async def dispatch(self, request: Request, call_next) -> Response:
88
+ response = await call_next(request)
89
+
90
+ if request.url.path == "/step" and request.method == "POST":
91
+ try:
92
+ body_bytes = b""
93
+ async for chunk in response.body_iterator:
94
+ body_bytes += chunk
95
+ data = json.loads(body_bytes)
96
+
97
+ obs = data.get("observation", {})
98
+ # Build info from observation fields that belong in metadata
99
+ info: dict = {}
100
+ if "episode_score" in obs and obs["episode_score"] is not None:
101
+ info["episode_score"] = float(obs["episode_score"])
102
+ # Propagate any error info
103
+ if "error" in obs:
104
+ info["error"] = obs["error"]
105
+
106
+ data["info"] = info
107
+
108
+ new_body = json.dumps(data).encode("utf-8")
109
+ # Build headers without content-length so Starlette sets it correctly
110
+ headers = {
111
+ k: v for k, v in response.headers.items()
112
+ if k.lower() != "content-length"
113
+ }
114
+ return Response(
115
+ content=new_body,
116
+ status_code=response.status_code,
117
+ headers=headers,
118
+ media_type="application/json",
119
+ )
120
+ except Exception:
121
+ # Never crash — return original response on any middleware error
122
+ headers = {
123
+ k: v for k, v in response.headers.items()
124
+ if k.lower() != "content-length"
125
+ }
126
+ return Response(
127
+ content=body_bytes,
128
+ status_code=response.status_code,
129
+ headers=headers,
130
+ media_type=response.media_type,
131
+ )
132
+
133
+ return response
134
+
135
+
136
+ app.add_middleware(StepInfoMiddleware)
137
+
138
+
139
  # Zero-crash policy (CLAUDE.md): invalid requests must return HTTP 200 with error
140
  # in the response body, never HTTP 422 or 500.
141
  @app.exception_handler(RequestValidationError)