Spaces:
Sleeping
Sleeping
File size: 4,469 Bytes
a4fc311 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.alert_mcp.main import app, get_db
from src.alert_mcp.models import Base, Alert
from src.alert_mcp.schemas import AlertCreate
# Setup in-memory DB for tests
# We need to make sure the connection is shared if we use :memory:
# or use a file. The issue with :memory: and multiple sessions is that
# each connection gets a fresh DB unless shared cache is used or same connection.
# For FastAPI dependency injection with SessionLocal, each request gets a new session.
# If they open new connections, they might see different in-memory DBs if not careful.
# But SQLAlchemy engine usually pools connections.
# Let's try StaticPool for in-memory testing to keep data across sessions
from sqlalchemy.pool import StaticPool
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(autouse=True)
def init_test_db():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture
def client():
return TestClient(app)
def test_log_alert(client):
response = client.post(
"/mcp/tools/log_alert",
json={
"provider_id": 1,
"severity": "critical",
"window_days": 30,
"message": "Test alert",
"channel": "ui"
}
)
assert response.status_code == 200
data = response.json()
assert data["provider_id"] == 1
assert data["severity"] == "critical"
assert data["id"] is not None
assert data["resolved_at"] is None
def test_log_alert_invalid_severity(client):
response = client.post(
"/mcp/tools/log_alert",
json={
"provider_id": 1,
"severity": "super_critical",
"window_days": 30,
"message": "Test alert"
}
)
# FastAPI/Pydantic validation returns 422 Unprocessable Entity
assert response.status_code == 422
def test_get_open_alerts(client):
# Create two alerts
client.post("/mcp/tools/log_alert", json={
"provider_id": 1, "severity": "info", "window_days": 30, "message": "Info alert"
})
client.post("/mcp/tools/log_alert", json={
"provider_id": 1, "severity": "critical", "window_days": 30, "message": "Critical alert"
})
# Get all open alerts
response = client.post("/mcp/tools/get_open_alerts", json={})
assert response.status_code == 200
data = response.json()
assert len(data) == 2
# Check sorting: critical should be first
assert data[0]["severity"] == "critical"
assert data[1]["severity"] == "info"
def test_mark_alert_resolved(client):
# Create alert
create_res = client.post("/mcp/tools/log_alert", json={
"provider_id": 1, "severity": "warning", "window_days": 30, "message": "Warning alert"
})
alert_id = create_res.json()["id"]
# Resolve it
response = client.post(
"/mcp/tools/mark_alert_resolved",
params={"alert_id": alert_id, "resolution_note": "Fixed it"}
)
assert response.status_code == 200
data = response.json()
assert data["resolved_at"] is not None
assert data["resolution_note"] == "Fixed it"
# Verify it's not in open alerts
open_res = client.post("/mcp/tools/get_open_alerts", json={})
assert len(open_res.json()) == 0
def test_summarize_alerts(client):
client.post("/mcp/tools/log_alert", json={"provider_id": 1, "severity": "info", "window_days": 30, "message": "1"})
client.post("/mcp/tools/log_alert", json={"provider_id": 1, "severity": "info", "window_days": 30, "message": "2"})
client.post("/mcp/tools/log_alert", json={"provider_id": 1, "severity": "critical", "window_days": 30, "message": "3"})
response = client.post("/mcp/tools/summarize_alerts", json={})
assert response.status_code == 200
data = response.json()
assert data["total_alerts"] == 3
assert data["by_severity"]["info"] == 2
assert data["by_severity"]["critical"] == 1
assert data["by_severity"]["warning"] == 0
|