Brajmovech commited on
Commit
95c7e7f
·
1 Parent(s): 97f6726

feat: add validation gate to API endpoints with rate limiting and health check

Browse files
Files changed (2) hide show
  1. app.py +107 -5
  2. tests/test_api.py +100 -0
app.py CHANGED
@@ -3,7 +3,10 @@ from flask_cors import CORS
3
  import traceback
4
  import os
5
  import json
 
 
6
  import uuid
 
7
  from datetime import datetime, timezone
8
  from pathlib import Path
9
  import numpy as np
@@ -146,6 +149,49 @@ def _report_matches_symbol(report: dict, target: str) -> bool:
146
  return False
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def get_latest_llm_reports(symbol: str) -> dict:
150
  """Read the latest reports for the given symbol from the configured LLM models."""
151
  # Try all candidate paths; use the first that exists on disk.
@@ -396,11 +442,27 @@ def analyze_ticker():
396
  if not iris_app:
397
  return jsonify({"error": "IRIS System failed to initialize on the server."}), 500
398
 
399
- ticker = request.args.get('ticker')
400
- if not ticker:
401
  return jsonify({"error": "Ticker parameter is required"}), 400
402
 
403
- ticker = str(ticker).strip().upper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  timeframe = str(request.args.get('timeframe', '') or '').strip().upper()
405
 
406
  if timeframe:
@@ -416,8 +478,8 @@ def analyze_ticker():
416
 
417
  try:
418
  print(
419
- f"API Request for Analysis: {ticker} | timeframe={timeframe or 'custom'} | "
420
- f"period={period} interval={interval}"
421
  )
422
  # Run the analysis for the single ticker quietly
423
  report = iris_app.run_one_ticker(
@@ -638,6 +700,46 @@ def latest_session_summary():
638
  return jsonify({"error": "No session summary found yet."}), 404
639
  return send_file(str(path), mimetype="application/json")
640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  if __name__ == '__main__':
642
  # Run the Flask app
643
  app.run(debug=True, port=5000)
 
3
  import traceback
4
  import os
5
  import json
6
+ import logging
7
+ import time
8
  import uuid
9
+ from collections import defaultdict
10
  from datetime import datetime, timezone
11
  from pathlib import Path
12
  import numpy as np
 
149
  return False
150
 
151
 
152
+ # ---------------------------------------------------------------------------
153
+ # Ticker validation setup
154
+ # ---------------------------------------------------------------------------
155
+
156
+ try:
157
+ from ticker_validator import validate_ticker as _validate_ticker
158
+ from ticker_db import load_ticker_db as _load_ticker_db
159
+ _VALIDATOR_AVAILABLE = True
160
+ except ImportError:
161
+ _VALIDATOR_AVAILABLE = False
162
+ _load_ticker_db = None
163
+
164
+ _validation_logger = logging.getLogger("iris.ticker_validation")
165
+
166
+ # Simple in-memory rate limiter: {ip: [unix_timestamp, ...]}
167
+ _rate_limit_store: dict[str, list[float]] = defaultdict(list)
168
+ _RATE_LIMIT_MAX = 30
169
+ _RATE_LIMIT_WINDOW = 60 # seconds
170
+
171
+
172
+ def _check_rate_limit(ip: str) -> bool:
173
+ """Return True if request is allowed, False if rate limit exceeded."""
174
+ now = time.time()
175
+ cutoff = now - _RATE_LIMIT_WINDOW
176
+ _rate_limit_store[ip] = [t for t in _rate_limit_store[ip] if t > cutoff]
177
+ if len(_rate_limit_store[ip]) >= _RATE_LIMIT_MAX:
178
+ return False
179
+ _rate_limit_store[ip].append(now)
180
+ return True
181
+
182
+
183
+ def _log_validation(raw_input: str, result) -> None:
184
+ _validation_logger.info(
185
+ "TICKER_VALIDATION | input=%s | valid=%s | source=%s | error=%s",
186
+ raw_input,
187
+ result.valid if result else False,
188
+ result.source if result else "",
189
+ result.error if result else "validator_unavailable",
190
+ )
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+
195
  def get_latest_llm_reports(symbol: str) -> dict:
196
  """Read the latest reports for the given symbol from the configured LLM models."""
197
  # Try all candidate paths; use the first that exists on disk.
 
442
  if not iris_app:
443
  return jsonify({"error": "IRIS System failed to initialize on the server."}), 500
444
 
445
+ raw_ticker = request.args.get('ticker')
446
+ if not raw_ticker:
447
  return jsonify({"error": "Ticker parameter is required"}), 400
448
 
449
+ # --- Validation gate (Layer 1-3) before any LLM / heavy computation -----
450
+ if _VALIDATOR_AVAILABLE:
451
+ val_result = _validate_ticker(str(raw_ticker))
452
+ _log_validation(raw_ticker, val_result)
453
+ if not val_result.valid:
454
+ return jsonify({
455
+ "error": val_result.error,
456
+ "suggestions": val_result.suggestions,
457
+ "valid": False,
458
+ }), 422
459
+ ticker = val_result.ticker
460
+ company_name = val_result.company_name # confirmed context for LLM
461
+ else:
462
+ ticker = str(raw_ticker).strip().upper()
463
+ company_name = ""
464
+ # -------------------------------------------------------------------------
465
+
466
  timeframe = str(request.args.get('timeframe', '') or '').strip().upper()
467
 
468
  if timeframe:
 
478
 
479
  try:
480
  print(
481
+ f"API Request for Analysis: {ticker} ({company_name or 'unknown'}) | "
482
+ f"timeframe={timeframe or 'custom'} | period={period} interval={interval}"
483
  )
484
  # Run the analysis for the single ticker quietly
485
  report = iris_app.run_one_ticker(
 
700
  return jsonify({"error": "No session summary found yet."}), 404
701
  return send_file(str(path), mimetype="application/json")
702
 
703
+ @app.route('/api/validate-ticker', methods=['POST'])
704
+ def validate_ticker_endpoint():
705
+ """Real-time ticker validation for the frontend (always returns HTTP 200)."""
706
+ ip = request.remote_addr or "unknown"
707
+ if not _check_rate_limit(ip):
708
+ return jsonify({"error": "Too many requests. Please wait before trying again."}), 429
709
+
710
+ body = request.get_json(silent=True) or {}
711
+ raw = body.get("ticker", "")
712
+
713
+ if not _VALIDATOR_AVAILABLE:
714
+ return jsonify({"valid": True, "ticker": str(raw).strip().upper(),
715
+ "company_name": ""}), 200
716
+
717
+ result = _validate_ticker(str(raw))
718
+ _log_validation(raw, result)
719
+
720
+ if result.valid:
721
+ return jsonify({"valid": True, "ticker": result.ticker,
722
+ "company_name": result.company_name}), 200
723
+ return jsonify({"valid": False, "error": result.error,
724
+ "suggestions": result.suggestions}), 200
725
+
726
+
727
+ @app.route('/api/health', methods=['GET'])
728
+ def health_check():
729
+ """Report service health and ticker database status."""
730
+ ticker_db_loaded = False
731
+ ticker_count = 0
732
+ if _VALIDATOR_AVAILABLE and _load_ticker_db is not None:
733
+ try:
734
+ db = _load_ticker_db()
735
+ ticker_db_loaded = True
736
+ ticker_count = len(db)
737
+ except Exception:
738
+ pass
739
+ return jsonify({"status": "ok", "ticker_db_loaded": ticker_db_loaded,
740
+ "ticker_count": ticker_count}), 200
741
+
742
+
743
  if __name__ == '__main__':
744
  # Run the Flask app
745
  app.run(debug=True, port=5000)
tests/test_api.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for the Flask API validation endpoints.
2
+
3
+ Run with: python -m unittest tests.test_api -v
4
+ Slow tests (live yfinance) are marked; skip them with:
5
+ python -m unittest tests.test_api.TestAPIFast -v
6
+ """
7
+
8
+ import sys
9
+ import os
10
+ import json
11
+ import unittest
12
+
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
+
15
+ # Import the Flask app (IRIS_System init may fail in CI — that's fine)
16
+ from app import app as flask_app
17
+
18
+
19
+ class TestAPIFast(unittest.TestCase):
20
+ """Tests that don't require live network access."""
21
+
22
+ def setUp(self):
23
+ flask_app.config["TESTING"] = True
24
+ self.client = flask_app.test_client()
25
+
26
+ def test_health_endpoint(self):
27
+ """GET /api/health should return 200 with ticker_db_loaded=true."""
28
+ resp = self.client.get("/api/health")
29
+ self.assertEqual(resp.status_code, 200)
30
+ data = json.loads(resp.data)
31
+ self.assertIn("status", data)
32
+ self.assertEqual(data["status"], "ok")
33
+ self.assertTrue(data.get("ticker_db_loaded"), "Ticker DB should be loaded")
34
+ self.assertGreater(data.get("ticker_count", 0), 0)
35
+
36
+ def test_validate_missing_body(self):
37
+ """POST /api/validate-ticker with no body should return valid=false."""
38
+ resp = self.client.post("/api/validate-ticker",
39
+ content_type="application/json",
40
+ data=json.dumps({}))
41
+ self.assertEqual(resp.status_code, 200)
42
+ data = json.loads(resp.data)
43
+ self.assertFalse(data.get("valid"))
44
+
45
+ def test_validate_invalid_format(self):
46
+ """POST /api/validate-ticker with bad format should return valid=false immediately."""
47
+ resp = self.client.post("/api/validate-ticker",
48
+ content_type="application/json",
49
+ data=json.dumps({"ticker": "123ABC"}))
50
+ self.assertEqual(resp.status_code, 200)
51
+ data = json.loads(resp.data)
52
+ self.assertFalse(data.get("valid"))
53
+ self.assertIn("error", data)
54
+
55
+ def test_analyze_rejects_invalid_ticker(self):
56
+ """GET /api/analyze with a clearly invalid ticker should return 422."""
57
+ resp = self.client.get("/api/analyze?ticker=XYZQW")
58
+ self.assertEqual(resp.status_code, 422)
59
+ data = json.loads(resp.data)
60
+ self.assertFalse(data.get("valid"))
61
+ self.assertIn("error", data)
62
+
63
+
64
+ class TestAPISlow(unittest.TestCase):
65
+ """Tests that hit live yfinance — skip in CI with -k 'not slow'."""
66
+
67
+ def setUp(self):
68
+ flask_app.config["TESTING"] = True
69
+ self.client = flask_app.test_client()
70
+
71
+ def test_validate_valid_ticker(self):
72
+ """POST /api/validate-ticker with AAPL should return valid=true."""
73
+ resp = self.client.post("/api/validate-ticker",
74
+ content_type="application/json",
75
+ data=json.dumps({"ticker": "AAPL"}))
76
+ self.assertEqual(resp.status_code, 200)
77
+ data = json.loads(resp.data)
78
+ self.assertTrue(data.get("valid"))
79
+ self.assertEqual(data.get("ticker"), "AAPL")
80
+ self.assertIn("company_name", data)
81
+
82
+ def test_validate_invalid_ticker(self):
83
+ """POST /api/validate-ticker with XYZQW should return valid=false with error."""
84
+ resp = self.client.post("/api/validate-ticker",
85
+ content_type="application/json",
86
+ data=json.dumps({"ticker": "XYZQW"}))
87
+ self.assertEqual(resp.status_code, 200)
88
+ data = json.loads(resp.data)
89
+ self.assertFalse(data.get("valid"))
90
+ self.assertIn("error", data)
91
+
92
+ def test_analyze_accepts_valid_ticker(self):
93
+ """GET /api/analyze with AAPL should NOT return 422 (validation must pass)."""
94
+ resp = self.client.get("/api/analyze?ticker=AAPL")
95
+ self.assertNotEqual(resp.status_code, 422,
96
+ "Validation gate should not reject AAPL")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ unittest.main()