Spaces:
Running
Running
Sync from GitHub
Browse files- tests/test_ai_engine.py +267 -0
- tests/test_api.py +250 -0
- tests/test_data_ingestion.py +233 -0
tests/test_ai_engine.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for AI Engine components.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from datetime import datetime, timezone, timedelta
|
| 9 |
+
from unittest.mock import patch, MagicMock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestFinBERTScoring:
|
| 13 |
+
"""Tests for FinBERT sentiment scoring."""
|
| 14 |
+
|
| 15 |
+
def test_score_text_empty_input(self):
|
| 16 |
+
"""Test scoring with empty input."""
|
| 17 |
+
from app.ai_engine import score_text_with_finbert
|
| 18 |
+
|
| 19 |
+
# Mock pipeline
|
| 20 |
+
mock_pipe = MagicMock()
|
| 21 |
+
|
| 22 |
+
# Empty text should return neutral scores
|
| 23 |
+
result = score_text_with_finbert(mock_pipe, "")
|
| 24 |
+
|
| 25 |
+
assert result["prob_positive"] == 0.33
|
| 26 |
+
assert result["prob_neutral"] == 0.34
|
| 27 |
+
assert result["prob_negative"] == 0.33
|
| 28 |
+
assert result["score"] == 0.0
|
| 29 |
+
|
| 30 |
+
def test_score_text_short_input(self):
|
| 31 |
+
"""Test scoring with very short input."""
|
| 32 |
+
from app.ai_engine import score_text_with_finbert
|
| 33 |
+
|
| 34 |
+
mock_pipe = MagicMock()
|
| 35 |
+
|
| 36 |
+
# Short text (< 10 chars) should return neutral
|
| 37 |
+
result = score_text_with_finbert(mock_pipe, "hi")
|
| 38 |
+
|
| 39 |
+
assert result["score"] == 0.0
|
| 40 |
+
|
| 41 |
+
def test_score_text_normal_input(self):
|
| 42 |
+
"""Test scoring with normal input."""
|
| 43 |
+
from app.ai_engine import score_text_with_finbert
|
| 44 |
+
|
| 45 |
+
# Mock pipeline to return positive sentiment
|
| 46 |
+
mock_pipe = MagicMock()
|
| 47 |
+
mock_pipe.return_value = [[
|
| 48 |
+
{"label": "positive", "score": 0.8},
|
| 49 |
+
{"label": "neutral", "score": 0.15},
|
| 50 |
+
{"label": "negative", "score": 0.05},
|
| 51 |
+
]]
|
| 52 |
+
|
| 53 |
+
result = score_text_with_finbert(
|
| 54 |
+
mock_pipe,
|
| 55 |
+
"Copper prices surge to new highs on strong demand"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
assert result["prob_positive"] == 0.8
|
| 59 |
+
assert result["prob_neutral"] == 0.15
|
| 60 |
+
assert result["prob_negative"] == 0.05
|
| 61 |
+
assert result["score"] == 0.75 # 0.8 - 0.05
|
| 62 |
+
|
| 63 |
+
def test_score_text_negative_sentiment(self):
|
| 64 |
+
"""Test scoring with negative sentiment."""
|
| 65 |
+
from app.ai_engine import score_text_with_finbert
|
| 66 |
+
|
| 67 |
+
mock_pipe = MagicMock()
|
| 68 |
+
mock_pipe.return_value = [[
|
| 69 |
+
{"label": "positive", "score": 0.1},
|
| 70 |
+
{"label": "neutral", "score": 0.2},
|
| 71 |
+
{"label": "negative", "score": 0.7},
|
| 72 |
+
]]
|
| 73 |
+
|
| 74 |
+
result = score_text_with_finbert(
|
| 75 |
+
mock_pipe,
|
| 76 |
+
"Copper prices crash amid recession fears"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
assert result["score"] == -0.6 # 0.1 - 0.7
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestSentimentAggregation:
|
| 83 |
+
"""Tests for sentiment aggregation logic."""
|
| 84 |
+
|
| 85 |
+
def test_recency_weighting(self):
|
| 86 |
+
"""Test that later articles get higher weight."""
|
| 87 |
+
# This tests the concept, actual implementation may vary
|
| 88 |
+
tau = 12.0
|
| 89 |
+
|
| 90 |
+
# Article at 9am vs 4pm
|
| 91 |
+
hours_early = 9.0
|
| 92 |
+
hours_late = 16.0
|
| 93 |
+
|
| 94 |
+
weight_early = np.exp(hours_early / tau)
|
| 95 |
+
weight_late = np.exp(hours_late / tau)
|
| 96 |
+
|
| 97 |
+
# Later article should have higher weight
|
| 98 |
+
assert weight_late > weight_early
|
| 99 |
+
|
| 100 |
+
def test_weighted_average_calculation(self):
|
| 101 |
+
"""Test weighted average calculation."""
|
| 102 |
+
scores = np.array([0.5, -0.2, 0.3])
|
| 103 |
+
weights = np.array([0.2, 0.3, 0.5]) # Normalized weights
|
| 104 |
+
|
| 105 |
+
weighted_avg = np.sum(scores * weights)
|
| 106 |
+
expected = 0.5 * 0.2 + (-0.2) * 0.3 + 0.3 * 0.5
|
| 107 |
+
|
| 108 |
+
assert abs(weighted_avg - expected) < 1e-10
|
| 109 |
+
|
| 110 |
+
def test_sentiment_index_range(self):
|
| 111 |
+
"""Test that sentiment index is in valid range."""
|
| 112 |
+
# Sentiment index should be between -1 and 1
|
| 113 |
+
scores = np.array([0.9, -0.8, 0.5])
|
| 114 |
+
weights = np.array([0.33, 0.33, 0.34])
|
| 115 |
+
|
| 116 |
+
weighted_avg = np.sum(scores * weights)
|
| 117 |
+
|
| 118 |
+
assert -1 <= weighted_avg <= 1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TestFeatureEngineering:
|
| 122 |
+
"""Tests for feature engineering."""
|
| 123 |
+
|
| 124 |
+
def test_technical_indicators(self, sample_price_data):
|
| 125 |
+
"""Test that technical indicators are calculated correctly."""
|
| 126 |
+
df = sample_price_data
|
| 127 |
+
|
| 128 |
+
# Calculate SMA
|
| 129 |
+
sma_5 = df["close"].rolling(window=5).mean()
|
| 130 |
+
sma_10 = df["close"].rolling(window=10).mean()
|
| 131 |
+
|
| 132 |
+
# SMA calculations should not be NaN after sufficient data
|
| 133 |
+
assert not np.isnan(sma_5.iloc[-1])
|
| 134 |
+
assert not np.isnan(sma_10.iloc[-1])
|
| 135 |
+
|
| 136 |
+
# SMA10 should smooth more than SMA5
|
| 137 |
+
assert sma_10.std() < df["close"].std()
|
| 138 |
+
|
| 139 |
+
def test_return_calculation(self, sample_price_data):
|
| 140 |
+
"""Test return calculation."""
|
| 141 |
+
df = sample_price_data
|
| 142 |
+
|
| 143 |
+
# Calculate returns
|
| 144 |
+
returns = df["close"].pct_change()
|
| 145 |
+
|
| 146 |
+
# First return should be NaN
|
| 147 |
+
assert np.isnan(returns.iloc[0])
|
| 148 |
+
|
| 149 |
+
# Returns should be small (reasonable daily returns)
|
| 150 |
+
assert abs(returns.iloc[1:].mean()) < 0.1
|
| 151 |
+
|
| 152 |
+
def test_volatility_calculation(self, sample_price_data):
|
| 153 |
+
"""Test volatility calculation."""
|
| 154 |
+
df = sample_price_data
|
| 155 |
+
|
| 156 |
+
returns = df["close"].pct_change()
|
| 157 |
+
volatility_10 = returns.rolling(window=10).std()
|
| 158 |
+
|
| 159 |
+
# Volatility should be positive
|
| 160 |
+
assert all(v >= 0 or np.isnan(v) for v in volatility_10)
|
| 161 |
+
|
| 162 |
+
def test_lagged_features(self, sample_price_data):
|
| 163 |
+
"""Test lagged feature creation."""
|
| 164 |
+
df = sample_price_data
|
| 165 |
+
|
| 166 |
+
returns = df["close"].pct_change()
|
| 167 |
+
|
| 168 |
+
# Create lags
|
| 169 |
+
lag_1 = returns.shift(1)
|
| 170 |
+
lag_2 = returns.shift(2)
|
| 171 |
+
lag_3 = returns.shift(3)
|
| 172 |
+
|
| 173 |
+
# Lags should have correct offset
|
| 174 |
+
assert lag_1.iloc[5] == returns.iloc[4]
|
| 175 |
+
assert lag_2.iloc[5] == returns.iloc[3]
|
| 176 |
+
assert lag_3.iloc[5] == returns.iloc[2]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class TestModelTraining:
|
| 180 |
+
"""Tests for model training logic."""
|
| 181 |
+
|
| 182 |
+
def test_train_test_split_temporal(self):
|
| 183 |
+
"""Test that train/test split respects time order."""
|
| 184 |
+
dates = pd.date_range(start="2025-01-01", periods=100, freq="D")
|
| 185 |
+
|
| 186 |
+
validation_days = 20
|
| 187 |
+
split_date = dates.max() - timedelta(days=validation_days)
|
| 188 |
+
|
| 189 |
+
train_dates = dates[dates <= split_date]
|
| 190 |
+
val_dates = dates[dates > split_date]
|
| 191 |
+
|
| 192 |
+
# All train dates should be before all val dates
|
| 193 |
+
assert train_dates.max() < val_dates.min()
|
| 194 |
+
|
| 195 |
+
# Correct number of validation samples
|
| 196 |
+
assert len(val_dates) == validation_days
|
| 197 |
+
|
| 198 |
+
def test_feature_importance_normalized(self):
|
| 199 |
+
"""Test that feature importance sums to 1."""
|
| 200 |
+
importance = {
|
| 201 |
+
"feature_a": 10.0,
|
| 202 |
+
"feature_b": 5.0,
|
| 203 |
+
"feature_c": 3.0,
|
| 204 |
+
"feature_d": 2.0,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
total = sum(importance.values())
|
| 208 |
+
normalized = {k: v / total for k, v in importance.items()}
|
| 209 |
+
|
| 210 |
+
assert abs(sum(normalized.values()) - 1.0) < 1e-10
|
| 211 |
+
|
| 212 |
+
def test_prediction_direction_from_return(self):
|
| 213 |
+
"""Test prediction direction logic."""
|
| 214 |
+
def get_direction(predicted_return, threshold=0.005):
|
| 215 |
+
if predicted_return > threshold:
|
| 216 |
+
return "up"
|
| 217 |
+
elif predicted_return < -threshold:
|
| 218 |
+
return "down"
|
| 219 |
+
else:
|
| 220 |
+
return "neutral"
|
| 221 |
+
|
| 222 |
+
assert get_direction(0.02) == "up"
|
| 223 |
+
assert get_direction(-0.02) == "down"
|
| 224 |
+
assert get_direction(0.001) == "neutral"
|
| 225 |
+
assert get_direction(-0.003) == "neutral"
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TestModelPersistence:
|
| 229 |
+
"""Tests for model saving and loading."""
|
| 230 |
+
|
| 231 |
+
def test_model_path_generation(self):
|
| 232 |
+
"""Test model path generation."""
|
| 233 |
+
from datetime import datetime
|
| 234 |
+
|
| 235 |
+
target_symbol = "HG=F"
|
| 236 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 237 |
+
|
| 238 |
+
model_filename = f"xgb_{target_symbol.replace('=', '_')}_{timestamp}.json"
|
| 239 |
+
latest_filename = f"xgb_{target_symbol.replace('=', '_')}_latest.json"
|
| 240 |
+
|
| 241 |
+
assert "HG_F" in model_filename
|
| 242 |
+
assert "HG_F" in latest_filename
|
| 243 |
+
assert model_filename.endswith(".json")
|
| 244 |
+
|
| 245 |
+
def test_metrics_json_structure(self):
|
| 246 |
+
"""Test that metrics JSON has required fields."""
|
| 247 |
+
import json
|
| 248 |
+
|
| 249 |
+
metrics = {
|
| 250 |
+
"target_symbol": "HG=F",
|
| 251 |
+
"trained_at": datetime.now(timezone.utc).isoformat(),
|
| 252 |
+
"train_samples": 200,
|
| 253 |
+
"val_samples": 30,
|
| 254 |
+
"train_mae": 0.01,
|
| 255 |
+
"train_rmse": 0.015,
|
| 256 |
+
"val_mae": 0.02,
|
| 257 |
+
"val_rmse": 0.025,
|
| 258 |
+
"best_iteration": 50,
|
| 259 |
+
"feature_count": 58,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
# Should serialize properly
|
| 263 |
+
json_str = json.dumps(metrics)
|
| 264 |
+
loaded = json.loads(json_str)
|
| 265 |
+
|
| 266 |
+
assert loaded["target_symbol"] == "HG=F"
|
| 267 |
+
assert loaded["val_mae"] == 0.02
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for API endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from unittest.mock import patch, MagicMock
|
| 7 |
+
from datetime import datetime, timezone
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestHealthEndpoint:
|
| 11 |
+
"""Tests for /api/health endpoint."""
|
| 12 |
+
|
| 13 |
+
def test_health_response_structure(self):
|
| 14 |
+
"""Test that health response has required fields."""
|
| 15 |
+
from app.schemas import HealthResponse
|
| 16 |
+
|
| 17 |
+
response = HealthResponse(
|
| 18 |
+
status="healthy",
|
| 19 |
+
db_type="postgresql",
|
| 20 |
+
models_found=1,
|
| 21 |
+
pipeline_locked=False,
|
| 22 |
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
| 23 |
+
news_count=100,
|
| 24 |
+
price_bars_count=500
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
assert response.status == "healthy"
|
| 28 |
+
assert response.db_type == "postgresql"
|
| 29 |
+
assert response.models_found == 1
|
| 30 |
+
assert response.pipeline_locked is False
|
| 31 |
+
assert response.news_count == 100
|
| 32 |
+
assert response.price_bars_count == 500
|
| 33 |
+
|
| 34 |
+
def test_health_status_degraded_no_models(self):
|
| 35 |
+
"""Test degraded status when no models found."""
|
| 36 |
+
from app.schemas import HealthResponse
|
| 37 |
+
|
| 38 |
+
response = HealthResponse(
|
| 39 |
+
status="degraded",
|
| 40 |
+
db_type="postgresql",
|
| 41 |
+
models_found=0,
|
| 42 |
+
pipeline_locked=False,
|
| 43 |
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
assert response.status == "degraded"
|
| 47 |
+
assert response.models_found == 0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestAnalysisSchema:
|
| 51 |
+
"""Tests for analysis report schema."""
|
| 52 |
+
|
| 53 |
+
def test_analysis_report_structure(self):
|
| 54 |
+
"""Test AnalysisReport schema validation."""
|
| 55 |
+
from app.schemas import AnalysisReport, Influencer
|
| 56 |
+
|
| 57 |
+
influencers = [
|
| 58 |
+
Influencer(feature="HG=F_EMA_10", importance=0.15, description="Test"),
|
| 59 |
+
Influencer(feature="DX-Y.NYB_ret1", importance=0.10, description="Test"),
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
report = AnalysisReport(
|
| 63 |
+
symbol="HG=F",
|
| 64 |
+
prediction_direction="up",
|
| 65 |
+
confidence_score=0.75,
|
| 66 |
+
current_price=4.25,
|
| 67 |
+
predicted_return=0.015,
|
| 68 |
+
sentiment_index=0.35,
|
| 69 |
+
news_count_24h=15,
|
| 70 |
+
model_metrics={
|
| 71 |
+
"val_mae": 0.02,
|
| 72 |
+
"val_rmse": 0.025,
|
| 73 |
+
},
|
| 74 |
+
top_influencers=influencers,
|
| 75 |
+
generated_at=datetime.now(timezone.utc).isoformat(),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
assert report.symbol == "HG=F"
|
| 79 |
+
assert report.prediction_direction == "up"
|
| 80 |
+
assert report.confidence_score == 0.75
|
| 81 |
+
assert len(report.top_influencers) == 2
|
| 82 |
+
|
| 83 |
+
def test_prediction_direction_values(self):
|
| 84 |
+
"""Test valid prediction directions."""
|
| 85 |
+
from app.schemas import AnalysisReport
|
| 86 |
+
|
| 87 |
+
for direction in ["up", "down", "neutral"]:
|
| 88 |
+
report = AnalysisReport(
|
| 89 |
+
symbol="HG=F",
|
| 90 |
+
prediction_direction=direction,
|
| 91 |
+
confidence_score=0.5,
|
| 92 |
+
current_price=4.0,
|
| 93 |
+
predicted_return=0.0,
|
| 94 |
+
sentiment_index=0.0,
|
| 95 |
+
news_count_24h=0,
|
| 96 |
+
model_metrics={},
|
| 97 |
+
top_influencers=[],
|
| 98 |
+
generated_at=datetime.now(timezone.utc).isoformat(),
|
| 99 |
+
)
|
| 100 |
+
assert report.prediction_direction == direction
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TestHistorySchema:
|
| 104 |
+
"""Tests for history response schema."""
|
| 105 |
+
|
| 106 |
+
def test_history_data_point(self):
|
| 107 |
+
"""Test HistoryDataPoint schema."""
|
| 108 |
+
from app.schemas import HistoryDataPoint
|
| 109 |
+
|
| 110 |
+
point = HistoryDataPoint(
|
| 111 |
+
date="2026-01-01",
|
| 112 |
+
price=4.25,
|
| 113 |
+
sentiment_index=0.35,
|
| 114 |
+
sentiment_news_count=10,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
assert point.date == "2026-01-01"
|
| 118 |
+
assert point.price == 4.25
|
| 119 |
+
assert point.sentiment_index == 0.35
|
| 120 |
+
assert point.sentiment_news_count == 10
|
| 121 |
+
|
| 122 |
+
def test_history_data_point_nullable_sentiment(self):
|
| 123 |
+
"""Test that sentiment can be None."""
|
| 124 |
+
from app.schemas import HistoryDataPoint
|
| 125 |
+
|
| 126 |
+
point = HistoryDataPoint(
|
| 127 |
+
date="2026-01-01",
|
| 128 |
+
price=4.25,
|
| 129 |
+
sentiment_index=None,
|
| 130 |
+
sentiment_news_count=None,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
assert point.sentiment_index is None
|
| 134 |
+
assert point.sentiment_news_count is None
|
| 135 |
+
|
| 136 |
+
def test_history_response(self):
|
| 137 |
+
"""Test HistoryResponse schema."""
|
| 138 |
+
from app.schemas import HistoryResponse, HistoryDataPoint
|
| 139 |
+
|
| 140 |
+
data = [
|
| 141 |
+
HistoryDataPoint(date="2026-01-01", price=4.20),
|
| 142 |
+
HistoryDataPoint(date="2026-01-02", price=4.25),
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
response = HistoryResponse(symbol="HG=F", data=data)
|
| 146 |
+
|
| 147 |
+
assert response.symbol == "HG=F"
|
| 148 |
+
assert len(response.data) == 2
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class TestPipelineLock:
|
| 152 |
+
"""Tests for pipeline lock mechanism."""
|
| 153 |
+
|
| 154 |
+
def test_lock_acquire_release(self, tmp_path):
|
| 155 |
+
"""Test acquiring and releasing lock."""
|
| 156 |
+
from app.lock import PipelineLock
|
| 157 |
+
|
| 158 |
+
lock_file = tmp_path / "test.lock"
|
| 159 |
+
lock = PipelineLock(lock_file=str(lock_file), timeout=0)
|
| 160 |
+
|
| 161 |
+
# Should acquire
|
| 162 |
+
assert lock.acquire() is True
|
| 163 |
+
assert lock_file.exists()
|
| 164 |
+
|
| 165 |
+
# Should release
|
| 166 |
+
lock.release()
|
| 167 |
+
assert not lock_file.exists()
|
| 168 |
+
|
| 169 |
+
def test_lock_already_held(self, tmp_path):
|
| 170 |
+
"""Test that second acquire fails when lock is held."""
|
| 171 |
+
from app.lock import PipelineLock
|
| 172 |
+
|
| 173 |
+
lock_file = tmp_path / "test.lock"
|
| 174 |
+
lock1 = PipelineLock(lock_file=str(lock_file), timeout=0)
|
| 175 |
+
lock2 = PipelineLock(lock_file=str(lock_file), timeout=0)
|
| 176 |
+
|
| 177 |
+
# First lock should succeed
|
| 178 |
+
assert lock1.acquire() is True
|
| 179 |
+
|
| 180 |
+
# Second lock should fail
|
| 181 |
+
assert lock2.acquire() is False
|
| 182 |
+
|
| 183 |
+
# Cleanup
|
| 184 |
+
lock1.release()
|
| 185 |
+
|
| 186 |
+
def test_is_pipeline_locked(self, tmp_path):
|
| 187 |
+
"""Test is_pipeline_locked helper."""
|
| 188 |
+
from app.lock import PipelineLock
|
| 189 |
+
|
| 190 |
+
lock_file = tmp_path / "test.lock"
|
| 191 |
+
|
| 192 |
+
with patch("app.lock.get_settings") as mock_settings:
|
| 193 |
+
mock_settings.return_value.pipeline_lock_file = str(lock_file)
|
| 194 |
+
|
| 195 |
+
from app.lock import is_pipeline_locked
|
| 196 |
+
|
| 197 |
+
# Initially not locked
|
| 198 |
+
assert is_pipeline_locked() is False
|
| 199 |
+
|
| 200 |
+
# Create lock
|
| 201 |
+
lock_file.write_text("locked")
|
| 202 |
+
assert is_pipeline_locked() is True
|
| 203 |
+
|
| 204 |
+
# Remove lock
|
| 205 |
+
lock_file.unlink()
|
| 206 |
+
assert is_pipeline_locked() is False
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TestDataNormalization:
|
| 210 |
+
"""Tests for URL and text normalization."""
|
| 211 |
+
|
| 212 |
+
def test_normalize_url(self):
|
| 213 |
+
"""Test URL normalization."""
|
| 214 |
+
from app.utils import normalize_url
|
| 215 |
+
|
| 216 |
+
# Should remove tracking params
|
| 217 |
+
url = "https://example.com/article?id=123&utm_source=google&utm_medium=cpc"
|
| 218 |
+
normalized = normalize_url(url)
|
| 219 |
+
|
| 220 |
+
assert "utm_source" not in normalized
|
| 221 |
+
assert "utm_medium" not in normalized
|
| 222 |
+
assert "id=123" in normalized
|
| 223 |
+
|
| 224 |
+
def test_generate_dedup_key(self):
|
| 225 |
+
"""Test dedup key generation."""
|
| 226 |
+
from app.utils import generate_dedup_key
|
| 227 |
+
|
| 228 |
+
key1 = generate_dedup_key("Copper prices rise", "https://example.com/a")
|
| 229 |
+
key2 = generate_dedup_key("Copper prices rise", "https://example.com/a")
|
| 230 |
+
key3 = generate_dedup_key("Different title", "https://example.com/a")
|
| 231 |
+
|
| 232 |
+
# Same input should give same key
|
| 233 |
+
assert key1 == key2
|
| 234 |
+
|
| 235 |
+
# Different input should give different key
|
| 236 |
+
assert key1 != key3
|
| 237 |
+
|
| 238 |
+
def test_truncate_text(self):
|
| 239 |
+
"""Test text truncation."""
|
| 240 |
+
from app.utils import truncate_text
|
| 241 |
+
|
| 242 |
+
long_text = "a" * 1000
|
| 243 |
+
truncated = truncate_text(long_text, max_length=100)
|
| 244 |
+
|
| 245 |
+
assert len(truncated) == 100
|
| 246 |
+
|
| 247 |
+
short_text = "hello"
|
| 248 |
+
not_truncated = truncate_text(short_text, max_length=100)
|
| 249 |
+
|
| 250 |
+
assert not_truncated == "hello"
|
tests/test_data_ingestion.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for data ingestion and management.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from datetime import datetime, timezone, timedelta
|
| 7 |
+
from unittest.mock import patch, MagicMock
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestLanguageDetection:
|
| 11 |
+
"""Tests for language detection."""
|
| 12 |
+
|
| 13 |
+
def test_detect_english(self):
|
| 14 |
+
"""Test detection of English text."""
|
| 15 |
+
from app.data_manager import detect_language
|
| 16 |
+
|
| 17 |
+
result = detect_language("Copper prices rose sharply today")
|
| 18 |
+
assert result == "en"
|
| 19 |
+
|
| 20 |
+
def test_detect_non_english(self):
|
| 21 |
+
"""Test detection of non-English text."""
|
| 22 |
+
from app.data_manager import detect_language
|
| 23 |
+
|
| 24 |
+
# German
|
| 25 |
+
result = detect_language("Die Kupferpreise sind heute gestiegen")
|
| 26 |
+
assert result != "en"
|
| 27 |
+
|
| 28 |
+
def test_detect_empty_text(self):
|
| 29 |
+
"""Test detection with empty text."""
|
| 30 |
+
from app.data_manager import detect_language
|
| 31 |
+
|
| 32 |
+
result = detect_language("")
|
| 33 |
+
assert result is None
|
| 34 |
+
|
| 35 |
+
def test_detect_short_text(self):
|
| 36 |
+
"""Test detection with very short text."""
|
| 37 |
+
from app.data_manager import detect_language
|
| 38 |
+
|
| 39 |
+
# Short text may fail detection
|
| 40 |
+
result = detect_language("Hi")
|
| 41 |
+
# Should handle gracefully
|
| 42 |
+
assert result is None or isinstance(result, str)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestLanguageFiltering:
|
| 46 |
+
"""Tests for language filtering."""
|
| 47 |
+
|
| 48 |
+
def test_filter_keeps_english(self, sample_articles):
|
| 49 |
+
"""Test that English articles are kept."""
|
| 50 |
+
from app.data_manager import filter_by_language
|
| 51 |
+
|
| 52 |
+
articles = [
|
| 53 |
+
{"title": "Copper prices rise", "description": "Copper up today"},
|
| 54 |
+
{"title": "Mining output increases", "description": "Good news"},
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
filtered, count = filter_by_language(articles, "en")
|
| 58 |
+
|
| 59 |
+
assert len(filtered) == 2
|
| 60 |
+
assert count == 0
|
| 61 |
+
|
| 62 |
+
def test_filter_removes_non_english(self):
|
| 63 |
+
"""Test that non-English articles are filtered."""
|
| 64 |
+
from app.data_manager import filter_by_language
|
| 65 |
+
|
| 66 |
+
articles = [
|
| 67 |
+
{"title": "Copper prices rise", "description": "Copper up today"},
|
| 68 |
+
{"title": "Kupferpreise steigen", "description": "Kupfer heute höher"},
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
filtered, count = filter_by_language(articles, "en")
|
| 72 |
+
|
| 73 |
+
assert len(filtered) == 1
|
| 74 |
+
assert count == 1
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TestFuzzyDeduplication:
|
| 78 |
+
"""Tests for fuzzy title matching."""
|
| 79 |
+
|
| 80 |
+
def test_exact_duplicate(self):
|
| 81 |
+
"""Test that exact duplicates are detected."""
|
| 82 |
+
from app.data_manager import is_fuzzy_duplicate
|
| 83 |
+
|
| 84 |
+
existing = ["Copper prices surge on supply concerns"]
|
| 85 |
+
new_title = "Copper prices surge on supply concerns"
|
| 86 |
+
|
| 87 |
+
assert is_fuzzy_duplicate(new_title, existing, threshold=85) is True
|
| 88 |
+
|
| 89 |
+
def test_similar_titles(self):
|
| 90 |
+
"""Test that similar titles are detected."""
|
| 91 |
+
from app.data_manager import is_fuzzy_duplicate
|
| 92 |
+
|
| 93 |
+
existing = ["Copper prices surge on supply concerns"]
|
| 94 |
+
new_title = "Copper prices rise on supply concerns" # Similar
|
| 95 |
+
|
| 96 |
+
# Should be detected as duplicate with default threshold
|
| 97 |
+
result = is_fuzzy_duplicate(new_title, existing, threshold=85)
|
| 98 |
+
assert result is True
|
| 99 |
+
|
| 100 |
+
def test_different_titles(self):
|
| 101 |
+
"""Test that different titles are not marked as duplicates."""
|
| 102 |
+
from app.data_manager import is_fuzzy_duplicate
|
| 103 |
+
|
| 104 |
+
existing = ["Copper prices surge on supply concerns"]
|
| 105 |
+
new_title = "Gold reaches new all-time high" # Different topic
|
| 106 |
+
|
| 107 |
+
assert is_fuzzy_duplicate(new_title, existing, threshold=85) is False
|
| 108 |
+
|
| 109 |
+
def test_empty_existing_titles(self):
|
| 110 |
+
"""Test with no existing titles."""
|
| 111 |
+
from app.data_manager import is_fuzzy_duplicate
|
| 112 |
+
|
| 113 |
+
existing = []
|
| 114 |
+
new_title = "Any title here"
|
| 115 |
+
|
| 116 |
+
assert is_fuzzy_duplicate(new_title, existing, threshold=85) is False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TestRSSParsing:
|
| 120 |
+
"""Tests for RSS feed parsing."""
|
| 121 |
+
|
| 122 |
+
def test_rss_query_building(self):
|
| 123 |
+
"""Test RSS query URL building."""
|
| 124 |
+
query = "copper OR copper price OR copper futures"
|
| 125 |
+
language = "en"
|
| 126 |
+
|
| 127 |
+
# URL encoding
|
| 128 |
+
from urllib.parse import quote
|
| 129 |
+
encoded_query = quote(query)
|
| 130 |
+
|
| 131 |
+
url = f"https://news.google.com/rss/search?q={encoded_query}&hl={language}&gl=US&ceid=US:en"
|
| 132 |
+
|
| 133 |
+
assert "copper" in url
|
| 134 |
+
assert "hl=en" in url
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TestPriceIngestion:
|
| 138 |
+
"""Tests for price data ingestion."""
|
| 139 |
+
|
| 140 |
+
def test_symbol_parsing(self):
|
| 141 |
+
"""Test multi-symbol parsing."""
|
| 142 |
+
symbols_str = "HG=F,DX-Y.NYB,CL=F,FXI"
|
| 143 |
+
symbols = symbols_str.split(",")
|
| 144 |
+
|
| 145 |
+
assert len(symbols) == 4
|
| 146 |
+
assert "HG=F" in symbols
|
| 147 |
+
assert "DX-Y.NYB" in symbols
|
| 148 |
+
|
| 149 |
+
def test_lookback_calculation(self):
|
| 150 |
+
"""Test lookback date calculation."""
|
| 151 |
+
lookback_days = 365
|
| 152 |
+
end_date = datetime.now(timezone.utc)
|
| 153 |
+
start_date = end_date - timedelta(days=lookback_days)
|
| 154 |
+
|
| 155 |
+
delta = end_date - start_date
|
| 156 |
+
assert delta.days == lookback_days
|
| 157 |
+
|
| 158 |
+
def test_price_bar_fields(self):
|
| 159 |
+
"""Test that price bars have required fields."""
|
| 160 |
+
required_fields = ["date", "open", "high", "low", "close", "volume"]
|
| 161 |
+
|
| 162 |
+
sample_bar = {
|
| 163 |
+
"date": datetime.now(),
|
| 164 |
+
"open": 4.0,
|
| 165 |
+
"high": 4.1,
|
| 166 |
+
"low": 3.9,
|
| 167 |
+
"close": 4.05,
|
| 168 |
+
"volume": 50000,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
for field in required_fields:
|
| 172 |
+
assert field in sample_bar
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class TestDatabaseUpsert:
|
| 176 |
+
"""Tests for database upsert logic."""
|
| 177 |
+
|
| 178 |
+
def test_upsert_key_generation(self):
|
| 179 |
+
"""Test unique key generation for upsert."""
|
| 180 |
+
from app.utils import generate_dedup_key
|
| 181 |
+
|
| 182 |
+
# Same URL should give same key
|
| 183 |
+
url = "https://example.com/article/123"
|
| 184 |
+
key1 = generate_dedup_key("Title 1", url)
|
| 185 |
+
key2 = generate_dedup_key("Title 2", url)
|
| 186 |
+
|
| 187 |
+
# Keys based on URL should be consistent
|
| 188 |
+
# (depends on implementation - may include title or not)
|
| 189 |
+
assert isinstance(key1, str)
|
| 190 |
+
assert isinstance(key2, str)
|
| 191 |
+
|
| 192 |
+
def test_date_normalization(self):
|
| 193 |
+
"""Test date normalization for comparison."""
|
| 194 |
+
dt1 = datetime(2026, 1, 1, 10, 30, 0, tzinfo=timezone.utc)
|
| 195 |
+
dt2 = datetime(2026, 1, 1, 14, 45, 0, tzinfo=timezone.utc)
|
| 196 |
+
|
| 197 |
+
# Same date, different time
|
| 198 |
+
date1 = dt1.date()
|
| 199 |
+
date2 = dt2.date()
|
| 200 |
+
|
| 201 |
+
assert date1 == date2
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class TestDataValidation:
|
| 205 |
+
"""Tests for data validation."""
|
| 206 |
+
|
| 207 |
+
def test_price_validation(self):
|
| 208 |
+
"""Test that prices are positive."""
|
| 209 |
+
prices = [4.0, 4.1, 4.05, 3.95]
|
| 210 |
+
|
| 211 |
+
assert all(p > 0 for p in prices)
|
| 212 |
+
|
| 213 |
+
def test_volume_validation(self):
|
| 214 |
+
"""Test that volume is non-negative."""
|
| 215 |
+
volumes = [50000, 0, 100000]
|
| 216 |
+
|
| 217 |
+
assert all(v >= 0 for v in volumes)
|
| 218 |
+
|
| 219 |
+
def test_date_validation(self):
|
| 220 |
+
"""Test date is not in future."""
|
| 221 |
+
from datetime import datetime, timezone
|
| 222 |
+
|
| 223 |
+
test_date = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
| 224 |
+
now = datetime.now(timezone.utc)
|
| 225 |
+
|
| 226 |
+
# For historical data, date should be in past or present
|
| 227 |
+
assert test_date <= now or True # Flexible for test dates
|
| 228 |
+
|
| 229 |
+
def test_sentiment_score_range(self):
|
| 230 |
+
"""Test that sentiment scores are in valid range."""
|
| 231 |
+
scores = [0.5, -0.3, 0.8, -0.9, 0.0]
|
| 232 |
+
|
| 233 |
+
assert all(-1 <= s <= 1 for s in scores)
|