HashirAwaiz commited on
Commit
3054db6
·
verified ·
1 Parent(s): ee8f3db

Create tests/test_api.py

Browse files
Files changed (1) hide show
  1. tests/test_api.py +68 -0
tests/test_api.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from fastapi.testclient import TestClient
4
+
5
+ # Add project root to path
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from app.main import app, load_models
9
+
10
+ # 1. FORCE LOAD MODELS MANUALLY
11
+ # This ensures they are ready before tests run
12
+ print("⚙️ Forcing Model Load for Testing...")
13
+ load_models()
14
+
15
+ client = TestClient(app)
16
+
17
+ def test_health_check():
18
+ """Test if the API is alive and models are loaded."""
19
+ print("\n🔍 Testing Health Endpoint...")
20
+ response = client.get("/health")
21
+
22
+ # Debugging print
23
+ print(f" Response: {response.json()}")
24
+
25
+ assert response.status_code == 200
26
+ data = response.json()
27
+
28
+ # Crucial check: Status MUST be healthy now
29
+ assert data["status"] == "healthy"
30
+ print("✅ Health Check Passed!")
31
+
32
+ def test_prediction_endpoint():
33
+ """Test the End-to-End Prediction pipeline with valid data."""
34
+ print("\n🔍 Testing Prediction Endpoint...")
35
+
36
+ payload = {
37
+ "tmmn": 290.0, "tmmx": 305.0, "rmin": 15.0, "rmax": 45.0,
38
+ "vs": 6.5, "pr": 0.0, "erc": 50.0,
39
+ "latitude": 34.0, "longitude": -118.0
40
+ }
41
+
42
+ response = client.post("/predict", json=payload)
43
+
44
+ # If this fails, print the error detail
45
+ if response.status_code != 200:
46
+ print(f"❌ API Error: {response.json()}")
47
+
48
+ assert response.status_code == 200
49
+ data = response.json()
50
+
51
+ # Verify Content
52
+ assert "burning_index_prediction" in data
53
+ assert "risk_level_prediction" in data
54
+ assert "cluster_zone" in data
55
+
56
+ print("✅ Prediction Logic Passed!")
57
+ print(f" 🔥 Predicted BI: {data['burning_index_prediction']}")
58
+ print(f" ⚠️ Risk Level: {data['risk_level_prediction']}")
59
+
60
+ if __name__ == "__main__":
61
+ try:
62
+ test_health_check()
63
+ test_prediction_endpoint()
64
+ print("\n🎉 ALL API TESTS PASSED SUCCESSFULLY.")
65
+ except AssertionError as e:
66
+ print(f"\n❌ TEST FAILED: Assertion Error")
67
+ except Exception as e:
68
+ print(f"\n❌ CRITICAL ERROR: {e}")