CausalOps-Env / inference.py
omm7's picture
Upload folder using huggingface_hub
f84289a verified
from __future__ import annotations
import os
from typing import Any, Dict, List, Optional
import requests
from openai import OpenAI
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN") or OPENAI_API_KEY
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
LOGENV_URL = os.getenv("LOGENV_URL", "http://localhost:7860")
BENCHMARK = "NovaTechIncidentCommand"
SUCCESS_THRESHOLD = 0.70
client = OpenAI(api_key=HF_TOKEN or "placeholder", base_url=API_BASE_URL)
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error if error else 'null'}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
print(
f"[END] success={str(success).lower()} steps={steps} score={max(0.0, min(1.0, score)):.3f} rewards={','.join(f'{r:.2f}' for r in rewards)}",
flush=True,
)
def api_reset(task_id: str) -> Dict[str, Any]:
response = requests.post(f"{LOGENV_URL}/reset", json={"task_id": task_id}, timeout=30)
response.raise_for_status()
return response.json()
def api_step(payload: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{LOGENV_URL}/step", json=payload, timeout=60)
response.raise_for_status()
return response.json()
def _allowed_hypotheses(task_id: str) -> List[Dict[str, Any]]:
if task_id == "easy":
return [
{
"primary_service": "auth-service",
"failure_mode": "resource_exhaustion",
"dependency": "none",
"customer_impact": "login_failures",
"confidence": 0.88,
},
{
"primary_service": "user-service",
"failure_mode": "traffic_abuse",
"dependency": "ldap-directory",
"customer_impact": "login_failures",
"confidence": 0.52,
},
]
if task_id == "medium":
return [
{
"primary_service": "payment-api",
"failure_mode": "dependency_outage",
"dependency": "payment-gateway",
"customer_impact": "checkout_delays",
"confidence": 0.87,
},
{
"primary_service": "auth-service",
"failure_mode": "resource_exhaustion",
"dependency": "none",
"customer_impact": "login_failures",
"confidence": 0.61,
},
]
return [
{
"primary_service": "auth-service",
"failure_mode": "resource_exhaustion",
"dependency": "payment-api",
"customer_impact": "cross_service_major_incident",
"confidence": 0.92,
},
{
"primary_service": "order-service",
"failure_mode": "storage_saturation",
"dependency": "mysql",
"customer_impact": "order_write_failures",
"confidence": 0.71,
},
{
"primary_service": "notification-service",
"failure_mode": "certificate_expiry",
"dependency": "email-relay",
"customer_impact": "notification_delivery_failure",
"confidence": 0.68,
},
]
def _model_select_hypothesis(task_id: str, observation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not HF_TOKEN:
return None
candidates = _allowed_hypotheses(task_id)
visible_logs = observation.get("visible_logs", [])[:8]
compact_logs = [
{
"log_id": log["log_id"],
"service_name": log["service_name"],
"log_level": log["log_level"],
"message": log["message"],
"response_time_ms": log["response_time_ms"],
"cpu_usage_percent": log["cpu_usage_percent"],
"memory_usage_percent": log["memory_usage_percent"],
}
for log in visible_logs
]
prompt = {
"task_id": task_id,
"briefing": observation.get("briefing", {}),
"visible_logs": compact_logs,
"candidates": candidates,
"instruction": (
"Choose the single best candidate hypothesis index for the incident. "
"Return strict JSON with keys selected_index and rationale. "
"Do not invent any fields. Use only the provided candidates."
),
}
try:
response = client.responses.create(
model=MODEL_NAME,
input=[
{
"role": "system",
"content": "You are a deterministic incident triage assistant. Return only valid JSON.",
},
{"role": "user", "content": str(prompt)},
],
temperature=0,
max_output_tokens=120,
)
text = getattr(response, "output_text", "") or ""
if not text:
return None
import json
payload = json.loads(text)
idx = int(payload.get("selected_index", -1))
if 0 <= idx < len(candidates):
return candidates[idx]
except Exception:
return None
return None
def _severity_score(log: Dict[str, Any]) -> float:
level_weight = {"CRITICAL": 4.0, "ERROR": 3.0, "WARN": 1.0, "INFO": 0.2}
score = level_weight.get(str(log["log_level"]).upper(), 0.0)
if float(log.get("cpu_usage_percent", 0.0)) >= 90.0:
score += 1.0
if float(log.get("memory_usage_percent", 0.0)) >= 95.0:
score += 1.0
if int(log.get("response_time_ms", 0)) >= 3000:
score += 1.0
message = str(log["message"]).lower()
for needle, bonus in {
"outofmemoryerror": 2.0,
"connection refused": 2.0,
"disk full": 2.0,
"ssl certificate expired": 1.8,
"segmentation fault": 1.8,
"timeout exceeded": 1.0,
}.items():
if needle in message:
score += bonus
return score
def _infer_hypothesis(observation: Dict[str, Any]) -> Dict[str, Any]:
task_id = observation.get("task_id", "easy")
model_choice = _model_select_hypothesis(task_id, observation)
if model_choice is not None:
return model_choice
logs = sorted(observation.get("visible_logs", []), key=_severity_score, reverse=True)
services = {log["service_name"] for log in logs}
messages = " ".join(str(log["message"]).lower() for log in logs)
if "outofmemoryerror" in messages and {"payment-api", "order-service", "notification-service"} & services:
return {
"primary_service": "auth-service",
"failure_mode": "resource_exhaustion",
"dependency": "payment-api",
"customer_impact": "cross_service_major_incident",
"confidence": 0.92,
}
if "connection refused" in messages or "payment confirmation" in messages:
return {
"primary_service": "payment-api",
"failure_mode": "dependency_outage",
"dependency": "payment-gateway",
"customer_impact": "checkout_delays",
"confidence": 0.87,
}
if "disk full" in messages:
return {
"primary_service": "order-service",
"failure_mode": "storage_saturation",
"dependency": "mysql",
"customer_impact": "order_write_failures",
"confidence": 0.82,
}
if "ssl certificate expired" in messages or "email-relay" in messages:
return {
"primary_service": "notification-service",
"failure_mode": "certificate_expiry",
"dependency": "email-relay",
"customer_impact": "notification_delivery_failure",
"confidence": 0.81,
}
return {
"primary_service": observation["briefing"]["suspected_services"][0],
"failure_mode": "traffic_abuse",
"dependency": "none",
"customer_impact": "login_failures",
"confidence": 0.55,
}
def _containment_for_hypothesis(hypothesis: Dict[str, Any]) -> List[str]:
if hypothesis["primary_service"] == "auth-service" and hypothesis["customer_impact"] == "cross_service_major_incident":
return [
"increase_auth_heap",
"enable_login_rate_limiting",
"restore_payment_gateway_connectivity",
"free_order_log_disk",
"renew_smtp_certificate",
"page_major_incident_team",
]
if hypothesis["primary_service"] == "payment-api":
return ["restore_payment_gateway_connectivity", "reduce_checkout_retry_pressure"]
if hypothesis["primary_service"] == "order-service":
return ["free_order_log_disk", "reset_mysql_connection_pool"]
if hypothesis["primary_service"] == "notification-service":
return ["renew_smtp_certificate", "reroute_notification_traffic"]
return ["increase_auth_heap", "enable_login_rate_limiting"]
def _build_report(observation: Dict[str, Any], hypothesis: Dict[str, Any]) -> Dict[str, Any]:
logs = sorted(observation.get("visible_logs", []), key=lambda log: _severity_score(log), reverse=True)
evidence_ids = [int(log["log_id"]) for log in logs[: min(10, len(logs))]]
impacted_services = sorted({log["service_name"] for log in logs if _severity_score(log) >= 3.0})
if not impacted_services:
impacted_services = [hypothesis["primary_service"]]
return {
"evidence_log_ids": evidence_ids,
"impacted_services": impacted_services,
"root_cause": hypothesis,
"containment_plan": _containment_for_hypothesis(hypothesis),
"summary": (
f"The most likely incident source is {hypothesis['primary_service']} with failure mode "
f"{hypothesis['failure_mode']}, creating customer impact {hypothesis['customer_impact']}."
),
}
def run_task(task_id: str) -> float:
rewards: List[float] = []
steps_taken = 0
final_score = 0.0
success = False
observation: Dict[str, Any] | None = None
log_start(task_id, BENCHMARK, MODEL_NAME)
try:
observation = api_reset(task_id)
session_id = observation["session_id"]
query_payload = {
"session_id": session_id,
"action_type": "query_logs",
"query": {
"levels": ["CRITICAL", "ERROR"],
"start_time": observation["briefing"]["incident_window_start"],
"end_time": observation["briefing"]["incident_window_end"],
"limit": 6,
},
}
result = api_step(query_payload)
observation = result["observation"]
rewards.append(float(result["reward"]["value"]))
steps_taken = 1
log_step(1, "query_logs", rewards[-1], bool(result["done"]), None)
target_service = max(
observation["briefing"]["suspected_services"],
key=lambda service: sum(1 for log in observation["visible_logs"] if log["service_name"] == service),
)
dep_payload = {
"session_id": session_id,
"action_type": "inspect_dependencies",
"target_service": target_service,
}
result = api_step(dep_payload)
observation = result["observation"]
rewards.append(float(result["reward"]["value"]))
steps_taken = 2
log_step(2, f"inspect_dependencies({target_service})", rewards[-1], bool(result["done"]), None)
hypothesis = _infer_hypothesis(observation)
hyp_payload = {
"session_id": session_id,
"action_type": "update_hypothesis",
"hypothesis": hypothesis,
}
result = api_step(hyp_payload)
observation = result["observation"]
rewards.append(float(result["reward"]["value"]))
steps_taken = 3
log_step(3, "update_hypothesis", rewards[-1], bool(result["done"]), None)
containment_payload = {
"session_id": session_id,
"action_type": "execute_containment",
"containment_plan": _containment_for_hypothesis(hypothesis),
}
result = api_step(containment_payload)
observation = result["observation"]
rewards.append(float(result["reward"]["value"]))
steps_taken = 4
log_step(4, "execute_containment", rewards[-1], bool(result["done"]), None)
report_payload = {
"session_id": session_id,
"action_type": "submit_report",
"report": _build_report(observation, hypothesis),
}
result = api_step(report_payload)
final_score = float(result["reward"]["value"])
rewards.append(final_score)
steps_taken = 5
log_step(5, "submit_report", final_score, bool(result["done"]), None)
success = final_score >= SUCCESS_THRESHOLD
except Exception as exc:
log_step(steps_taken + 1 if steps_taken else 1, "error", 0.0, True, str(exc).replace("\n", " "))
final_score = 0.0
success = False
finally:
log_end(success, steps_taken if steps_taken else 1, final_score, rewards or [0.0])
return final_score
if __name__ == "__main__":
for task_name in ("easy", "medium", "hard"):
run_task(task_name)