CausalOps-Env / data /db_loader.py
omm7's picture
Upload folder using huggingface_hub
bc2ead7 verified
"""
Database and scenario loading helpers.
"""
from __future__ import annotations
import os
import random
import sqlite3
from pathlib import Path
from typing import Any, Dict, Iterable, List
from tasks.catalog import TASK_SPECS
DEFAULT_DB_PATH = Path(__file__).resolve().parents[1] / "novatech_logs.db"
DB_PATH = Path(os.getenv("DB_PATH", str(DEFAULT_DB_PATH))).expanduser().resolve()
def _connect() -> sqlite3.Connection:
if not DB_PATH.exists():
raise FileNotFoundError(f"Database not found at '{DB_PATH}'")
return sqlite3.connect(str(DB_PATH))
def load_thresholds() -> Dict[str, Dict[str, float]]:
conn = _connect()
rows = conn.execute(
"SELECT metric_name, warning_threshold, critical_threshold, consecutive_count FROM anomaly_thresholds"
).fetchall()
conn.close()
return {
row[0]: {
"warning": float(row[1]),
"critical": float(row[2]),
"consecutive": float(row[3]),
}
for row in rows
}
def load_patterns() -> Dict[str, Dict[str, str]]:
conn = _connect()
rows = conn.execute(
"SELECT pattern_keyword, severity, description FROM known_error_patterns ORDER BY pattern_id"
).fetchall()
conn.close()
return {row[0]: {"severity": row[1], "description": row[2]} for row in rows}
def load_all_logs() -> List[Dict[str, Any]]:
conn = _connect()
rows = conn.execute(
"""
SELECT log_id, timestamp, server_id, log_level, service_name,
message, response_time_ms, cpu_usage_percent, memory_usage_percent
FROM server_logs
ORDER BY timestamp ASC, log_id ASC
"""
).fetchall()
conn.close()
return [
{
"log_id": int(row[0]),
"timestamp": str(row[1]),
"server_id": str(row[2]),
"log_level": str(row[3]),
"service_name": str(row[4]),
"message": str(row[5]),
"response_time_ms": int(row[6] or 0),
"cpu_usage_percent": float(row[7] or 0.0),
"memory_usage_percent": float(row[8] or 0.0),
}
for row in rows
]
def _within_window(log: Dict[str, Any], start: str, end: str) -> bool:
return start <= str(log["timestamp"]) <= end
def _base_scope(task_id: str, logs: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
spec = TASK_SPECS[task_id]
scope_servers = set(spec["scope_servers"])
scope_services = set(spec["scope_services"])
start = str(spec["incident_window_start"])
end = str(spec["incident_window_end"])
return [
log
for log in logs
if log["server_id"] in scope_servers
and log["service_name"] in scope_services
and (
_within_window(log, start, end)
or log["log_id"] in set(spec["must_include_ids"])
)
]
def build_task_log_pool(task_id: str, seed: int) -> List[Dict[str, Any]]:
spec = TASK_SPECS[task_id]
rng = random.Random(seed)
all_logs = load_all_logs()
must_include_ids = set(spec["must_include_ids"])
base_scope = _base_scope(task_id, all_logs)
scope_ids = {log["log_id"] for log in base_scope}
for log in all_logs:
if log["log_id"] in must_include_ids:
scope_ids.add(log["log_id"])
scope_logs = [log for log in all_logs if log["log_id"] in scope_ids]
noise_candidates = [
log
for log in all_logs
if log["log_id"] not in scope_ids
and log["server_id"] in set(spec["scope_servers"])
and log["service_name"] in set(spec["scope_services"])
]
sample_size = min(int(spec["noise_sample_size"]), len(noise_candidates))
if sample_size:
for log in rng.sample(noise_candidates, sample_size):
scope_logs.append(log)
enriched = []
for index, log in enumerate(scope_logs):
log_copy = dict(log)
log_copy["_seed_rank"] = rng.random() + (index * 0.00001)
enriched.append(log_copy)
return enriched