Spaces:
Running
Running
Upload 52 files
Browse files- .dockerignore +31 -0
- .env.example +13 -0
- Dockerfile +41 -26
- src/westernfront/analytics/aggregator.py +4 -45
- src/westernfront/api/auth.py +4 -14
- src/westernfront/api/middleware/__init__.py +5 -0
- src/westernfront/api/middleware/rate_limit.py +89 -0
- src/westernfront/api/routes.py +38 -71
- src/westernfront/api/schemas.py +2 -3
- src/westernfront/config.py +1 -2
- src/westernfront/core/__init__.py +24 -0
- src/westernfront/core/constants.py +123 -0
- src/westernfront/core/exceptions.py +29 -0
- src/westernfront/core/models.py +3 -4
- src/westernfront/dependencies.py +54 -44
- src/westernfront/main.py +18 -3
- src/westernfront/prompts/analysis.py +119 -24
- src/westernfront/repositories/analysis.py +15 -43
- src/westernfront/repositories/vectors.py +33 -68
- src/westernfront/services/__init__.py +12 -0
- src/westernfront/services/analysis.py +239 -362
- src/westernfront/services/cache.py +65 -48
- src/westernfront/services/chain_analysis.py +108 -0
- src/westernfront/services/embeddings.py +37 -52
- src/westernfront/services/http.py +57 -0
- src/westernfront/services/newsapi.py +56 -65
- src/westernfront/services/parsing.py +88 -0
- src/westernfront/services/reddit.py +73 -114
- src/westernfront/services/retrieval.py +101 -0
- src/westernfront/services/rss.py +65 -143
- src/westernfront/services/scheduler.py +69 -0
- src/westernfront/services/validation.py +119 -0
- src/westernfront/utils/__init__.py +5 -0
- src/westernfront/utils/json_parser.py +42 -0
- tests/__pycache__/__init__.cpython-312.pyc +0 -0
- tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc +0 -0
- tests/__pycache__/test_services.cpython-312-pytest-8.4.2.pyc +0 -0
- tests/test_api.py +71 -0
- tests/test_parsing.py +111 -0
.dockerignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exclude local dev environment
|
| 2 |
+
.venv
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.pyd
|
| 7 |
+
|
| 8 |
+
# Exclude git history
|
| 9 |
+
.git
|
| 10 |
+
.gitignore
|
| 11 |
+
|
| 12 |
+
# Exclude test cache and coverage reports
|
| 13 |
+
.pytest_cache
|
| 14 |
+
.coverage
|
| 15 |
+
htmlcov
|
| 16 |
+
tests/
|
| 17 |
+
|
| 18 |
+
# Exclude logs and data (unless specific data is needed)
|
| 19 |
+
logs/
|
| 20 |
+
data/
|
| 21 |
+
|
| 22 |
+
# Exclude local env files (secrets should be passed as env vars)
|
| 23 |
+
.env
|
| 24 |
+
.env.example
|
| 25 |
+
|
| 26 |
+
# Exclude IDE settings
|
| 27 |
+
.vscode
|
| 28 |
+
.idea
|
| 29 |
+
|
| 30 |
+
# Exclude poetry cache (if local)
|
| 31 |
+
.cache
|
.env.example
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Google Gemini API Key
|
| 2 |
+
GEMINI_API_KEY=
|
| 3 |
+
|
| 4 |
+
# Application Settings
|
| 5 |
+
UPDATE_INTERVAL_MINUTES=
|
| 6 |
+
CACHE_EXPIRY_MINUTES=
|
| 7 |
+
LOG_LEVEL=
|
| 8 |
+
AUTO_UPDATE_ENABLED=
|
| 9 |
+
REDDIT_CLIENT_ID=
|
| 10 |
+
REDDIT_CLIENT_SECRET=
|
| 11 |
+
REDDIT_USER_AGENT=
|
| 12 |
+
NEWSAPI_KEY=
|
| 13 |
+
WESTERNFRONT_API_KEY=
|
Dockerfile
CHANGED
|
@@ -1,27 +1,42 @@
|
|
| 1 |
-
FROM python:3.11-slim
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
CMD ["uvicorn", "westernfront.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
+
FROM python:3.11-slim AS builder
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
curl \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 8 |
+
&& pip install --no-cache-dir poetry==1.8.0
|
| 9 |
+
|
| 10 |
+
ENV POETRY_NO_INTERACTION=1 \
|
| 11 |
+
POETRY_VIRTUALENVS_IN_PROJECT=1 \
|
| 12 |
+
POETRY_VIRTUALENVS_CREATE=1 \
|
| 13 |
+
POETRY_CACHE_DIR=/tmp/poetry_cache
|
| 14 |
+
|
| 15 |
+
COPY pyproject.toml poetry.lock ./
|
| 16 |
+
|
| 17 |
+
RUN poetry install --only main --no-root && rm -rf $POETRY_CACHE_DIR
|
| 18 |
+
|
| 19 |
+
FROM python:3.11-slim AS runtime
|
| 20 |
+
|
| 21 |
+
WORKDIR /app
|
| 22 |
+
|
| 23 |
+
RUN groupadd -g 1000 appuser && \
|
| 24 |
+
useradd -u 1000 -g appuser -s /bin/bash -m appuser
|
| 25 |
+
|
| 26 |
+
RUN mkdir -p /app/data /app/logs && \
|
| 27 |
+
chown -R appuser:appuser /app
|
| 28 |
+
|
| 29 |
+
COPY --from=builder /app/.venv /app/.venv
|
| 30 |
+
COPY --chown=appuser:appuser src/ ./src/
|
| 31 |
+
|
| 32 |
+
ENV PATH="/app/.venv/bin:$PATH" \
|
| 33 |
+
PYTHONPATH="/app/src" \
|
| 34 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 35 |
+
PYTHONUNBUFFERED=1 \
|
| 36 |
+
PORT=7860
|
| 37 |
+
|
| 38 |
+
USER appuser
|
| 39 |
+
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
CMD ["uvicorn", "westernfront.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
src/westernfront/analytics/aggregator.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
"""Analytics aggregation for graph data."""
|
| 2 |
|
| 3 |
from collections import Counter
|
| 4 |
-
from datetime import datetime, timedelta, timezone
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
from westernfront.core.enums import TensionTrend
|
| 8 |
from westernfront.repositories.analysis import AnalysisRepository
|
|
@@ -12,24 +10,10 @@ class AnalyticsAggregator:
|
|
| 12 |
"""Aggregates analysis data for visualization."""
|
| 13 |
|
| 14 |
def __init__(self, repository: AnalysisRepository) -> None:
|
| 15 |
-
"""
|
| 16 |
-
Initialize the aggregator.
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
repository: Analysis repository for data access.
|
| 20 |
-
"""
|
| 21 |
self._repository = repository
|
| 22 |
|
| 23 |
async def get_tension_history(self, days: int = 30) -> dict:
|
| 24 |
-
"""
|
| 25 |
-
Get tension score history for graphing.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
days: Number of days to include.
|
| 29 |
-
|
| 30 |
-
Returns:
|
| 31 |
-
Dictionary with data points and summary statistics.
|
| 32 |
-
"""
|
| 33 |
history = await self._repository.get_tension_history(days)
|
| 34 |
|
| 35 |
if not history:
|
|
@@ -66,15 +50,7 @@ class AnalyticsAggregator:
|
|
| 66 |
}
|
| 67 |
|
| 68 |
async def get_source_breakdown(self, days: int = 7) -> dict:
|
| 69 |
-
"""
|
| 70 |
-
Get breakdown of sources used in recent analyses.
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
days: Number of days to include.
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
Dictionary with source counts.
|
| 77 |
-
"""
|
| 78 |
snapshots = await self._repository.get_history(days=days)
|
| 79 |
|
| 80 |
return {
|
|
@@ -90,16 +66,7 @@ class AnalyticsAggregator:
|
|
| 90 |
}
|
| 91 |
|
| 92 |
async def get_entity_frequency(self, days: int = 30, top_n: int = 10) -> dict:
|
| 93 |
-
"""
|
| 94 |
-
Get most frequently mentioned entities.
|
| 95 |
-
|
| 96 |
-
Args:
|
| 97 |
-
days: Number of days to include.
|
| 98 |
-
top_n: Number of top entities to return.
|
| 99 |
-
|
| 100 |
-
Returns:
|
| 101 |
-
Dictionary with entity frequency data.
|
| 102 |
-
"""
|
| 103 |
snapshots = await self._repository.get_history(days=days)
|
| 104 |
|
| 105 |
all_entities: list[str] = []
|
|
@@ -118,15 +85,7 @@ class AnalyticsAggregator:
|
|
| 118 |
}
|
| 119 |
|
| 120 |
async def get_analysis_type_distribution(self, days: int = 30) -> dict:
|
| 121 |
-
"""
|
| 122 |
-
Get distribution of analysis types.
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
days: Number of days to include.
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
Dictionary with type distribution.
|
| 129 |
-
"""
|
| 130 |
snapshots = await self._repository.get_history(days=days)
|
| 131 |
|
| 132 |
counter = Counter(s.analysis_type.value for s in snapshots)
|
|
|
|
| 1 |
"""Analytics aggregation for graph data."""
|
| 2 |
|
| 3 |
from collections import Counter
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from westernfront.core.enums import TensionTrend
|
| 6 |
from westernfront.repositories.analysis import AnalysisRepository
|
|
|
|
| 10 |
"""Aggregates analysis data for visualization."""
|
| 11 |
|
| 12 |
def __init__(self, repository: AnalysisRepository) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
self._repository = repository
|
| 14 |
|
| 15 |
async def get_tension_history(self, days: int = 30) -> dict:
|
| 16 |
+
"""Get tension score history for graphing."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
history = await self._repository.get_tension_history(days)
|
| 18 |
|
| 19 |
if not history:
|
|
|
|
| 50 |
}
|
| 51 |
|
| 52 |
async def get_source_breakdown(self, days: int = 7) -> dict:
|
| 53 |
+
"""Get breakdown of sources used in recent analyses."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
snapshots = await self._repository.get_history(days=days)
|
| 55 |
|
| 56 |
return {
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
async def get_entity_frequency(self, days: int = 30, top_n: int = 10) -> dict:
|
| 69 |
+
"""Get most frequently mentioned entities."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
snapshots = await self._repository.get_history(days=days)
|
| 71 |
|
| 72 |
all_entities: list[str] = []
|
|
|
|
| 85 |
}
|
| 86 |
|
| 87 |
async def get_analysis_type_distribution(self, days: int = 30) -> dict:
|
| 88 |
+
"""Get distribution of analysis types."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
snapshots = await self._repository.get_history(days=days)
|
| 90 |
|
| 91 |
counter = Counter(s.analysis_type.value for s in snapshots)
|
src/westernfront/api/auth.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
"""API key authentication middleware."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from fastapi import HTTPException, Request, status
|
| 4 |
-
from fastapi.security import APIKeyHeader
|
| 5 |
|
| 6 |
from westernfront.config import get_settings
|
| 7 |
|
| 8 |
-
|
| 9 |
-
API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
| 10 |
-
|
| 11 |
PUBLIC_PATHS = frozenset([
|
| 12 |
"/",
|
| 13 |
"/health",
|
|
@@ -18,15 +16,7 @@ PUBLIC_PATHS = frozenset([
|
|
| 18 |
|
| 19 |
|
| 20 |
async def verify_api_key(request: Request) -> None:
|
| 21 |
-
"""
|
| 22 |
-
Verify the API key from request headers.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
request: The incoming request.
|
| 26 |
-
|
| 27 |
-
Raises:
|
| 28 |
-
HTTPException: If API key is missing or invalid.
|
| 29 |
-
"""
|
| 30 |
if request.url.path in PUBLIC_PATHS:
|
| 31 |
return
|
| 32 |
|
|
@@ -39,7 +29,7 @@ async def verify_api_key(request: Request) -> None:
|
|
| 39 |
detail="Missing API key. Include X-API-Key header.",
|
| 40 |
)
|
| 41 |
|
| 42 |
-
if api_key
|
| 43 |
raise HTTPException(
|
| 44 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 45 |
detail="Invalid API key.",
|
|
|
|
| 1 |
"""API key authentication middleware."""
|
| 2 |
|
| 3 |
+
import secrets
|
| 4 |
+
|
| 5 |
from fastapi import HTTPException, Request, status
|
|
|
|
| 6 |
|
| 7 |
from westernfront.config import get_settings
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
PUBLIC_PATHS = frozenset([
|
| 10 |
"/",
|
| 11 |
"/health",
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
async def verify_api_key(request: Request) -> None:
|
| 19 |
+
"""Verify the API key from request headers using timing-safe comparison."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
if request.url.path in PUBLIC_PATHS:
|
| 21 |
return
|
| 22 |
|
|
|
|
| 29 |
detail="Missing API key. Include X-API-Key header.",
|
| 30 |
)
|
| 31 |
|
| 32 |
+
if not secrets.compare_digest(api_key.encode(), settings.api_key.encode()):
|
| 33 |
raise HTTPException(
|
| 34 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 35 |
detail="Invalid API key.",
|
src/westernfront/api/middleware/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Middleware package exports."""
|
| 2 |
+
|
| 3 |
+
from westernfront.api.middleware.rate_limit import RateLimitMiddleware
|
| 4 |
+
|
| 5 |
+
__all__ = ["RateLimitMiddleware"]
|
src/westernfront/api/middleware/rate_limit.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rate limiting middleware using token bucket algorithm with automatic cleanup."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from cachetools import TTLCache
|
| 6 |
+
from fastapi import Request, Response
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TokenBucket:
|
| 12 |
+
"""Token bucket for rate limiting."""
|
| 13 |
+
|
| 14 |
+
__slots__ = ("capacity", "refill_rate", "tokens", "last_refill")
|
| 15 |
+
|
| 16 |
+
def __init__(self, capacity: int, refill_rate: float) -> None:
|
| 17 |
+
self.capacity = capacity
|
| 18 |
+
self.refill_rate = refill_rate
|
| 19 |
+
self.tokens = float(capacity)
|
| 20 |
+
self.last_refill = time.monotonic()
|
| 21 |
+
|
| 22 |
+
def consume(self) -> bool:
|
| 23 |
+
"""Attempt to consume a token. Returns True if successful."""
|
| 24 |
+
now = time.monotonic()
|
| 25 |
+
elapsed = now - self.last_refill
|
| 26 |
+
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
| 27 |
+
self.last_refill = now
|
| 28 |
+
|
| 29 |
+
if self.tokens >= 1:
|
| 30 |
+
self.tokens -= 1
|
| 31 |
+
return True
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
def time_until_available(self) -> float:
|
| 35 |
+
"""Calculate seconds until next token is available."""
|
| 36 |
+
if self.tokens >= 1:
|
| 37 |
+
return 0.0
|
| 38 |
+
return (1 - self.tokens) / self.refill_rate
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 42 |
+
"""Rate limiting middleware using per-IP token buckets with automatic cleanup."""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
app,
|
| 47 |
+
requests_per_minute: int = 60,
|
| 48 |
+
burst_size: int = 10,
|
| 49 |
+
bucket_ttl_seconds: int = 300,
|
| 50 |
+
max_buckets: int = 10000,
|
| 51 |
+
) -> None:
|
| 52 |
+
super().__init__(app)
|
| 53 |
+
self._requests_per_minute = requests_per_minute
|
| 54 |
+
self._burst_size = burst_size
|
| 55 |
+
self._refill_rate = requests_per_minute / 60.0
|
| 56 |
+
self._buckets: TTLCache[str, TokenBucket] = TTLCache(
|
| 57 |
+
maxsize=max_buckets,
|
| 58 |
+
ttl=bucket_ttl_seconds,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def _get_bucket(self, client_ip: str) -> TokenBucket:
|
| 62 |
+
"""Get or create a token bucket for the client IP."""
|
| 63 |
+
bucket = self._buckets.get(client_ip)
|
| 64 |
+
if bucket is None:
|
| 65 |
+
bucket = TokenBucket(self._burst_size, self._refill_rate)
|
| 66 |
+
self._buckets[client_ip] = bucket
|
| 67 |
+
return bucket
|
| 68 |
+
|
| 69 |
+
async def dispatch(self, request: Request, call_next) -> Response:
|
| 70 |
+
"""Process request with rate limiting."""
|
| 71 |
+
client_ip = self._get_client_ip(request)
|
| 72 |
+
bucket = self._get_bucket(client_ip)
|
| 73 |
+
|
| 74 |
+
if not bucket.consume():
|
| 75 |
+
retry_after = int(bucket.time_until_available()) + 1
|
| 76 |
+
return JSONResponse(
|
| 77 |
+
status_code=429,
|
| 78 |
+
content={"detail": "Rate limit exceeded. Please slow down."},
|
| 79 |
+
headers={"Retry-After": str(retry_after)},
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return await call_next(request)
|
| 83 |
+
|
| 84 |
+
def _get_client_ip(self, request: Request) -> str:
|
| 85 |
+
"""Extract client IP from request, handling proxies."""
|
| 86 |
+
forwarded = request.headers.get("X-Forwarded-For")
|
| 87 |
+
if forwarded:
|
| 88 |
+
return forwarded.split(",")[0].strip()
|
| 89 |
+
return request.client.host if request.client else "unknown"
|
src/westernfront/api/routes.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
"""API route definitions."""
|
| 2 |
|
| 3 |
from datetime import datetime
|
| 4 |
-
from typing import Optional
|
| 5 |
|
| 6 |
-
from fastapi import APIRouter,
|
|
|
|
| 7 |
|
| 8 |
from westernfront import __version__
|
| 9 |
-
from westernfront.analytics import AnalyticsAggregator
|
| 10 |
from westernfront.api.schemas import (
|
| 11 |
AnalysisHistoryResponse,
|
| 12 |
AnalysisSnapshotResponse,
|
|
|
|
| 13 |
ConflictAnalysisResponse,
|
| 14 |
EntityFrequencyResponse,
|
| 15 |
HealthResponse,
|
|
@@ -20,25 +20,17 @@ from westernfront.api.schemas import (
|
|
| 20 |
SourcesResponse,
|
| 21 |
SubredditSourceResponse,
|
| 22 |
TensionHistoryResponse,
|
| 23 |
-
AnalysisTypeDistributionResponse,
|
| 24 |
)
|
| 25 |
from westernfront.core.enums import TensionLevel
|
| 26 |
-
from westernfront.dependencies import
|
| 27 |
-
get_analysis_service,
|
| 28 |
-
get_app_state,
|
| 29 |
-
get_repository,
|
| 30 |
-
)
|
| 31 |
-
from westernfront.repositories import AnalysisRepository
|
| 32 |
-
from westernfront.services import AnalysisService
|
| 33 |
-
|
| 34 |
|
| 35 |
router = APIRouter()
|
| 36 |
|
| 37 |
|
| 38 |
@router.get("/", response_model=RootResponse, tags=["General"])
|
| 39 |
-
async def root() -> RootResponse:
|
| 40 |
"""Root endpoint with API information."""
|
| 41 |
-
state =
|
| 42 |
return RootResponse(
|
| 43 |
name="WesternFront API",
|
| 44 |
description="AI-powered conflict tracker for India-Pakistan tensions",
|
|
@@ -49,9 +41,9 @@ async def root() -> RootResponse:
|
|
| 49 |
|
| 50 |
@router.get("/health", response_model=HealthResponse, tags=["General"])
|
| 51 |
@router.head("/health", response_model=HealthResponse, tags=["General"])
|
| 52 |
-
async def health_check() -> HealthResponse:
|
| 53 |
"""Health check endpoint."""
|
| 54 |
-
state =
|
| 55 |
latest = await state.repository.get_latest()
|
| 56 |
|
| 57 |
return HealthResponse(
|
|
@@ -68,16 +60,11 @@ async def health_check() -> HealthResponse:
|
|
| 68 |
)
|
| 69 |
|
| 70 |
|
| 71 |
-
@router.get(
|
| 72 |
-
|
| 73 |
-
response_model=ConflictAnalysisResponse,
|
| 74 |
-
tags=["Analysis"],
|
| 75 |
-
)
|
| 76 |
-
async def get_latest_analysis(
|
| 77 |
-
service: AnalysisService = Depends(get_analysis_service),
|
| 78 |
-
) -> ConflictAnalysisResponse:
|
| 79 |
"""Get the latest conflict analysis."""
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
if not analysis:
|
| 83 |
raise HTTPException(
|
|
@@ -88,18 +75,15 @@ async def get_latest_analysis(
|
|
| 88 |
return ConflictAnalysisResponse.model_validate(analysis.model_dump())
|
| 89 |
|
| 90 |
|
| 91 |
-
@router.get(
|
| 92 |
-
"/analysis/history",
|
| 93 |
-
response_model=AnalysisHistoryResponse,
|
| 94 |
-
tags=["Analysis"],
|
| 95 |
-
)
|
| 96 |
async def get_analysis_history(
|
|
|
|
| 97 |
days: int = Query(default=30, ge=1, le=90),
|
| 98 |
limit: int = Query(default=50, ge=1, le=100),
|
| 99 |
-
repository: AnalysisRepository = Depends(get_repository),
|
| 100 |
) -> AnalysisHistoryResponse:
|
| 101 |
"""Get historical analysis snapshots."""
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
return AnalysisHistoryResponse(
|
| 105 |
count=len(snapshots),
|
|
@@ -110,71 +94,55 @@ async def get_analysis_history(
|
|
| 110 |
)
|
| 111 |
|
| 112 |
|
| 113 |
-
@router.get(
|
| 114 |
-
"/analytics/tension-history",
|
| 115 |
-
response_model=TensionHistoryResponse,
|
| 116 |
-
tags=["Analytics"],
|
| 117 |
-
)
|
| 118 |
async def get_tension_history(
|
|
|
|
| 119 |
days: int = Query(default=30, ge=1, le=90),
|
| 120 |
-
repository: AnalysisRepository = Depends(get_repository),
|
| 121 |
) -> TensionHistoryResponse:
|
| 122 |
"""Get tension score history for graphing."""
|
| 123 |
-
|
| 124 |
-
result = await
|
| 125 |
return TensionHistoryResponse.model_validate(result)
|
| 126 |
|
| 127 |
|
| 128 |
-
@router.get(
|
| 129 |
-
"/analytics/source-breakdown",
|
| 130 |
-
response_model=SourceBreakdownResponse,
|
| 131 |
-
tags=["Analytics"],
|
| 132 |
-
)
|
| 133 |
async def get_source_breakdown(
|
|
|
|
| 134 |
days: int = Query(default=7, ge=1, le=30),
|
| 135 |
-
repository: AnalysisRepository = Depends(get_repository),
|
| 136 |
) -> SourceBreakdownResponse:
|
| 137 |
"""Get breakdown of sources used in analyses."""
|
| 138 |
-
|
| 139 |
-
result = await
|
| 140 |
return SourceBreakdownResponse.model_validate(result)
|
| 141 |
|
| 142 |
|
| 143 |
-
@router.get(
|
| 144 |
-
"/analytics/entity-frequency",
|
| 145 |
-
response_model=EntityFrequencyResponse,
|
| 146 |
-
tags=["Analytics"],
|
| 147 |
-
)
|
| 148 |
async def get_entity_frequency(
|
|
|
|
| 149 |
days: int = Query(default=30, ge=1, le=90),
|
| 150 |
top_n: int = Query(default=10, ge=1, le=50),
|
| 151 |
-
repository: AnalysisRepository = Depends(get_repository),
|
| 152 |
) -> EntityFrequencyResponse:
|
| 153 |
"""Get most frequently mentioned entities."""
|
| 154 |
-
|
| 155 |
-
result = await
|
| 156 |
return EntityFrequencyResponse.model_validate(result)
|
| 157 |
|
| 158 |
|
| 159 |
-
@router.get(
|
| 160 |
-
"/analytics/type-distribution",
|
| 161 |
-
response_model=AnalysisTypeDistributionResponse,
|
| 162 |
-
tags=["Analytics"],
|
| 163 |
-
)
|
| 164 |
async def get_type_distribution(
|
|
|
|
| 165 |
days: int = Query(default=30, ge=1, le=90),
|
| 166 |
-
repository: AnalysisRepository = Depends(get_repository),
|
| 167 |
) -> AnalysisTypeDistributionResponse:
|
| 168 |
"""Get distribution of analysis types."""
|
| 169 |
-
|
| 170 |
-
result = await
|
| 171 |
return AnalysisTypeDistributionResponse.model_validate(result)
|
| 172 |
|
| 173 |
|
| 174 |
@router.get("/sources", response_model=SourcesResponse, tags=["Configuration"])
|
| 175 |
-
async def get_sources() -> SourcesResponse:
|
| 176 |
"""Get current data sources configuration."""
|
| 177 |
-
state =
|
| 178 |
|
| 179 |
return SourcesResponse(
|
| 180 |
subreddits=[
|
|
@@ -190,13 +158,12 @@ async def get_sources() -> SourcesResponse:
|
|
| 190 |
|
| 191 |
|
| 192 |
@router.get("/keywords", response_model=KeywordsResponse, tags=["Configuration"])
|
| 193 |
-
async def get_keywords(
|
| 194 |
-
service: AnalysisService = Depends(get_analysis_service),
|
| 195 |
-
) -> KeywordsResponse:
|
| 196 |
"""Get current search keywords."""
|
|
|
|
| 197 |
return KeywordsResponse(
|
| 198 |
-
count=len(
|
| 199 |
-
keywords=
|
| 200 |
)
|
| 201 |
|
| 202 |
|
|
|
|
| 1 |
"""API route definitions."""
|
| 2 |
|
| 3 |
from datetime import datetime
|
|
|
|
| 4 |
|
| 5 |
+
from fastapi import APIRouter, Query, Request, status
|
| 6 |
+
from fastapi.exceptions import HTTPException
|
| 7 |
|
| 8 |
from westernfront import __version__
|
|
|
|
| 9 |
from westernfront.api.schemas import (
|
| 10 |
AnalysisHistoryResponse,
|
| 11 |
AnalysisSnapshotResponse,
|
| 12 |
+
AnalysisTypeDistributionResponse,
|
| 13 |
ConflictAnalysisResponse,
|
| 14 |
EntityFrequencyResponse,
|
| 15 |
HealthResponse,
|
|
|
|
| 20 |
SourcesResponse,
|
| 21 |
SubredditSourceResponse,
|
| 22 |
TensionHistoryResponse,
|
|
|
|
| 23 |
)
|
| 24 |
from westernfront.core.enums import TensionLevel
|
| 25 |
+
from westernfront.dependencies import get_state_from_request
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
router = APIRouter()
|
| 28 |
|
| 29 |
|
| 30 |
@router.get("/", response_model=RootResponse, tags=["General"])
|
| 31 |
+
async def root(request: Request) -> RootResponse:
|
| 32 |
"""Root endpoint with API information."""
|
| 33 |
+
state = get_state_from_request(request)
|
| 34 |
return RootResponse(
|
| 35 |
name="WesternFront API",
|
| 36 |
description="AI-powered conflict tracker for India-Pakistan tensions",
|
|
|
|
| 41 |
|
| 42 |
@router.get("/health", response_model=HealthResponse, tags=["General"])
|
| 43 |
@router.head("/health", response_model=HealthResponse, tags=["General"])
|
| 44 |
+
async def health_check(request: Request) -> HealthResponse:
|
| 45 |
"""Health check endpoint."""
|
| 46 |
+
state = get_state_from_request(request)
|
| 47 |
latest = await state.repository.get_latest()
|
| 48 |
|
| 49 |
return HealthResponse(
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
| 63 |
+
@router.get("/analysis", response_model=ConflictAnalysisResponse, tags=["Analysis"])
|
| 64 |
+
async def get_latest_analysis(request: Request) -> ConflictAnalysisResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
"""Get the latest conflict analysis."""
|
| 66 |
+
state = get_state_from_request(request)
|
| 67 |
+
analysis = await state.analysis.get_latest()
|
| 68 |
|
| 69 |
if not analysis:
|
| 70 |
raise HTTPException(
|
|
|
|
| 75 |
return ConflictAnalysisResponse.model_validate(analysis.model_dump())
|
| 76 |
|
| 77 |
|
| 78 |
+
@router.get("/analysis/history", response_model=AnalysisHistoryResponse, tags=["Analysis"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
async def get_analysis_history(
|
| 80 |
+
request: Request,
|
| 81 |
days: int = Query(default=30, ge=1, le=90),
|
| 82 |
limit: int = Query(default=50, ge=1, le=100),
|
|
|
|
| 83 |
) -> AnalysisHistoryResponse:
|
| 84 |
"""Get historical analysis snapshots."""
|
| 85 |
+
state = get_state_from_request(request)
|
| 86 |
+
snapshots = await state.repository.get_history(days=days, limit=limit)
|
| 87 |
|
| 88 |
return AnalysisHistoryResponse(
|
| 89 |
count=len(snapshots),
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
|
| 97 |
+
@router.get("/analytics/tension-history", response_model=TensionHistoryResponse, tags=["Analytics"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
async def get_tension_history(
|
| 99 |
+
request: Request,
|
| 100 |
days: int = Query(default=30, ge=1, le=90),
|
|
|
|
| 101 |
) -> TensionHistoryResponse:
|
| 102 |
"""Get tension score history for graphing."""
|
| 103 |
+
state = get_state_from_request(request)
|
| 104 |
+
result = await state.analytics.get_tension_history(days)
|
| 105 |
return TensionHistoryResponse.model_validate(result)
|
| 106 |
|
| 107 |
|
| 108 |
+
@router.get("/analytics/source-breakdown", response_model=SourceBreakdownResponse, tags=["Analytics"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
async def get_source_breakdown(
|
| 110 |
+
request: Request,
|
| 111 |
days: int = Query(default=7, ge=1, le=30),
|
|
|
|
| 112 |
) -> SourceBreakdownResponse:
|
| 113 |
"""Get breakdown of sources used in analyses."""
|
| 114 |
+
state = get_state_from_request(request)
|
| 115 |
+
result = await state.analytics.get_source_breakdown(days)
|
| 116 |
return SourceBreakdownResponse.model_validate(result)
|
| 117 |
|
| 118 |
|
| 119 |
+
@router.get("/analytics/entity-frequency", response_model=EntityFrequencyResponse, tags=["Analytics"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
async def get_entity_frequency(
|
| 121 |
+
request: Request,
|
| 122 |
days: int = Query(default=30, ge=1, le=90),
|
| 123 |
top_n: int = Query(default=10, ge=1, le=50),
|
|
|
|
| 124 |
) -> EntityFrequencyResponse:
|
| 125 |
"""Get most frequently mentioned entities."""
|
| 126 |
+
state = get_state_from_request(request)
|
| 127 |
+
result = await state.analytics.get_entity_frequency(days, top_n)
|
| 128 |
return EntityFrequencyResponse.model_validate(result)
|
| 129 |
|
| 130 |
|
| 131 |
+
@router.get("/analytics/type-distribution", response_model=AnalysisTypeDistributionResponse, tags=["Analytics"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
async def get_type_distribution(
|
| 133 |
+
request: Request,
|
| 134 |
days: int = Query(default=30, ge=1, le=90),
|
|
|
|
| 135 |
) -> AnalysisTypeDistributionResponse:
|
| 136 |
"""Get distribution of analysis types."""
|
| 137 |
+
state = get_state_from_request(request)
|
| 138 |
+
result = await state.analytics.get_analysis_type_distribution(days)
|
| 139 |
return AnalysisTypeDistributionResponse.model_validate(result)
|
| 140 |
|
| 141 |
|
| 142 |
@router.get("/sources", response_model=SourcesResponse, tags=["Configuration"])
|
| 143 |
+
async def get_sources(request: Request) -> SourcesResponse:
|
| 144 |
"""Get current data sources configuration."""
|
| 145 |
+
state = get_state_from_request(request)
|
| 146 |
|
| 147 |
return SourcesResponse(
|
| 148 |
subreddits=[
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
@router.get("/keywords", response_model=KeywordsResponse, tags=["Configuration"])
|
| 161 |
+
async def get_keywords(request: Request) -> KeywordsResponse:
|
|
|
|
|
|
|
| 162 |
"""Get current search keywords."""
|
| 163 |
+
state = get_state_from_request(request)
|
| 164 |
return KeywordsResponse(
|
| 165 |
+
count=len(state.analysis.keywords),
|
| 166 |
+
keywords=state.analysis.keywords,
|
| 167 |
)
|
| 168 |
|
| 169 |
|
src/westernfront/api/schemas.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""API request and response schemas."""
|
| 2 |
|
| 3 |
from datetime import datetime
|
| 4 |
-
from typing import Optional
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
@@ -14,7 +13,7 @@ class HealthResponse(BaseModel):
|
|
| 14 |
status: str
|
| 15 |
version: str
|
| 16 |
timestamp: datetime
|
| 17 |
-
last_update:
|
| 18 |
components: dict[str, bool]
|
| 19 |
|
| 20 |
|
|
@@ -93,7 +92,7 @@ class KeyDevelopmentResponse(BaseModel):
|
|
| 93 |
title: str
|
| 94 |
description: str
|
| 95 |
sources: list[str]
|
| 96 |
-
timestamp:
|
| 97 |
|
| 98 |
|
| 99 |
class ReliabilityAssessmentResponse(BaseModel):
|
|
|
|
| 1 |
"""API request and response schemas."""
|
| 2 |
|
| 3 |
from datetime import datetime
|
|
|
|
| 4 |
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
|
|
|
| 13 |
status: str
|
| 14 |
version: str
|
| 15 |
timestamp: datetime
|
| 16 |
+
last_update: datetime | None = None
|
| 17 |
components: dict[str, bool]
|
| 18 |
|
| 19 |
|
|
|
|
| 92 |
title: str
|
| 93 |
description: str
|
| 94 |
sources: list[str]
|
| 95 |
+
timestamp: datetime | None = None
|
| 96 |
|
| 97 |
|
| 98 |
class ReliabilityAssessmentResponse(BaseModel):
|
src/westernfront/config.py
CHANGED
|
@@ -6,7 +6,6 @@ for type-safe environment variable parsing and validation.
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from functools import lru_cache
|
| 9 |
-
from typing import Optional
|
| 10 |
|
| 11 |
from pydantic import Field
|
| 12 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
@@ -34,7 +33,7 @@ class Settings(BaseSettings):
|
|
| 34 |
gemini_api_key: str = Field(alias="GEMINI_API_KEY")
|
| 35 |
|
| 36 |
# NewsAPI (optional)
|
| 37 |
-
newsapi_key:
|
| 38 |
|
| 39 |
# Application Settings
|
| 40 |
update_interval_minutes: int = Field(default=60, alias="UPDATE_INTERVAL_MINUTES")
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from functools import lru_cache
|
|
|
|
| 9 |
|
| 10 |
from pydantic import Field
|
| 11 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
| 33 |
gemini_api_key: str = Field(alias="GEMINI_API_KEY")
|
| 34 |
|
| 35 |
# NewsAPI (optional)
|
| 36 |
+
newsapi_key: str | None = Field(default=None, alias="NEWSAPI_KEY")
|
| 37 |
|
| 38 |
# Application Settings
|
| 39 |
update_interval_minutes: int = Field(default=60, alias="UPDATE_INTERVAL_MINUTES")
|
src/westernfront/core/__init__.py
CHANGED
|
@@ -1,20 +1,44 @@
|
|
| 1 |
"""Core package exports."""
|
| 2 |
|
| 3 |
from westernfront.core.enums import AnalysisType, SourceType, TensionLevel, TensionTrend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from westernfront.core.models import (
|
|
|
|
| 5 |
ConflictAnalysis,
|
| 6 |
KeyDevelopment,
|
| 7 |
NewsItem,
|
|
|
|
|
|
|
|
|
|
| 8 |
SubredditSource,
|
| 9 |
)
|
| 10 |
|
| 11 |
__all__ = [
|
|
|
|
|
|
|
| 12 |
"AnalysisType",
|
|
|
|
| 13 |
"ConflictAnalysis",
|
|
|
|
| 14 |
"KeyDevelopment",
|
| 15 |
"NewsItem",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"SourceType",
|
| 17 |
"SubredditSource",
|
| 18 |
"TensionLevel",
|
| 19 |
"TensionTrend",
|
|
|
|
|
|
|
| 20 |
]
|
|
|
|
| 1 |
"""Core package exports."""
|
| 2 |
|
| 3 |
from westernfront.core.enums import AnalysisType, SourceType, TensionLevel, TensionTrend
|
| 4 |
+
from westernfront.core.exceptions import (
|
| 5 |
+
AnalysisError,
|
| 6 |
+
AuthenticationError,
|
| 7 |
+
DataFetchError,
|
| 8 |
+
RateLimitExceededError,
|
| 9 |
+
ServiceNotInitializedError,
|
| 10 |
+
VectorStoreError,
|
| 11 |
+
WesternFrontError,
|
| 12 |
+
)
|
| 13 |
from westernfront.core.models import (
|
| 14 |
+
AnalysisSnapshot,
|
| 15 |
ConflictAnalysis,
|
| 16 |
KeyDevelopment,
|
| 17 |
NewsItem,
|
| 18 |
+
RegionalImplications,
|
| 19 |
+
ReliabilityAssessment,
|
| 20 |
+
RssFeed,
|
| 21 |
SubredditSource,
|
| 22 |
)
|
| 23 |
|
| 24 |
__all__ = [
|
| 25 |
+
"AnalysisError",
|
| 26 |
+
"AnalysisSnapshot",
|
| 27 |
"AnalysisType",
|
| 28 |
+
"AuthenticationError",
|
| 29 |
"ConflictAnalysis",
|
| 30 |
+
"DataFetchError",
|
| 31 |
"KeyDevelopment",
|
| 32 |
"NewsItem",
|
| 33 |
+
"RateLimitExceededError",
|
| 34 |
+
"RegionalImplications",
|
| 35 |
+
"ReliabilityAssessment",
|
| 36 |
+
"RssFeed",
|
| 37 |
+
"ServiceNotInitializedError",
|
| 38 |
"SourceType",
|
| 39 |
"SubredditSource",
|
| 40 |
"TensionLevel",
|
| 41 |
"TensionTrend",
|
| 42 |
+
"VectorStoreError",
|
| 43 |
+
"WesternFrontError",
|
| 44 |
]
|
src/westernfront/core/constants.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constants and default configurations for WesternFront."""
|
| 2 |
+
|
| 3 |
+
from westernfront.core.models import RssFeed, SubredditSource
|
| 4 |
+
|
| 5 |
+
RELIABLE_DOMAINS = frozenset([
|
| 6 |
+
"bbc.com",
|
| 7 |
+
"reuters.com",
|
| 8 |
+
"apnews.com",
|
| 9 |
+
"aljazeera.com",
|
| 10 |
+
"nytimes.com",
|
| 11 |
+
"wsj.com",
|
| 12 |
+
"ft.com",
|
| 13 |
+
"economist.com",
|
| 14 |
+
"thediplomat.com",
|
| 15 |
+
"foreignpolicy.com",
|
| 16 |
+
"foreignaffairs.com",
|
| 17 |
+
"dawn.com",
|
| 18 |
+
"timesofindia.indiatimes.com",
|
| 19 |
+
"ndtv.com",
|
| 20 |
+
"geo.tv",
|
| 21 |
+
])
|
| 22 |
+
|
| 23 |
+
DEFAULT_SUBREDDITS = [
|
| 24 |
+
SubredditSource(name="geopolitics", reliability_score=0.85),
|
| 25 |
+
SubredditSource(name="CredibleDefense", reliability_score=0.9),
|
| 26 |
+
SubredditSource(name="worldnews", reliability_score=0.8),
|
| 27 |
+
SubredditSource(name="neutralnews", reliability_score=0.8),
|
| 28 |
+
SubredditSource(name="DefenseNews", reliability_score=0.85),
|
| 29 |
+
SubredditSource(name="GeopoliticsIndia", reliability_score=0.75),
|
| 30 |
+
SubredditSource(name="SouthAsia", reliability_score=0.7),
|
| 31 |
+
SubredditSource(name="india", reliability_score=0.7),
|
| 32 |
+
SubredditSource(name="pakistan", reliability_score=0.7),
|
| 33 |
+
SubredditSource(name="Nepal", reliability_score=0.65),
|
| 34 |
+
SubredditSource(name="bangladesh", reliability_score=0.65),
|
| 35 |
+
SubredditSource(name="srilanka", reliability_score=0.65),
|
| 36 |
+
SubredditSource(name="China", reliability_score=0.6),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
DEFAULT_RSS_FEEDS = [
|
| 40 |
+
RssFeed(name="Dawn (Pakistan)", url="https://www.dawn.com/feeds/home", reliability_score=0.85),
|
| 41 |
+
RssFeed(name="Geo News", url="https://www.geo.tv/rss/1/1", reliability_score=0.8),
|
| 42 |
+
RssFeed(name="Express Tribune", url="https://tribune.com.pk/feed/home", reliability_score=0.75),
|
| 43 |
+
RssFeed(name="Times of India", url="https://timesofindia.indiatimes.com/rssfeeds/296589292.cms", reliability_score=0.75),
|
| 44 |
+
RssFeed(name="NDTV India", url="https://feeds.feedburner.com/ndtvnews-india-news", reliability_score=0.8),
|
| 45 |
+
RssFeed(name="The Hindu", url="https://www.thehindu.com/news/national/feeder/default.rss", reliability_score=0.85),
|
| 46 |
+
RssFeed(name="Indian Express", url="https://indianexpress.com/section/india/feed/", reliability_score=0.85),
|
| 47 |
+
RssFeed(name="South China Morning Post - Asia", url="https://www.scmp.com/rss/91/feed", reliability_score=0.85),
|
| 48 |
+
RssFeed(name="Kathmandu Post", url="https://kathmandupost.com/rss", reliability_score=0.75),
|
| 49 |
+
RssFeed(name="Dhaka Tribune", url="https://www.dhakatribune.com/rss", reliability_score=0.75),
|
| 50 |
+
RssFeed(name="Daily Star Bangladesh", url="https://www.thedailystar.net/rss.xml", reliability_score=0.75),
|
| 51 |
+
RssFeed(name="Daily Mirror Sri Lanka", url="http://www.dailymirror.lk/RSS_Feeds/breaking-news", reliability_score=0.7),
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
NEWSAPI_QUERIES = [
|
| 55 |
+
"India Pakistan",
|
| 56 |
+
"Kashmir conflict",
|
| 57 |
+
"India Pakistan border",
|
| 58 |
+
"LOC firing",
|
| 59 |
+
"Indo-Pak",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
RAG_QUERY_TOPICS = [
|
| 63 |
+
"India Pakistan military conflict border tensions ceasefire violation",
|
| 64 |
+
"Kashmir territorial dispute LOC Line of Control",
|
| 65 |
+
"India China LAC Ladakh Arunachal standoff",
|
| 66 |
+
"Nepal Bangladesh Sri Lanka India bilateral relations",
|
| 67 |
+
"South Asia terrorism cross-border insurgency",
|
| 68 |
+
"India diplomatic relations regional geopolitics",
|
| 69 |
+
"Military exercises defense buildup South Asia",
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
NEWSAPI_BASE_URL = "https://newsapi.org/v2"
|
| 73 |
+
|
| 74 |
+
HTTP_TIMEOUT_SECONDS = 30
|
| 75 |
+
MAX_CONCURRENT_REQUESTS = 10
|
| 76 |
+
|
| 77 |
+
# Source diversity rules for retrieval
|
| 78 |
+
SOURCE_DIVERSITY_RULES = {
|
| 79 |
+
"reddit": {"min_pct": 0.25, "max_pct": 0.50},
|
| 80 |
+
"rss": {"min_pct": 0.30, "max_pct": 0.55},
|
| 81 |
+
"newsapi": {"min_pct": 0.10, "max_pct": 0.30},
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# Temporal weighting for recency boost
|
| 85 |
+
RECENCY_BOOST = {
|
| 86 |
+
"hours_24": 1.5,
|
| 87 |
+
"hours_48": 1.25,
|
| 88 |
+
"days_7": 1.0,
|
| 89 |
+
"older": 0.75,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Tension level criteria for validation
|
| 93 |
+
TENSION_LEVEL_CRITERIA = {
|
| 94 |
+
"LOW": {
|
| 95 |
+
"score_range": (1, 3),
|
| 96 |
+
"description": "Normal diplomatic activity, routine border incidents, no escalation",
|
| 97 |
+
},
|
| 98 |
+
"MEDIUM": {
|
| 99 |
+
"score_range": (4, 5),
|
| 100 |
+
"description": "Heightened rhetoric, minor military movements, diplomatic notes exchanged",
|
| 101 |
+
},
|
| 102 |
+
"HIGH": {
|
| 103 |
+
"score_range": (6, 8),
|
| 104 |
+
"description": "Military mobilization, cross-border firing, diplomatic summoning",
|
| 105 |
+
},
|
| 106 |
+
"CRITICAL": {
|
| 107 |
+
"score_range": (9, 10),
|
| 108 |
+
"description": "Active military engagement, imminent conflict, emergency measures",
|
| 109 |
+
},
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
SEARCH_KEYWORDS = [
|
| 113 |
+
"India Pakistan",
|
| 114 |
+
"Kashmir",
|
| 115 |
+
"LOC",
|
| 116 |
+
"ceasefire",
|
| 117 |
+
"border tension",
|
| 118 |
+
"military",
|
| 119 |
+
"diplomatic",
|
| 120 |
+
"terrorist",
|
| 121 |
+
"strike",
|
| 122 |
+
"conflict",
|
| 123 |
+
]
|
src/westernfront/core/exceptions.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom exceptions for WesternFront."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class WesternFrontError(Exception):
|
| 5 |
+
"""Base exception for all WesternFront errors."""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AuthenticationError(WesternFrontError):
|
| 9 |
+
"""Raised when API authentication fails."""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RateLimitExceededError(WesternFrontError):
|
| 13 |
+
"""Raised when rate limit is exceeded."""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ServiceNotInitializedError(WesternFrontError):
|
| 17 |
+
"""Raised when a service is accessed before initialization."""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DataFetchError(WesternFrontError):
|
| 21 |
+
"""Raised when fetching data from external sources fails."""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AnalysisError(WesternFrontError):
|
| 25 |
+
"""Raised when AI analysis fails."""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class VectorStoreError(WesternFrontError):
|
| 29 |
+
"""Raised when vector store operations fail."""
|
src/westernfront/core/models.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""Core domain models for WesternFront."""
|
| 2 |
|
| 3 |
from datetime import datetime
|
| 4 |
-
from typing import Optional
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
@@ -36,8 +35,8 @@ class NewsItem(BaseModel):
|
|
| 36 |
source_type: SourceType
|
| 37 |
published_at: datetime
|
| 38 |
reliability_score: float = Field(default=0.5, ge=0.0, le=1.0)
|
| 39 |
-
author:
|
| 40 |
-
score:
|
| 41 |
|
| 42 |
|
| 43 |
class KeyDevelopment(BaseModel):
|
|
@@ -46,7 +45,7 @@ class KeyDevelopment(BaseModel):
|
|
| 46 |
title: str
|
| 47 |
description: str
|
| 48 |
sources: list[str]
|
| 49 |
-
timestamp:
|
| 50 |
|
| 51 |
|
| 52 |
class ReliabilityAssessment(BaseModel):
|
|
|
|
| 1 |
"""Core domain models for WesternFront."""
|
| 2 |
|
| 3 |
from datetime import datetime
|
|
|
|
| 4 |
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
|
|
|
| 35 |
source_type: SourceType
|
| 36 |
published_at: datetime
|
| 37 |
reliability_score: float = Field(default=0.5, ge=0.0, le=1.0)
|
| 38 |
+
author: str | None = None
|
| 39 |
+
score: int | None = None
|
| 40 |
|
| 41 |
|
| 42 |
class KeyDevelopment(BaseModel):
|
|
|
|
| 45 |
title: str
|
| 46 |
description: str
|
| 47 |
sources: list[str]
|
| 48 |
+
timestamp: datetime | None = None
|
| 49 |
|
| 50 |
|
| 51 |
class ReliabilityAssessment(BaseModel):
|
src/westernfront/dependencies.py
CHANGED
|
@@ -1,21 +1,24 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Dependency injection container for WesternFront.
|
| 3 |
-
|
| 4 |
-
Provides FastAPI dependencies for services with proper lifecycle management.
|
| 5 |
-
"""
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
from dataclasses import dataclass
|
| 9 |
-
from typing import
|
| 10 |
|
| 11 |
from loguru import logger
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from westernfront.config import Settings, get_settings
|
| 14 |
from westernfront.repositories.analysis import AnalysisRepository
|
| 15 |
from westernfront.repositories.vectors import VectorRepository
|
| 16 |
from westernfront.services.analysis import AnalysisService
|
| 17 |
from westernfront.services.cache import CacheService
|
| 18 |
from westernfront.services.embeddings import EmbeddingService
|
|
|
|
| 19 |
from westernfront.services.newsapi import NewsApiService
|
| 20 |
from westernfront.services.reddit import RedditService
|
| 21 |
from westernfront.services.rss import RssService
|
|
@@ -26,6 +29,7 @@ class AppState:
|
|
| 26 |
"""Container for application-scoped services."""
|
| 27 |
|
| 28 |
settings: Settings
|
|
|
|
| 29 |
cache: CacheService
|
| 30 |
reddit: RedditService
|
| 31 |
rss: RssService
|
|
@@ -34,16 +38,28 @@ class AppState:
|
|
| 34 |
embeddings: EmbeddingService
|
| 35 |
vectors: VectorRepository
|
| 36 |
analysis: AnalysisService
|
|
|
|
|
|
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
|
| 41 |
|
| 42 |
async def init_services() -> AppState:
|
| 43 |
-
"""Initialize all services for the application."""
|
| 44 |
settings = get_settings()
|
| 45 |
|
|
|
|
| 46 |
cache = CacheService(ttl_seconds=settings.cache_expiry_minutes * 60)
|
|
|
|
| 47 |
|
| 48 |
reddit = RedditService(
|
| 49 |
client_id=settings.reddit_client_id,
|
|
@@ -51,23 +67,23 @@ async def init_services() -> AppState:
|
|
| 51 |
user_agent=settings.reddit_user_agent,
|
| 52 |
cache=cache,
|
| 53 |
)
|
| 54 |
-
await reddit.initialize()
|
| 55 |
-
|
| 56 |
-
rss = RssService(cache=cache)
|
| 57 |
|
| 58 |
-
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
await
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
-
vectors =
|
| 68 |
-
vectors.initialize()
|
| 69 |
logger.info(f"Vector repository initialized with {vectors.get_count()} items")
|
| 70 |
|
|
|
|
|
|
|
| 71 |
analysis = AnalysisService(
|
| 72 |
gemini_api_key=settings.gemini_api_key,
|
| 73 |
reddit=reddit,
|
|
@@ -81,6 +97,7 @@ async def init_services() -> AppState:
|
|
| 81 |
|
| 82 |
return AppState(
|
| 83 |
settings=settings,
|
|
|
|
| 84 |
cache=cache,
|
| 85 |
reddit=reddit,
|
| 86 |
rss=rss,
|
|
@@ -89,6 +106,7 @@ async def init_services() -> AppState:
|
|
| 89 |
embeddings=embeddings,
|
| 90 |
vectors=vectors,
|
| 91 |
analysis=analysis,
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
|
|
@@ -96,43 +114,35 @@ async def shutdown_services(state: AppState) -> None:
|
|
| 96 |
"""Clean up all services."""
|
| 97 |
await state.analysis.close()
|
| 98 |
await state.reddit.close()
|
|
|
|
| 99 |
await state.repository.close()
|
| 100 |
|
| 101 |
|
| 102 |
@asynccontextmanager
|
| 103 |
async def lifespan_context() -> AsyncGenerator[AppState, None]:
|
| 104 |
"""Lifespan context manager for FastAPI."""
|
| 105 |
-
|
| 106 |
-
_app_state = await init_services()
|
| 107 |
try:
|
| 108 |
-
yield
|
| 109 |
finally:
|
| 110 |
-
await shutdown_services(
|
| 111 |
-
_app_state = None
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def get_app_state() -> AppState:
|
| 115 |
-
"""Get the current application state."""
|
| 116 |
-
if _app_state is None:
|
| 117 |
-
raise RuntimeError("Application not initialized")
|
| 118 |
-
return _app_state
|
| 119 |
|
| 120 |
|
| 121 |
-
def
|
| 122 |
-
"""
|
| 123 |
-
return
|
| 124 |
|
| 125 |
|
| 126 |
-
def
|
| 127 |
-
"""
|
| 128 |
-
return
|
| 129 |
|
| 130 |
|
| 131 |
-
def
|
| 132 |
-
"""
|
| 133 |
-
return
|
| 134 |
|
| 135 |
|
| 136 |
-
def
|
| 137 |
-
"""
|
| 138 |
-
return
|
|
|
|
| 1 |
+
"""Dependency injection container for WesternFront."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
from collections.abc import AsyncGenerator
|
| 5 |
from contextlib import asynccontextmanager
|
| 6 |
from dataclasses import dataclass
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from fastapi import Request
|
| 13 |
+
|
| 14 |
+
from westernfront.analytics import AnalyticsAggregator
|
| 15 |
from westernfront.config import Settings, get_settings
|
| 16 |
from westernfront.repositories.analysis import AnalysisRepository
|
| 17 |
from westernfront.repositories.vectors import VectorRepository
|
| 18 |
from westernfront.services.analysis import AnalysisService
|
| 19 |
from westernfront.services.cache import CacheService
|
| 20 |
from westernfront.services.embeddings import EmbeddingService
|
| 21 |
+
from westernfront.services.http import HttpService
|
| 22 |
from westernfront.services.newsapi import NewsApiService
|
| 23 |
from westernfront.services.reddit import RedditService
|
| 24 |
from westernfront.services.rss import RssService
|
|
|
|
| 29 |
"""Container for application-scoped services."""
|
| 30 |
|
| 31 |
settings: Settings
|
| 32 |
+
http: HttpService
|
| 33 |
cache: CacheService
|
| 34 |
reddit: RedditService
|
| 35 |
rss: RssService
|
|
|
|
| 38 |
embeddings: EmbeddingService
|
| 39 |
vectors: VectorRepository
|
| 40 |
analysis: AnalysisService
|
| 41 |
+
analytics: AnalyticsAggregator
|
| 42 |
+
|
| 43 |
|
| 44 |
+
async def _init_sync_services(settings: Settings) -> tuple[EmbeddingService, VectorRepository]:
|
| 45 |
+
"""Initialize synchronous services in a thread pool."""
|
| 46 |
+
def _init_embeddings_and_vectors() -> tuple[EmbeddingService, VectorRepository]:
|
| 47 |
+
embeddings = EmbeddingService()
|
| 48 |
+
embeddings.initialize()
|
| 49 |
+
vectors = VectorRepository(embedding_service=embeddings)
|
| 50 |
+
vectors.initialize()
|
| 51 |
+
return embeddings, vectors
|
| 52 |
|
| 53 |
+
return await asyncio.to_thread(_init_embeddings_and_vectors)
|
| 54 |
|
| 55 |
|
| 56 |
async def init_services() -> AppState:
|
| 57 |
+
"""Initialize all services for the application with parallel execution."""
|
| 58 |
settings = get_settings()
|
| 59 |
|
| 60 |
+
http = HttpService()
|
| 61 |
cache = CacheService(ttl_seconds=settings.cache_expiry_minutes * 60)
|
| 62 |
+
repository = AnalysisRepository(db_path=settings.database_path)
|
| 63 |
|
| 64 |
reddit = RedditService(
|
| 65 |
client_id=settings.reddit_client_id,
|
|
|
|
| 67 |
user_agent=settings.reddit_user_agent,
|
| 68 |
cache=cache,
|
| 69 |
)
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
rss = RssService(cache=cache, http=http)
|
| 72 |
+
newsapi = NewsApiService(api_key=settings.newsapi_key, cache=cache, http=http)
|
| 73 |
|
| 74 |
+
# Parallel initialization of independent services
|
| 75 |
+
http_init, reddit_init, repo_init, vectors_result = await asyncio.gather(
|
| 76 |
+
http.initialize(),
|
| 77 |
+
reddit.initialize(),
|
| 78 |
+
repository.initialize(),
|
| 79 |
+
_init_sync_services(settings),
|
| 80 |
+
)
|
| 81 |
|
| 82 |
+
embeddings, vectors = vectors_result
|
|
|
|
| 83 |
logger.info(f"Vector repository initialized with {vectors.get_count()} items")
|
| 84 |
|
| 85 |
+
analytics = AnalyticsAggregator(repository)
|
| 86 |
+
|
| 87 |
analysis = AnalysisService(
|
| 88 |
gemini_api_key=settings.gemini_api_key,
|
| 89 |
reddit=reddit,
|
|
|
|
| 97 |
|
| 98 |
return AppState(
|
| 99 |
settings=settings,
|
| 100 |
+
http=http,
|
| 101 |
cache=cache,
|
| 102 |
reddit=reddit,
|
| 103 |
rss=rss,
|
|
|
|
| 106 |
embeddings=embeddings,
|
| 107 |
vectors=vectors,
|
| 108 |
analysis=analysis,
|
| 109 |
+
analytics=analytics,
|
| 110 |
)
|
| 111 |
|
| 112 |
|
|
|
|
| 114 |
"""Clean up all services."""
|
| 115 |
await state.analysis.close()
|
| 116 |
await state.reddit.close()
|
| 117 |
+
await state.http.close()
|
| 118 |
await state.repository.close()
|
| 119 |
|
| 120 |
|
| 121 |
@asynccontextmanager
|
| 122 |
async def lifespan_context() -> AsyncGenerator[AppState, None]:
|
| 123 |
"""Lifespan context manager for FastAPI."""
|
| 124 |
+
state = await init_services()
|
|
|
|
| 125 |
try:
|
| 126 |
+
yield state
|
| 127 |
finally:
|
| 128 |
+
await shutdown_services(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
+
def get_state_from_request(request: "Request") -> AppState:
|
| 132 |
+
"""Get application state from request."""
|
| 133 |
+
return request.app.state.westernfront
|
| 134 |
|
| 135 |
|
| 136 |
+
def get_analysis_service(request: "Request") -> AnalysisService:
|
| 137 |
+
"""Get AnalysisService from request."""
|
| 138 |
+
return get_state_from_request(request).analysis
|
| 139 |
|
| 140 |
|
| 141 |
+
def get_repository(request: "Request") -> AnalysisRepository:
|
| 142 |
+
"""Get AnalysisRepository from request."""
|
| 143 |
+
return get_state_from_request(request).repository
|
| 144 |
|
| 145 |
|
| 146 |
+
def get_analytics(request: "Request") -> AnalyticsAggregator:
|
| 147 |
+
"""Get AnalyticsAggregator from request."""
|
| 148 |
+
return get_state_from_request(request).analytics
|
src/westernfront/main.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
| 1 |
"""FastAPI application factory and entry point."""
|
| 2 |
|
| 3 |
import os
|
|
|
|
| 4 |
from contextlib import asynccontextmanager
|
| 5 |
-
from typing import AsyncGenerator
|
| 6 |
|
| 7 |
-
from fastapi import FastAPI, Request
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
from westernfront import __version__
|
| 12 |
from westernfront.api.auth import verify_api_key
|
|
|
|
| 13 |
from westernfront.api.routes import router
|
| 14 |
from westernfront.config import get_settings
|
| 15 |
from westernfront.dependencies import lifespan_context
|
|
@@ -43,6 +44,8 @@ def create_app() -> FastAPI:
|
|
| 43 |
lifespan=lifespan,
|
| 44 |
)
|
| 45 |
|
|
|
|
|
|
|
| 46 |
app.add_middleware(
|
| 47 |
CORSMiddleware,
|
| 48 |
allow_origins=settings.allowed_origins,
|
|
@@ -52,7 +55,19 @@ def create_app() -> FastAPI:
|
|
| 52 |
)
|
| 53 |
|
| 54 |
@app.middleware("http")
|
| 55 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
await verify_api_key(request)
|
| 57 |
return await call_next(request)
|
| 58 |
|
|
|
|
| 1 |
"""FastAPI application factory and entry point."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
from collections.abc import AsyncGenerator
|
| 5 |
from contextlib import asynccontextmanager
|
|
|
|
| 6 |
|
| 7 |
+
from fastapi import FastAPI, Request, Response
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
from westernfront import __version__
|
| 12 |
from westernfront.api.auth import verify_api_key
|
| 13 |
+
from westernfront.api.middleware import RateLimitMiddleware
|
| 14 |
from westernfront.api.routes import router
|
| 15 |
from westernfront.config import get_settings
|
| 16 |
from westernfront.dependencies import lifespan_context
|
|
|
|
| 44 |
lifespan=lifespan,
|
| 45 |
)
|
| 46 |
|
| 47 |
+
app.add_middleware(RateLimitMiddleware, requests_per_minute=120, burst_size=20)
|
| 48 |
+
|
| 49 |
app.add_middleware(
|
| 50 |
CORSMiddleware,
|
| 51 |
allow_origins=settings.allowed_origins,
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
@app.middleware("http")
|
| 58 |
+
async def security_headers_middleware(request: Request, call_next) -> Response:
|
| 59 |
+
"""Add security headers to all responses."""
|
| 60 |
+
response = await call_next(request)
|
| 61 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 62 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 63 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 64 |
+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
| 65 |
+
response.headers["Content-Security-Policy"] = "default-src 'self'; frame-ancestors 'none'"
|
| 66 |
+
return response
|
| 67 |
+
|
| 68 |
+
@app.middleware("http")
|
| 69 |
+
async def api_key_middleware(request: Request, call_next) -> Response:
|
| 70 |
+
"""Verify API key for protected endpoints."""
|
| 71 |
await verify_api_key(request)
|
| 72 |
return await call_next(request)
|
| 73 |
|
src/westernfront/prompts/analysis.py
CHANGED
|
@@ -2,33 +2,53 @@
|
|
| 2 |
|
| 3 |
from datetime import datetime
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
""
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
Args:
|
| 11 |
-
retrieved_items: Items retrieved from vector search with metadata.
|
| 12 |
-
total_in_memory: Total items in the vector database.
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
"""
|
| 17 |
source_entries = []
|
| 18 |
for i, item in enumerate(retrieved_items):
|
| 19 |
meta = item.get("metadata", {})
|
| 20 |
doc = item.get("document", "")
|
| 21 |
-
score = item.get("similarity_score", 0)
|
| 22 |
reliability_val = meta.get("reliability_score", 0.5)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
entry = (
|
| 29 |
-
f"INTEL #{i + 1} [Relevance: {score:.0%}] [Reliability: {reliability}]:\n"
|
| 30 |
f"Source: {meta.get('source_name', 'Unknown')} ({meta.get('source_type', 'unknown')})\n"
|
| 31 |
-
f"Date: {
|
| 32 |
f"Content: {doc}"
|
| 33 |
)
|
| 34 |
source_entries.append(entry)
|
|
@@ -37,7 +57,13 @@ def build_rag_prompt(retrieved_items: list[dict], total_in_memory: int = 0) -> s
|
|
| 37 |
|
| 38 |
memory_note = ""
|
| 39 |
if total_in_memory > 0:
|
| 40 |
-
memory_note =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
return f"""**TOP SECRET // FOR OFFICIAL USE ONLY**
|
| 43 |
|
|
@@ -59,6 +85,11 @@ Analyze ALL matters relevant to India's regional relationships, prioritized as:
|
|
| 59 |
**NEUTRALITY DIRECTIVE:**
|
| 60 |
Maintain absolute neutrality. Present multiple perspectives when conflicting information exists. Acknowledge information gaps.
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
**INTELLIGENCE FEEDS:**
|
| 63 |
---
|
| 64 |
{intelligence_data}
|
|
@@ -67,8 +98,9 @@ Maintain absolute neutrality. Present multiple perspectives when conflicting inf
|
|
| 67 |
**ANALYTICAL DIRECTIVES:**
|
| 68 |
1. **Synthesize, Do Not Summarize:** Integrate all data into a coherent assessment.
|
| 69 |
2. **Impersonal Tone:** Use formal, analytical language.
|
| 70 |
-
3. **
|
| 71 |
4. **Acknowledge Uncertainty:** Indicate confidence levels and information gaps.
|
|
|
|
| 72 |
|
| 73 |
**REQUIRED OUTPUT FORMAT (Strict JSON):**
|
| 74 |
Produce a single, valid JSON object:
|
|
@@ -83,8 +115,8 @@ Produce a single, valid JSON object:
|
|
| 83 |
}}
|
| 84 |
],
|
| 85 |
"reliability_assessment": {{
|
| 86 |
-
"source_credibility": "Overall credibility assessment.",
|
| 87 |
-
"information_gaps": "What critical information is missing.",
|
| 88 |
"confidence_rating": "HIGH, MEDIUM, or LOW with justification."
|
| 89 |
}},
|
| 90 |
"regional_implications": {{
|
|
@@ -93,9 +125,72 @@ Produce a single, valid JSON object:
|
|
| 93 |
"economic": "Potential economic consequences."
|
| 94 |
}},
|
| 95 |
"tension_level": "LOW|MEDIUM|HIGH|CRITICAL",
|
| 96 |
-
"tension_rationale": "Justification for the assessed tension level.",
|
| 97 |
-
"tension_score": "Integer 1-10
|
| 98 |
"tension_trend": "INCREASING|DECREASING|STABLE",
|
| 99 |
"analysis_type": "Military|Diplomatic|Internal Security|Political|Other",
|
| 100 |
-
"key_entities": ["3-5 key actors, locations, or organizations."]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
}}"""
|
|
|
|
| 2 |
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
+
from westernfront.core.constants import TENSION_LEVEL_CRITERIA
|
| 6 |
|
| 7 |
+
RELIABILITY_THRESHOLDS = {
|
| 8 |
+
"HIGH": 0.8,
|
| 9 |
+
"MEDIUM": 0.6,
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _get_reliability_label(score: float) -> str:
|
| 14 |
+
"""Convert reliability score to label."""
|
| 15 |
+
if score > RELIABILITY_THRESHOLDS["HIGH"]:
|
| 16 |
+
return "HIGH"
|
| 17 |
+
if score > RELIABILITY_THRESHOLDS["MEDIUM"]:
|
| 18 |
+
return "MEDIUM"
|
| 19 |
+
return "LOW"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _build_tension_criteria() -> str:
|
| 23 |
+
"""Build tension level criteria section for prompt."""
|
| 24 |
+
lines = []
|
| 25 |
+
for level, criteria in TENSION_LEVEL_CRITERIA.items():
|
| 26 |
+
score_range = criteria["score_range"]
|
| 27 |
+
desc = criteria["description"]
|
| 28 |
+
lines.append(f"- **{level}** (Score {score_range[0]}-{score_range[1]}): {desc}")
|
| 29 |
+
return "\n".join(lines)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
def build_rag_prompt(retrieved_items: list[dict], total_in_memory: int = 0) -> str:
|
| 33 |
+
"""Build prompt for RAG-enhanced analysis using vector-retrieved items."""
|
|
|
|
| 34 |
source_entries = []
|
| 35 |
for i, item in enumerate(retrieved_items):
|
| 36 |
meta = item.get("metadata", {})
|
| 37 |
doc = item.get("document", "")
|
| 38 |
+
score = item.get("boosted_score", item.get("similarity_score", 0))
|
| 39 |
reliability_val = meta.get("reliability_score", 0.5)
|
| 40 |
+
reliability = _get_reliability_label(reliability_val)
|
| 41 |
+
|
| 42 |
+
published = meta.get("published_at", "")
|
| 43 |
+
date_str = published[:10] if published else "Unknown"
|
| 44 |
+
|
| 45 |
+
recency = item.get("recency_multiplier", 1.0)
|
| 46 |
+
recency_label = "FRESH" if recency >= 1.5 else ("RECENT" if recency >= 1.0 else "OLDER")
|
| 47 |
+
|
| 48 |
entry = (
|
| 49 |
+
f"INTEL #{i + 1} [Relevance: {score:.0%}] [Reliability: {reliability}] [{recency_label}]:\n"
|
| 50 |
f"Source: {meta.get('source_name', 'Unknown')} ({meta.get('source_type', 'unknown')})\n"
|
| 51 |
+
f"Date: {date_str}\n"
|
| 52 |
f"Content: {doc}"
|
| 53 |
)
|
| 54 |
source_entries.append(entry)
|
|
|
|
| 57 |
|
| 58 |
memory_note = ""
|
| 59 |
if total_in_memory > 0:
|
| 60 |
+
memory_note = (
|
| 61 |
+
f"\n\n**INSTITUTIONAL MEMORY:** This analysis draws from a database of "
|
| 62 |
+
f"{total_in_memory:,} indexed news items. The items shown below are the most "
|
| 63 |
+
f"semantically relevant to South Asia conflict dynamics, weighted by recency.\n"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
tension_criteria = _build_tension_criteria()
|
| 67 |
|
| 68 |
return f"""**TOP SECRET // FOR OFFICIAL USE ONLY**
|
| 69 |
|
|
|
|
| 85 |
**NEUTRALITY DIRECTIVE:**
|
| 86 |
Maintain absolute neutrality. Present multiple perspectives when conflicting information exists. Acknowledge information gaps.
|
| 87 |
|
| 88 |
+
**TENSION LEVEL ASSESSMENT CRITERIA:**
|
| 89 |
+
{tension_criteria}
|
| 90 |
+
|
| 91 |
+
Use these criteria strictly. Your tension_score MUST align with your tension_level.
|
| 92 |
+
|
| 93 |
**INTELLIGENCE FEEDS:**
|
| 94 |
---
|
| 95 |
{intelligence_data}
|
|
|
|
| 98 |
**ANALYTICAL DIRECTIVES:**
|
| 99 |
1. **Synthesize, Do Not Summarize:** Integrate all data into a coherent assessment.
|
| 100 |
2. **Impersonal Tone:** Use formal, analytical language.
|
| 101 |
+
3. **Ground All Claims:** Every key entity and event must be supported by the intelligence feeds above.
|
| 102 |
4. **Acknowledge Uncertainty:** Indicate confidence levels and information gaps.
|
| 103 |
+
5. **Prioritize Recency:** Weight FRESH sources more heavily than OLDER ones.
|
| 104 |
|
| 105 |
**REQUIRED OUTPUT FORMAT (Strict JSON):**
|
| 106 |
Produce a single, valid JSON object:
|
|
|
|
| 115 |
}}
|
| 116 |
],
|
| 117 |
"reliability_assessment": {{
|
| 118 |
+
"source_credibility": "Overall credibility assessment of available sources.",
|
| 119 |
+
"information_gaps": "What critical information is missing from the intelligence feeds.",
|
| 120 |
"confidence_rating": "HIGH, MEDIUM, or LOW with justification."
|
| 121 |
}},
|
| 122 |
"regional_implications": {{
|
|
|
|
| 125 |
"economic": "Potential economic consequences."
|
| 126 |
}},
|
| 127 |
"tension_level": "LOW|MEDIUM|HIGH|CRITICAL",
|
| 128 |
+
"tension_rationale": "Justification for the assessed tension level using specific evidence from sources.",
|
| 129 |
+
"tension_score": "Integer 1-10 matching to tension level per criteria above.",
|
| 130 |
"tension_trend": "INCREASING|DECREASING|STABLE",
|
| 131 |
"analysis_type": "Military|Diplomatic|Internal Security|Political|Other",
|
| 132 |
+
"key_entities": ["3-5 key actors, locations, or organizations ONLY from the sources above."]
|
| 133 |
+
}}"""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def build_extraction_prompt(items: list[dict]) -> str:
|
| 137 |
+
"""Build prompt for fact extraction stage of chain analysis."""
|
| 138 |
+
source_entries = []
|
| 139 |
+
for i, item in enumerate(items[:20]):
|
| 140 |
+
doc = item.get("document", "")
|
| 141 |
+
source_entries.append(f"SOURCE {i+1}: {doc[:500]}")
|
| 142 |
+
|
| 143 |
+
sources_text = "\n\n".join(source_entries)
|
| 144 |
+
|
| 145 |
+
return f"""Extract key facts from the following intelligence sources.
|
| 146 |
+
|
| 147 |
+
SOURCES:
|
| 148 |
+
{sources_text}
|
| 149 |
+
|
| 150 |
+
OUTPUT FORMAT (JSON):
|
| 151 |
+
{{
|
| 152 |
+
"facts": [
|
| 153 |
+
{{
|
| 154 |
+
"fact": "Clear, objective statement of fact",
|
| 155 |
+
"source_index": 1,
|
| 156 |
+
"type": "military|diplomatic|political|economic|other",
|
| 157 |
+
"date_mentioned": "YYYY-MM-DD or null",
|
| 158 |
+
"entities": ["entity1", "entity2"]
|
| 159 |
+
}}
|
| 160 |
+
],
|
| 161 |
+
"total_sources_analyzed": {len(items)}
|
| 162 |
+
}}"""
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def build_synthesis_prompt(facts: list[dict], historical_context: str = "") -> str:
|
| 166 |
+
"""Build prompt for synthesis stage of chain analysis."""
|
| 167 |
+
facts_text = "\n".join([f"- {f.get('fact', '')}" for f in facts[:30]])
|
| 168 |
+
|
| 169 |
+
return f"""Synthesize the following extracted facts into a coherent assessment.
|
| 170 |
+
|
| 171 |
+
EXTRACTED FACTS:
|
| 172 |
+
{facts_text}
|
| 173 |
+
|
| 174 |
+
{f"HISTORICAL CONTEXT: {historical_context}" if historical_context else ""}
|
| 175 |
+
|
| 176 |
+
TASK:
|
| 177 |
+
1. Identify the 3-5 most significant developments
|
| 178 |
+
2. Assess overall tension level (LOW/MEDIUM/HIGH/CRITICAL)
|
| 179 |
+
3. Identify trends and patterns
|
| 180 |
+
4. Note any contradictions or information gaps
|
| 181 |
+
|
| 182 |
+
OUTPUT FORMAT (JSON):
|
| 183 |
+
{{
|
| 184 |
+
"significant_developments": [
|
| 185 |
+
{{
|
| 186 |
+
"title": "Development title",
|
| 187 |
+
"description": "Synthesized description",
|
| 188 |
+
"supporting_facts": [0, 1, 2]
|
| 189 |
+
}}
|
| 190 |
+
],
|
| 191 |
+
"preliminary_tension": "LOW|MEDIUM|HIGH|CRITICAL",
|
| 192 |
+
"tension_reasoning": "Brief reasoning",
|
| 193 |
+
"trends": ["trend1", "trend2"],
|
| 194 |
+
"contradictions": ["any contradictory information"],
|
| 195 |
+
"gaps": ["information gaps identified"]
|
| 196 |
}}"""
|
src/westernfront/repositories/analysis.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
"""SQLite repository for storing analysis history."""
|
| 2 |
|
| 3 |
import json
|
| 4 |
-
from datetime import
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Optional
|
| 7 |
|
| 8 |
import aiosqlite
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
|
|
|
|
| 12 |
from westernfront.core.models import AnalysisSnapshot, ConflictAnalysis
|
| 13 |
|
| 14 |
|
|
@@ -16,14 +16,8 @@ class AnalysisRepository:
|
|
| 16 |
"""SQLite-based repository for analysis storage and retrieval."""
|
| 17 |
|
| 18 |
def __init__(self, db_path: str = "data/westernfront.db") -> None:
|
| 19 |
-
"""
|
| 20 |
-
Initialize the repository.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
db_path: Path to the SQLite database file.
|
| 24 |
-
"""
|
| 25 |
self._db_path = Path(db_path)
|
| 26 |
-
self._conn:
|
| 27 |
|
| 28 |
async def initialize(self) -> None:
|
| 29 |
"""Initialize the database and create tables."""
|
|
@@ -43,7 +37,7 @@ class AnalysisRepository:
|
|
| 43 |
async def _create_tables(self) -> None:
|
| 44 |
"""Create required database tables."""
|
| 45 |
if not self._conn:
|
| 46 |
-
raise
|
| 47 |
|
| 48 |
await self._conn.execute("""
|
| 49 |
CREATE TABLE IF NOT EXISTS analyses (
|
|
@@ -70,7 +64,7 @@ class AnalysisRepository:
|
|
| 70 |
if not self._conn:
|
| 71 |
return
|
| 72 |
|
| 73 |
-
cutoff = (datetime.now(
|
| 74 |
cursor = await self._conn.execute(
|
| 75 |
"DELETE FROM analyses WHERE generated_at < ?",
|
| 76 |
(cutoff,),
|
|
@@ -81,14 +75,9 @@ class AnalysisRepository:
|
|
| 81 |
logger.info(f"Cleaned up {cursor.rowcount} old analysis records")
|
| 82 |
|
| 83 |
async def save(self, analysis: ConflictAnalysis) -> None:
|
| 84 |
-
"""
|
| 85 |
-
Save an analysis to the database.
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
analysis: The conflict analysis to save.
|
| 89 |
-
"""
|
| 90 |
if not self._conn:
|
| 91 |
-
raise
|
| 92 |
|
| 93 |
await self._conn.execute(
|
| 94 |
"""
|
|
@@ -112,10 +101,10 @@ class AnalysisRepository:
|
|
| 112 |
await self._conn.commit()
|
| 113 |
logger.debug(f"Saved analysis {analysis.analysis_id}")
|
| 114 |
|
| 115 |
-
async def get_latest(self) ->
|
| 116 |
"""Get the most recent analysis."""
|
| 117 |
if not self._conn:
|
| 118 |
-
raise
|
| 119 |
|
| 120 |
cursor = await self._conn.execute(
|
| 121 |
"SELECT full_analysis FROM analyses ORDER BY generated_at DESC LIMIT 1"
|
|
@@ -131,20 +120,11 @@ class AnalysisRepository:
|
|
| 131 |
days: int = 30,
|
| 132 |
limit: int = 100,
|
| 133 |
) -> list[AnalysisSnapshot]:
|
| 134 |
-
"""
|
| 135 |
-
Get historical analysis snapshots.
|
| 136 |
-
|
| 137 |
-
Args:
|
| 138 |
-
days: Number of days to look back.
|
| 139 |
-
limit: Maximum number of records to return.
|
| 140 |
-
|
| 141 |
-
Returns:
|
| 142 |
-
List of analysis snapshots.
|
| 143 |
-
"""
|
| 144 |
if not self._conn:
|
| 145 |
-
raise
|
| 146 |
|
| 147 |
-
cutoff = (datetime.now(
|
| 148 |
|
| 149 |
cursor = await self._conn.execute(
|
| 150 |
"""
|
|
@@ -178,19 +158,11 @@ class AnalysisRepository:
|
|
| 178 |
return snapshots
|
| 179 |
|
| 180 |
async def get_tension_history(self, days: int = 30) -> list[dict]:
|
| 181 |
-
"""
|
| 182 |
-
Get tension score history for graphing.
|
| 183 |
-
|
| 184 |
-
Args:
|
| 185 |
-
days: Number of days to look back.
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
List of date/score pairs.
|
| 189 |
-
"""
|
| 190 |
if not self._conn:
|
| 191 |
-
raise
|
| 192 |
|
| 193 |
-
cutoff = (datetime.now(
|
| 194 |
|
| 195 |
cursor = await self._conn.execute(
|
| 196 |
"""
|
|
|
|
| 1 |
"""SQLite repository for storing analysis history."""
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
from datetime import UTC, datetime, timedelta
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
|
| 7 |
import aiosqlite
|
| 8 |
from loguru import logger
|
| 9 |
|
| 10 |
from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
|
| 11 |
+
from westernfront.core.exceptions import ServiceNotInitializedError
|
| 12 |
from westernfront.core.models import AnalysisSnapshot, ConflictAnalysis
|
| 13 |
|
| 14 |
|
|
|
|
| 16 |
"""SQLite-based repository for analysis storage and retrieval."""
|
| 17 |
|
| 18 |
def __init__(self, db_path: str = "data/westernfront.db") -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
self._db_path = Path(db_path)
|
| 20 |
+
self._conn: aiosqlite.Connection | None = None
|
| 21 |
|
| 22 |
async def initialize(self) -> None:
|
| 23 |
"""Initialize the database and create tables."""
|
|
|
|
| 37 |
async def _create_tables(self) -> None:
|
| 38 |
"""Create required database tables."""
|
| 39 |
if not self._conn:
|
| 40 |
+
raise ServiceNotInitializedError("Analysis repository not initialized")
|
| 41 |
|
| 42 |
await self._conn.execute("""
|
| 43 |
CREATE TABLE IF NOT EXISTS analyses (
|
|
|
|
| 64 |
if not self._conn:
|
| 65 |
return
|
| 66 |
|
| 67 |
+
cutoff = (datetime.now(UTC) - timedelta(days=retention_days)).isoformat()
|
| 68 |
cursor = await self._conn.execute(
|
| 69 |
"DELETE FROM analyses WHERE generated_at < ?",
|
| 70 |
(cutoff,),
|
|
|
|
| 75 |
logger.info(f"Cleaned up {cursor.rowcount} old analysis records")
|
| 76 |
|
| 77 |
async def save(self, analysis: ConflictAnalysis) -> None:
|
| 78 |
+
"""Save an analysis to the database."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
if not self._conn:
|
| 80 |
+
raise ServiceNotInitializedError("Analysis repository not initialized")
|
| 81 |
|
| 82 |
await self._conn.execute(
|
| 83 |
"""
|
|
|
|
| 101 |
await self._conn.commit()
|
| 102 |
logger.debug(f"Saved analysis {analysis.analysis_id}")
|
| 103 |
|
| 104 |
+
async def get_latest(self) -> ConflictAnalysis | None:
|
| 105 |
"""Get the most recent analysis."""
|
| 106 |
if not self._conn:
|
| 107 |
+
raise ServiceNotInitializedError("Analysis repository not initialized")
|
| 108 |
|
| 109 |
cursor = await self._conn.execute(
|
| 110 |
"SELECT full_analysis FROM analyses ORDER BY generated_at DESC LIMIT 1"
|
|
|
|
| 120 |
days: int = 30,
|
| 121 |
limit: int = 100,
|
| 122 |
) -> list[AnalysisSnapshot]:
|
| 123 |
+
"""Get historical analysis snapshots."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if not self._conn:
|
| 125 |
+
raise ServiceNotInitializedError("Analysis repository not initialized")
|
| 126 |
|
| 127 |
+
cutoff = (datetime.now(UTC) - timedelta(days=days)).isoformat()
|
| 128 |
|
| 129 |
cursor = await self._conn.execute(
|
| 130 |
"""
|
|
|
|
| 158 |
return snapshots
|
| 159 |
|
| 160 |
async def get_tension_history(self, days: int = 30) -> list[dict]:
|
| 161 |
+
"""Get tension score history for graphing."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
if not self._conn:
|
| 163 |
+
raise ServiceNotInitializedError("Analysis repository not initialized")
|
| 164 |
|
| 165 |
+
cutoff = (datetime.now(UTC) - timedelta(days=days)).isoformat()
|
| 166 |
|
| 167 |
cursor = await self._conn.execute(
|
| 168 |
"""
|
src/westernfront/repositories/vectors.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
"""ChromaDB vector repository for semantic search."""
|
| 2 |
|
| 3 |
-
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import chromadb
|
| 8 |
from chromadb.config import Settings as ChromaSettings
|
| 9 |
from loguru import logger
|
| 10 |
|
|
|
|
| 11 |
from westernfront.core.models import NewsItem
|
| 12 |
|
| 13 |
if TYPE_CHECKING:
|
|
@@ -21,16 +22,9 @@ class VectorRepository:
|
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
self,
|
| 24 |
-
persist_dir:
|
| 25 |
-
embedding_service:
|
| 26 |
) -> None:
|
| 27 |
-
"""
|
| 28 |
-
Initialize the vector repository.
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
persist_dir: Directory for ChromaDB persistence (defaults to server/data/chroma).
|
| 32 |
-
embedding_service: Service for generating embeddings.
|
| 33 |
-
"""
|
| 34 |
if persist_dir:
|
| 35 |
self._persist_dir = Path(persist_dir)
|
| 36 |
else:
|
|
@@ -39,54 +33,36 @@ class VectorRepository:
|
|
| 39 |
self._persist_dir.mkdir(parents=True, exist_ok=True)
|
| 40 |
|
| 41 |
self._embeddings = embedding_service
|
| 42 |
-
self._client:
|
| 43 |
self._collection = None
|
| 44 |
|
| 45 |
def initialize(self) -> bool:
|
| 46 |
-
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
Returns:
|
| 50 |
-
True if initialization was successful.
|
| 51 |
-
"""
|
| 52 |
-
try:
|
| 53 |
-
logger.info(f"Initializing ChromaDB at {self._persist_dir}")
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
self._collection = self._client.get_or_create_collection(
|
| 61 |
-
name=self.COLLECTION_NAME,
|
| 62 |
-
metadata={"hnsw:space": "cosine"},
|
| 63 |
-
)
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
@property
|
| 74 |
def is_initialized(self) -> bool:
|
| 75 |
"""Check if the repository is initialized."""
|
| 76 |
return self._collection is not None
|
| 77 |
|
| 78 |
-
def
|
| 79 |
-
"""
|
| 80 |
-
Add news items to the vector store.
|
| 81 |
-
|
| 82 |
-
Args:
|
| 83 |
-
items: News items to add.
|
| 84 |
-
|
| 85 |
-
Returns:
|
| 86 |
-
Number of items added.
|
| 87 |
-
"""
|
| 88 |
if not self._collection or not self._embeddings:
|
| 89 |
-
raise
|
| 90 |
|
| 91 |
existing_ids = set(self._collection.get(ids=[item.id for item in items])["ids"])
|
| 92 |
new_items = [item for item in items if item.id not in existing_ids]
|
|
@@ -117,25 +93,23 @@ class VectorRepository:
|
|
| 117 |
logger.info(f"Added {len(new_items)} items to vector store")
|
| 118 |
return len(new_items)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def query_similar(
|
| 121 |
self,
|
| 122 |
query: str,
|
| 123 |
n_results: int = 20,
|
| 124 |
min_score: float = 0.3,
|
| 125 |
) -> list[dict]:
|
| 126 |
-
"""
|
| 127 |
-
Query for similar items.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
query: Query text.
|
| 131 |
-
n_results: Maximum number of results.
|
| 132 |
-
min_score: Minimum similarity score (0-1).
|
| 133 |
-
|
| 134 |
-
Returns:
|
| 135 |
-
List of similar items with metadata.
|
| 136 |
-
"""
|
| 137 |
if not self._collection or not self._embeddings:
|
| 138 |
-
raise
|
| 139 |
|
| 140 |
query_embedding = self._embeddings.embed(query)
|
| 141 |
|
|
@@ -164,16 +138,7 @@ class VectorRepository:
|
|
| 164 |
topics: list[str],
|
| 165 |
n_per_topic: int = 10,
|
| 166 |
) -> list[dict]:
|
| 167 |
-
"""
|
| 168 |
-
Query for items related to multiple topics.
|
| 169 |
-
|
| 170 |
-
Args:
|
| 171 |
-
topics: List of topic queries.
|
| 172 |
-
n_per_topic: Results per topic.
|
| 173 |
-
|
| 174 |
-
Returns:
|
| 175 |
-
Deduplicated list of relevant items.
|
| 176 |
-
"""
|
| 177 |
seen_ids: set[str] = set()
|
| 178 |
all_items: list[dict] = []
|
| 179 |
|
|
|
|
| 1 |
"""ChromaDB vector repository for semantic search."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import chromadb
|
| 8 |
from chromadb.config import Settings as ChromaSettings
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
+
from westernfront.core.exceptions import VectorStoreError
|
| 12 |
from westernfront.core.models import NewsItem
|
| 13 |
|
| 14 |
if TYPE_CHECKING:
|
|
|
|
| 22 |
|
| 23 |
def __init__(
|
| 24 |
self,
|
| 25 |
+
persist_dir: str | None = None,
|
| 26 |
+
embedding_service: "EmbeddingService | None" = None,
|
| 27 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if persist_dir:
|
| 29 |
self._persist_dir = Path(persist_dir)
|
| 30 |
else:
|
|
|
|
| 33 |
self._persist_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
|
| 35 |
self._embeddings = embedding_service
|
| 36 |
+
self._client: chromadb.PersistentClient | None = None
|
| 37 |
self._collection = None
|
| 38 |
|
| 39 |
def initialize(self) -> bool:
|
| 40 |
+
"""Initialize ChromaDB client and collection."""
|
| 41 |
+
logger.info(f"Initializing ChromaDB at {self._persist_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
self._client = chromadb.PersistentClient(
|
| 44 |
+
path=str(self._persist_dir),
|
| 45 |
+
settings=ChromaSettings(anonymized_telemetry=False),
|
| 46 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
self._collection = self._client.get_or_create_collection(
|
| 49 |
+
name=self.COLLECTION_NAME,
|
| 50 |
+
metadata={"hnsw:space": "cosine"},
|
| 51 |
+
)
|
| 52 |
|
| 53 |
+
count = self._collection.count()
|
| 54 |
+
logger.info(f"ChromaDB initialized with {count} existing documents")
|
| 55 |
+
return True
|
| 56 |
|
| 57 |
@property
|
| 58 |
def is_initialized(self) -> bool:
|
| 59 |
"""Check if the repository is initialized."""
|
| 60 |
return self._collection is not None
|
| 61 |
|
| 62 |
+
def _add_items_sync(self, items: list[NewsItem]) -> int:
|
| 63 |
+
"""Synchronous implementation of add_items."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if not self._collection or not self._embeddings:
|
| 65 |
+
raise VectorStoreError("Vector repository not initialized")
|
| 66 |
|
| 67 |
existing_ids = set(self._collection.get(ids=[item.id for item in items])["ids"])
|
| 68 |
new_items = [item for item in items if item.id not in existing_ids]
|
|
|
|
| 93 |
logger.info(f"Added {len(new_items)} items to vector store")
|
| 94 |
return len(new_items)
|
| 95 |
|
| 96 |
+
def add_items(self, items: list[NewsItem]) -> int:
|
| 97 |
+
"""Add news items to the vector store (blocking)."""
|
| 98 |
+
return self._add_items_sync(items)
|
| 99 |
+
|
| 100 |
+
async def add_items_async(self, items: list[NewsItem]) -> int:
|
| 101 |
+
"""Add news items to the vector store (non-blocking)."""
|
| 102 |
+
return await asyncio.to_thread(self._add_items_sync, items)
|
| 103 |
+
|
| 104 |
def query_similar(
|
| 105 |
self,
|
| 106 |
query: str,
|
| 107 |
n_results: int = 20,
|
| 108 |
min_score: float = 0.3,
|
| 109 |
) -> list[dict]:
|
| 110 |
+
"""Query for similar items."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
if not self._collection or not self._embeddings:
|
| 112 |
+
raise VectorStoreError("Vector repository not initialized")
|
| 113 |
|
| 114 |
query_embedding = self._embeddings.embed(query)
|
| 115 |
|
|
|
|
| 138 |
topics: list[str],
|
| 139 |
n_per_topic: int = 10,
|
| 140 |
) -> list[dict]:
|
| 141 |
+
"""Query for items related to multiple topics."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
seen_ids: set[str] = set()
|
| 143 |
all_items: list[dict] = []
|
| 144 |
|
src/westernfront/services/__init__.py
CHANGED
|
@@ -2,16 +2,28 @@
|
|
| 2 |
|
| 3 |
from westernfront.services.analysis import AnalysisService
|
| 4 |
from westernfront.services.cache import CacheService
|
|
|
|
| 5 |
from westernfront.services.embeddings import EmbeddingService
|
|
|
|
| 6 |
from westernfront.services.newsapi import NewsApiService
|
|
|
|
| 7 |
from westernfront.services.reddit import RedditService
|
|
|
|
| 8 |
from westernfront.services.rss import RssService
|
|
|
|
|
|
|
| 9 |
|
| 10 |
__all__ = [
|
|
|
|
| 11 |
"AnalysisService",
|
|
|
|
| 12 |
"CacheService",
|
|
|
|
| 13 |
"EmbeddingService",
|
|
|
|
| 14 |
"NewsApiService",
|
| 15 |
"RedditService",
|
|
|
|
|
|
|
| 16 |
"RssService",
|
| 17 |
]
|
|
|
|
| 2 |
|
| 3 |
from westernfront.services.analysis import AnalysisService
|
| 4 |
from westernfront.services.cache import CacheService
|
| 5 |
+
from westernfront.services.chain_analysis import ChainAnalysisService
|
| 6 |
from westernfront.services.embeddings import EmbeddingService
|
| 7 |
+
from westernfront.services.http import HttpService
|
| 8 |
from westernfront.services.newsapi import NewsApiService
|
| 9 |
+
from westernfront.services.parsing import ResponseParser
|
| 10 |
from westernfront.services.reddit import RedditService
|
| 11 |
+
from westernfront.services.retrieval import RetrievalService
|
| 12 |
from westernfront.services.rss import RssService
|
| 13 |
+
from westernfront.services.scheduler import AnalysisScheduler
|
| 14 |
+
from westernfront.services.validation import AnalysisValidator
|
| 15 |
|
| 16 |
__all__ = [
|
| 17 |
+
"AnalysisScheduler",
|
| 18 |
"AnalysisService",
|
| 19 |
+
"AnalysisValidator",
|
| 20 |
"CacheService",
|
| 21 |
+
"ChainAnalysisService",
|
| 22 |
"EmbeddingService",
|
| 23 |
+
"HttpService",
|
| 24 |
"NewsApiService",
|
| 25 |
"RedditService",
|
| 26 |
+
"ResponseParser",
|
| 27 |
+
"RetrievalService",
|
| 28 |
"RssService",
|
| 29 |
]
|
src/westernfront/services/analysis.py
CHANGED
|
@@ -1,362 +1,239 @@
|
|
| 1 |
-
"""AI-powered conflict analysis service with RAG enhancement."""
|
| 2 |
-
|
| 3 |
-
import asyncio
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
-
from
|
| 13 |
-
|
| 14 |
-
from westernfront.
|
| 15 |
-
from westernfront.core.
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
from westernfront.
|
| 24 |
-
from westernfront.
|
| 25 |
-
from westernfront.
|
| 26 |
-
from westernfront.services.
|
| 27 |
-
from westernfront.services.reddit import RedditService
|
| 28 |
-
from westernfront.services.
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
"
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
self.
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
self.
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
logger.
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
"
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
return TensionTrend.STABLE
|
| 241 |
-
|
| 242 |
-
def _parse_analysis_type(self, value: str) -> AnalysisType:
|
| 243 |
-
"""Parse analysis type from string."""
|
| 244 |
-
value = value.upper()
|
| 245 |
-
if "MILITARY" in value:
|
| 246 |
-
return AnalysisType.MILITARY
|
| 247 |
-
if "DIPLOMATIC" in value:
|
| 248 |
-
return AnalysisType.DIPLOMATIC
|
| 249 |
-
if "INTERNAL" in value:
|
| 250 |
-
return AnalysisType.INTERNAL_SECURITY
|
| 251 |
-
if "POLITICAL" in value:
|
| 252 |
-
return AnalysisType.POLITICAL
|
| 253 |
-
return AnalysisType.OTHER
|
| 254 |
-
|
| 255 |
-
def _parse_key_developments(self, data: list[dict]) -> list[KeyDevelopment]:
|
| 256 |
-
"""Parse key developments from response data."""
|
| 257 |
-
developments = []
|
| 258 |
-
for item in data:
|
| 259 |
-
if not isinstance(item, dict):
|
| 260 |
-
continue
|
| 261 |
-
developments.append(
|
| 262 |
-
KeyDevelopment(
|
| 263 |
-
title=item.get("title", "Unnamed"),
|
| 264 |
-
description=item.get("description", "No description"),
|
| 265 |
-
sources=item.get("sources", []),
|
| 266 |
-
timestamp=datetime.now(),
|
| 267 |
-
)
|
| 268 |
-
)
|
| 269 |
-
return developments
|
| 270 |
-
|
| 271 |
-
def _count_sources(self, items: list[dict]) -> dict[str, int]:
|
| 272 |
-
"""Count items by source type from retrieved results."""
|
| 273 |
-
counts: dict[str, int] = {}
|
| 274 |
-
for item in items:
|
| 275 |
-
meta = item.get("metadata", {})
|
| 276 |
-
key = meta.get("source_type", "unknown")
|
| 277 |
-
counts[key] = counts.get(key, 0) + 1
|
| 278 |
-
return counts
|
| 279 |
-
|
| 280 |
-
async def generate_analysis(self, trigger: str = "scheduled") -> Optional[ConflictAnalysis]:
|
| 281 |
-
"""
|
| 282 |
-
Generate a new conflict analysis using RAG.
|
| 283 |
-
|
| 284 |
-
Flow:
|
| 285 |
-
1. Ingest ALL news into vector store
|
| 286 |
-
2. Retrieve most relevant items via semantic search
|
| 287 |
-
3. Send retrieved items to Gemini for analysis
|
| 288 |
-
|
| 289 |
-
Args:
|
| 290 |
-
trigger: What triggered this analysis.
|
| 291 |
-
|
| 292 |
-
Returns:
|
| 293 |
-
The generated analysis or None.
|
| 294 |
-
"""
|
| 295 |
-
await self._ingest_all_news()
|
| 296 |
-
|
| 297 |
-
retrieved_items = self._retrieve_relevant_items(
|
| 298 |
-
max_items=self._settings.max_posts_for_analysis
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
if len(retrieved_items) < self._settings.min_posts_for_analysis:
|
| 302 |
-
logger.warning(f"Insufficient data: {len(retrieved_items)} items")
|
| 303 |
-
return None
|
| 304 |
-
|
| 305 |
-
prompt = build_rag_prompt(retrieved_items, self._vectors.get_count())
|
| 306 |
-
|
| 307 |
-
try:
|
| 308 |
-
data = await self._call_gemini(prompt)
|
| 309 |
-
except RetryError as e:
|
| 310 |
-
logger.error(f"Gemini failed after retries: {e}")
|
| 311 |
-
return None
|
| 312 |
-
|
| 313 |
-
if not data:
|
| 314 |
-
return None
|
| 315 |
-
|
| 316 |
-
tension_score = 1
|
| 317 |
-
raw_score = data.get("tension_score")
|
| 318 |
-
if isinstance(raw_score, (int, float)):
|
| 319 |
-
tension_score = max(1, min(10, int(raw_score)))
|
| 320 |
-
elif isinstance(raw_score, str) and raw_score.isdigit():
|
| 321 |
-
tension_score = max(1, min(10, int(raw_score)))
|
| 322 |
-
|
| 323 |
-
key_entities = data.get("key_entities", [])
|
| 324 |
-
if isinstance(key_entities, str):
|
| 325 |
-
key_entities = [e.strip() for e in key_entities.split(",") if e.strip()]
|
| 326 |
-
|
| 327 |
-
reliability_data = data.get("reliability_assessment", {})
|
| 328 |
-
regional_data = data.get("regional_implications", {})
|
| 329 |
-
|
| 330 |
-
analysis = ConflictAnalysis(
|
| 331 |
-
analysis_id=str(uuid.uuid4()),
|
| 332 |
-
generated_at=datetime.now(),
|
| 333 |
-
latest_status=data.get("latest_status", "No status available"),
|
| 334 |
-
situation_summary=data.get("situation_summary", "No summary available"),
|
| 335 |
-
key_developments=self._parse_key_developments(data.get("key_developments", [])),
|
| 336 |
-
reliability_assessment=ReliabilityAssessment(
|
| 337 |
-
source_credibility=reliability_data.get("source_credibility", "Unknown"),
|
| 338 |
-
information_gaps=reliability_data.get("information_gaps", "Unknown"),
|
| 339 |
-
confidence_rating=reliability_data.get("confidence_rating", "LOW"),
|
| 340 |
-
),
|
| 341 |
-
regional_implications=RegionalImplications(
|
| 342 |
-
security=regional_data.get("security", "No assessment"),
|
| 343 |
-
diplomatic=regional_data.get("diplomatic", "No assessment"),
|
| 344 |
-
economic=regional_data.get("economic", "No assessment"),
|
| 345 |
-
),
|
| 346 |
-
tension_level=self._parse_tension_level(data.get("tension_level", "LOW")),
|
| 347 |
-
tension_rationale=data.get("tension_rationale", "No rationale"),
|
| 348 |
-
tension_score=tension_score,
|
| 349 |
-
tension_trend=self._parse_tension_trend(data.get("tension_trend", "STABLE")),
|
| 350 |
-
analysis_type=self._parse_analysis_type(data.get("analysis_type", "OTHER")),
|
| 351 |
-
key_entities=key_entities if isinstance(key_entities, list) else [],
|
| 352 |
-
source_count=len(retrieved_items),
|
| 353 |
-
source_breakdown=self._count_sources(retrieved_items),
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
await self._repository.save(analysis)
|
| 357 |
-
logger.info(f"Generated RAG-enhanced analysis {analysis.analysis_id}")
|
| 358 |
-
return analysis
|
| 359 |
-
|
| 360 |
-
async def get_latest(self) -> Optional[ConflictAnalysis]:
|
| 361 |
-
"""Get the latest analysis from the repository."""
|
| 362 |
-
return await self._repository.get_latest()
|
|
|
|
| 1 |
+
"""AI-powered conflict analysis service with RAG enhancement and quality improvements."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import uuid
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import google.generativeai as genai
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
|
| 11 |
+
|
| 12 |
+
from westernfront.config import Settings
|
| 13 |
+
from westernfront.core.constants import RAG_QUERY_TOPICS, SEARCH_KEYWORDS
|
| 14 |
+
from westernfront.core.exceptions import ServiceNotInitializedError
|
| 15 |
+
from westernfront.core.models import (
|
| 16 |
+
ConflictAnalysis,
|
| 17 |
+
NewsItem,
|
| 18 |
+
RegionalImplications,
|
| 19 |
+
ReliabilityAssessment,
|
| 20 |
+
)
|
| 21 |
+
from westernfront.prompts.analysis import build_rag_prompt
|
| 22 |
+
from westernfront.repositories.analysis import AnalysisRepository
|
| 23 |
+
from westernfront.repositories.vectors import VectorRepository
|
| 24 |
+
from westernfront.services.chain_analysis import ChainAnalysisService
|
| 25 |
+
from westernfront.services.newsapi import NewsApiService
|
| 26 |
+
from westernfront.services.parsing import ResponseParser
|
| 27 |
+
from westernfront.services.reddit import RedditService
|
| 28 |
+
from westernfront.services.retrieval import RetrievalService
|
| 29 |
+
from westernfront.services.rss import RssService
|
| 30 |
+
from westernfront.services.scheduler import AnalysisScheduler
|
| 31 |
+
from westernfront.services.validation import AnalysisValidator
|
| 32 |
+
from westernfront.utils import extract_json_from_response
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AnalysisService:
|
| 36 |
+
"""Service for generating AI-powered conflict analysis with RAG and quality improvements."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
gemini_api_key: str,
|
| 41 |
+
reddit: RedditService,
|
| 42 |
+
rss: RssService,
|
| 43 |
+
newsapi: NewsApiService,
|
| 44 |
+
repository: AnalysisRepository,
|
| 45 |
+
vectors: VectorRepository,
|
| 46 |
+
settings: Settings,
|
| 47 |
+
) -> None:
|
| 48 |
+
self._api_key = gemini_api_key
|
| 49 |
+
self._reddit = reddit
|
| 50 |
+
self._rss = rss
|
| 51 |
+
self._newsapi = newsapi
|
| 52 |
+
self._repository = repository
|
| 53 |
+
self._vectors = vectors
|
| 54 |
+
self._settings = settings
|
| 55 |
+
self._model: genai.GenerativeModel | None = None
|
| 56 |
+
self._retrieval: RetrievalService | None = None
|
| 57 |
+
self._chain: ChainAnalysisService | None = None
|
| 58 |
+
self._parser = ResponseParser()
|
| 59 |
+
self._validator = AnalysisValidator()
|
| 60 |
+
self._scheduler: AnalysisScheduler | None = None
|
| 61 |
+
|
| 62 |
+
async def initialize(self) -> None:
|
| 63 |
+
"""Initialize the Gemini model and start background updates."""
|
| 64 |
+
logger.info("Initializing Gemini AI")
|
| 65 |
+
genai.configure(api_key=self._api_key)
|
| 66 |
+
self._model = genai.GenerativeModel(
|
| 67 |
+
"gemma-3-27b-it",
|
| 68 |
+
generation_config={
|
| 69 |
+
"temperature": 0.2,
|
| 70 |
+
"top_p": 0.95,
|
| 71 |
+
"top_k": 40,
|
| 72 |
+
},
|
| 73 |
+
)
|
| 74 |
+
logger.info("Gemini AI initialized")
|
| 75 |
+
|
| 76 |
+
self._retrieval = RetrievalService(self._vectors)
|
| 77 |
+
self._chain = ChainAnalysisService(self._model)
|
| 78 |
+
logger.info("Quality services initialized (retrieval, chain analysis)")
|
| 79 |
+
|
| 80 |
+
self._scheduler = AnalysisScheduler(self, self._settings.update_interval_minutes)
|
| 81 |
+
self._scheduler.start()
|
| 82 |
+
|
| 83 |
+
async def close(self) -> None:
|
| 84 |
+
"""Clean up resources."""
|
| 85 |
+
if self._scheduler:
|
| 86 |
+
await self._scheduler.stop()
|
| 87 |
+
logger.info("Analysis service closed")
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def is_initialized(self) -> bool:
|
| 91 |
+
"""Check if the service is initialized."""
|
| 92 |
+
return self._model is not None
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def keywords(self) -> list[str]:
|
| 96 |
+
"""Get the search keywords used for analysis."""
|
| 97 |
+
return SEARCH_KEYWORDS
|
| 98 |
+
|
| 99 |
+
async def _ingest_all_news(self) -> int:
|
| 100 |
+
"""Ingest all news from all sources into vector store in parallel."""
|
| 101 |
+
days = self._settings.analysis_days_back
|
| 102 |
+
|
| 103 |
+
results = await asyncio.gather(
|
| 104 |
+
self._reddit.get_all_posts(days),
|
| 105 |
+
self._rss.get_all_articles(days),
|
| 106 |
+
self._newsapi.get_related_articles(days_back=days),
|
| 107 |
+
return_exceptions=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
all_items: list[NewsItem] = []
|
| 111 |
+
source_names = ["Reddit", "RSS", "NewsAPI"]
|
| 112 |
+
|
| 113 |
+
for i, result in enumerate(results):
|
| 114 |
+
if isinstance(result, Exception):
|
| 115 |
+
logger.error(f"Error fetching from {source_names[i]}: {result}")
|
| 116 |
+
else:
|
| 117 |
+
all_items.extend(result)
|
| 118 |
+
|
| 119 |
+
logger.info(f"Ingested {len(all_items)} total news items from all sources")
|
| 120 |
+
|
| 121 |
+
if not self._vectors.is_initialized:
|
| 122 |
+
logger.warning("Vector store not initialized, skipping ingestion")
|
| 123 |
+
return 0
|
| 124 |
+
|
| 125 |
+
stored = await self._vectors.add_items_async(all_items)
|
| 126 |
+
total_count = self._vectors.get_count()
|
| 127 |
+
logger.info(f"Stored {stored} new items. Total in vector store: {total_count}")
|
| 128 |
+
return stored
|
| 129 |
+
|
| 130 |
+
def _retrieve_relevant_items(self, max_items: int = 40) -> list[dict]:
|
| 131 |
+
"""Retrieve relevant items with quality weighting."""
|
| 132 |
+
if self._retrieval:
|
| 133 |
+
return self._retrieval.retrieve_with_quality(
|
| 134 |
+
RAG_QUERY_TOPICS,
|
| 135 |
+
max_items=max_items,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if not self._vectors.is_initialized:
|
| 139 |
+
logger.warning("Vector store not initialized")
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
return self._vectors.query_by_topics(
|
| 143 |
+
RAG_QUERY_TOPICS,
|
| 144 |
+
n_per_topic=max_items // len(RAG_QUERY_TOPICS) + 1,
|
| 145 |
+
)[:max_items]
|
| 146 |
+
|
| 147 |
+
@retry(wait=wait_exponential(min=2, max=60), stop=stop_after_attempt(3))
|
| 148 |
+
async def _call_gemini(self, prompt: str) -> dict[str, Any] | None:
|
| 149 |
+
"""Call Gemini API with retry logic."""
|
| 150 |
+
if not self._model:
|
| 151 |
+
raise ServiceNotInitializedError("Gemini model not initialized")
|
| 152 |
+
|
| 153 |
+
logger.info("Calling Gemini API")
|
| 154 |
+
response = await self._model.generate_content_async(prompt)
|
| 155 |
+
result = extract_json_from_response(response.text)
|
| 156 |
+
if not result:
|
| 157 |
+
raise ValueError("Could not parse JSON from response")
|
| 158 |
+
return result
|
| 159 |
+
|
| 160 |
+
async def generate_analysis(self, use_chain: bool = True) -> ConflictAnalysis | None:
|
| 161 |
+
"""Generate a new conflict analysis using RAG with quality improvements."""
|
| 162 |
+
await self._ingest_all_news()
|
| 163 |
+
|
| 164 |
+
retrieved_items = self._retrieve_relevant_items(
|
| 165 |
+
max_items=self._settings.max_posts_for_analysis
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if len(retrieved_items) < self._settings.min_posts_for_analysis:
|
| 169 |
+
logger.warning(f"Insufficient data: {len(retrieved_items)} items")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
total_count = self._vectors.get_count()
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
if use_chain and self._chain and len(retrieved_items) >= 10:
|
| 176 |
+
data = await self._chain.run_chain_analysis(
|
| 177 |
+
retrieved_items,
|
| 178 |
+
total_in_memory=total_count,
|
| 179 |
+
use_full_chain=True,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
prompt = build_rag_prompt(retrieved_items, total_count)
|
| 183 |
+
data = await self._call_gemini(prompt)
|
| 184 |
+
except RetryError as e:
|
| 185 |
+
logger.error(f"AI analysis failed after retries: {e}")
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
if not data:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
source_texts = [item.get("document", "") for item in retrieved_items]
|
| 192 |
+
is_valid, issues = self._validator.validate_analysis(data, source_texts)
|
| 193 |
+
if not is_valid:
|
| 194 |
+
logger.warning(f"Validation issues: {issues}")
|
| 195 |
+
|
| 196 |
+
reliability_data = data.get("reliability_assessment", {})
|
| 197 |
+
regional_data = data.get("regional_implications", {})
|
| 198 |
+
|
| 199 |
+
analysis = ConflictAnalysis(
|
| 200 |
+
analysis_id=str(uuid.uuid4()),
|
| 201 |
+
generated_at=datetime.now(),
|
| 202 |
+
latest_status=data.get("latest_status", "No status available"),
|
| 203 |
+
situation_summary=data.get("situation_summary", "No summary available"),
|
| 204 |
+
key_developments=self._parser.parse_key_developments(
|
| 205 |
+
data.get("key_developments", [])
|
| 206 |
+
),
|
| 207 |
+
reliability_assessment=ReliabilityAssessment(
|
| 208 |
+
source_credibility=reliability_data.get("source_credibility", "Unknown"),
|
| 209 |
+
information_gaps=reliability_data.get("information_gaps", "Unknown"),
|
| 210 |
+
confidence_rating=reliability_data.get("confidence_rating", "LOW"),
|
| 211 |
+
),
|
| 212 |
+
regional_implications=RegionalImplications(
|
| 213 |
+
security=regional_data.get("security", "No assessment"),
|
| 214 |
+
diplomatic=regional_data.get("diplomatic", "No assessment"),
|
| 215 |
+
economic=regional_data.get("economic", "No assessment"),
|
| 216 |
+
),
|
| 217 |
+
tension_level=self._parser.parse_tension_level(
|
| 218 |
+
data.get("tension_level", "LOW")
|
| 219 |
+
),
|
| 220 |
+
tension_rationale=data.get("tension_rationale", "No rationale"),
|
| 221 |
+
tension_score=self._parser.parse_tension_score(data.get("tension_score")),
|
| 222 |
+
tension_trend=self._parser.parse_tension_trend(
|
| 223 |
+
data.get("tension_trend", "STABLE")
|
| 224 |
+
),
|
| 225 |
+
analysis_type=self._parser.parse_analysis_type(
|
| 226 |
+
data.get("analysis_type", "OTHER")
|
| 227 |
+
),
|
| 228 |
+
key_entities=self._parser.parse_key_entities(data.get("key_entities")),
|
| 229 |
+
source_count=len(retrieved_items),
|
| 230 |
+
source_breakdown=self._parser.count_sources(retrieved_items),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
await self._repository.save(analysis)
|
| 234 |
+
logger.info(f"Generated quality-enhanced analysis {analysis.analysis_id}")
|
| 235 |
+
return analysis
|
| 236 |
+
|
| 237 |
+
async def get_latest(self) -> ConflictAnalysis | None:
|
| 238 |
+
"""Get the latest analysis from the repository."""
|
| 239 |
+
return await self._repository.get_latest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/westernfront/services/cache.py
CHANGED
|
@@ -1,73 +1,90 @@
|
|
| 1 |
-
"""Thread-safe async cache service."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
-
from typing import Any
|
| 5 |
|
| 6 |
from cachetools import TTLCache
|
| 7 |
|
| 8 |
|
| 9 |
-
class
|
| 10 |
-
"""Async-
|
| 11 |
|
| 12 |
-
def __init__(self
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
async def
|
| 24 |
-
"""
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
return self._cache.get(key)
|
|
|
|
|
|
|
| 35 |
|
| 36 |
async def set(self, key: str, value: Any) -> None:
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
key: The cache key.
|
| 42 |
-
value: The value to cache.
|
| 43 |
-
"""
|
| 44 |
-
async with self._lock:
|
| 45 |
self._cache[key] = value
|
|
|
|
|
|
|
| 46 |
|
| 47 |
async def delete(self, key: str) -> None:
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
key: The cache key.
|
| 53 |
-
"""
|
| 54 |
-
async with self._lock:
|
| 55 |
self._cache.pop(key, None)
|
|
|
|
|
|
|
| 56 |
|
| 57 |
async def clear(self) -> None:
|
| 58 |
"""Clear all entries from the cache."""
|
| 59 |
-
|
|
|
|
| 60 |
self._cache.clear()
|
|
|
|
|
|
|
| 61 |
|
| 62 |
async def has(self, key: str) -> bool:
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
key: The cache key.
|
| 68 |
-
|
| 69 |
-
Returns:
|
| 70 |
-
True if the key exists, False otherwise.
|
| 71 |
-
"""
|
| 72 |
-
async with self._lock:
|
| 73 |
return key in self._cache
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thread-safe async cache service with read-write lock pattern."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
from typing import Any
|
| 5 |
|
| 6 |
from cachetools import TTLCache
|
| 7 |
|
| 8 |
|
| 9 |
+
class ReadWriteLock:
|
| 10 |
+
"""Async read-write lock allowing concurrent reads but exclusive writes."""
|
| 11 |
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
self._readers = 0
|
| 14 |
+
self._writer = False
|
| 15 |
+
self._condition = asyncio.Condition()
|
| 16 |
|
| 17 |
+
async def acquire_read(self) -> None:
|
| 18 |
+
"""Acquire read lock (allows concurrent readers)."""
|
| 19 |
+
async with self._condition:
|
| 20 |
+
while self._writer:
|
| 21 |
+
await self._condition.wait()
|
| 22 |
+
self._readers += 1
|
| 23 |
+
|
| 24 |
+
async def release_read(self) -> None:
|
| 25 |
+
"""Release read lock."""
|
| 26 |
+
async with self._condition:
|
| 27 |
+
self._readers -= 1
|
| 28 |
+
if self._readers == 0:
|
| 29 |
+
self._condition.notify_all()
|
| 30 |
|
| 31 |
+
async def acquire_write(self) -> None:
|
| 32 |
+
"""Acquire exclusive write lock."""
|
| 33 |
+
async with self._condition:
|
| 34 |
+
while self._writer or self._readers > 0:
|
| 35 |
+
await self._condition.wait()
|
| 36 |
+
self._writer = True
|
| 37 |
|
| 38 |
+
async def release_write(self) -> None:
|
| 39 |
+
"""Release write lock."""
|
| 40 |
+
async with self._condition:
|
| 41 |
+
self._writer = False
|
| 42 |
+
self._condition.notify_all()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CacheService:
|
| 46 |
+
"""Async-safe caching with TTL support and read-write lock for concurrent reads."""
|
| 47 |
|
| 48 |
+
def __init__(self, ttl_seconds: int = 3600, max_size: int = 100) -> None:
|
| 49 |
+
self._cache: TTLCache[str, Any] = TTLCache(maxsize=max_size, ttl=ttl_seconds)
|
| 50 |
+
self._lock = ReadWriteLock()
|
| 51 |
+
|
| 52 |
+
async def get(self, key: str) -> Any | None:
|
| 53 |
+
"""Get a value from the cache."""
|
| 54 |
+
await self._lock.acquire_read()
|
| 55 |
+
try:
|
| 56 |
return self._cache.get(key)
|
| 57 |
+
finally:
|
| 58 |
+
await self._lock.release_read()
|
| 59 |
|
| 60 |
async def set(self, key: str, value: Any) -> None:
|
| 61 |
+
"""Set a value in the cache."""
|
| 62 |
+
await self._lock.acquire_write()
|
| 63 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
self._cache[key] = value
|
| 65 |
+
finally:
|
| 66 |
+
await self._lock.release_write()
|
| 67 |
|
| 68 |
async def delete(self, key: str) -> None:
|
| 69 |
+
"""Delete a value from the cache."""
|
| 70 |
+
await self._lock.acquire_write()
|
| 71 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._cache.pop(key, None)
|
| 73 |
+
finally:
|
| 74 |
+
await self._lock.release_write()
|
| 75 |
|
| 76 |
async def clear(self) -> None:
|
| 77 |
"""Clear all entries from the cache."""
|
| 78 |
+
await self._lock.acquire_write()
|
| 79 |
+
try:
|
| 80 |
self._cache.clear()
|
| 81 |
+
finally:
|
| 82 |
+
await self._lock.release_write()
|
| 83 |
|
| 84 |
async def has(self, key: str) -> bool:
|
| 85 |
+
"""Check if a key exists in the cache."""
|
| 86 |
+
await self._lock.acquire_read()
|
| 87 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
return key in self._cache
|
| 89 |
+
finally:
|
| 90 |
+
await self._lock.release_read()
|
src/westernfront/services/chain_analysis.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-pass chain analysis for improved quality."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
from loguru import logger
|
| 7 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 8 |
+
|
| 9 |
+
from westernfront.prompts.analysis import (
|
| 10 |
+
build_extraction_prompt,
|
| 11 |
+
build_rag_prompt,
|
| 12 |
+
build_synthesis_prompt,
|
| 13 |
+
)
|
| 14 |
+
from westernfront.utils import extract_json_from_response
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChainAnalysisService:
|
| 18 |
+
"""Multi-pass analysis using chain-of-thought for better quality."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model: genai.GenerativeModel) -> None:
|
| 21 |
+
self._model = model
|
| 22 |
+
|
| 23 |
+
@retry(wait=wait_exponential(min=2, max=30), stop=stop_after_attempt(3))
|
| 24 |
+
async def _call_model(self, prompt: str) -> dict[str, Any] | None:
|
| 25 |
+
"""Call model with retry and JSON parsing."""
|
| 26 |
+
response = await self._model.generate_content_async(prompt)
|
| 27 |
+
return extract_json_from_response(response.text)
|
| 28 |
+
|
| 29 |
+
async def extract_facts(self, items: list[dict]) -> list[dict]:
|
| 30 |
+
"""Stage 1: Extract facts from sources."""
|
| 31 |
+
logger.info("Chain Analysis Stage 1: Extracting facts")
|
| 32 |
+
prompt = build_extraction_prompt(items)
|
| 33 |
+
|
| 34 |
+
result = await self._call_model(prompt)
|
| 35 |
+
if not result:
|
| 36 |
+
logger.warning("Fact extraction failed, continuing with empty facts")
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
facts = result.get("facts", [])
|
| 40 |
+
logger.info(f"Extracted {len(facts)} facts from sources")
|
| 41 |
+
return facts
|
| 42 |
+
|
| 43 |
+
async def synthesize_facts(
|
| 44 |
+
self,
|
| 45 |
+
facts: list[dict],
|
| 46 |
+
historical_context: str = "",
|
| 47 |
+
) -> dict:
|
| 48 |
+
"""Stage 2: Synthesize facts into preliminary assessment."""
|
| 49 |
+
logger.info("Chain Analysis Stage 2: Synthesizing facts")
|
| 50 |
+
prompt = build_synthesis_prompt(facts, historical_context)
|
| 51 |
+
|
| 52 |
+
result = await self._call_model(prompt)
|
| 53 |
+
if not result:
|
| 54 |
+
logger.warning("Synthesis failed, using defaults")
|
| 55 |
+
return {
|
| 56 |
+
"significant_developments": [],
|
| 57 |
+
"preliminary_tension": "MEDIUM",
|
| 58 |
+
"tension_reasoning": "Unable to synthesize",
|
| 59 |
+
"trends": [],
|
| 60 |
+
"contradictions": [],
|
| 61 |
+
"gaps": ["Synthesis stage failed"],
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
logger.info(f"Synthesis complete: preliminary tension = {result.get('preliminary_tension')}")
|
| 65 |
+
return result
|
| 66 |
+
|
| 67 |
+
async def generate_final_report(
|
| 68 |
+
self,
|
| 69 |
+
items: list[dict],
|
| 70 |
+
synthesis: dict,
|
| 71 |
+
total_in_memory: int = 0,
|
| 72 |
+
) -> dict[str, Any] | None:
|
| 73 |
+
"""Stage 3: Generate final analysis report."""
|
| 74 |
+
logger.info("Chain Analysis Stage 3: Generating final report")
|
| 75 |
+
|
| 76 |
+
prompt = build_rag_prompt(items, total_in_memory)
|
| 77 |
+
|
| 78 |
+
synthesis_context = f"""
|
| 79 |
+
PRELIMINARY ASSESSMENT (from internal analysis):
|
| 80 |
+
- Tension Level Estimate: {synthesis.get('preliminary_tension', 'UNKNOWN')}
|
| 81 |
+
- Reasoning: {synthesis.get('tension_reasoning', 'N/A')}
|
| 82 |
+
- Key Trends: {', '.join(synthesis.get('trends', []))}
|
| 83 |
+
- Information Gaps: {', '.join(synthesis.get('gaps', []))}
|
| 84 |
+
|
| 85 |
+
Consider this preliminary assessment but verify against the source data.
|
| 86 |
+
"""
|
| 87 |
+
enhanced_prompt = prompt + "\n\n" + synthesis_context
|
| 88 |
+
|
| 89 |
+
result = await self._call_model(enhanced_prompt)
|
| 90 |
+
if result:
|
| 91 |
+
logger.info("Final report generated successfully")
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
async def run_chain_analysis(
|
| 95 |
+
self,
|
| 96 |
+
items: list[dict],
|
| 97 |
+
total_in_memory: int = 0,
|
| 98 |
+
use_full_chain: bool = True,
|
| 99 |
+
) -> dict[str, Any] | None:
|
| 100 |
+
"""Run complete chain analysis."""
|
| 101 |
+
if not use_full_chain or len(items) < 10:
|
| 102 |
+
logger.info("Using direct analysis (chain disabled or insufficient data)")
|
| 103 |
+
prompt = build_rag_prompt(items, total_in_memory)
|
| 104 |
+
return await self._call_model(prompt)
|
| 105 |
+
|
| 106 |
+
facts = await self.extract_facts(items)
|
| 107 |
+
synthesis = await self.synthesize_facts(facts)
|
| 108 |
+
return await self.generate_final_report(items, synthesis, total_in_memory)
|
src/westernfront/services/embeddings.py
CHANGED
|
@@ -1,25 +1,21 @@
|
|
| 1 |
"""Local embedding service using sentence-transformers MiniLM model."""
|
| 2 |
|
|
|
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
from loguru import logger
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class EmbeddingService:
|
| 14 |
"""Service for generating text embeddings using local MiniLM model."""
|
| 15 |
|
| 16 |
-
def __init__(self, model_cache_dir:
|
| 17 |
-
"""
|
| 18 |
-
Initialize the embedding service.
|
| 19 |
-
|
| 20 |
-
Args:
|
| 21 |
-
model_cache_dir: Directory to cache the model (defaults to server/models).
|
| 22 |
-
"""
|
| 23 |
if model_cache_dir:
|
| 24 |
self._cache_dir = Path(model_cache_dir)
|
| 25 |
else:
|
|
@@ -32,28 +28,19 @@ class EmbeddingService:
|
|
| 32 |
os.environ["TRANSFORMERS_CACHE"] = str(self._cache_dir)
|
| 33 |
|
| 34 |
self._model = None
|
| 35 |
-
self._dimension =
|
| 36 |
|
| 37 |
def initialize(self) -> bool:
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
self._model = SentenceTransformer(
|
| 49 |
-
MODEL_NAME,
|
| 50 |
-
cache_folder=str(self._cache_dir),
|
| 51 |
-
)
|
| 52 |
-
logger.info(f"Embedding model loaded: {MODEL_NAME}")
|
| 53 |
-
return True
|
| 54 |
-
except Exception as e:
|
| 55 |
-
logger.error(f"Failed to load embedding model: {e}")
|
| 56 |
-
return False
|
| 57 |
|
| 58 |
@property
|
| 59 |
def is_initialized(self) -> bool:
|
|
@@ -65,34 +52,32 @@ class EmbeddingService:
|
|
| 65 |
"""Get the embedding dimension."""
|
| 66 |
return self._dimension
|
| 67 |
|
| 68 |
-
def
|
| 69 |
-
"""
|
| 70 |
-
Generate embedding for a single text.
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
text: Text to embed.
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
List of floats representing the embedding.
|
| 77 |
-
"""
|
| 78 |
if not self._model:
|
| 79 |
-
raise
|
| 80 |
-
|
| 81 |
embedding = self._model.encode(text, convert_to_numpy=True)
|
| 82 |
return embedding.tolist()
|
| 83 |
|
| 84 |
-
def
|
| 85 |
-
"""
|
| 86 |
-
Generate embeddings for multiple texts.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
texts: List of texts to embed.
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
List of embeddings.
|
| 93 |
-
"""
|
| 94 |
if not self._model:
|
| 95 |
-
raise
|
| 96 |
-
|
| 97 |
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
| 98 |
return embeddings.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Local embedding service using sentence-transformers MiniLM model."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
import os
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
|
| 7 |
from loguru import logger
|
| 8 |
|
| 9 |
+
from westernfront.core.exceptions import ServiceNotInitializedError
|
| 10 |
|
| 11 |
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 12 |
+
EMBEDDING_DIMENSION = 384
|
| 13 |
|
| 14 |
|
| 15 |
class EmbeddingService:
|
| 16 |
"""Service for generating text embeddings using local MiniLM model."""
|
| 17 |
|
| 18 |
+
def __init__(self, model_cache_dir: str | None = None) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
if model_cache_dir:
|
| 20 |
self._cache_dir = Path(model_cache_dir)
|
| 21 |
else:
|
|
|
|
| 28 |
os.environ["TRANSFORMERS_CACHE"] = str(self._cache_dir)
|
| 29 |
|
| 30 |
self._model = None
|
| 31 |
+
self._dimension = EMBEDDING_DIMENSION
|
| 32 |
|
| 33 |
def initialize(self) -> bool:
|
| 34 |
+
"""Initialize the embedding model."""
|
| 35 |
+
logger.info(f"Loading embedding model from {self._cache_dir}")
|
| 36 |
+
from sentence_transformers import SentenceTransformer
|
| 37 |
+
|
| 38 |
+
self._model = SentenceTransformer(
|
| 39 |
+
MODEL_NAME,
|
| 40 |
+
cache_folder=str(self._cache_dir),
|
| 41 |
+
)
|
| 42 |
+
logger.info(f"Embedding model loaded: {MODEL_NAME}")
|
| 43 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
@property
|
| 46 |
def is_initialized(self) -> bool:
|
|
|
|
| 52 |
"""Get the embedding dimension."""
|
| 53 |
return self._dimension
|
| 54 |
|
| 55 |
+
def _embed_sync(self, text: str) -> list[float]:
|
| 56 |
+
"""Synchronous embedding generation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
if not self._model:
|
| 58 |
+
raise ServiceNotInitializedError("Embedding model not initialized")
|
|
|
|
| 59 |
embedding = self._model.encode(text, convert_to_numpy=True)
|
| 60 |
return embedding.tolist()
|
| 61 |
|
| 62 |
+
def _embed_batch_sync(self, texts: list[str]) -> list[list[float]]:
|
| 63 |
+
"""Synchronous batch embedding generation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if not self._model:
|
| 65 |
+
raise ServiceNotInitializedError("Embedding model not initialized")
|
|
|
|
| 66 |
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
| 67 |
return embeddings.tolist()
|
| 68 |
+
|
| 69 |
+
def embed(self, text: str) -> list[float]:
|
| 70 |
+
"""Generate embedding for a single text (sync version)."""
|
| 71 |
+
return self._embed_sync(text)
|
| 72 |
+
|
| 73 |
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
| 74 |
+
"""Generate embeddings for multiple texts (sync version)."""
|
| 75 |
+
return self._embed_batch_sync(texts)
|
| 76 |
+
|
| 77 |
+
async def embed_async(self, text: str) -> list[float]:
|
| 78 |
+
"""Generate embedding for a single text without blocking the event loop."""
|
| 79 |
+
return await asyncio.to_thread(self._embed_sync, text)
|
| 80 |
+
|
| 81 |
+
async def embed_batch_async(self, texts: list[str]) -> list[list[float]]:
|
| 82 |
+
"""Generate embeddings for multiple texts without blocking the event loop."""
|
| 83 |
+
return await asyncio.to_thread(self._embed_batch_sync, texts)
|
src/westernfront/services/http.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared HTTP client service with connection pooling."""
|
| 2 |
+
|
| 3 |
+
from collections.abc import AsyncGenerator
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
|
| 6 |
+
import aiohttp
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from westernfront.core.constants import HTTP_TIMEOUT_SECONDS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HttpService:
|
| 13 |
+
"""Shared HTTP client with connection pooling for all external requests."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, timeout_seconds: int = HTTP_TIMEOUT_SECONDS) -> None:
|
| 16 |
+
self._timeout = aiohttp.ClientTimeout(total=timeout_seconds)
|
| 17 |
+
self._session: aiohttp.ClientSession | None = None
|
| 18 |
+
|
| 19 |
+
async def initialize(self) -> None:
|
| 20 |
+
"""Initialize the shared HTTP session with connection pooling."""
|
| 21 |
+
if self._session is None or self._session.closed:
|
| 22 |
+
connector = aiohttp.TCPConnector(
|
| 23 |
+
limit=100,
|
| 24 |
+
limit_per_host=10,
|
| 25 |
+
ttl_dns_cache=300,
|
| 26 |
+
enable_cleanup_closed=True,
|
| 27 |
+
)
|
| 28 |
+
self._session = aiohttp.ClientSession(
|
| 29 |
+
connector=connector,
|
| 30 |
+
timeout=self._timeout,
|
| 31 |
+
)
|
| 32 |
+
logger.info("HTTP service initialized with connection pooling")
|
| 33 |
+
|
| 34 |
+
async def close(self) -> None:
|
| 35 |
+
"""Close the HTTP session."""
|
| 36 |
+
if self._session and not self._session.closed:
|
| 37 |
+
await self._session.close()
|
| 38 |
+
self._session = None
|
| 39 |
+
logger.info("HTTP service closed")
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def session(self) -> aiohttp.ClientSession:
|
| 43 |
+
"""Get the shared HTTP session."""
|
| 44 |
+
if self._session is None or self._session.closed:
|
| 45 |
+
raise RuntimeError("HTTP service not initialized")
|
| 46 |
+
return self._session
|
| 47 |
+
|
| 48 |
+
@asynccontextmanager
|
| 49 |
+
async def get(
|
| 50 |
+
self,
|
| 51 |
+
url: str,
|
| 52 |
+
params: dict | None = None,
|
| 53 |
+
headers: dict | None = None,
|
| 54 |
+
) -> AsyncGenerator[aiohttp.ClientResponse, None]:
|
| 55 |
+
"""Perform a GET request with automatic error handling."""
|
| 56 |
+
async with self.session.get(url, params=params, headers=headers) as response:
|
| 57 |
+
yield response
|
src/westernfront/services/newsapi.py
CHANGED
|
@@ -1,41 +1,32 @@
|
|
| 1 |
-
"""NewsAPI integration service."""
|
| 2 |
|
|
|
|
| 3 |
import hashlib
|
| 4 |
-
from datetime import
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
import aiohttp
|
| 8 |
from loguru import logger
|
| 9 |
|
|
|
|
| 10 |
from westernfront.core.enums import SourceType
|
| 11 |
from westernfront.core.models import NewsItem
|
| 12 |
from westernfront.services.cache import CacheService
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
NEWSAPI_BASE_URL = "https://newsapi.org/v2"
|
| 16 |
-
|
| 17 |
-
INDIA_PAKISTAN_QUERIES = [
|
| 18 |
-
"India Pakistan",
|
| 19 |
-
"Kashmir conflict",
|
| 20 |
-
"India Pakistan border",
|
| 21 |
-
"LOC firing",
|
| 22 |
-
"Indo-Pak",
|
| 23 |
-
]
|
| 24 |
|
| 25 |
|
| 26 |
class NewsApiService:
|
| 27 |
-
"""Service for fetching news from NewsAPI.org."""
|
| 28 |
-
|
| 29 |
-
def __init__(self, api_key: Optional[str], cache: CacheService) -> None:
|
| 30 |
-
"""
|
| 31 |
-
Initialize the NewsAPI service.
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
self._api_key = api_key
|
| 38 |
self._cache = cache
|
|
|
|
|
|
|
| 39 |
|
| 40 |
@property
|
| 41 |
def is_enabled(self) -> bool:
|
|
@@ -49,14 +40,14 @@ class NewsApiService:
|
|
| 49 |
raw = f"newsapi:{url}:{title}"
|
| 50 |
return hashlib.sha256(raw.encode()).hexdigest()[:16]
|
| 51 |
|
| 52 |
-
def _parse_date(self, date_str:
|
| 53 |
"""Parse ISO date string from NewsAPI."""
|
| 54 |
if not date_str:
|
| 55 |
-
return datetime.now(
|
| 56 |
try:
|
| 57 |
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
| 58 |
except ValueError:
|
| 59 |
-
return datetime.now(
|
| 60 |
|
| 61 |
async def _search_news(self, query: str, days_back: int = 2) -> list[dict]:
|
| 62 |
"""Search NewsAPI for articles matching query."""
|
|
@@ -69,22 +60,25 @@ class NewsApiService:
|
|
| 69 |
logger.debug(f"Cache hit for NewsAPI query: {query}")
|
| 70 |
return cached
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
if response.status == 401:
|
| 89 |
logger.error("NewsAPI: Invalid API key")
|
| 90 |
return []
|
|
@@ -101,37 +95,34 @@ class NewsApiService:
|
|
| 101 |
logger.info(f"NewsAPI: Found {len(articles)} articles for '{query}'")
|
| 102 |
return articles
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
|
| 108 |
async def get_related_articles(
|
| 109 |
self,
|
| 110 |
-
keywords:
|
| 111 |
days_back: int = 2,
|
| 112 |
) -> list[NewsItem]:
|
| 113 |
-
"""
|
| 114 |
-
Get articles related to India-Pakistan from NewsAPI.
|
| 115 |
-
|
| 116 |
-
Args:
|
| 117 |
-
keywords: Optional additional keywords (defaults to built-in queries).
|
| 118 |
-
days_back: How many days back to search.
|
| 119 |
-
|
| 120 |
-
Returns:
|
| 121 |
-
List of news items.
|
| 122 |
-
"""
|
| 123 |
if not self.is_enabled:
|
| 124 |
logger.debug("NewsAPI service is disabled (no API key)")
|
| 125 |
return []
|
| 126 |
|
| 127 |
-
queries = keywords if keywords else
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
seen_urls: set[str] = set()
|
| 129 |
-
|
| 130 |
|
| 131 |
-
for
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
for article in
|
| 135 |
url = article.get("url", "")
|
| 136 |
if not url or url in seen_urls:
|
| 137 |
continue
|
|
@@ -154,9 +145,9 @@ class NewsApiService:
|
|
| 154 |
reliability_score=0.9,
|
| 155 |
author=article.get("author"),
|
| 156 |
)
|
| 157 |
-
|
| 158 |
seen_urls.add(url)
|
| 159 |
|
| 160 |
-
|
| 161 |
-
logger.info(f"Found {len(
|
| 162 |
-
return
|
|
|
|
| 1 |
+
"""NewsAPI integration service with parallel query fetching."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
import hashlib
|
| 5 |
+
from datetime import UTC, datetime, timedelta
|
|
|
|
| 6 |
|
| 7 |
import aiohttp
|
| 8 |
from loguru import logger
|
| 9 |
|
| 10 |
+
from westernfront.core.constants import MAX_CONCURRENT_REQUESTS, NEWSAPI_BASE_URL, NEWSAPI_QUERIES
|
| 11 |
from westernfront.core.enums import SourceType
|
| 12 |
from westernfront.core.models import NewsItem
|
| 13 |
from westernfront.services.cache import CacheService
|
| 14 |
+
from westernfront.services.http import HttpService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class NewsApiService:
|
| 18 |
+
"""Service for fetching news from NewsAPI.org with parallel query execution."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
api_key: str | None,
|
| 23 |
+
cache: CacheService,
|
| 24 |
+
http: HttpService,
|
| 25 |
+
) -> None:
|
| 26 |
self._api_key = api_key
|
| 27 |
self._cache = cache
|
| 28 |
+
self._http = http
|
| 29 |
+
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
| 30 |
|
| 31 |
@property
|
| 32 |
def is_enabled(self) -> bool:
|
|
|
|
| 40 |
raw = f"newsapi:{url}:{title}"
|
| 41 |
return hashlib.sha256(raw.encode()).hexdigest()[:16]
|
| 42 |
|
| 43 |
+
def _parse_date(self, date_str: str | None) -> datetime:
|
| 44 |
"""Parse ISO date string from NewsAPI."""
|
| 45 |
if not date_str:
|
| 46 |
+
return datetime.now(UTC)
|
| 47 |
try:
|
| 48 |
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
| 49 |
except ValueError:
|
| 50 |
+
return datetime.now(UTC)
|
| 51 |
|
| 52 |
async def _search_news(self, query: str, days_back: int = 2) -> list[dict]:
|
| 53 |
"""Search NewsAPI for articles matching query."""
|
|
|
|
| 60 |
logger.debug(f"Cache hit for NewsAPI query: {query}")
|
| 61 |
return cached
|
| 62 |
|
| 63 |
+
async with self._semaphore:
|
| 64 |
+
from_date = (datetime.now(UTC) - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
| 65 |
+
|
| 66 |
+
params = {
|
| 67 |
+
"q": query,
|
| 68 |
+
"from": from_date,
|
| 69 |
+
"language": "en",
|
| 70 |
+
"sortBy": "publishedAt",
|
| 71 |
+
"pageSize": 50,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
headers = {"X-Api-Key": self._api_key}
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
async with self._http.get(
|
| 78 |
+
f"{NEWSAPI_BASE_URL}/everything",
|
| 79 |
+
params=params,
|
| 80 |
+
headers=headers,
|
| 81 |
+
) as response:
|
| 82 |
if response.status == 401:
|
| 83 |
logger.error("NewsAPI: Invalid API key")
|
| 84 |
return []
|
|
|
|
| 95 |
logger.info(f"NewsAPI: Found {len(articles)} articles for '{query}'")
|
| 96 |
return articles
|
| 97 |
|
| 98 |
+
except aiohttp.ClientError as e:
|
| 99 |
+
logger.error(f"NewsAPI HTTP error: {e}")
|
| 100 |
+
return []
|
| 101 |
|
| 102 |
async def get_related_articles(
|
| 103 |
self,
|
| 104 |
+
keywords: list[str] | None = None,
|
| 105 |
days_back: int = 2,
|
| 106 |
) -> list[NewsItem]:
|
| 107 |
+
"""Get articles related to India-Pakistan from NewsAPI in parallel."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if not self.is_enabled:
|
| 109 |
logger.debug("NewsAPI service is disabled (no API key)")
|
| 110 |
return []
|
| 111 |
|
| 112 |
+
queries = keywords if keywords else NEWSAPI_QUERIES[:3]
|
| 113 |
+
|
| 114 |
+
tasks = [self._search_news(query, days_back) for query in queries]
|
| 115 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 116 |
+
|
| 117 |
seen_urls: set[str] = set()
|
| 118 |
+
all_articles: list[NewsItem] = []
|
| 119 |
|
| 120 |
+
for i, result in enumerate(results):
|
| 121 |
+
if isinstance(result, Exception):
|
| 122 |
+
logger.error(f"NewsAPI query '{queries[i]}' failed: {result}")
|
| 123 |
+
continue
|
| 124 |
|
| 125 |
+
for article in result:
|
| 126 |
url = article.get("url", "")
|
| 127 |
if not url or url in seen_urls:
|
| 128 |
continue
|
|
|
|
| 145 |
reliability_score=0.9,
|
| 146 |
author=article.get("author"),
|
| 147 |
)
|
| 148 |
+
all_articles.append(item)
|
| 149 |
seen_urls.add(url)
|
| 150 |
|
| 151 |
+
all_articles.sort(key=lambda i: (-i.published_at.timestamp(), -i.reliability_score))
|
| 152 |
+
logger.info(f"Found {len(all_articles)} articles from NewsAPI")
|
| 153 |
+
return all_articles
|
src/westernfront/services/parsing.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Response parsing service for AI-generated analysis data."""
|
| 2 |
+
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
|
| 6 |
+
from westernfront.core.models import KeyDevelopment
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ResponseParser:
|
| 10 |
+
"""Parses and validates AI-generated analysis response data."""
|
| 11 |
+
|
| 12 |
+
def parse_tension_level(self, value: str) -> TensionLevel:
|
| 13 |
+
"""Parse tension level from string."""
|
| 14 |
+
value = value.upper()
|
| 15 |
+
if "CRITICAL" in value:
|
| 16 |
+
return TensionLevel.CRITICAL
|
| 17 |
+
if "HIGH" in value:
|
| 18 |
+
return TensionLevel.HIGH
|
| 19 |
+
if "MEDIUM" in value:
|
| 20 |
+
return TensionLevel.MEDIUM
|
| 21 |
+
return TensionLevel.LOW
|
| 22 |
+
|
| 23 |
+
def parse_tension_trend(self, value: str) -> TensionTrend:
|
| 24 |
+
"""Parse tension trend from string."""
|
| 25 |
+
value = value.upper()
|
| 26 |
+
if "INCREASING" in value:
|
| 27 |
+
return TensionTrend.INCREASING
|
| 28 |
+
if "DECREASING" in value:
|
| 29 |
+
return TensionTrend.DECREASING
|
| 30 |
+
return TensionTrend.STABLE
|
| 31 |
+
|
| 32 |
+
def parse_analysis_type(self, value: str) -> AnalysisType:
|
| 33 |
+
"""Parse analysis type from string."""
|
| 34 |
+
value = value.upper()
|
| 35 |
+
if "MILITARY" in value:
|
| 36 |
+
return AnalysisType.MILITARY
|
| 37 |
+
if "DIPLOMATIC" in value:
|
| 38 |
+
return AnalysisType.DIPLOMATIC
|
| 39 |
+
if "INTERNAL" in value:
|
| 40 |
+
return AnalysisType.INTERNAL_SECURITY
|
| 41 |
+
if "POLITICAL" in value:
|
| 42 |
+
return AnalysisType.POLITICAL
|
| 43 |
+
return AnalysisType.OTHER
|
| 44 |
+
|
| 45 |
+
def parse_key_developments(self, data: list[dict]) -> list[KeyDevelopment]:
|
| 46 |
+
"""Parse key developments from response data."""
|
| 47 |
+
developments = []
|
| 48 |
+
for item in data:
|
| 49 |
+
if not isinstance(item, dict):
|
| 50 |
+
continue
|
| 51 |
+
developments.append(
|
| 52 |
+
KeyDevelopment(
|
| 53 |
+
title=item.get("title", "Unnamed"),
|
| 54 |
+
description=item.get("description", "No description"),
|
| 55 |
+
sources=item.get("sources", []),
|
| 56 |
+
timestamp=datetime.now(),
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
return developments
|
| 60 |
+
|
| 61 |
+
def parse_tension_score(self, raw_score: int | float | str | None) -> int:
|
| 62 |
+
"""Parse and clamp tension score to valid range [1, 10]."""
|
| 63 |
+
if raw_score is None:
|
| 64 |
+
return 1
|
| 65 |
+
if isinstance(raw_score, int | float):
|
| 66 |
+
return max(1, min(10, int(raw_score)))
|
| 67 |
+
if isinstance(raw_score, str) and raw_score.isdigit():
|
| 68 |
+
return max(1, min(10, int(raw_score)))
|
| 69 |
+
return 1
|
| 70 |
+
|
| 71 |
+
def parse_key_entities(self, entities: list[str] | str | None) -> list[str]:
|
| 72 |
+
"""Parse key entities from response, handling string or list format."""
|
| 73 |
+
if entities is None:
|
| 74 |
+
return []
|
| 75 |
+
if isinstance(entities, str):
|
| 76 |
+
return [e.strip() for e in entities.split(",") if e.strip()]
|
| 77 |
+
if isinstance(entities, list):
|
| 78 |
+
return entities
|
| 79 |
+
return []
|
| 80 |
+
|
| 81 |
+
def count_sources(self, items: list[dict]) -> dict[str, int]:
|
| 82 |
+
"""Count items by source type."""
|
| 83 |
+
counts: dict[str, int] = {}
|
| 84 |
+
for item in items:
|
| 85 |
+
meta = item.get("metadata", {})
|
| 86 |
+
key = meta.get("source_type", "unknown")
|
| 87 |
+
counts[key] = counts.get(key, 0) + 1
|
| 88 |
+
return counts
|
src/westernfront/services/reddit.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
"""Reddit data collection service using AsyncPRAW."""
|
| 2 |
|
| 3 |
-
|
| 4 |
-
from
|
| 5 |
from urllib.parse import urlparse
|
| 6 |
|
| 7 |
import aiohttp
|
|
@@ -10,40 +10,19 @@ import asyncprawcore
|
|
| 10 |
from loguru import logger
|
| 11 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from westernfront.core.enums import SourceType
|
|
|
|
| 14 |
from westernfront.core.models import NewsItem, SubredditSource
|
| 15 |
from westernfront.services.cache import CacheService
|
| 16 |
|
| 17 |
|
| 18 |
-
RELIABLE_DOMAINS = frozenset([
|
| 19 |
-
"bbc.com", "reuters.com", "apnews.com", "aljazeera.com",
|
| 20 |
-
"nytimes.com", "wsj.com", "ft.com", "economist.com",
|
| 21 |
-
"thediplomat.com", "foreignpolicy.com", "foreignaffairs.com",
|
| 22 |
-
"dawn.com", "timesofindia.indiatimes.com", "ndtv.com", "geo.tv",
|
| 23 |
-
])
|
| 24 |
-
|
| 25 |
-
DEFAULT_SUBREDDITS = [
|
| 26 |
-
# High-quality geopolitics sources
|
| 27 |
-
SubredditSource(name="geopolitics", reliability_score=0.85),
|
| 28 |
-
SubredditSource(name="CredibleDefense", reliability_score=0.9),
|
| 29 |
-
SubredditSource(name="worldnews", reliability_score=0.8),
|
| 30 |
-
SubredditSource(name="neutralnews", reliability_score=0.8),
|
| 31 |
-
SubredditSource(name="DefenseNews", reliability_score=0.85),
|
| 32 |
-
# South Asia focused
|
| 33 |
-
SubredditSource(name="GeopoliticsIndia", reliability_score=0.75),
|
| 34 |
-
SubredditSource(name="SouthAsia", reliability_score=0.7),
|
| 35 |
-
SubredditSource(name="india", reliability_score=0.7),
|
| 36 |
-
SubredditSource(name="pakistan", reliability_score=0.7),
|
| 37 |
-
# Regional neighbors
|
| 38 |
-
SubredditSource(name="Nepal", reliability_score=0.65),
|
| 39 |
-
SubredditSource(name="bangladesh", reliability_score=0.65),
|
| 40 |
-
SubredditSource(name="srilanka", reliability_score=0.65),
|
| 41 |
-
SubredditSource(name="China", reliability_score=0.6),
|
| 42 |
-
]
|
| 43 |
-
|
| 44 |
-
|
| 45 |
class RedditService:
|
| 46 |
-
"""Service for collecting posts from Reddit via AsyncPRAW."""
|
| 47 |
|
| 48 |
def __init__(
|
| 49 |
self,
|
|
@@ -52,30 +31,17 @@ class RedditService:
|
|
| 52 |
user_agent: str,
|
| 53 |
cache: CacheService,
|
| 54 |
) -> None:
|
| 55 |
-
"""
|
| 56 |
-
Initialize the Reddit service.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
client_id: Reddit API client ID.
|
| 60 |
-
client_secret: Reddit API client secret.
|
| 61 |
-
user_agent: User agent string for API requests.
|
| 62 |
-
cache: Cache service for storing results.
|
| 63 |
-
"""
|
| 64 |
self._client_id = client_id
|
| 65 |
self._client_secret = client_secret
|
| 66 |
self._user_agent = user_agent
|
| 67 |
self._cache = cache
|
| 68 |
-
self._reddit:
|
| 69 |
-
self._session:
|
| 70 |
self._sources = list(DEFAULT_SUBREDDITS)
|
|
|
|
| 71 |
|
| 72 |
async def initialize(self) -> bool:
|
| 73 |
-
"""
|
| 74 |
-
Initialize the Reddit API client.
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
True if initialization was successful.
|
| 78 |
-
"""
|
| 79 |
logger.info("Initializing Reddit service")
|
| 80 |
self._session = aiohttp.ClientSession()
|
| 81 |
self._reddit = asyncpraw.Reddit(
|
|
@@ -115,7 +81,7 @@ class RedditService:
|
|
| 115 |
try:
|
| 116 |
domain = urlparse(url).netloc
|
| 117 |
return domain.replace("www.", "")
|
| 118 |
-
except
|
| 119 |
return ""
|
| 120 |
|
| 121 |
def _calculate_reliability(self, post_url: str, score: int, base_score: float) -> float:
|
|
@@ -139,9 +105,9 @@ class RedditService:
|
|
| 139 |
source: SubredditSource,
|
| 140 |
limit: int = 100,
|
| 141 |
) -> list[NewsItem]:
|
| 142 |
-
"""Fetch posts from a subreddit with retry logic."""
|
| 143 |
if not self._reddit:
|
| 144 |
-
raise
|
| 145 |
|
| 146 |
cache_key = f"reddit_{source.name}_{limit}"
|
| 147 |
cached = await self._cache.get(cache_key)
|
|
@@ -149,69 +115,62 @@ class RedditService:
|
|
| 149 |
logger.debug(f"Cache hit for r/{source.name}")
|
| 150 |
return cached
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
Get ALL posts from all active subreddits without keyword filtering.
|
| 186 |
-
|
| 187 |
-
Used for RAG ingestion where vector search determines relevance.
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
|
| 192 |
-
Returns:
|
| 193 |
-
List of all news items.
|
| 194 |
-
"""
|
| 195 |
-
active_sources = [s for s in self._sources if s.is_active]
|
| 196 |
-
cutoff = datetime.now(timezone.utc) - timedelta(days=days_back)
|
| 197 |
seen_ids: set[str] = set()
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
for
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
| 1 |
+
"""Reddit data collection service using AsyncPRAW with full parallelization."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
from datetime import UTC, datetime, timedelta
|
| 5 |
from urllib.parse import urlparse
|
| 6 |
|
| 7 |
import aiohttp
|
|
|
|
| 10 |
from loguru import logger
|
| 11 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 12 |
|
| 13 |
+
from westernfront.core.constants import (
|
| 14 |
+
DEFAULT_SUBREDDITS,
|
| 15 |
+
MAX_CONCURRENT_REQUESTS,
|
| 16 |
+
RELIABLE_DOMAINS,
|
| 17 |
+
)
|
| 18 |
from westernfront.core.enums import SourceType
|
| 19 |
+
from westernfront.core.exceptions import ServiceNotInitializedError
|
| 20 |
from westernfront.core.models import NewsItem, SubredditSource
|
| 21 |
from westernfront.services.cache import CacheService
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
class RedditService:
|
| 25 |
+
"""Service for collecting posts from Reddit via AsyncPRAW with parallel fetching."""
|
| 26 |
|
| 27 |
def __init__(
|
| 28 |
self,
|
|
|
|
| 31 |
user_agent: str,
|
| 32 |
cache: CacheService,
|
| 33 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self._client_id = client_id
|
| 35 |
self._client_secret = client_secret
|
| 36 |
self._user_agent = user_agent
|
| 37 |
self._cache = cache
|
| 38 |
+
self._reddit: asyncpraw.Reddit | None = None
|
| 39 |
+
self._session: aiohttp.ClientSession | None = None
|
| 40 |
self._sources = list(DEFAULT_SUBREDDITS)
|
| 41 |
+
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
| 42 |
|
| 43 |
async def initialize(self) -> bool:
|
| 44 |
+
"""Initialize the Reddit API client."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
logger.info("Initializing Reddit service")
|
| 46 |
self._session = aiohttp.ClientSession()
|
| 47 |
self._reddit = asyncpraw.Reddit(
|
|
|
|
| 81 |
try:
|
| 82 |
domain = urlparse(url).netloc
|
| 83 |
return domain.replace("www.", "")
|
| 84 |
+
except (ValueError, AttributeError):
|
| 85 |
return ""
|
| 86 |
|
| 87 |
def _calculate_reliability(self, post_url: str, score: int, base_score: float) -> float:
|
|
|
|
| 105 |
source: SubredditSource,
|
| 106 |
limit: int = 100,
|
| 107 |
) -> list[NewsItem]:
|
| 108 |
+
"""Fetch posts from a subreddit with retry logic and rate limiting."""
|
| 109 |
if not self._reddit:
|
| 110 |
+
raise ServiceNotInitializedError("Reddit service not initialized")
|
| 111 |
|
| 112 |
cache_key = f"reddit_{source.name}_{limit}"
|
| 113 |
cached = await self._cache.get(cache_key)
|
|
|
|
| 115 |
logger.debug(f"Cache hit for r/{source.name}")
|
| 116 |
return cached
|
| 117 |
|
| 118 |
+
async with self._semaphore:
|
| 119 |
+
logger.info(f"Fetching posts from r/{source.name}")
|
| 120 |
+
subreddit = await self._reddit.subreddit(source.name)
|
| 121 |
+
posts: list[NewsItem] = []
|
| 122 |
+
|
| 123 |
+
async for submission in subreddit.new(limit=limit):
|
| 124 |
+
content = f"{submission.title}\n{getattr(submission, 'selftext', '')}"
|
| 125 |
+
author = str(submission.author) if submission.author else "[deleted]"
|
| 126 |
+
|
| 127 |
+
post = NewsItem(
|
| 128 |
+
id=submission.id,
|
| 129 |
+
title=submission.title,
|
| 130 |
+
content=content,
|
| 131 |
+
url=submission.url,
|
| 132 |
+
source_name=f"r/{source.name}",
|
| 133 |
+
source_type=SourceType.REDDIT,
|
| 134 |
+
published_at=datetime.fromtimestamp(submission.created_utc, tz=UTC),
|
| 135 |
+
reliability_score=self._calculate_reliability(
|
| 136 |
+
submission.url, submission.score, source.reliability_score
|
| 137 |
+
),
|
| 138 |
+
author=author,
|
| 139 |
+
score=submission.score,
|
| 140 |
+
)
|
| 141 |
+
posts.append(post)
|
| 142 |
+
|
| 143 |
+
await self._cache.set(cache_key, posts)
|
| 144 |
+
logger.info(f"Fetched {len(posts)} posts from r/{source.name}")
|
| 145 |
+
return posts
|
| 146 |
+
|
| 147 |
+
async def get_all_posts(self, days_back: int = 2) -> list[NewsItem]:
|
| 148 |
+
"""Get all posts from all active subreddits in parallel."""
|
| 149 |
+
active_sources = [s for s in self._sources if s.is_active]
|
| 150 |
+
cutoff = datetime.now(UTC) - timedelta(days=days_back)
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
tasks = [self._fetch_subreddit_posts(source) for source in active_sources]
|
| 153 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
seen_ids: set[str] = set()
|
| 156 |
+
all_posts: list[NewsItem] = []
|
| 157 |
+
|
| 158 |
+
for i, result in enumerate(results):
|
| 159 |
+
if isinstance(result, Exception):
|
| 160 |
+
source_name = active_sources[i].name
|
| 161 |
+
if isinstance(result, asyncprawcore.exceptions.RequestException):
|
| 162 |
+
logger.error(f"Reddit API error for r/{source_name}: {result}")
|
| 163 |
+
else:
|
| 164 |
+
logger.error(f"Error fetching r/{source_name}: {result}")
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
for post in result:
|
| 168 |
+
if post.published_at < cutoff:
|
| 169 |
+
continue
|
| 170 |
+
if post.id in seen_ids:
|
| 171 |
+
continue
|
| 172 |
+
all_posts.append(post)
|
| 173 |
+
seen_ids.add(post.id)
|
| 174 |
+
|
| 175 |
+
logger.info(f"Collected {len(all_posts)} total Reddit posts for ingestion")
|
| 176 |
+
return all_posts
|
src/westernfront/services/retrieval.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enhanced retrieval service with diversity and recency weighting."""
|
| 2 |
+
|
| 3 |
+
from datetime import UTC, datetime
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
from westernfront.core.constants import RECENCY_BOOST, SOURCE_DIVERSITY_RULES
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from westernfront.repositories.vectors import VectorRepository
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RetrievalService:
|
| 15 |
+
"""Enhanced retrieval with source diversity and temporal weighting."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, vectors: "VectorRepository") -> None:
|
| 18 |
+
self._vectors = vectors
|
| 19 |
+
|
| 20 |
+
def _calculate_recency_score(self, published_at: str) -> float:
|
| 21 |
+
"""Calculate recency boost based on publication time."""
|
| 22 |
+
try:
|
| 23 |
+
pub_date = datetime.fromisoformat(published_at.replace("Z", "+00:00"))
|
| 24 |
+
now = datetime.now(UTC)
|
| 25 |
+
hours_old = (now - pub_date).total_seconds() / 3600
|
| 26 |
+
|
| 27 |
+
if hours_old <= 24:
|
| 28 |
+
return RECENCY_BOOST["hours_24"]
|
| 29 |
+
if hours_old <= 48:
|
| 30 |
+
return RECENCY_BOOST["hours_48"]
|
| 31 |
+
if hours_old <= 168:
|
| 32 |
+
return RECENCY_BOOST["days_7"]
|
| 33 |
+
return RECENCY_BOOST["older"]
|
| 34 |
+
except (ValueError, TypeError):
|
| 35 |
+
return 1.0
|
| 36 |
+
|
| 37 |
+
def _apply_recency_boost(self, items: list[dict]) -> list[dict]:
|
| 38 |
+
"""Apply recency boost to similarity scores."""
|
| 39 |
+
for item in items:
|
| 40 |
+
meta = item.get("metadata", {})
|
| 41 |
+
published_at = meta.get("published_at", "")
|
| 42 |
+
recency_mult = self._calculate_recency_score(published_at)
|
| 43 |
+
original_score = item.get("similarity_score", 0.5)
|
| 44 |
+
item["boosted_score"] = min(1.0, original_score * recency_mult)
|
| 45 |
+
item["recency_multiplier"] = recency_mult
|
| 46 |
+
|
| 47 |
+
items.sort(key=lambda x: x.get("boosted_score", 0), reverse=True)
|
| 48 |
+
return items
|
| 49 |
+
|
| 50 |
+
def _enforce_diversity(self, items: list[dict], max_items: int) -> list[dict]:
|
| 51 |
+
"""Enforce source type diversity in results."""
|
| 52 |
+
by_source: dict[str, list[dict]] = {"reddit": [], "rss": [], "newsapi": []}
|
| 53 |
+
|
| 54 |
+
for item in items:
|
| 55 |
+
meta = item.get("metadata", {})
|
| 56 |
+
source_type = meta.get("source_type", "unknown")
|
| 57 |
+
if source_type in by_source:
|
| 58 |
+
by_source[source_type].append(item)
|
| 59 |
+
|
| 60 |
+
result = []
|
| 61 |
+
for source_type, rules in SOURCE_DIVERSITY_RULES.items():
|
| 62 |
+
min_count = int(max_items * rules["min_pct"])
|
| 63 |
+
available = by_source.get(source_type, [])
|
| 64 |
+
to_add = available[:min_count]
|
| 65 |
+
result.extend(to_add)
|
| 66 |
+
logger.debug(f"Diversity: Added {len(to_add)} from {source_type} (min: {min_count})")
|
| 67 |
+
|
| 68 |
+
seen_ids = {item["id"] for item in result}
|
| 69 |
+
remaining_items = [item for item in items if item["id"] not in seen_ids]
|
| 70 |
+
|
| 71 |
+
space_left = max_items - len(result)
|
| 72 |
+
for item in remaining_items[:space_left]:
|
| 73 |
+
meta = item.get("metadata", {})
|
| 74 |
+
source_type = meta.get("source_type", "unknown")
|
| 75 |
+
max_count = int(max_items * SOURCE_DIVERSITY_RULES.get(source_type, {}).get("max_pct", 1.0))
|
| 76 |
+
current_count = sum(1 for r in result if r.get("metadata", {}).get("source_type") == source_type)
|
| 77 |
+
if current_count < max_count:
|
| 78 |
+
result.append(item)
|
| 79 |
+
|
| 80 |
+
result.sort(key=lambda x: x.get("boosted_score", 0), reverse=True)
|
| 81 |
+
return result[:max_items]
|
| 82 |
+
|
| 83 |
+
def retrieve_with_quality(
|
| 84 |
+
self,
|
| 85 |
+
topics: list[str],
|
| 86 |
+
max_items: int = 40,
|
| 87 |
+
n_per_topic: int = 10,
|
| 88 |
+
) -> list[dict]:
|
| 89 |
+
"""Retrieve items with recency boost and source diversity."""
|
| 90 |
+
if not self._vectors.is_initialized:
|
| 91 |
+
logger.warning("Vector store not initialized")
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
raw_results = self._vectors.query_by_topics(topics, n_per_topic=n_per_topic)
|
| 95 |
+
logger.info(f"Retrieved {len(raw_results)} raw items from vector store")
|
| 96 |
+
|
| 97 |
+
boosted = self._apply_recency_boost(raw_results)
|
| 98 |
+
diverse = self._enforce_diversity(boosted, max_items)
|
| 99 |
+
logger.info(f"After diversity enforcement: {len(diverse)} items")
|
| 100 |
+
|
| 101 |
+
return diverse
|
src/westernfront/services/rss.py
CHANGED
|
@@ -1,99 +1,29 @@
|
|
| 1 |
-
"""RSS feed collection service."""
|
| 2 |
|
|
|
|
| 3 |
import hashlib
|
| 4 |
-
from datetime import datetime,
|
| 5 |
-
from typing import Optional
|
| 6 |
from email.utils import parsedate_to_datetime
|
| 7 |
|
| 8 |
import aiohttp
|
| 9 |
import feedparser
|
| 10 |
from loguru import logger
|
| 11 |
|
|
|
|
| 12 |
from westernfront.core.enums import SourceType
|
| 13 |
from westernfront.core.models import NewsItem, RssFeed
|
| 14 |
from westernfront.services.cache import CacheService
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
DEFAULT_RSS_FEEDS = [
|
| 18 |
-
# Tier 1: Pakistan
|
| 19 |
-
RssFeed(
|
| 20 |
-
name="Dawn (Pakistan)",
|
| 21 |
-
url="https://www.dawn.com/feeds/home",
|
| 22 |
-
reliability_score=0.85,
|
| 23 |
-
),
|
| 24 |
-
RssFeed(
|
| 25 |
-
name="Geo News",
|
| 26 |
-
url="https://www.geo.tv/rss/1/1",
|
| 27 |
-
reliability_score=0.8,
|
| 28 |
-
),
|
| 29 |
-
RssFeed(
|
| 30 |
-
name="Express Tribune",
|
| 31 |
-
url="https://tribune.com.pk/feed/home",
|
| 32 |
-
reliability_score=0.75,
|
| 33 |
-
),
|
| 34 |
-
# India
|
| 35 |
-
RssFeed(
|
| 36 |
-
name="Times of India",
|
| 37 |
-
url="https://timesofindia.indiatimes.com/rssfeeds/296589292.cms",
|
| 38 |
-
reliability_score=0.75,
|
| 39 |
-
),
|
| 40 |
-
RssFeed(
|
| 41 |
-
name="NDTV India",
|
| 42 |
-
url="https://feeds.feedburner.com/ndtvnews-india-news",
|
| 43 |
-
reliability_score=0.8,
|
| 44 |
-
),
|
| 45 |
-
RssFeed(
|
| 46 |
-
name="The Hindu",
|
| 47 |
-
url="https://www.thehindu.com/news/national/feeder/default.rss",
|
| 48 |
-
reliability_score=0.85,
|
| 49 |
-
),
|
| 50 |
-
RssFeed(
|
| 51 |
-
name="Indian Express",
|
| 52 |
-
url="https://indianexpress.com/section/india/feed/",
|
| 53 |
-
reliability_score=0.85,
|
| 54 |
-
),
|
| 55 |
-
# Tier 2: China/Nepal/Bangladesh
|
| 56 |
-
RssFeed(
|
| 57 |
-
name="South China Morning Post - Asia",
|
| 58 |
-
url="https://www.scmp.com/rss/91/feed",
|
| 59 |
-
reliability_score=0.85,
|
| 60 |
-
),
|
| 61 |
-
RssFeed(
|
| 62 |
-
name="Kathmandu Post",
|
| 63 |
-
url="https://kathmandupost.com/rss",
|
| 64 |
-
reliability_score=0.75,
|
| 65 |
-
),
|
| 66 |
-
RssFeed(
|
| 67 |
-
name="Dhaka Tribune",
|
| 68 |
-
url="https://www.dhakatribune.com/rss",
|
| 69 |
-
reliability_score=0.75,
|
| 70 |
-
),
|
| 71 |
-
RssFeed(
|
| 72 |
-
name="Daily Star Bangladesh",
|
| 73 |
-
url="https://www.thedailystar.net/rss.xml",
|
| 74 |
-
reliability_score=0.75,
|
| 75 |
-
),
|
| 76 |
-
# Tier 3: Others
|
| 77 |
-
RssFeed(
|
| 78 |
-
name="Daily Mirror Sri Lanka",
|
| 79 |
-
url="http://www.dailymirror.lk/RSS_Feeds/breaking-news",
|
| 80 |
-
reliability_score=0.7,
|
| 81 |
-
),
|
| 82 |
-
]
|
| 83 |
|
| 84 |
|
| 85 |
class RssService:
|
| 86 |
-
"""Service for collecting news from RSS feeds."""
|
| 87 |
-
|
| 88 |
-
def __init__(self, cache: CacheService) -> None:
|
| 89 |
-
"""
|
| 90 |
-
Initialize the RSS service.
|
| 91 |
|
| 92 |
-
|
| 93 |
-
cache: Cache service for storing results.
|
| 94 |
-
"""
|
| 95 |
self._cache = cache
|
|
|
|
| 96 |
self._feeds = list(DEFAULT_RSS_FEEDS)
|
|
|
|
| 97 |
|
| 98 |
@property
|
| 99 |
def feeds(self) -> list[RssFeed]:
|
|
@@ -112,11 +42,11 @@ class RssService:
|
|
| 112 |
try:
|
| 113 |
parsed = parsedate_to_datetime(entry[date_field])
|
| 114 |
if parsed.tzinfo is None:
|
| 115 |
-
return parsed.replace(tzinfo=
|
| 116 |
return parsed
|
| 117 |
except (ValueError, TypeError):
|
| 118 |
pass
|
| 119 |
-
return datetime.now(
|
| 120 |
|
| 121 |
def _generate_id(self, entry: dict, feed_name: str) -> str:
|
| 122 |
"""Generate a unique ID for an entry."""
|
|
@@ -133,81 +63,73 @@ class RssService:
|
|
| 133 |
logger.debug(f"Cache hit for RSS: {feed.name}")
|
| 134 |
return cached
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
async with session.get(feed.url, timeout=30) as response:
|
| 142 |
if response.status != 200:
|
| 143 |
logger.warning(f"RSS feed {feed.name} returned {response.status}")
|
| 144 |
return []
|
| 145 |
content = await response.text()
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
title=title,
|
| 160 |
-
content=f"{title}\n{description}",
|
| 161 |
-
url=link,
|
| 162 |
-
source_name=feed.name,
|
| 163 |
-
source_type=SourceType.RSS,
|
| 164 |
-
published_at=self._parse_date(entry),
|
| 165 |
-
reliability_score=feed.reliability_score,
|
| 166 |
-
author=entry.get("author"),
|
| 167 |
-
)
|
| 168 |
-
items.append(item)
|
| 169 |
-
|
| 170 |
-
await self._cache.set(cache_key, items)
|
| 171 |
-
logger.info(f"Fetched {len(items)} items from {feed.name}")
|
| 172 |
-
|
| 173 |
-
except aiohttp.ClientError as e:
|
| 174 |
-
logger.error(f"HTTP error fetching {feed.name}: {e}")
|
| 175 |
-
except Exception as e:
|
| 176 |
-
logger.error(f"Error parsing {feed.name}: {e}")
|
| 177 |
-
|
| 178 |
-
return items
|
| 179 |
-
|
| 180 |
-
async def get_all_articles(
|
| 181 |
-
self,
|
| 182 |
-
days_back: int = 2,
|
| 183 |
-
) -> list[NewsItem]:
|
| 184 |
-
"""
|
| 185 |
-
Get ALL articles from all active feeds without keyword filtering.
|
| 186 |
-
|
| 187 |
-
Used for RAG ingestion where vector search determines relevance.
|
| 188 |
-
|
| 189 |
-
Args:
|
| 190 |
-
days_back: How many days back to search.
|
| 191 |
-
|
| 192 |
-
Returns:
|
| 193 |
-
List of all news items.
|
| 194 |
-
"""
|
| 195 |
-
from datetime import timedelta
|
| 196 |
|
| 197 |
-
active_feeds = [f for f in self._feeds if f.is_active]
|
| 198 |
-
cutoff = datetime.now(timezone.utc) - timedelta(days=days_back)
|
| 199 |
seen_urls: set[str] = set()
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
| 203 |
-
items = await self._fetch_feed(feed)
|
| 204 |
-
for item in items:
|
| 205 |
if item.published_at < cutoff:
|
| 206 |
continue
|
| 207 |
if item.url in seen_urls:
|
| 208 |
continue
|
| 209 |
-
|
| 210 |
seen_urls.add(item.url)
|
| 211 |
|
| 212 |
-
logger.info(f"Collected {len(
|
| 213 |
-
return
|
|
|
|
| 1 |
+
"""RSS feed collection service with parallel fetching."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
import hashlib
|
| 5 |
+
from datetime import UTC, datetime, timedelta
|
|
|
|
| 6 |
from email.utils import parsedate_to_datetime
|
| 7 |
|
| 8 |
import aiohttp
|
| 9 |
import feedparser
|
| 10 |
from loguru import logger
|
| 11 |
|
| 12 |
+
from westernfront.core.constants import DEFAULT_RSS_FEEDS, MAX_CONCURRENT_REQUESTS
|
| 13 |
from westernfront.core.enums import SourceType
|
| 14 |
from westernfront.core.models import NewsItem, RssFeed
|
| 15 |
from westernfront.services.cache import CacheService
|
| 16 |
+
from westernfront.services.http import HttpService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class RssService:
|
| 20 |
+
"""Service for collecting news from RSS feeds with parallel fetching."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
def __init__(self, cache: CacheService, http: HttpService) -> None:
|
|
|
|
|
|
|
| 23 |
self._cache = cache
|
| 24 |
+
self._http = http
|
| 25 |
self._feeds = list(DEFAULT_RSS_FEEDS)
|
| 26 |
+
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
| 27 |
|
| 28 |
@property
|
| 29 |
def feeds(self) -> list[RssFeed]:
|
|
|
|
| 42 |
try:
|
| 43 |
parsed = parsedate_to_datetime(entry[date_field])
|
| 44 |
if parsed.tzinfo is None:
|
| 45 |
+
return parsed.replace(tzinfo=UTC)
|
| 46 |
return parsed
|
| 47 |
except (ValueError, TypeError):
|
| 48 |
pass
|
| 49 |
+
return datetime.now(UTC)
|
| 50 |
|
| 51 |
def _generate_id(self, entry: dict, feed_name: str) -> str:
|
| 52 |
"""Generate a unique ID for an entry."""
|
|
|
|
| 63 |
logger.debug(f"Cache hit for RSS: {feed.name}")
|
| 64 |
return cached
|
| 65 |
|
| 66 |
+
async with self._semaphore:
|
| 67 |
+
logger.info(f"Fetching RSS feed: {feed.name}")
|
| 68 |
+
items: list[NewsItem] = []
|
| 69 |
|
| 70 |
+
try:
|
| 71 |
+
async with self._http.get(feed.url) as response:
|
|
|
|
| 72 |
if response.status != 200:
|
| 73 |
logger.warning(f"RSS feed {feed.name} returned {response.status}")
|
| 74 |
return []
|
| 75 |
content = await response.text()
|
| 76 |
|
| 77 |
+
parsed = feedparser.parse(content)
|
| 78 |
+
|
| 79 |
+
for entry in parsed.entries:
|
| 80 |
+
title = entry.get("title", "").strip()
|
| 81 |
+
description = entry.get("description", "") or entry.get("summary", "")
|
| 82 |
+
link = entry.get("link", "")
|
| 83 |
+
|
| 84 |
+
if not title or not link:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
item = NewsItem(
|
| 88 |
+
id=self._generate_id(entry, feed.name),
|
| 89 |
+
title=title,
|
| 90 |
+
content=f"{title}\n{description}",
|
| 91 |
+
url=link,
|
| 92 |
+
source_name=feed.name,
|
| 93 |
+
source_type=SourceType.RSS,
|
| 94 |
+
published_at=self._parse_date(entry),
|
| 95 |
+
reliability_score=feed.reliability_score,
|
| 96 |
+
author=entry.get("author"),
|
| 97 |
+
)
|
| 98 |
+
items.append(item)
|
| 99 |
+
|
| 100 |
+
await self._cache.set(cache_key, items)
|
| 101 |
+
logger.info(f"Fetched {len(items)} items from {feed.name}")
|
| 102 |
+
|
| 103 |
+
except aiohttp.ClientError as e:
|
| 104 |
+
logger.error(f"HTTP error fetching {feed.name}: {e}")
|
| 105 |
+
except feedparser.CharacterEncodingOverride as e:
|
| 106 |
+
logger.error(f"Parse error for {feed.name}: {e}")
|
| 107 |
+
|
| 108 |
+
return items
|
| 109 |
+
|
| 110 |
+
async def get_all_articles(self, days_back: int = 2) -> list[NewsItem]:
|
| 111 |
+
"""Get all articles from all active feeds in parallel."""
|
| 112 |
+
active_feeds = [f for f in self._feeds if f.is_active]
|
| 113 |
+
cutoff = datetime.now(UTC) - timedelta(days=days_back)
|
| 114 |
|
| 115 |
+
tasks = [self._fetch_feed(feed) for feed in active_feeds]
|
| 116 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
|
|
|
|
|
|
| 118 |
seen_urls: set[str] = set()
|
| 119 |
+
all_articles: list[NewsItem] = []
|
| 120 |
+
|
| 121 |
+
for i, result in enumerate(results):
|
| 122 |
+
if isinstance(result, Exception):
|
| 123 |
+
logger.error(f"Error fetching {active_feeds[i].name}: {result}")
|
| 124 |
+
continue
|
| 125 |
|
| 126 |
+
for item in result:
|
|
|
|
|
|
|
| 127 |
if item.published_at < cutoff:
|
| 128 |
continue
|
| 129 |
if item.url in seen_urls:
|
| 130 |
continue
|
| 131 |
+
all_articles.append(item)
|
| 132 |
seen_urls.add(item.url)
|
| 133 |
|
| 134 |
+
logger.info(f"Collected {len(all_articles)} total RSS articles for ingestion")
|
| 135 |
+
return all_articles
|
src/westernfront/services/scheduler.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Background task scheduler for periodic analysis updates."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import contextlib
|
| 5 |
+
from typing import TYPE_CHECKING, Protocol
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from westernfront.core.models import ConflictAnalysis
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AnalysisGenerator(Protocol):
|
| 14 |
+
"""Protocol for analysis generation."""
|
| 15 |
+
|
| 16 |
+
async def generate_analysis(self) -> "ConflictAnalysis | None":
|
| 17 |
+
"""Generate a new conflict analysis."""
|
| 18 |
+
...
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AnalysisScheduler:
|
| 22 |
+
"""Manages periodic background analysis updates."""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
generator: AnalysisGenerator,
|
| 27 |
+
interval_minutes: int = 60,
|
| 28 |
+
) -> None:
|
| 29 |
+
self._generator = generator
|
| 30 |
+
self._interval_seconds = interval_minutes * 60
|
| 31 |
+
self._task: asyncio.Task[None] | None = None
|
| 32 |
+
|
| 33 |
+
def start(self) -> None:
|
| 34 |
+
"""Start the periodic update task."""
|
| 35 |
+
if self._task is None:
|
| 36 |
+
self._task = asyncio.create_task(self._periodic_update())
|
| 37 |
+
logger.info(f"Scheduler started with {self._interval_seconds // 60}min interval")
|
| 38 |
+
|
| 39 |
+
async def stop(self) -> None:
|
| 40 |
+
"""Stop the periodic update task."""
|
| 41 |
+
if self._task:
|
| 42 |
+
self._task.cancel()
|
| 43 |
+
with contextlib.suppress(asyncio.CancelledError):
|
| 44 |
+
await self._task
|
| 45 |
+
self._task = None
|
| 46 |
+
logger.info("Scheduler stopped")
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def is_running(self) -> bool:
|
| 50 |
+
"""Check if the scheduler is running."""
|
| 51 |
+
return self._task is not None and not self._task.done()
|
| 52 |
+
|
| 53 |
+
async def _periodic_update(self) -> None:
|
| 54 |
+
"""Background task for periodic analysis updates."""
|
| 55 |
+
await asyncio.sleep(5)
|
| 56 |
+
await self._run_update()
|
| 57 |
+
|
| 58 |
+
while True:
|
| 59 |
+
await asyncio.sleep(self._interval_seconds)
|
| 60 |
+
await self._run_update()
|
| 61 |
+
|
| 62 |
+
async def _run_update(self) -> None:
|
| 63 |
+
"""Execute a single analysis update cycle."""
|
| 64 |
+
logger.info("Starting scheduled analysis update")
|
| 65 |
+
analysis = await self._generator.generate_analysis()
|
| 66 |
+
if analysis:
|
| 67 |
+
logger.info(f"Scheduled analysis complete. Tension: {analysis.tension_level.value}")
|
| 68 |
+
else:
|
| 69 |
+
logger.warning("Scheduled analysis produced no result")
|
src/westernfront/services/validation.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analysis validation service for quality assurance."""
|
| 2 |
+
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from westernfront.core.constants import TENSION_LEVEL_CRITERIA
|
| 6 |
+
from westernfront.core.enums import TensionLevel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AnalysisValidator:
|
| 10 |
+
"""Validates AI-generated analysis for quality and consistency."""
|
| 11 |
+
|
| 12 |
+
def validate_tension_consistency(
|
| 13 |
+
self,
|
| 14 |
+
level: TensionLevel,
|
| 15 |
+
score: int,
|
| 16 |
+
rationale: str,
|
| 17 |
+
) -> tuple[bool, list[str]]:
|
| 18 |
+
"""Validate tension level matches score and rationale."""
|
| 19 |
+
issues = []
|
| 20 |
+
criteria = TENSION_LEVEL_CRITERIA.get(level.value.upper(), {})
|
| 21 |
+
expected_range = criteria.get("score_range", (1, 10))
|
| 22 |
+
|
| 23 |
+
if not (expected_range[0] <= score <= expected_range[1]):
|
| 24 |
+
issues.append(
|
| 25 |
+
f"Score {score} inconsistent with {level.value} level "
|
| 26 |
+
f"(expected {expected_range[0]}-{expected_range[1]})"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
level_keywords = {
|
| 30 |
+
TensionLevel.LOW: ["calm", "normal", "routine", "stable", "peaceful"],
|
| 31 |
+
TensionLevel.MEDIUM: ["elevated", "heightened", "tension", "concern", "monitoring"],
|
| 32 |
+
TensionLevel.HIGH: ["serious", "alert", "mobilization", "firing", "escalation"],
|
| 33 |
+
TensionLevel.CRITICAL: ["urgent", "imminent", "emergency", "active", "combat"],
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
keywords = level_keywords.get(level, [])
|
| 37 |
+
rationale_lower = rationale.lower()
|
| 38 |
+
if keywords and not any(kw in rationale_lower for kw in keywords):
|
| 39 |
+
issues.append(
|
| 40 |
+
f"Rationale may not justify {level.value} level "
|
| 41 |
+
f"(expected keywords like: {', '.join(keywords[:3])})"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
return len(issues) == 0, issues
|
| 45 |
+
|
| 46 |
+
def validate_entities(
|
| 47 |
+
self,
|
| 48 |
+
entities: list[str],
|
| 49 |
+
source_texts: list[str],
|
| 50 |
+
) -> tuple[bool, list[str]]:
|
| 51 |
+
"""Validate that key entities are grounded in source texts."""
|
| 52 |
+
combined_text = " ".join(source_texts).lower()
|
| 53 |
+
ungrounded = []
|
| 54 |
+
|
| 55 |
+
for entity in entities:
|
| 56 |
+
entity_lower = entity.lower()
|
| 57 |
+
entity_words = entity_lower.split()
|
| 58 |
+
found = any(word in combined_text for word in entity_words if len(word) > 3)
|
| 59 |
+
if not found:
|
| 60 |
+
ungrounded.append(entity)
|
| 61 |
+
|
| 62 |
+
is_valid = len(ungrounded) <= 1
|
| 63 |
+
return is_valid, ungrounded
|
| 64 |
+
|
| 65 |
+
def validate_dates(
|
| 66 |
+
self,
|
| 67 |
+
developments: list[dict],
|
| 68 |
+
) -> tuple[bool, list[str]]:
|
| 69 |
+
"""Validate that development timestamps are reasonable."""
|
| 70 |
+
issues = []
|
| 71 |
+
now = datetime.now()
|
| 72 |
+
|
| 73 |
+
for dev in developments:
|
| 74 |
+
title = dev.get("title", "Unknown")
|
| 75 |
+
timestamp = dev.get("timestamp")
|
| 76 |
+
if timestamp and isinstance(timestamp, datetime) and timestamp > now:
|
| 77 |
+
issues.append(f"Future date in development: {title}")
|
| 78 |
+
|
| 79 |
+
return len(issues) == 0, issues
|
| 80 |
+
|
| 81 |
+
def validate_analysis(
|
| 82 |
+
self,
|
| 83 |
+
analysis_data: dict,
|
| 84 |
+
source_texts: list[str],
|
| 85 |
+
) -> tuple[bool, list[str]]:
|
| 86 |
+
"""Perform comprehensive validation on analysis."""
|
| 87 |
+
all_issues = []
|
| 88 |
+
|
| 89 |
+
level_str = analysis_data.get("tension_level", "LOW")
|
| 90 |
+
try:
|
| 91 |
+
level = TensionLevel(level_str)
|
| 92 |
+
except ValueError:
|
| 93 |
+
level = TensionLevel.LOW
|
| 94 |
+
all_issues.append(f"Invalid tension level: {level_str}")
|
| 95 |
+
|
| 96 |
+
score = analysis_data.get("tension_score", 1)
|
| 97 |
+
if not isinstance(score, int):
|
| 98 |
+
try:
|
| 99 |
+
score = int(score)
|
| 100 |
+
except (ValueError, TypeError):
|
| 101 |
+
score = 1
|
| 102 |
+
all_issues.append(f"Invalid tension score: {analysis_data.get('tension_score')}")
|
| 103 |
+
|
| 104 |
+
rationale = analysis_data.get("tension_rationale", "")
|
| 105 |
+
_, tension_issues = self.validate_tension_consistency(level, score, rationale)
|
| 106 |
+
all_issues.extend(tension_issues)
|
| 107 |
+
|
| 108 |
+
entities = analysis_data.get("key_entities", [])
|
| 109 |
+
if isinstance(entities, list):
|
| 110 |
+
_, ungrounded = self.validate_entities(entities, source_texts)
|
| 111 |
+
if ungrounded:
|
| 112 |
+
all_issues.append(f"Possibly ungrounded entities: {', '.join(ungrounded)}")
|
| 113 |
+
|
| 114 |
+
developments = analysis_data.get("key_developments", [])
|
| 115 |
+
if isinstance(developments, list):
|
| 116 |
+
_, date_issues = self.validate_dates(developments)
|
| 117 |
+
all_issues.extend(date_issues)
|
| 118 |
+
|
| 119 |
+
return len(all_issues) == 0, all_issues
|
src/westernfront/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities package exports."""
|
| 2 |
+
|
| 3 |
+
from westernfront.utils.json_parser import extract_json_from_response
|
| 4 |
+
|
| 5 |
+
__all__ = ["extract_json_from_response"]
|
src/westernfront/utils/json_parser.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for parsing JSON from LLM responses."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def extract_json_from_response(text: str) -> dict[str, Any] | None:
|
| 11 |
+
"""
|
| 12 |
+
Extract JSON from an LLM response that may contain markdown code blocks or raw JSON.
|
| 13 |
+
|
| 14 |
+
Tries multiple strategies:
|
| 15 |
+
1. Direct JSON parsing
|
| 16 |
+
2. Extract from markdown code block
|
| 17 |
+
3. Find JSON object in text
|
| 18 |
+
"""
|
| 19 |
+
# Try direct parsing first
|
| 20 |
+
try:
|
| 21 |
+
return json.loads(text)
|
| 22 |
+
except json.JSONDecodeError:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
# Try extracting from markdown code block
|
| 26 |
+
json_match = re.search(r"```(?:json)?\n(.*?)\n```", text, re.DOTALL)
|
| 27 |
+
if json_match:
|
| 28 |
+
try:
|
| 29 |
+
return json.loads(json_match.group(1))
|
| 30 |
+
except json.JSONDecodeError:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
# Try finding JSON object in text
|
| 34 |
+
json_match = re.search(r"\{.*\}", text, re.DOTALL)
|
| 35 |
+
if json_match:
|
| 36 |
+
try:
|
| 37 |
+
return json.loads(json_match.group(0))
|
| 38 |
+
except json.JSONDecodeError:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
logger.warning(f"Failed to parse JSON from response: {text[:200]}...")
|
| 42 |
+
return None
|
tests/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
tests/__pycache__/test_services.cpython-312-pytest-8.4.2.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for API routes."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from httpx import ASGITransport, AsyncClient
|
| 5 |
+
|
| 6 |
+
from westernfront.main import app
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.fixture
|
| 10 |
+
async def client():
|
| 11 |
+
"""Create test client."""
|
| 12 |
+
transport = ASGITransport(app=app)
|
| 13 |
+
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
| 14 |
+
yield client
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestPublicEndpoints:
|
| 18 |
+
"""Tests for public API endpoints."""
|
| 19 |
+
|
| 20 |
+
async def test_root_endpoint(self, client):
|
| 21 |
+
"""Test root endpoint returns API info."""
|
| 22 |
+
response = await client.get("/")
|
| 23 |
+
assert response.status_code == 200
|
| 24 |
+
data = response.json()
|
| 25 |
+
assert data["name"] == "WesternFront API"
|
| 26 |
+
assert "version" in data
|
| 27 |
+
assert data["status"] in ["ready", "initializing"]
|
| 28 |
+
|
| 29 |
+
async def test_health_endpoint(self, client):
|
| 30 |
+
"""Test health endpoint returns status."""
|
| 31 |
+
response = await client.get("/health")
|
| 32 |
+
assert response.status_code == 200
|
| 33 |
+
data = response.json()
|
| 34 |
+
assert data["status"] in ["healthy", "initializing"]
|
| 35 |
+
assert "version" in data
|
| 36 |
+
assert "timestamp" in data
|
| 37 |
+
assert "components" in data
|
| 38 |
+
|
| 39 |
+
async def test_health_head_endpoint(self, client):
|
| 40 |
+
"""Test health HEAD request."""
|
| 41 |
+
response = await client.head("/health")
|
| 42 |
+
assert response.status_code == 200
|
| 43 |
+
|
| 44 |
+
async def test_tension_levels_endpoint(self, client):
|
| 45 |
+
"""Test tension levels endpoint."""
|
| 46 |
+
response = await client.get(
|
| 47 |
+
"/tension-levels",
|
| 48 |
+
headers={"X-API-Key": "test-key"},
|
| 49 |
+
)
|
| 50 |
+
# Will return 401 if no valid key, or 200 with levels
|
| 51 |
+
assert response.status_code in [200, 401]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TestProtectedEndpoints:
|
| 55 |
+
"""Tests for protected API endpoints."""
|
| 56 |
+
|
| 57 |
+
async def test_analysis_requires_auth(self, client):
|
| 58 |
+
"""Test analysis endpoint requires API key."""
|
| 59 |
+
response = await client.get("/analysis")
|
| 60 |
+
assert response.status_code == 401
|
| 61 |
+
assert "API key" in response.json()["detail"]
|
| 62 |
+
|
| 63 |
+
async def test_sources_requires_auth(self, client):
|
| 64 |
+
"""Test sources endpoint requires API key."""
|
| 65 |
+
response = await client.get("/sources")
|
| 66 |
+
assert response.status_code == 401
|
| 67 |
+
|
| 68 |
+
async def test_keywords_requires_auth(self, client):
|
| 69 |
+
"""Test keywords endpoint requires API key."""
|
| 70 |
+
response = await client.get("/keywords")
|
| 71 |
+
assert response.status_code == 401
|
tests/test_parsing.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for ResponseParser service."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
|
| 7 |
+
from westernfront.services.parsing import ResponseParser
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestResponseParser:
|
| 11 |
+
"""Tests for ResponseParser methods."""
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def parser(self):
|
| 15 |
+
"""Create parser instance."""
|
| 16 |
+
return ResponseParser()
|
| 17 |
+
|
| 18 |
+
def test_parse_tension_level_low(self, parser):
|
| 19 |
+
"""Test parsing LOW tension level."""
|
| 20 |
+
assert parser.parse_tension_level("low") == TensionLevel.LOW
|
| 21 |
+
assert parser.parse_tension_level("LOW") == TensionLevel.LOW
|
| 22 |
+
assert parser.parse_tension_level("unknown") == TensionLevel.LOW
|
| 23 |
+
|
| 24 |
+
def test_parse_tension_level_medium(self, parser):
|
| 25 |
+
"""Test parsing MEDIUM tension level."""
|
| 26 |
+
assert parser.parse_tension_level("medium") == TensionLevel.MEDIUM
|
| 27 |
+
assert parser.parse_tension_level("MEDIUM") == TensionLevel.MEDIUM
|
| 28 |
+
|
| 29 |
+
def test_parse_tension_level_high(self, parser):
|
| 30 |
+
"""Test parsing HIGH tension level."""
|
| 31 |
+
assert parser.parse_tension_level("high") == TensionLevel.HIGH
|
| 32 |
+
assert parser.parse_tension_level("HIGH") == TensionLevel.HIGH
|
| 33 |
+
|
| 34 |
+
def test_parse_tension_level_critical(self, parser):
|
| 35 |
+
"""Test parsing CRITICAL tension level."""
|
| 36 |
+
assert parser.parse_tension_level("critical") == TensionLevel.CRITICAL
|
| 37 |
+
assert parser.parse_tension_level("CRITICAL") == TensionLevel.CRITICAL
|
| 38 |
+
|
| 39 |
+
def test_parse_tension_trend(self, parser):
|
| 40 |
+
"""Test parsing tension trends."""
|
| 41 |
+
assert parser.parse_tension_trend("increasing") == TensionTrend.INCREASING
|
| 42 |
+
assert parser.parse_tension_trend("decreasing") == TensionTrend.DECREASING
|
| 43 |
+
assert parser.parse_tension_trend("stable") == TensionTrend.STABLE
|
| 44 |
+
assert parser.parse_tension_trend("unknown") == TensionTrend.STABLE
|
| 45 |
+
|
| 46 |
+
def test_parse_analysis_type(self, parser):
|
| 47 |
+
"""Test parsing analysis types."""
|
| 48 |
+
assert parser.parse_analysis_type("military") == AnalysisType.MILITARY
|
| 49 |
+
assert parser.parse_analysis_type("diplomatic") == AnalysisType.DIPLOMATIC
|
| 50 |
+
assert parser.parse_analysis_type("internal security") == AnalysisType.INTERNAL_SECURITY
|
| 51 |
+
assert parser.parse_analysis_type("political") == AnalysisType.POLITICAL
|
| 52 |
+
assert parser.parse_analysis_type("unknown") == AnalysisType.OTHER
|
| 53 |
+
|
| 54 |
+
def test_parse_tension_score_valid(self, parser):
|
| 55 |
+
"""Test parsing valid tension scores."""
|
| 56 |
+
assert parser.parse_tension_score(5) == 5
|
| 57 |
+
assert parser.parse_tension_score(5.7) == 5
|
| 58 |
+
assert parser.parse_tension_score("7") == 7
|
| 59 |
+
|
| 60 |
+
def test_parse_tension_score_clamped(self, parser):
|
| 61 |
+
"""Test tension scores are clamped to valid range."""
|
| 62 |
+
assert parser.parse_tension_score(0) == 1
|
| 63 |
+
assert parser.parse_tension_score(-5) == 1
|
| 64 |
+
assert parser.parse_tension_score(15) == 10
|
| 65 |
+
assert parser.parse_tension_score(None) == 1
|
| 66 |
+
|
| 67 |
+
def test_parse_key_entities_list(self, parser):
|
| 68 |
+
"""Test parsing entities from list."""
|
| 69 |
+
result = parser.parse_key_entities(["India", "Pakistan", "Kashmir"])
|
| 70 |
+
assert result == ["India", "Pakistan", "Kashmir"]
|
| 71 |
+
|
| 72 |
+
def test_parse_key_entities_string(self, parser):
|
| 73 |
+
"""Test parsing entities from comma-separated string."""
|
| 74 |
+
result = parser.parse_key_entities("India, Pakistan, Kashmir")
|
| 75 |
+
assert result == ["India", "Pakistan", "Kashmir"]
|
| 76 |
+
|
| 77 |
+
def test_parse_key_entities_none(self, parser):
|
| 78 |
+
"""Test parsing None returns empty list."""
|
| 79 |
+
assert parser.parse_key_entities(None) == []
|
| 80 |
+
|
| 81 |
+
def test_parse_key_developments_valid(self, parser):
|
| 82 |
+
"""Test parsing key developments."""
|
| 83 |
+
data = [
|
| 84 |
+
{
|
| 85 |
+
"title": "Test Event",
|
| 86 |
+
"description": "Test description",
|
| 87 |
+
"sources": ["Military Activity"],
|
| 88 |
+
}
|
| 89 |
+
]
|
| 90 |
+
result = parser.parse_key_developments(data)
|
| 91 |
+
assert len(result) == 1
|
| 92 |
+
assert result[0].title == "Test Event"
|
| 93 |
+
assert result[0].description == "Test description"
|
| 94 |
+
assert result[0].sources == ["Military Activity"]
|
| 95 |
+
|
| 96 |
+
def test_parse_key_developments_skips_non_dict(self, parser):
|
| 97 |
+
"""Test that non-dict items are skipped."""
|
| 98 |
+
data = [{"title": "Valid"}, "invalid", None]
|
| 99 |
+
result = parser.parse_key_developments(data)
|
| 100 |
+
assert len(result) == 1
|
| 101 |
+
|
| 102 |
+
def test_count_sources(self, parser):
|
| 103 |
+
"""Test source counting."""
|
| 104 |
+
items = [
|
| 105 |
+
{"metadata": {"source_type": "reddit"}},
|
| 106 |
+
{"metadata": {"source_type": "reddit"}},
|
| 107 |
+
{"metadata": {"source_type": "rss"}},
|
| 108 |
+
{"metadata": {}},
|
| 109 |
+
]
|
| 110 |
+
result = parser.count_sources(items)
|
| 111 |
+
assert result == {"reddit": 2, "rss": 1, "unknown": 1}
|