abhisheksan commited on
Commit
14aaf07
·
verified ·
1 Parent(s): dcf5c87

Upload 52 files

Browse files
Files changed (39) hide show
  1. .dockerignore +31 -0
  2. .env.example +13 -0
  3. Dockerfile +41 -26
  4. src/westernfront/analytics/aggregator.py +4 -45
  5. src/westernfront/api/auth.py +4 -14
  6. src/westernfront/api/middleware/__init__.py +5 -0
  7. src/westernfront/api/middleware/rate_limit.py +89 -0
  8. src/westernfront/api/routes.py +38 -71
  9. src/westernfront/api/schemas.py +2 -3
  10. src/westernfront/config.py +1 -2
  11. src/westernfront/core/__init__.py +24 -0
  12. src/westernfront/core/constants.py +123 -0
  13. src/westernfront/core/exceptions.py +29 -0
  14. src/westernfront/core/models.py +3 -4
  15. src/westernfront/dependencies.py +54 -44
  16. src/westernfront/main.py +18 -3
  17. src/westernfront/prompts/analysis.py +119 -24
  18. src/westernfront/repositories/analysis.py +15 -43
  19. src/westernfront/repositories/vectors.py +33 -68
  20. src/westernfront/services/__init__.py +12 -0
  21. src/westernfront/services/analysis.py +239 -362
  22. src/westernfront/services/cache.py +65 -48
  23. src/westernfront/services/chain_analysis.py +108 -0
  24. src/westernfront/services/embeddings.py +37 -52
  25. src/westernfront/services/http.py +57 -0
  26. src/westernfront/services/newsapi.py +56 -65
  27. src/westernfront/services/parsing.py +88 -0
  28. src/westernfront/services/reddit.py +73 -114
  29. src/westernfront/services/retrieval.py +101 -0
  30. src/westernfront/services/rss.py +65 -143
  31. src/westernfront/services/scheduler.py +69 -0
  32. src/westernfront/services/validation.py +119 -0
  33. src/westernfront/utils/__init__.py +5 -0
  34. src/westernfront/utils/json_parser.py +42 -0
  35. tests/__pycache__/__init__.cpython-312.pyc +0 -0
  36. tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc +0 -0
  37. tests/__pycache__/test_services.cpython-312-pytest-8.4.2.pyc +0 -0
  38. tests/test_api.py +71 -0
  39. tests/test_parsing.py +111 -0
.dockerignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Exclude local dev environment
2
+ .venv
3
+ __pycache__
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+
8
+ # Exclude git history
9
+ .git
10
+ .gitignore
11
+
12
+ # Exclude test cache and coverage reports
13
+ .pytest_cache
14
+ .coverage
15
+ htmlcov
16
+ tests/
17
+
18
+ # Exclude logs and data (unless specific data is needed)
19
+ logs/
20
+ data/
21
+
22
+ # Exclude local env files (secrets should be passed as env vars)
23
+ .env
24
+ .env.example
25
+
26
+ # Exclude IDE settings
27
+ .vscode
28
+ .idea
29
+
30
+ # Exclude poetry cache (if local)
31
+ .cache
.env.example ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Google Gemini API Key
2
+ GEMINI_API_KEY=
3
+
4
+ # Application Settings
5
+ UPDATE_INTERVAL_MINUTES=
6
+ CACHE_EXPIRY_MINUTES=
7
+ LOG_LEVEL=
8
+ AUTO_UPDATE_ENABLED=
9
+ REDDIT_CLIENT_ID=
10
+ REDDIT_CLIENT_SECRET=
11
+ REDDIT_USER_AGENT=
12
+ NEWSAPI_KEY=
13
+ WESTERNFRONT_API_KEY=
Dockerfile CHANGED
@@ -1,27 +1,42 @@
1
- FROM python:3.11-slim
2
-
3
- WORKDIR /app
4
-
5
- # Install Poetry
6
- RUN pip install --no-cache-dir poetry==1.8.0
7
-
8
- # Copy dependency files
9
- COPY pyproject.toml poetry.lock ./
10
-
11
- # Install dependencies
12
- RUN poetry config virtualenvs.create false \
13
- && poetry install --only main --no-interaction --no-ansi
14
-
15
- # Copy application code
16
- COPY src/ ./src/
17
-
18
- # Create data directory
19
- RUN mkdir -p /app/data /app/logs
20
-
21
- ENV PYTHONDONTWRITEBYTECODE=1
22
- ENV PYTHONUNBUFFERED=1
23
- ENV PYTHONPATH=/app/src
24
-
25
- EXPOSE 7860
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  CMD ["uvicorn", "westernfront.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ FROM python:3.11-slim AS builder
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ curl \
7
+ && rm -rf /var/lib/apt/lists/* \
8
+ && pip install --no-cache-dir poetry==1.8.0
9
+
10
+ ENV POETRY_NO_INTERACTION=1 \
11
+ POETRY_VIRTUALENVS_IN_PROJECT=1 \
12
+ POETRY_VIRTUALENVS_CREATE=1 \
13
+ POETRY_CACHE_DIR=/tmp/poetry_cache
14
+
15
+ COPY pyproject.toml poetry.lock ./
16
+
17
+ RUN poetry install --only main --no-root && rm -rf $POETRY_CACHE_DIR
18
+
19
+ FROM python:3.11-slim AS runtime
20
+
21
+ WORKDIR /app
22
+
23
+ RUN groupadd -g 1000 appuser && \
24
+ useradd -u 1000 -g appuser -s /bin/bash -m appuser
25
+
26
+ RUN mkdir -p /app/data /app/logs && \
27
+ chown -R appuser:appuser /app
28
+
29
+ COPY --from=builder /app/.venv /app/.venv
30
+ COPY --chown=appuser:appuser src/ ./src/
31
+
32
+ ENV PATH="/app/.venv/bin:$PATH" \
33
+ PYTHONPATH="/app/src" \
34
+ PYTHONDONTWRITEBYTECODE=1 \
35
+ PYTHONUNBUFFERED=1 \
36
+ PORT=7860
37
+
38
+ USER appuser
39
+
40
+ EXPOSE 7860
41
+
42
  CMD ["uvicorn", "westernfront.main:app", "--host", "0.0.0.0", "--port", "7860"]
src/westernfront/analytics/aggregator.py CHANGED
@@ -1,8 +1,6 @@
1
  """Analytics aggregation for graph data."""
2
 
3
  from collections import Counter
4
- from datetime import datetime, timedelta, timezone
5
- from typing import Optional
6
 
7
  from westernfront.core.enums import TensionTrend
8
  from westernfront.repositories.analysis import AnalysisRepository
@@ -12,24 +10,10 @@ class AnalyticsAggregator:
12
  """Aggregates analysis data for visualization."""
13
 
14
  def __init__(self, repository: AnalysisRepository) -> None:
15
- """
16
- Initialize the aggregator.
17
-
18
- Args:
19
- repository: Analysis repository for data access.
20
- """
21
  self._repository = repository
22
 
23
  async def get_tension_history(self, days: int = 30) -> dict:
24
- """
25
- Get tension score history for graphing.
26
-
27
- Args:
28
- days: Number of days to include.
29
-
30
- Returns:
31
- Dictionary with data points and summary statistics.
32
- """
33
  history = await self._repository.get_tension_history(days)
34
 
35
  if not history:
@@ -66,15 +50,7 @@ class AnalyticsAggregator:
66
  }
67
 
68
  async def get_source_breakdown(self, days: int = 7) -> dict:
69
- """
70
- Get breakdown of sources used in recent analyses.
71
-
72
- Args:
73
- days: Number of days to include.
74
-
75
- Returns:
76
- Dictionary with source counts.
77
- """
78
  snapshots = await self._repository.get_history(days=days)
79
 
80
  return {
@@ -90,16 +66,7 @@ class AnalyticsAggregator:
90
  }
91
 
92
  async def get_entity_frequency(self, days: int = 30, top_n: int = 10) -> dict:
93
- """
94
- Get most frequently mentioned entities.
95
-
96
- Args:
97
- days: Number of days to include.
98
- top_n: Number of top entities to return.
99
-
100
- Returns:
101
- Dictionary with entity frequency data.
102
- """
103
  snapshots = await self._repository.get_history(days=days)
104
 
105
  all_entities: list[str] = []
@@ -118,15 +85,7 @@ class AnalyticsAggregator:
118
  }
119
 
120
  async def get_analysis_type_distribution(self, days: int = 30) -> dict:
121
- """
122
- Get distribution of analysis types.
123
-
124
- Args:
125
- days: Number of days to include.
126
-
127
- Returns:
128
- Dictionary with type distribution.
129
- """
130
  snapshots = await self._repository.get_history(days=days)
131
 
132
  counter = Counter(s.analysis_type.value for s in snapshots)
 
1
  """Analytics aggregation for graph data."""
2
 
3
  from collections import Counter
 
 
4
 
5
  from westernfront.core.enums import TensionTrend
6
  from westernfront.repositories.analysis import AnalysisRepository
 
10
  """Aggregates analysis data for visualization."""
11
 
12
  def __init__(self, repository: AnalysisRepository) -> None:
 
 
 
 
 
 
13
  self._repository = repository
14
 
15
  async def get_tension_history(self, days: int = 30) -> dict:
16
+ """Get tension score history for graphing."""
 
 
 
 
 
 
 
 
17
  history = await self._repository.get_tension_history(days)
18
 
19
  if not history:
 
50
  }
51
 
52
  async def get_source_breakdown(self, days: int = 7) -> dict:
53
+ """Get breakdown of sources used in recent analyses."""
 
 
 
 
 
 
 
 
54
  snapshots = await self._repository.get_history(days=days)
55
 
56
  return {
 
66
  }
67
 
68
  async def get_entity_frequency(self, days: int = 30, top_n: int = 10) -> dict:
69
+ """Get most frequently mentioned entities."""
 
 
 
 
 
 
 
 
 
70
  snapshots = await self._repository.get_history(days=days)
71
 
72
  all_entities: list[str] = []
 
85
  }
86
 
87
  async def get_analysis_type_distribution(self, days: int = 30) -> dict:
88
+ """Get distribution of analysis types."""
 
 
 
 
 
 
 
 
89
  snapshots = await self._repository.get_history(days=days)
90
 
91
  counter = Counter(s.analysis_type.value for s in snapshots)
src/westernfront/api/auth.py CHANGED
@@ -1,13 +1,11 @@
1
  """API key authentication middleware."""
2
 
 
 
3
  from fastapi import HTTPException, Request, status
4
- from fastapi.security import APIKeyHeader
5
 
6
  from westernfront.config import get_settings
7
 
8
-
9
- API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
10
-
11
  PUBLIC_PATHS = frozenset([
12
  "/",
13
  "/health",
@@ -18,15 +16,7 @@ PUBLIC_PATHS = frozenset([
18
 
19
 
20
  async def verify_api_key(request: Request) -> None:
21
- """
22
- Verify the API key from request headers.
23
-
24
- Args:
25
- request: The incoming request.
26
-
27
- Raises:
28
- HTTPException: If API key is missing or invalid.
29
- """
30
  if request.url.path in PUBLIC_PATHS:
31
  return
32
 
@@ -39,7 +29,7 @@ async def verify_api_key(request: Request) -> None:
39
  detail="Missing API key. Include X-API-Key header.",
40
  )
41
 
42
- if api_key != settings.api_key:
43
  raise HTTPException(
44
  status_code=status.HTTP_401_UNAUTHORIZED,
45
  detail="Invalid API key.",
 
1
  """API key authentication middleware."""
2
 
3
+ import secrets
4
+
5
  from fastapi import HTTPException, Request, status
 
6
 
7
  from westernfront.config import get_settings
8
 
 
 
 
9
  PUBLIC_PATHS = frozenset([
10
  "/",
11
  "/health",
 
16
 
17
 
18
  async def verify_api_key(request: Request) -> None:
19
+ """Verify the API key from request headers using timing-safe comparison."""
 
 
 
 
 
 
 
 
20
  if request.url.path in PUBLIC_PATHS:
21
  return
22
 
 
29
  detail="Missing API key. Include X-API-Key header.",
30
  )
31
 
32
+ if not secrets.compare_digest(api_key.encode(), settings.api_key.encode()):
33
  raise HTTPException(
34
  status_code=status.HTTP_401_UNAUTHORIZED,
35
  detail="Invalid API key.",
src/westernfront/api/middleware/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Middleware package exports."""
2
+
3
+ from westernfront.api.middleware.rate_limit import RateLimitMiddleware
4
+
5
+ __all__ = ["RateLimitMiddleware"]
src/westernfront/api/middleware/rate_limit.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rate limiting middleware using token bucket algorithm with automatic cleanup."""
2
+
3
+ import time
4
+
5
+ from cachetools import TTLCache
6
+ from fastapi import Request, Response
7
+ from fastapi.responses import JSONResponse
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+
10
+
11
+ class TokenBucket:
12
+ """Token bucket for rate limiting."""
13
+
14
+ __slots__ = ("capacity", "refill_rate", "tokens", "last_refill")
15
+
16
+ def __init__(self, capacity: int, refill_rate: float) -> None:
17
+ self.capacity = capacity
18
+ self.refill_rate = refill_rate
19
+ self.tokens = float(capacity)
20
+ self.last_refill = time.monotonic()
21
+
22
+ def consume(self) -> bool:
23
+ """Attempt to consume a token. Returns True if successful."""
24
+ now = time.monotonic()
25
+ elapsed = now - self.last_refill
26
+ self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
27
+ self.last_refill = now
28
+
29
+ if self.tokens >= 1:
30
+ self.tokens -= 1
31
+ return True
32
+ return False
33
+
34
+ def time_until_available(self) -> float:
35
+ """Calculate seconds until next token is available."""
36
+ if self.tokens >= 1:
37
+ return 0.0
38
+ return (1 - self.tokens) / self.refill_rate
39
+
40
+
41
+ class RateLimitMiddleware(BaseHTTPMiddleware):
42
+ """Rate limiting middleware using per-IP token buckets with automatic cleanup."""
43
+
44
+ def __init__(
45
+ self,
46
+ app,
47
+ requests_per_minute: int = 60,
48
+ burst_size: int = 10,
49
+ bucket_ttl_seconds: int = 300,
50
+ max_buckets: int = 10000,
51
+ ) -> None:
52
+ super().__init__(app)
53
+ self._requests_per_minute = requests_per_minute
54
+ self._burst_size = burst_size
55
+ self._refill_rate = requests_per_minute / 60.0
56
+ self._buckets: TTLCache[str, TokenBucket] = TTLCache(
57
+ maxsize=max_buckets,
58
+ ttl=bucket_ttl_seconds,
59
+ )
60
+
61
+ def _get_bucket(self, client_ip: str) -> TokenBucket:
62
+ """Get or create a token bucket for the client IP."""
63
+ bucket = self._buckets.get(client_ip)
64
+ if bucket is None:
65
+ bucket = TokenBucket(self._burst_size, self._refill_rate)
66
+ self._buckets[client_ip] = bucket
67
+ return bucket
68
+
69
+ async def dispatch(self, request: Request, call_next) -> Response:
70
+ """Process request with rate limiting."""
71
+ client_ip = self._get_client_ip(request)
72
+ bucket = self._get_bucket(client_ip)
73
+
74
+ if not bucket.consume():
75
+ retry_after = int(bucket.time_until_available()) + 1
76
+ return JSONResponse(
77
+ status_code=429,
78
+ content={"detail": "Rate limit exceeded. Please slow down."},
79
+ headers={"Retry-After": str(retry_after)},
80
+ )
81
+
82
+ return await call_next(request)
83
+
84
+ def _get_client_ip(self, request: Request) -> str:
85
+ """Extract client IP from request, handling proxies."""
86
+ forwarded = request.headers.get("X-Forwarded-For")
87
+ if forwarded:
88
+ return forwarded.split(",")[0].strip()
89
+ return request.client.host if request.client else "unknown"
src/westernfront/api/routes.py CHANGED
@@ -1,15 +1,15 @@
1
  """API route definitions."""
2
 
3
  from datetime import datetime
4
- from typing import Optional
5
 
6
- from fastapi import APIRouter, Depends, HTTPException, Query, status
 
7
 
8
  from westernfront import __version__
9
- from westernfront.analytics import AnalyticsAggregator
10
  from westernfront.api.schemas import (
11
  AnalysisHistoryResponse,
12
  AnalysisSnapshotResponse,
 
13
  ConflictAnalysisResponse,
14
  EntityFrequencyResponse,
15
  HealthResponse,
@@ -20,25 +20,17 @@ from westernfront.api.schemas import (
20
  SourcesResponse,
21
  SubredditSourceResponse,
22
  TensionHistoryResponse,
23
- AnalysisTypeDistributionResponse,
24
  )
25
  from westernfront.core.enums import TensionLevel
26
- from westernfront.dependencies import (
27
- get_analysis_service,
28
- get_app_state,
29
- get_repository,
30
- )
31
- from westernfront.repositories import AnalysisRepository
32
- from westernfront.services import AnalysisService
33
-
34
 
35
  router = APIRouter()
36
 
37
 
38
  @router.get("/", response_model=RootResponse, tags=["General"])
39
- async def root() -> RootResponse:
40
  """Root endpoint with API information."""
41
- state = get_app_state()
42
  return RootResponse(
43
  name="WesternFront API",
44
  description="AI-powered conflict tracker for India-Pakistan tensions",
@@ -49,9 +41,9 @@ async def root() -> RootResponse:
49
 
50
  @router.get("/health", response_model=HealthResponse, tags=["General"])
51
  @router.head("/health", response_model=HealthResponse, tags=["General"])
52
- async def health_check() -> HealthResponse:
53
  """Health check endpoint."""
54
- state = get_app_state()
55
  latest = await state.repository.get_latest()
56
 
57
  return HealthResponse(
@@ -68,16 +60,11 @@ async def health_check() -> HealthResponse:
68
  )
69
 
70
 
71
- @router.get(
72
- "/analysis",
73
- response_model=ConflictAnalysisResponse,
74
- tags=["Analysis"],
75
- )
76
- async def get_latest_analysis(
77
- service: AnalysisService = Depends(get_analysis_service),
78
- ) -> ConflictAnalysisResponse:
79
  """Get the latest conflict analysis."""
80
- analysis = await service.get_latest()
 
81
 
82
  if not analysis:
83
  raise HTTPException(
@@ -88,18 +75,15 @@ async def get_latest_analysis(
88
  return ConflictAnalysisResponse.model_validate(analysis.model_dump())
89
 
90
 
91
- @router.get(
92
- "/analysis/history",
93
- response_model=AnalysisHistoryResponse,
94
- tags=["Analysis"],
95
- )
96
  async def get_analysis_history(
 
97
  days: int = Query(default=30, ge=1, le=90),
98
  limit: int = Query(default=50, ge=1, le=100),
99
- repository: AnalysisRepository = Depends(get_repository),
100
  ) -> AnalysisHistoryResponse:
101
  """Get historical analysis snapshots."""
102
- snapshots = await repository.get_history(days=days, limit=limit)
 
103
 
104
  return AnalysisHistoryResponse(
105
  count=len(snapshots),
@@ -110,71 +94,55 @@ async def get_analysis_history(
110
  )
111
 
112
 
113
- @router.get(
114
- "/analytics/tension-history",
115
- response_model=TensionHistoryResponse,
116
- tags=["Analytics"],
117
- )
118
  async def get_tension_history(
 
119
  days: int = Query(default=30, ge=1, le=90),
120
- repository: AnalysisRepository = Depends(get_repository),
121
  ) -> TensionHistoryResponse:
122
  """Get tension score history for graphing."""
123
- aggregator = AnalyticsAggregator(repository)
124
- result = await aggregator.get_tension_history(days)
125
  return TensionHistoryResponse.model_validate(result)
126
 
127
 
128
- @router.get(
129
- "/analytics/source-breakdown",
130
- response_model=SourceBreakdownResponse,
131
- tags=["Analytics"],
132
- )
133
  async def get_source_breakdown(
 
134
  days: int = Query(default=7, ge=1, le=30),
135
- repository: AnalysisRepository = Depends(get_repository),
136
  ) -> SourceBreakdownResponse:
137
  """Get breakdown of sources used in analyses."""
138
- aggregator = AnalyticsAggregator(repository)
139
- result = await aggregator.get_source_breakdown(days)
140
  return SourceBreakdownResponse.model_validate(result)
141
 
142
 
143
- @router.get(
144
- "/analytics/entity-frequency",
145
- response_model=EntityFrequencyResponse,
146
- tags=["Analytics"],
147
- )
148
  async def get_entity_frequency(
 
149
  days: int = Query(default=30, ge=1, le=90),
150
  top_n: int = Query(default=10, ge=1, le=50),
151
- repository: AnalysisRepository = Depends(get_repository),
152
  ) -> EntityFrequencyResponse:
153
  """Get most frequently mentioned entities."""
154
- aggregator = AnalyticsAggregator(repository)
155
- result = await aggregator.get_entity_frequency(days, top_n)
156
  return EntityFrequencyResponse.model_validate(result)
157
 
158
 
159
- @router.get(
160
- "/analytics/type-distribution",
161
- response_model=AnalysisTypeDistributionResponse,
162
- tags=["Analytics"],
163
- )
164
  async def get_type_distribution(
 
165
  days: int = Query(default=30, ge=1, le=90),
166
- repository: AnalysisRepository = Depends(get_repository),
167
  ) -> AnalysisTypeDistributionResponse:
168
  """Get distribution of analysis types."""
169
- aggregator = AnalyticsAggregator(repository)
170
- result = await aggregator.get_analysis_type_distribution(days)
171
  return AnalysisTypeDistributionResponse.model_validate(result)
172
 
173
 
174
  @router.get("/sources", response_model=SourcesResponse, tags=["Configuration"])
175
- async def get_sources() -> SourcesResponse:
176
  """Get current data sources configuration."""
177
- state = get_app_state()
178
 
179
  return SourcesResponse(
180
  subreddits=[
@@ -190,13 +158,12 @@ async def get_sources() -> SourcesResponse:
190
 
191
 
192
  @router.get("/keywords", response_model=KeywordsResponse, tags=["Configuration"])
193
- async def get_keywords(
194
- service: AnalysisService = Depends(get_analysis_service),
195
- ) -> KeywordsResponse:
196
  """Get current search keywords."""
 
197
  return KeywordsResponse(
198
- count=len(service.keywords),
199
- keywords=service.keywords,
200
  )
201
 
202
 
 
1
  """API route definitions."""
2
 
3
  from datetime import datetime
 
4
 
5
+ from fastapi import APIRouter, Query, Request, status
6
+ from fastapi.exceptions import HTTPException
7
 
8
  from westernfront import __version__
 
9
  from westernfront.api.schemas import (
10
  AnalysisHistoryResponse,
11
  AnalysisSnapshotResponse,
12
+ AnalysisTypeDistributionResponse,
13
  ConflictAnalysisResponse,
14
  EntityFrequencyResponse,
15
  HealthResponse,
 
20
  SourcesResponse,
21
  SubredditSourceResponse,
22
  TensionHistoryResponse,
 
23
  )
24
  from westernfront.core.enums import TensionLevel
25
+ from westernfront.dependencies import get_state_from_request
 
 
 
 
 
 
 
26
 
27
  router = APIRouter()
28
 
29
 
30
  @router.get("/", response_model=RootResponse, tags=["General"])
31
+ async def root(request: Request) -> RootResponse:
32
  """Root endpoint with API information."""
33
+ state = get_state_from_request(request)
34
  return RootResponse(
35
  name="WesternFront API",
36
  description="AI-powered conflict tracker for India-Pakistan tensions",
 
41
 
42
  @router.get("/health", response_model=HealthResponse, tags=["General"])
43
  @router.head("/health", response_model=HealthResponse, tags=["General"])
44
+ async def health_check(request: Request) -> HealthResponse:
45
  """Health check endpoint."""
46
+ state = get_state_from_request(request)
47
  latest = await state.repository.get_latest()
48
 
49
  return HealthResponse(
 
60
  )
61
 
62
 
63
+ @router.get("/analysis", response_model=ConflictAnalysisResponse, tags=["Analysis"])
64
+ async def get_latest_analysis(request: Request) -> ConflictAnalysisResponse:
 
 
 
 
 
 
65
  """Get the latest conflict analysis."""
66
+ state = get_state_from_request(request)
67
+ analysis = await state.analysis.get_latest()
68
 
69
  if not analysis:
70
  raise HTTPException(
 
75
  return ConflictAnalysisResponse.model_validate(analysis.model_dump())
76
 
77
 
78
+ @router.get("/analysis/history", response_model=AnalysisHistoryResponse, tags=["Analysis"])
 
 
 
 
79
  async def get_analysis_history(
80
+ request: Request,
81
  days: int = Query(default=30, ge=1, le=90),
82
  limit: int = Query(default=50, ge=1, le=100),
 
83
  ) -> AnalysisHistoryResponse:
84
  """Get historical analysis snapshots."""
85
+ state = get_state_from_request(request)
86
+ snapshots = await state.repository.get_history(days=days, limit=limit)
87
 
88
  return AnalysisHistoryResponse(
89
  count=len(snapshots),
 
94
  )
95
 
96
 
97
+ @router.get("/analytics/tension-history", response_model=TensionHistoryResponse, tags=["Analytics"])
 
 
 
 
98
  async def get_tension_history(
99
+ request: Request,
100
  days: int = Query(default=30, ge=1, le=90),
 
101
  ) -> TensionHistoryResponse:
102
  """Get tension score history for graphing."""
103
+ state = get_state_from_request(request)
104
+ result = await state.analytics.get_tension_history(days)
105
  return TensionHistoryResponse.model_validate(result)
106
 
107
 
108
+ @router.get("/analytics/source-breakdown", response_model=SourceBreakdownResponse, tags=["Analytics"])
 
 
 
 
109
  async def get_source_breakdown(
110
+ request: Request,
111
  days: int = Query(default=7, ge=1, le=30),
 
112
  ) -> SourceBreakdownResponse:
113
  """Get breakdown of sources used in analyses."""
114
+ state = get_state_from_request(request)
115
+ result = await state.analytics.get_source_breakdown(days)
116
  return SourceBreakdownResponse.model_validate(result)
117
 
118
 
119
+ @router.get("/analytics/entity-frequency", response_model=EntityFrequencyResponse, tags=["Analytics"])
 
 
 
 
120
  async def get_entity_frequency(
121
+ request: Request,
122
  days: int = Query(default=30, ge=1, le=90),
123
  top_n: int = Query(default=10, ge=1, le=50),
 
124
  ) -> EntityFrequencyResponse:
125
  """Get most frequently mentioned entities."""
126
+ state = get_state_from_request(request)
127
+ result = await state.analytics.get_entity_frequency(days, top_n)
128
  return EntityFrequencyResponse.model_validate(result)
129
 
130
 
131
+ @router.get("/analytics/type-distribution", response_model=AnalysisTypeDistributionResponse, tags=["Analytics"])
 
 
 
 
132
  async def get_type_distribution(
133
+ request: Request,
134
  days: int = Query(default=30, ge=1, le=90),
 
135
  ) -> AnalysisTypeDistributionResponse:
136
  """Get distribution of analysis types."""
137
+ state = get_state_from_request(request)
138
+ result = await state.analytics.get_analysis_type_distribution(days)
139
  return AnalysisTypeDistributionResponse.model_validate(result)
140
 
141
 
142
  @router.get("/sources", response_model=SourcesResponse, tags=["Configuration"])
143
+ async def get_sources(request: Request) -> SourcesResponse:
144
  """Get current data sources configuration."""
145
+ state = get_state_from_request(request)
146
 
147
  return SourcesResponse(
148
  subreddits=[
 
158
 
159
 
160
  @router.get("/keywords", response_model=KeywordsResponse, tags=["Configuration"])
161
+ async def get_keywords(request: Request) -> KeywordsResponse:
 
 
162
  """Get current search keywords."""
163
+ state = get_state_from_request(request)
164
  return KeywordsResponse(
165
+ count=len(state.analysis.keywords),
166
+ keywords=state.analysis.keywords,
167
  )
168
 
169
 
src/westernfront/api/schemas.py CHANGED
@@ -1,7 +1,6 @@
1
  """API request and response schemas."""
2
 
3
  from datetime import datetime
4
- from typing import Optional
5
 
6
  from pydantic import BaseModel, Field
7
 
@@ -14,7 +13,7 @@ class HealthResponse(BaseModel):
14
  status: str
15
  version: str
16
  timestamp: datetime
17
- last_update: Optional[datetime] = None
18
  components: dict[str, bool]
19
 
20
 
@@ -93,7 +92,7 @@ class KeyDevelopmentResponse(BaseModel):
93
  title: str
94
  description: str
95
  sources: list[str]
96
- timestamp: Optional[datetime] = None
97
 
98
 
99
  class ReliabilityAssessmentResponse(BaseModel):
 
1
  """API request and response schemas."""
2
 
3
  from datetime import datetime
 
4
 
5
  from pydantic import BaseModel, Field
6
 
 
13
  status: str
14
  version: str
15
  timestamp: datetime
16
+ last_update: datetime | None = None
17
  components: dict[str, bool]
18
 
19
 
 
92
  title: str
93
  description: str
94
  sources: list[str]
95
+ timestamp: datetime | None = None
96
 
97
 
98
  class ReliabilityAssessmentResponse(BaseModel):
src/westernfront/config.py CHANGED
@@ -6,7 +6,6 @@ for type-safe environment variable parsing and validation.
6
  """
7
 
8
  from functools import lru_cache
9
- from typing import Optional
10
 
11
  from pydantic import Field
12
  from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -34,7 +33,7 @@ class Settings(BaseSettings):
34
  gemini_api_key: str = Field(alias="GEMINI_API_KEY")
35
 
36
  # NewsAPI (optional)
37
- newsapi_key: Optional[str] = Field(default=None, alias="NEWSAPI_KEY")
38
 
39
  # Application Settings
40
  update_interval_minutes: int = Field(default=60, alias="UPDATE_INTERVAL_MINUTES")
 
6
  """
7
 
8
  from functools import lru_cache
 
9
 
10
  from pydantic import Field
11
  from pydantic_settings import BaseSettings, SettingsConfigDict
 
33
  gemini_api_key: str = Field(alias="GEMINI_API_KEY")
34
 
35
  # NewsAPI (optional)
36
+ newsapi_key: str | None = Field(default=None, alias="NEWSAPI_KEY")
37
 
38
  # Application Settings
39
  update_interval_minutes: int = Field(default=60, alias="UPDATE_INTERVAL_MINUTES")
src/westernfront/core/__init__.py CHANGED
@@ -1,20 +1,44 @@
1
  """Core package exports."""
2
 
3
  from westernfront.core.enums import AnalysisType, SourceType, TensionLevel, TensionTrend
 
 
 
 
 
 
 
 
 
4
  from westernfront.core.models import (
 
5
  ConflictAnalysis,
6
  KeyDevelopment,
7
  NewsItem,
 
 
 
8
  SubredditSource,
9
  )
10
 
11
  __all__ = [
 
 
12
  "AnalysisType",
 
13
  "ConflictAnalysis",
 
14
  "KeyDevelopment",
15
  "NewsItem",
 
 
 
 
 
16
  "SourceType",
17
  "SubredditSource",
18
  "TensionLevel",
19
  "TensionTrend",
 
 
20
  ]
 
1
  """Core package exports."""
2
 
3
  from westernfront.core.enums import AnalysisType, SourceType, TensionLevel, TensionTrend
4
+ from westernfront.core.exceptions import (
5
+ AnalysisError,
6
+ AuthenticationError,
7
+ DataFetchError,
8
+ RateLimitExceededError,
9
+ ServiceNotInitializedError,
10
+ VectorStoreError,
11
+ WesternFrontError,
12
+ )
13
  from westernfront.core.models import (
14
+ AnalysisSnapshot,
15
  ConflictAnalysis,
16
  KeyDevelopment,
17
  NewsItem,
18
+ RegionalImplications,
19
+ ReliabilityAssessment,
20
+ RssFeed,
21
  SubredditSource,
22
  )
23
 
24
  __all__ = [
25
+ "AnalysisError",
26
+ "AnalysisSnapshot",
27
  "AnalysisType",
28
+ "AuthenticationError",
29
  "ConflictAnalysis",
30
+ "DataFetchError",
31
  "KeyDevelopment",
32
  "NewsItem",
33
+ "RateLimitExceededError",
34
+ "RegionalImplications",
35
+ "ReliabilityAssessment",
36
+ "RssFeed",
37
+ "ServiceNotInitializedError",
38
  "SourceType",
39
  "SubredditSource",
40
  "TensionLevel",
41
  "TensionTrend",
42
+ "VectorStoreError",
43
+ "WesternFrontError",
44
  ]
src/westernfront/core/constants.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants and default configurations for WesternFront."""
2
+
3
+ from westernfront.core.models import RssFeed, SubredditSource
4
+
5
+ RELIABLE_DOMAINS = frozenset([
6
+ "bbc.com",
7
+ "reuters.com",
8
+ "apnews.com",
9
+ "aljazeera.com",
10
+ "nytimes.com",
11
+ "wsj.com",
12
+ "ft.com",
13
+ "economist.com",
14
+ "thediplomat.com",
15
+ "foreignpolicy.com",
16
+ "foreignaffairs.com",
17
+ "dawn.com",
18
+ "timesofindia.indiatimes.com",
19
+ "ndtv.com",
20
+ "geo.tv",
21
+ ])
22
+
23
+ DEFAULT_SUBREDDITS = [
24
+ SubredditSource(name="geopolitics", reliability_score=0.85),
25
+ SubredditSource(name="CredibleDefense", reliability_score=0.9),
26
+ SubredditSource(name="worldnews", reliability_score=0.8),
27
+ SubredditSource(name="neutralnews", reliability_score=0.8),
28
+ SubredditSource(name="DefenseNews", reliability_score=0.85),
29
+ SubredditSource(name="GeopoliticsIndia", reliability_score=0.75),
30
+ SubredditSource(name="SouthAsia", reliability_score=0.7),
31
+ SubredditSource(name="india", reliability_score=0.7),
32
+ SubredditSource(name="pakistan", reliability_score=0.7),
33
+ SubredditSource(name="Nepal", reliability_score=0.65),
34
+ SubredditSource(name="bangladesh", reliability_score=0.65),
35
+ SubredditSource(name="srilanka", reliability_score=0.65),
36
+ SubredditSource(name="China", reliability_score=0.6),
37
+ ]
38
+
39
+ DEFAULT_RSS_FEEDS = [
40
+ RssFeed(name="Dawn (Pakistan)", url="https://www.dawn.com/feeds/home", reliability_score=0.85),
41
+ RssFeed(name="Geo News", url="https://www.geo.tv/rss/1/1", reliability_score=0.8),
42
+ RssFeed(name="Express Tribune", url="https://tribune.com.pk/feed/home", reliability_score=0.75),
43
+ RssFeed(name="Times of India", url="https://timesofindia.indiatimes.com/rssfeeds/296589292.cms", reliability_score=0.75),
44
+ RssFeed(name="NDTV India", url="https://feeds.feedburner.com/ndtvnews-india-news", reliability_score=0.8),
45
+ RssFeed(name="The Hindu", url="https://www.thehindu.com/news/national/feeder/default.rss", reliability_score=0.85),
46
+ RssFeed(name="Indian Express", url="https://indianexpress.com/section/india/feed/", reliability_score=0.85),
47
+ RssFeed(name="South China Morning Post - Asia", url="https://www.scmp.com/rss/91/feed", reliability_score=0.85),
48
+ RssFeed(name="Kathmandu Post", url="https://kathmandupost.com/rss", reliability_score=0.75),
49
+ RssFeed(name="Dhaka Tribune", url="https://www.dhakatribune.com/rss", reliability_score=0.75),
50
+ RssFeed(name="Daily Star Bangladesh", url="https://www.thedailystar.net/rss.xml", reliability_score=0.75),
51
+ RssFeed(name="Daily Mirror Sri Lanka", url="http://www.dailymirror.lk/RSS_Feeds/breaking-news", reliability_score=0.7),
52
+ ]
53
+
54
+ NEWSAPI_QUERIES = [
55
+ "India Pakistan",
56
+ "Kashmir conflict",
57
+ "India Pakistan border",
58
+ "LOC firing",
59
+ "Indo-Pak",
60
+ ]
61
+
62
+ RAG_QUERY_TOPICS = [
63
+ "India Pakistan military conflict border tensions ceasefire violation",
64
+ "Kashmir territorial dispute LOC Line of Control",
65
+ "India China LAC Ladakh Arunachal standoff",
66
+ "Nepal Bangladesh Sri Lanka India bilateral relations",
67
+ "South Asia terrorism cross-border insurgency",
68
+ "India diplomatic relations regional geopolitics",
69
+ "Military exercises defense buildup South Asia",
70
+ ]
71
+
72
+ NEWSAPI_BASE_URL = "https://newsapi.org/v2"
73
+
74
+ HTTP_TIMEOUT_SECONDS = 30
75
+ MAX_CONCURRENT_REQUESTS = 10
76
+
77
+ # Source diversity rules for retrieval
78
+ SOURCE_DIVERSITY_RULES = {
79
+ "reddit": {"min_pct": 0.25, "max_pct": 0.50},
80
+ "rss": {"min_pct": 0.30, "max_pct": 0.55},
81
+ "newsapi": {"min_pct": 0.10, "max_pct": 0.30},
82
+ }
83
+
84
+ # Temporal weighting for recency boost
85
+ RECENCY_BOOST = {
86
+ "hours_24": 1.5,
87
+ "hours_48": 1.25,
88
+ "days_7": 1.0,
89
+ "older": 0.75,
90
+ }
91
+
92
+ # Tension level criteria for validation
93
+ TENSION_LEVEL_CRITERIA = {
94
+ "LOW": {
95
+ "score_range": (1, 3),
96
+ "description": "Normal diplomatic activity, routine border incidents, no escalation",
97
+ },
98
+ "MEDIUM": {
99
+ "score_range": (4, 5),
100
+ "description": "Heightened rhetoric, minor military movements, diplomatic notes exchanged",
101
+ },
102
+ "HIGH": {
103
+ "score_range": (6, 8),
104
+ "description": "Military mobilization, cross-border firing, diplomatic summoning",
105
+ },
106
+ "CRITICAL": {
107
+ "score_range": (9, 10),
108
+ "description": "Active military engagement, imminent conflict, emergency measures",
109
+ },
110
+ }
111
+
112
+ SEARCH_KEYWORDS = [
113
+ "India Pakistan",
114
+ "Kashmir",
115
+ "LOC",
116
+ "ceasefire",
117
+ "border tension",
118
+ "military",
119
+ "diplomatic",
120
+ "terrorist",
121
+ "strike",
122
+ "conflict",
123
+ ]
src/westernfront/core/exceptions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom exceptions for WesternFront."""
2
+
3
+
4
+ class WesternFrontError(Exception):
5
+ """Base exception for all WesternFront errors."""
6
+
7
+
8
+ class AuthenticationError(WesternFrontError):
9
+ """Raised when API authentication fails."""
10
+
11
+
12
+ class RateLimitExceededError(WesternFrontError):
13
+ """Raised when rate limit is exceeded."""
14
+
15
+
16
+ class ServiceNotInitializedError(WesternFrontError):
17
+ """Raised when a service is accessed before initialization."""
18
+
19
+
20
+ class DataFetchError(WesternFrontError):
21
+ """Raised when fetching data from external sources fails."""
22
+
23
+
24
+ class AnalysisError(WesternFrontError):
25
+ """Raised when AI analysis fails."""
26
+
27
+
28
+ class VectorStoreError(WesternFrontError):
29
+ """Raised when vector store operations fail."""
src/westernfront/core/models.py CHANGED
@@ -1,7 +1,6 @@
1
  """Core domain models for WesternFront."""
2
 
3
  from datetime import datetime
4
- from typing import Optional
5
 
6
  from pydantic import BaseModel, Field
7
 
@@ -36,8 +35,8 @@ class NewsItem(BaseModel):
36
  source_type: SourceType
37
  published_at: datetime
38
  reliability_score: float = Field(default=0.5, ge=0.0, le=1.0)
39
- author: Optional[str] = None
40
- score: Optional[int] = None
41
 
42
 
43
  class KeyDevelopment(BaseModel):
@@ -46,7 +45,7 @@ class KeyDevelopment(BaseModel):
46
  title: str
47
  description: str
48
  sources: list[str]
49
- timestamp: Optional[datetime] = None
50
 
51
 
52
  class ReliabilityAssessment(BaseModel):
 
1
  """Core domain models for WesternFront."""
2
 
3
  from datetime import datetime
 
4
 
5
  from pydantic import BaseModel, Field
6
 
 
35
  source_type: SourceType
36
  published_at: datetime
37
  reliability_score: float = Field(default=0.5, ge=0.0, le=1.0)
38
+ author: str | None = None
39
+ score: int | None = None
40
 
41
 
42
  class KeyDevelopment(BaseModel):
 
45
  title: str
46
  description: str
47
  sources: list[str]
48
+ timestamp: datetime | None = None
49
 
50
 
51
  class ReliabilityAssessment(BaseModel):
src/westernfront/dependencies.py CHANGED
@@ -1,21 +1,24 @@
1
- """
2
- Dependency injection container for WesternFront.
3
-
4
- Provides FastAPI dependencies for services with proper lifecycle management.
5
- """
6
 
 
 
7
  from contextlib import asynccontextmanager
8
  from dataclasses import dataclass
9
- from typing import AsyncGenerator
10
 
11
  from loguru import logger
12
 
 
 
 
 
13
  from westernfront.config import Settings, get_settings
14
  from westernfront.repositories.analysis import AnalysisRepository
15
  from westernfront.repositories.vectors import VectorRepository
16
  from westernfront.services.analysis import AnalysisService
17
  from westernfront.services.cache import CacheService
18
  from westernfront.services.embeddings import EmbeddingService
 
19
  from westernfront.services.newsapi import NewsApiService
20
  from westernfront.services.reddit import RedditService
21
  from westernfront.services.rss import RssService
@@ -26,6 +29,7 @@ class AppState:
26
  """Container for application-scoped services."""
27
 
28
  settings: Settings
 
29
  cache: CacheService
30
  reddit: RedditService
31
  rss: RssService
@@ -34,16 +38,28 @@ class AppState:
34
  embeddings: EmbeddingService
35
  vectors: VectorRepository
36
  analysis: AnalysisService
 
 
37
 
 
 
 
 
 
 
 
 
38
 
39
- _app_state: AppState | None = None
40
 
41
 
42
  async def init_services() -> AppState:
43
- """Initialize all services for the application."""
44
  settings = get_settings()
45
 
 
46
  cache = CacheService(ttl_seconds=settings.cache_expiry_minutes * 60)
 
47
 
48
  reddit = RedditService(
49
  client_id=settings.reddit_client_id,
@@ -51,23 +67,23 @@ async def init_services() -> AppState:
51
  user_agent=settings.reddit_user_agent,
52
  cache=cache,
53
  )
54
- await reddit.initialize()
55
-
56
- rss = RssService(cache=cache)
57
 
58
- newsapi = NewsApiService(api_key=settings.newsapi_key, cache=cache)
 
59
 
60
- repository = AnalysisRepository(db_path=settings.database_path)
61
- await repository.initialize()
62
-
63
- embeddings = EmbeddingService()
64
- embeddings.initialize()
65
- logger.info("Embedding service initialized")
 
66
 
67
- vectors = VectorRepository(embedding_service=embeddings)
68
- vectors.initialize()
69
  logger.info(f"Vector repository initialized with {vectors.get_count()} items")
70
 
 
 
71
  analysis = AnalysisService(
72
  gemini_api_key=settings.gemini_api_key,
73
  reddit=reddit,
@@ -81,6 +97,7 @@ async def init_services() -> AppState:
81
 
82
  return AppState(
83
  settings=settings,
 
84
  cache=cache,
85
  reddit=reddit,
86
  rss=rss,
@@ -89,6 +106,7 @@ async def init_services() -> AppState:
89
  embeddings=embeddings,
90
  vectors=vectors,
91
  analysis=analysis,
 
92
  )
93
 
94
 
@@ -96,43 +114,35 @@ async def shutdown_services(state: AppState) -> None:
96
  """Clean up all services."""
97
  await state.analysis.close()
98
  await state.reddit.close()
 
99
  await state.repository.close()
100
 
101
 
102
  @asynccontextmanager
103
  async def lifespan_context() -> AsyncGenerator[AppState, None]:
104
  """Lifespan context manager for FastAPI."""
105
- global _app_state
106
- _app_state = await init_services()
107
  try:
108
- yield _app_state
109
  finally:
110
- await shutdown_services(_app_state)
111
- _app_state = None
112
-
113
-
114
- def get_app_state() -> AppState:
115
- """Get the current application state."""
116
- if _app_state is None:
117
- raise RuntimeError("Application not initialized")
118
- return _app_state
119
 
120
 
121
- def get_analysis_service() -> AnalysisService:
122
- """FastAPI dependency for AnalysisService."""
123
- return get_app_state().analysis
124
 
125
 
126
- def get_repository() -> AnalysisRepository:
127
- """FastAPI dependency for AnalysisRepository."""
128
- return get_app_state().repository
129
 
130
 
131
- def get_vectors() -> VectorRepository:
132
- """FastAPI dependency for VectorRepository."""
133
- return get_app_state().vectors
134
 
135
 
136
- def get_settings_dep() -> Settings:
137
- """FastAPI dependency for Settings."""
138
- return get_app_state().settings
 
1
+ """Dependency injection container for WesternFront."""
 
 
 
 
2
 
3
+ import asyncio
4
+ from collections.abc import AsyncGenerator
5
  from contextlib import asynccontextmanager
6
  from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING
8
 
9
  from loguru import logger
10
 
11
+ if TYPE_CHECKING:
12
+ from fastapi import Request
13
+
14
+ from westernfront.analytics import AnalyticsAggregator
15
  from westernfront.config import Settings, get_settings
16
  from westernfront.repositories.analysis import AnalysisRepository
17
  from westernfront.repositories.vectors import VectorRepository
18
  from westernfront.services.analysis import AnalysisService
19
  from westernfront.services.cache import CacheService
20
  from westernfront.services.embeddings import EmbeddingService
21
+ from westernfront.services.http import HttpService
22
  from westernfront.services.newsapi import NewsApiService
23
  from westernfront.services.reddit import RedditService
24
  from westernfront.services.rss import RssService
 
29
  """Container for application-scoped services."""
30
 
31
  settings: Settings
32
+ http: HttpService
33
  cache: CacheService
34
  reddit: RedditService
35
  rss: RssService
 
38
  embeddings: EmbeddingService
39
  vectors: VectorRepository
40
  analysis: AnalysisService
41
+ analytics: AnalyticsAggregator
42
+
43
 
44
+ async def _init_sync_services(settings: Settings) -> tuple[EmbeddingService, VectorRepository]:
45
+ """Initialize synchronous services in a thread pool."""
46
+ def _init_embeddings_and_vectors() -> tuple[EmbeddingService, VectorRepository]:
47
+ embeddings = EmbeddingService()
48
+ embeddings.initialize()
49
+ vectors = VectorRepository(embedding_service=embeddings)
50
+ vectors.initialize()
51
+ return embeddings, vectors
52
 
53
+ return await asyncio.to_thread(_init_embeddings_and_vectors)
54
 
55
 
56
  async def init_services() -> AppState:
57
+ """Initialize all services for the application with parallel execution."""
58
  settings = get_settings()
59
 
60
+ http = HttpService()
61
  cache = CacheService(ttl_seconds=settings.cache_expiry_minutes * 60)
62
+ repository = AnalysisRepository(db_path=settings.database_path)
63
 
64
  reddit = RedditService(
65
  client_id=settings.reddit_client_id,
 
67
  user_agent=settings.reddit_user_agent,
68
  cache=cache,
69
  )
 
 
 
70
 
71
+ rss = RssService(cache=cache, http=http)
72
+ newsapi = NewsApiService(api_key=settings.newsapi_key, cache=cache, http=http)
73
 
74
+ # Parallel initialization of independent services
75
+ http_init, reddit_init, repo_init, vectors_result = await asyncio.gather(
76
+ http.initialize(),
77
+ reddit.initialize(),
78
+ repository.initialize(),
79
+ _init_sync_services(settings),
80
+ )
81
 
82
+ embeddings, vectors = vectors_result
 
83
  logger.info(f"Vector repository initialized with {vectors.get_count()} items")
84
 
85
+ analytics = AnalyticsAggregator(repository)
86
+
87
  analysis = AnalysisService(
88
  gemini_api_key=settings.gemini_api_key,
89
  reddit=reddit,
 
97
 
98
  return AppState(
99
  settings=settings,
100
+ http=http,
101
  cache=cache,
102
  reddit=reddit,
103
  rss=rss,
 
106
  embeddings=embeddings,
107
  vectors=vectors,
108
  analysis=analysis,
109
+ analytics=analytics,
110
  )
111
 
112
 
 
114
  """Clean up all services."""
115
  await state.analysis.close()
116
  await state.reddit.close()
117
+ await state.http.close()
118
  await state.repository.close()
119
 
120
 
121
  @asynccontextmanager
122
  async def lifespan_context() -> AsyncGenerator[AppState, None]:
123
  """Lifespan context manager for FastAPI."""
124
+ state = await init_services()
 
125
  try:
126
+ yield state
127
  finally:
128
+ await shutdown_services(state)
 
 
 
 
 
 
 
 
129
 
130
 
131
+ def get_state_from_request(request: "Request") -> AppState:
132
+ """Get application state from request."""
133
+ return request.app.state.westernfront
134
 
135
 
136
+ def get_analysis_service(request: "Request") -> AnalysisService:
137
+ """Get AnalysisService from request."""
138
+ return get_state_from_request(request).analysis
139
 
140
 
141
+ def get_repository(request: "Request") -> AnalysisRepository:
142
+ """Get AnalysisRepository from request."""
143
+ return get_state_from_request(request).repository
144
 
145
 
146
+ def get_analytics(request: "Request") -> AnalyticsAggregator:
147
+ """Get AnalyticsAggregator from request."""
148
+ return get_state_from_request(request).analytics
src/westernfront/main.py CHANGED
@@ -1,15 +1,16 @@
1
  """FastAPI application factory and entry point."""
2
 
3
  import os
 
4
  from contextlib import asynccontextmanager
5
- from typing import AsyncGenerator
6
 
7
- from fastapi import FastAPI, Request
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from loguru import logger
10
 
11
  from westernfront import __version__
12
  from westernfront.api.auth import verify_api_key
 
13
  from westernfront.api.routes import router
14
  from westernfront.config import get_settings
15
  from westernfront.dependencies import lifespan_context
@@ -43,6 +44,8 @@ def create_app() -> FastAPI:
43
  lifespan=lifespan,
44
  )
45
 
 
 
46
  app.add_middleware(
47
  CORSMiddleware,
48
  allow_origins=settings.allowed_origins,
@@ -52,7 +55,19 @@ def create_app() -> FastAPI:
52
  )
53
 
54
  @app.middleware("http")
55
- async def api_key_middleware(request: Request, call_next):
 
 
 
 
 
 
 
 
 
 
 
 
56
  await verify_api_key(request)
57
  return await call_next(request)
58
 
 
1
  """FastAPI application factory and entry point."""
2
 
3
  import os
4
+ from collections.abc import AsyncGenerator
5
  from contextlib import asynccontextmanager
 
6
 
7
+ from fastapi import FastAPI, Request, Response
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from loguru import logger
10
 
11
  from westernfront import __version__
12
  from westernfront.api.auth import verify_api_key
13
+ from westernfront.api.middleware import RateLimitMiddleware
14
  from westernfront.api.routes import router
15
  from westernfront.config import get_settings
16
  from westernfront.dependencies import lifespan_context
 
44
  lifespan=lifespan,
45
  )
46
 
47
+ app.add_middleware(RateLimitMiddleware, requests_per_minute=120, burst_size=20)
48
+
49
  app.add_middleware(
50
  CORSMiddleware,
51
  allow_origins=settings.allowed_origins,
 
55
  )
56
 
57
  @app.middleware("http")
58
+ async def security_headers_middleware(request: Request, call_next) -> Response:
59
+ """Add security headers to all responses."""
60
+ response = await call_next(request)
61
+ response.headers["X-Content-Type-Options"] = "nosniff"
62
+ response.headers["X-Frame-Options"] = "DENY"
63
+ response.headers["X-XSS-Protection"] = "1; mode=block"
64
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
65
+ response.headers["Content-Security-Policy"] = "default-src 'self'; frame-ancestors 'none'"
66
+ return response
67
+
68
+ @app.middleware("http")
69
+ async def api_key_middleware(request: Request, call_next) -> Response:
70
+ """Verify API key for protected endpoints."""
71
  await verify_api_key(request)
72
  return await call_next(request)
73
 
src/westernfront/prompts/analysis.py CHANGED
@@ -2,33 +2,53 @@
2
 
3
  from datetime import datetime
4
 
 
5
 
6
- def build_rag_prompt(retrieved_items: list[dict], total_in_memory: int = 0) -> str:
7
- """
8
- Build prompt for RAG-enhanced analysis using vector-retrieved items.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- Args:
11
- retrieved_items: Items retrieved from vector search with metadata.
12
- total_in_memory: Total items in the vector database.
13
 
14
- Returns:
15
- The formatted prompt string.
16
- """
17
  source_entries = []
18
  for i, item in enumerate(retrieved_items):
19
  meta = item.get("metadata", {})
20
  doc = item.get("document", "")
21
- score = item.get("similarity_score", 0)
22
  reliability_val = meta.get("reliability_score", 0.5)
23
-
24
- reliability = "HIGH" if reliability_val > 0.8 else (
25
- "MEDIUM" if reliability_val > 0.6 else "LOW"
26
- )
27
-
 
 
 
28
  entry = (
29
- f"INTEL #{i + 1} [Relevance: {score:.0%}] [Reliability: {reliability}]:\n"
30
  f"Source: {meta.get('source_name', 'Unknown')} ({meta.get('source_type', 'unknown')})\n"
31
- f"Date: {meta.get('published_at', 'Unknown')[:10] if meta.get('published_at') else 'Unknown'}\n"
32
  f"Content: {doc}"
33
  )
34
  source_entries.append(entry)
@@ -37,7 +57,13 @@ def build_rag_prompt(retrieved_items: list[dict], total_in_memory: int = 0) -> s
37
 
38
  memory_note = ""
39
  if total_in_memory > 0:
40
- memory_note = f"\n\n**INSTITUTIONAL MEMORY:** This analysis draws from a database of {total_in_memory:,} indexed news items. The items shown below are the most semantically relevant to South Asia conflict dynamics.\n"
 
 
 
 
 
 
41
 
42
  return f"""**TOP SECRET // FOR OFFICIAL USE ONLY**
43
 
@@ -59,6 +85,11 @@ Analyze ALL matters relevant to India's regional relationships, prioritized as:
59
  **NEUTRALITY DIRECTIVE:**
60
  Maintain absolute neutrality. Present multiple perspectives when conflicting information exists. Acknowledge information gaps.
61
 
 
 
 
 
 
62
  **INTELLIGENCE FEEDS:**
63
  ---
64
  {intelligence_data}
@@ -67,8 +98,9 @@ Maintain absolute neutrality. Present multiple perspectives when conflicting inf
67
  **ANALYTICAL DIRECTIVES:**
68
  1. **Synthesize, Do Not Summarize:** Integrate all data into a coherent assessment.
69
  2. **Impersonal Tone:** Use formal, analytical language.
70
- 3. **No Direct Attribution:** Your report is standalone. Do not attribute to specific sources.
71
  4. **Acknowledge Uncertainty:** Indicate confidence levels and information gaps.
 
72
 
73
  **REQUIRED OUTPUT FORMAT (Strict JSON):**
74
  Produce a single, valid JSON object:
@@ -83,8 +115,8 @@ Produce a single, valid JSON object:
83
  }}
84
  ],
85
  "reliability_assessment": {{
86
- "source_credibility": "Overall credibility assessment.",
87
- "information_gaps": "What critical information is missing.",
88
  "confidence_rating": "HIGH, MEDIUM, or LOW with justification."
89
  }},
90
  "regional_implications": {{
@@ -93,9 +125,72 @@ Produce a single, valid JSON object:
93
  "economic": "Potential economic consequences."
94
  }},
95
  "tension_level": "LOW|MEDIUM|HIGH|CRITICAL",
96
- "tension_rationale": "Justification for the assessed tension level.",
97
- "tension_score": "Integer 1-10 (1=calm, 10=active conflict imminent).",
98
  "tension_trend": "INCREASING|DECREASING|STABLE",
99
  "analysis_type": "Military|Diplomatic|Internal Security|Political|Other",
100
- "key_entities": ["3-5 key actors, locations, or organizations."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  }}"""
 
2
 
3
  from datetime import datetime
4
 
5
+ from westernfront.core.constants import TENSION_LEVEL_CRITERIA
6
 
7
+ RELIABILITY_THRESHOLDS = {
8
+ "HIGH": 0.8,
9
+ "MEDIUM": 0.6,
10
+ }
11
+
12
+
13
+ def _get_reliability_label(score: float) -> str:
14
+ """Convert reliability score to label."""
15
+ if score > RELIABILITY_THRESHOLDS["HIGH"]:
16
+ return "HIGH"
17
+ if score > RELIABILITY_THRESHOLDS["MEDIUM"]:
18
+ return "MEDIUM"
19
+ return "LOW"
20
+
21
+
22
+ def _build_tension_criteria() -> str:
23
+ """Build tension level criteria section for prompt."""
24
+ lines = []
25
+ for level, criteria in TENSION_LEVEL_CRITERIA.items():
26
+ score_range = criteria["score_range"]
27
+ desc = criteria["description"]
28
+ lines.append(f"- **{level}** (Score {score_range[0]}-{score_range[1]}): {desc}")
29
+ return "\n".join(lines)
30
 
 
 
 
31
 
32
+ def build_rag_prompt(retrieved_items: list[dict], total_in_memory: int = 0) -> str:
33
+ """Build prompt for RAG-enhanced analysis using vector-retrieved items."""
 
34
  source_entries = []
35
  for i, item in enumerate(retrieved_items):
36
  meta = item.get("metadata", {})
37
  doc = item.get("document", "")
38
+ score = item.get("boosted_score", item.get("similarity_score", 0))
39
  reliability_val = meta.get("reliability_score", 0.5)
40
+ reliability = _get_reliability_label(reliability_val)
41
+
42
+ published = meta.get("published_at", "")
43
+ date_str = published[:10] if published else "Unknown"
44
+
45
+ recency = item.get("recency_multiplier", 1.0)
46
+ recency_label = "FRESH" if recency >= 1.5 else ("RECENT" if recency >= 1.0 else "OLDER")
47
+
48
  entry = (
49
+ f"INTEL #{i + 1} [Relevance: {score:.0%}] [Reliability: {reliability}] [{recency_label}]:\n"
50
  f"Source: {meta.get('source_name', 'Unknown')} ({meta.get('source_type', 'unknown')})\n"
51
+ f"Date: {date_str}\n"
52
  f"Content: {doc}"
53
  )
54
  source_entries.append(entry)
 
57
 
58
  memory_note = ""
59
  if total_in_memory > 0:
60
+ memory_note = (
61
+ f"\n\n**INSTITUTIONAL MEMORY:** This analysis draws from a database of "
62
+ f"{total_in_memory:,} indexed news items. The items shown below are the most "
63
+ f"semantically relevant to South Asia conflict dynamics, weighted by recency.\n"
64
+ )
65
+
66
+ tension_criteria = _build_tension_criteria()
67
 
68
  return f"""**TOP SECRET // FOR OFFICIAL USE ONLY**
69
 
 
85
  **NEUTRALITY DIRECTIVE:**
86
  Maintain absolute neutrality. Present multiple perspectives when conflicting information exists. Acknowledge information gaps.
87
 
88
+ **TENSION LEVEL ASSESSMENT CRITERIA:**
89
+ {tension_criteria}
90
+
91
+ Use these criteria strictly. Your tension_score MUST align with your tension_level.
92
+
93
  **INTELLIGENCE FEEDS:**
94
  ---
95
  {intelligence_data}
 
98
  **ANALYTICAL DIRECTIVES:**
99
  1. **Synthesize, Do Not Summarize:** Integrate all data into a coherent assessment.
100
  2. **Impersonal Tone:** Use formal, analytical language.
101
+ 3. **Ground All Claims:** Every key entity and event must be supported by the intelligence feeds above.
102
  4. **Acknowledge Uncertainty:** Indicate confidence levels and information gaps.
103
+ 5. **Prioritize Recency:** Weight FRESH sources more heavily than OLDER ones.
104
 
105
  **REQUIRED OUTPUT FORMAT (Strict JSON):**
106
  Produce a single, valid JSON object:
 
115
  }}
116
  ],
117
  "reliability_assessment": {{
118
+ "source_credibility": "Overall credibility assessment of available sources.",
119
+ "information_gaps": "What critical information is missing from the intelligence feeds.",
120
  "confidence_rating": "HIGH, MEDIUM, or LOW with justification."
121
  }},
122
  "regional_implications": {{
 
125
  "economic": "Potential economic consequences."
126
  }},
127
  "tension_level": "LOW|MEDIUM|HIGH|CRITICAL",
128
+ "tension_rationale": "Justification for the assessed tension level using specific evidence from sources.",
129
+ "tension_score": "Integer 1-10 matching to tension level per criteria above.",
130
  "tension_trend": "INCREASING|DECREASING|STABLE",
131
  "analysis_type": "Military|Diplomatic|Internal Security|Political|Other",
132
+ "key_entities": ["3-5 key actors, locations, or organizations ONLY from the sources above."]
133
+ }}"""
134
+
135
+
136
+ def build_extraction_prompt(items: list[dict]) -> str:
137
+ """Build prompt for fact extraction stage of chain analysis."""
138
+ source_entries = []
139
+ for i, item in enumerate(items[:20]):
140
+ doc = item.get("document", "")
141
+ source_entries.append(f"SOURCE {i+1}: {doc[:500]}")
142
+
143
+ sources_text = "\n\n".join(source_entries)
144
+
145
+ return f"""Extract key facts from the following intelligence sources.
146
+
147
+ SOURCES:
148
+ {sources_text}
149
+
150
+ OUTPUT FORMAT (JSON):
151
+ {{
152
+ "facts": [
153
+ {{
154
+ "fact": "Clear, objective statement of fact",
155
+ "source_index": 1,
156
+ "type": "military|diplomatic|political|economic|other",
157
+ "date_mentioned": "YYYY-MM-DD or null",
158
+ "entities": ["entity1", "entity2"]
159
+ }}
160
+ ],
161
+ "total_sources_analyzed": {len(items)}
162
+ }}"""
163
+
164
+
165
+ def build_synthesis_prompt(facts: list[dict], historical_context: str = "") -> str:
166
+ """Build prompt for synthesis stage of chain analysis."""
167
+ facts_text = "\n".join([f"- {f.get('fact', '')}" for f in facts[:30]])
168
+
169
+ return f"""Synthesize the following extracted facts into a coherent assessment.
170
+
171
+ EXTRACTED FACTS:
172
+ {facts_text}
173
+
174
+ {f"HISTORICAL CONTEXT: {historical_context}" if historical_context else ""}
175
+
176
+ TASK:
177
+ 1. Identify the 3-5 most significant developments
178
+ 2. Assess overall tension level (LOW/MEDIUM/HIGH/CRITICAL)
179
+ 3. Identify trends and patterns
180
+ 4. Note any contradictions or information gaps
181
+
182
+ OUTPUT FORMAT (JSON):
183
+ {{
184
+ "significant_developments": [
185
+ {{
186
+ "title": "Development title",
187
+ "description": "Synthesized description",
188
+ "supporting_facts": [0, 1, 2]
189
+ }}
190
+ ],
191
+ "preliminary_tension": "LOW|MEDIUM|HIGH|CRITICAL",
192
+ "tension_reasoning": "Brief reasoning",
193
+ "trends": ["trend1", "trend2"],
194
+ "contradictions": ["any contradictory information"],
195
+ "gaps": ["information gaps identified"]
196
  }}"""
src/westernfront/repositories/analysis.py CHANGED
@@ -1,14 +1,14 @@
1
  """SQLite repository for storing analysis history."""
2
 
3
  import json
4
- from datetime import datetime, timedelta, timezone
5
  from pathlib import Path
6
- from typing import Optional
7
 
8
  import aiosqlite
9
  from loguru import logger
10
 
11
  from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
 
12
  from westernfront.core.models import AnalysisSnapshot, ConflictAnalysis
13
 
14
 
@@ -16,14 +16,8 @@ class AnalysisRepository:
16
  """SQLite-based repository for analysis storage and retrieval."""
17
 
18
  def __init__(self, db_path: str = "data/westernfront.db") -> None:
19
- """
20
- Initialize the repository.
21
-
22
- Args:
23
- db_path: Path to the SQLite database file.
24
- """
25
  self._db_path = Path(db_path)
26
- self._conn: Optional[aiosqlite.Connection] = None
27
 
28
  async def initialize(self) -> None:
29
  """Initialize the database and create tables."""
@@ -43,7 +37,7 @@ class AnalysisRepository:
43
  async def _create_tables(self) -> None:
44
  """Create required database tables."""
45
  if not self._conn:
46
- raise RuntimeError("Repository not initialized")
47
 
48
  await self._conn.execute("""
49
  CREATE TABLE IF NOT EXISTS analyses (
@@ -70,7 +64,7 @@ class AnalysisRepository:
70
  if not self._conn:
71
  return
72
 
73
- cutoff = (datetime.now(timezone.utc) - timedelta(days=retention_days)).isoformat()
74
  cursor = await self._conn.execute(
75
  "DELETE FROM analyses WHERE generated_at < ?",
76
  (cutoff,),
@@ -81,14 +75,9 @@ class AnalysisRepository:
81
  logger.info(f"Cleaned up {cursor.rowcount} old analysis records")
82
 
83
  async def save(self, analysis: ConflictAnalysis) -> None:
84
- """
85
- Save an analysis to the database.
86
-
87
- Args:
88
- analysis: The conflict analysis to save.
89
- """
90
  if not self._conn:
91
- raise RuntimeError("Repository not initialized")
92
 
93
  await self._conn.execute(
94
  """
@@ -112,10 +101,10 @@ class AnalysisRepository:
112
  await self._conn.commit()
113
  logger.debug(f"Saved analysis {analysis.analysis_id}")
114
 
115
- async def get_latest(self) -> Optional[ConflictAnalysis]:
116
  """Get the most recent analysis."""
117
  if not self._conn:
118
- raise RuntimeError("Repository not initialized")
119
 
120
  cursor = await self._conn.execute(
121
  "SELECT full_analysis FROM analyses ORDER BY generated_at DESC LIMIT 1"
@@ -131,20 +120,11 @@ class AnalysisRepository:
131
  days: int = 30,
132
  limit: int = 100,
133
  ) -> list[AnalysisSnapshot]:
134
- """
135
- Get historical analysis snapshots.
136
-
137
- Args:
138
- days: Number of days to look back.
139
- limit: Maximum number of records to return.
140
-
141
- Returns:
142
- List of analysis snapshots.
143
- """
144
  if not self._conn:
145
- raise RuntimeError("Repository not initialized")
146
 
147
- cutoff = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()
148
 
149
  cursor = await self._conn.execute(
150
  """
@@ -178,19 +158,11 @@ class AnalysisRepository:
178
  return snapshots
179
 
180
  async def get_tension_history(self, days: int = 30) -> list[dict]:
181
- """
182
- Get tension score history for graphing.
183
-
184
- Args:
185
- days: Number of days to look back.
186
-
187
- Returns:
188
- List of date/score pairs.
189
- """
190
  if not self._conn:
191
- raise RuntimeError("Repository not initialized")
192
 
193
- cutoff = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()
194
 
195
  cursor = await self._conn.execute(
196
  """
 
1
  """SQLite repository for storing analysis history."""
2
 
3
  import json
4
+ from datetime import UTC, datetime, timedelta
5
  from pathlib import Path
 
6
 
7
  import aiosqlite
8
  from loguru import logger
9
 
10
  from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
11
+ from westernfront.core.exceptions import ServiceNotInitializedError
12
  from westernfront.core.models import AnalysisSnapshot, ConflictAnalysis
13
 
14
 
 
16
  """SQLite-based repository for analysis storage and retrieval."""
17
 
18
  def __init__(self, db_path: str = "data/westernfront.db") -> None:
 
 
 
 
 
 
19
  self._db_path = Path(db_path)
20
+ self._conn: aiosqlite.Connection | None = None
21
 
22
  async def initialize(self) -> None:
23
  """Initialize the database and create tables."""
 
37
  async def _create_tables(self) -> None:
38
  """Create required database tables."""
39
  if not self._conn:
40
+ raise ServiceNotInitializedError("Analysis repository not initialized")
41
 
42
  await self._conn.execute("""
43
  CREATE TABLE IF NOT EXISTS analyses (
 
64
  if not self._conn:
65
  return
66
 
67
+ cutoff = (datetime.now(UTC) - timedelta(days=retention_days)).isoformat()
68
  cursor = await self._conn.execute(
69
  "DELETE FROM analyses WHERE generated_at < ?",
70
  (cutoff,),
 
75
  logger.info(f"Cleaned up {cursor.rowcount} old analysis records")
76
 
77
  async def save(self, analysis: ConflictAnalysis) -> None:
78
+ """Save an analysis to the database."""
 
 
 
 
 
79
  if not self._conn:
80
+ raise ServiceNotInitializedError("Analysis repository not initialized")
81
 
82
  await self._conn.execute(
83
  """
 
101
  await self._conn.commit()
102
  logger.debug(f"Saved analysis {analysis.analysis_id}")
103
 
104
+ async def get_latest(self) -> ConflictAnalysis | None:
105
  """Get the most recent analysis."""
106
  if not self._conn:
107
+ raise ServiceNotInitializedError("Analysis repository not initialized")
108
 
109
  cursor = await self._conn.execute(
110
  "SELECT full_analysis FROM analyses ORDER BY generated_at DESC LIMIT 1"
 
120
  days: int = 30,
121
  limit: int = 100,
122
  ) -> list[AnalysisSnapshot]:
123
+ """Get historical analysis snapshots."""
 
 
 
 
 
 
 
 
 
124
  if not self._conn:
125
+ raise ServiceNotInitializedError("Analysis repository not initialized")
126
 
127
+ cutoff = (datetime.now(UTC) - timedelta(days=days)).isoformat()
128
 
129
  cursor = await self._conn.execute(
130
  """
 
158
  return snapshots
159
 
160
  async def get_tension_history(self, days: int = 30) -> list[dict]:
161
+ """Get tension score history for graphing."""
 
 
 
 
 
 
 
 
162
  if not self._conn:
163
+ raise ServiceNotInitializedError("Analysis repository not initialized")
164
 
165
+ cutoff = (datetime.now(UTC) - timedelta(days=days)).isoformat()
166
 
167
  cursor = await self._conn.execute(
168
  """
src/westernfront/repositories/vectors.py CHANGED
@@ -1,13 +1,14 @@
1
  """ChromaDB vector repository for semantic search."""
2
 
3
- from datetime import datetime
4
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Any, Optional
6
 
7
  import chromadb
8
  from chromadb.config import Settings as ChromaSettings
9
  from loguru import logger
10
 
 
11
  from westernfront.core.models import NewsItem
12
 
13
  if TYPE_CHECKING:
@@ -21,16 +22,9 @@ class VectorRepository:
21
 
22
  def __init__(
23
  self,
24
- persist_dir: Optional[str] = None,
25
- embedding_service: Optional["EmbeddingService"] = None,
26
  ) -> None:
27
- """
28
- Initialize the vector repository.
29
-
30
- Args:
31
- persist_dir: Directory for ChromaDB persistence (defaults to server/data/chroma).
32
- embedding_service: Service for generating embeddings.
33
- """
34
  if persist_dir:
35
  self._persist_dir = Path(persist_dir)
36
  else:
@@ -39,54 +33,36 @@ class VectorRepository:
39
  self._persist_dir.mkdir(parents=True, exist_ok=True)
40
 
41
  self._embeddings = embedding_service
42
- self._client: Optional[chromadb.PersistentClient] = None
43
  self._collection = None
44
 
45
  def initialize(self) -> bool:
46
- """
47
- Initialize ChromaDB client and collection.
48
-
49
- Returns:
50
- True if initialization was successful.
51
- """
52
- try:
53
- logger.info(f"Initializing ChromaDB at {self._persist_dir}")
54
 
55
- self._client = chromadb.PersistentClient(
56
- path=str(self._persist_dir),
57
- settings=ChromaSettings(anonymized_telemetry=False),
58
- )
59
-
60
- self._collection = self._client.get_or_create_collection(
61
- name=self.COLLECTION_NAME,
62
- metadata={"hnsw:space": "cosine"},
63
- )
64
 
65
- count = self._collection.count()
66
- logger.info(f"ChromaDB initialized with {count} existing documents")
67
- return True
 
68
 
69
- except Exception as e:
70
- logger.error(f"Failed to initialize ChromaDB: {e}")
71
- return False
72
 
73
  @property
74
  def is_initialized(self) -> bool:
75
  """Check if the repository is initialized."""
76
  return self._collection is not None
77
 
78
- def add_items(self, items: list[NewsItem]) -> int:
79
- """
80
- Add news items to the vector store.
81
-
82
- Args:
83
- items: News items to add.
84
-
85
- Returns:
86
- Number of items added.
87
- """
88
  if not self._collection or not self._embeddings:
89
- raise RuntimeError("Repository not initialized")
90
 
91
  existing_ids = set(self._collection.get(ids=[item.id for item in items])["ids"])
92
  new_items = [item for item in items if item.id not in existing_ids]
@@ -117,25 +93,23 @@ class VectorRepository:
117
  logger.info(f"Added {len(new_items)} items to vector store")
118
  return len(new_items)
119
 
 
 
 
 
 
 
 
 
120
  def query_similar(
121
  self,
122
  query: str,
123
  n_results: int = 20,
124
  min_score: float = 0.3,
125
  ) -> list[dict]:
126
- """
127
- Query for similar items.
128
-
129
- Args:
130
- query: Query text.
131
- n_results: Maximum number of results.
132
- min_score: Minimum similarity score (0-1).
133
-
134
- Returns:
135
- List of similar items with metadata.
136
- """
137
  if not self._collection or not self._embeddings:
138
- raise RuntimeError("Repository not initialized")
139
 
140
  query_embedding = self._embeddings.embed(query)
141
 
@@ -164,16 +138,7 @@ class VectorRepository:
164
  topics: list[str],
165
  n_per_topic: int = 10,
166
  ) -> list[dict]:
167
- """
168
- Query for items related to multiple topics.
169
-
170
- Args:
171
- topics: List of topic queries.
172
- n_per_topic: Results per topic.
173
-
174
- Returns:
175
- Deduplicated list of relevant items.
176
- """
177
  seen_ids: set[str] = set()
178
  all_items: list[dict] = []
179
 
 
1
  """ChromaDB vector repository for semantic search."""
2
 
3
+ import asyncio
4
  from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
 
7
  import chromadb
8
  from chromadb.config import Settings as ChromaSettings
9
  from loguru import logger
10
 
11
+ from westernfront.core.exceptions import VectorStoreError
12
  from westernfront.core.models import NewsItem
13
 
14
  if TYPE_CHECKING:
 
22
 
23
  def __init__(
24
  self,
25
+ persist_dir: str | None = None,
26
+ embedding_service: "EmbeddingService | None" = None,
27
  ) -> None:
 
 
 
 
 
 
 
28
  if persist_dir:
29
  self._persist_dir = Path(persist_dir)
30
  else:
 
33
  self._persist_dir.mkdir(parents=True, exist_ok=True)
34
 
35
  self._embeddings = embedding_service
36
+ self._client: chromadb.PersistentClient | None = None
37
  self._collection = None
38
 
39
  def initialize(self) -> bool:
40
+ """Initialize ChromaDB client and collection."""
41
+ logger.info(f"Initializing ChromaDB at {self._persist_dir}")
 
 
 
 
 
 
42
 
43
+ self._client = chromadb.PersistentClient(
44
+ path=str(self._persist_dir),
45
+ settings=ChromaSettings(anonymized_telemetry=False),
46
+ )
 
 
 
 
 
47
 
48
+ self._collection = self._client.get_or_create_collection(
49
+ name=self.COLLECTION_NAME,
50
+ metadata={"hnsw:space": "cosine"},
51
+ )
52
 
53
+ count = self._collection.count()
54
+ logger.info(f"ChromaDB initialized with {count} existing documents")
55
+ return True
56
 
57
  @property
58
  def is_initialized(self) -> bool:
59
  """Check if the repository is initialized."""
60
  return self._collection is not None
61
 
62
+ def _add_items_sync(self, items: list[NewsItem]) -> int:
63
+ """Synchronous implementation of add_items."""
 
 
 
 
 
 
 
 
64
  if not self._collection or not self._embeddings:
65
+ raise VectorStoreError("Vector repository not initialized")
66
 
67
  existing_ids = set(self._collection.get(ids=[item.id for item in items])["ids"])
68
  new_items = [item for item in items if item.id not in existing_ids]
 
93
  logger.info(f"Added {len(new_items)} items to vector store")
94
  return len(new_items)
95
 
96
+ def add_items(self, items: list[NewsItem]) -> int:
97
+ """Add news items to the vector store (blocking)."""
98
+ return self._add_items_sync(items)
99
+
100
+ async def add_items_async(self, items: list[NewsItem]) -> int:
101
+ """Add news items to the vector store (non-blocking)."""
102
+ return await asyncio.to_thread(self._add_items_sync, items)
103
+
104
  def query_similar(
105
  self,
106
  query: str,
107
  n_results: int = 20,
108
  min_score: float = 0.3,
109
  ) -> list[dict]:
110
+ """Query for similar items."""
 
 
 
 
 
 
 
 
 
 
111
  if not self._collection or not self._embeddings:
112
+ raise VectorStoreError("Vector repository not initialized")
113
 
114
  query_embedding = self._embeddings.embed(query)
115
 
 
138
  topics: list[str],
139
  n_per_topic: int = 10,
140
  ) -> list[dict]:
141
+ """Query for items related to multiple topics."""
 
 
 
 
 
 
 
 
 
142
  seen_ids: set[str] = set()
143
  all_items: list[dict] = []
144
 
src/westernfront/services/__init__.py CHANGED
@@ -2,16 +2,28 @@
2
 
3
  from westernfront.services.analysis import AnalysisService
4
  from westernfront.services.cache import CacheService
 
5
  from westernfront.services.embeddings import EmbeddingService
 
6
  from westernfront.services.newsapi import NewsApiService
 
7
  from westernfront.services.reddit import RedditService
 
8
  from westernfront.services.rss import RssService
 
 
9
 
10
  __all__ = [
 
11
  "AnalysisService",
 
12
  "CacheService",
 
13
  "EmbeddingService",
 
14
  "NewsApiService",
15
  "RedditService",
 
 
16
  "RssService",
17
  ]
 
2
 
3
  from westernfront.services.analysis import AnalysisService
4
  from westernfront.services.cache import CacheService
5
+ from westernfront.services.chain_analysis import ChainAnalysisService
6
  from westernfront.services.embeddings import EmbeddingService
7
+ from westernfront.services.http import HttpService
8
  from westernfront.services.newsapi import NewsApiService
9
+ from westernfront.services.parsing import ResponseParser
10
  from westernfront.services.reddit import RedditService
11
+ from westernfront.services.retrieval import RetrievalService
12
  from westernfront.services.rss import RssService
13
+ from westernfront.services.scheduler import AnalysisScheduler
14
+ from westernfront.services.validation import AnalysisValidator
15
 
16
  __all__ = [
17
+ "AnalysisScheduler",
18
  "AnalysisService",
19
+ "AnalysisValidator",
20
  "CacheService",
21
+ "ChainAnalysisService",
22
  "EmbeddingService",
23
+ "HttpService",
24
  "NewsApiService",
25
  "RedditService",
26
+ "ResponseParser",
27
+ "RetrievalService",
28
  "RssService",
29
  ]
src/westernfront/services/analysis.py CHANGED
@@ -1,362 +1,239 @@
1
- """AI-powered conflict analysis service with RAG enhancement."""
2
-
3
- import asyncio
4
- import json
5
- import re
6
- import uuid
7
- from datetime import datetime
8
- from typing import Any, Optional
9
-
10
- import google.generativeai as genai
11
- from loguru import logger
12
- from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
13
-
14
- from westernfront.config import Settings
15
- from westernfront.core.enums import AnalysisType, SourceType, TensionLevel, TensionTrend
16
- from westernfront.core.models import (
17
- ConflictAnalysis,
18
- KeyDevelopment,
19
- NewsItem,
20
- RegionalImplications,
21
- ReliabilityAssessment,
22
- )
23
- from westernfront.prompts.analysis import build_rag_prompt
24
- from westernfront.repositories.analysis import AnalysisRepository
25
- from westernfront.repositories.vectors import VectorRepository
26
- from westernfront.services.newsapi import NewsApiService
27
- from westernfront.services.reddit import RedditService
28
- from westernfront.services.rss import RssService
29
-
30
-
31
- # Topics for RAG retrieval - these are semantic queries, not keyword filters
32
- RAG_QUERY_TOPICS = [
33
- "India Pakistan military conflict border tensions ceasefire violation",
34
- "Kashmir territorial dispute LOC Line of Control",
35
- "India China LAC Ladakh Arunachal standoff",
36
- "Nepal Bangladesh Sri Lanka India bilateral relations",
37
- "South Asia terrorism cross-border insurgency",
38
- "India diplomatic relations regional geopolitics",
39
- "Military exercises defense buildup South Asia",
40
- ]
41
-
42
-
43
- class AnalysisService:
44
- """Service for generating AI-powered conflict analysis with RAG."""
45
-
46
- def __init__(
47
- self,
48
- gemini_api_key: str,
49
- reddit: RedditService,
50
- rss: RssService,
51
- newsapi: NewsApiService,
52
- repository: AnalysisRepository,
53
- vectors: VectorRepository,
54
- settings: Settings,
55
- ) -> None:
56
- """
57
- Initialize the analysis service.
58
-
59
- Args:
60
- gemini_api_key: API key for Google Gemini.
61
- reddit: Reddit service for fetching posts.
62
- rss: RSS service for fetching articles.
63
- newsapi: NewsAPI service for fetching news.
64
- repository: Repository for storing analyses.
65
- vectors: Vector repository for RAG retrieval.
66
- settings: Application settings.
67
- """
68
- self._api_key = gemini_api_key
69
- self._reddit = reddit
70
- self._rss = rss
71
- self._newsapi = newsapi
72
- self._repository = repository
73
- self._vectors = vectors
74
- self._settings = settings
75
- self._model: Optional[genai.GenerativeModel] = None
76
- self._update_task: Optional[asyncio.Task[None]] = None
77
-
78
- async def initialize(self) -> None:
79
- """Initialize the Gemini model and start background updates."""
80
- logger.info("Initializing Gemini AI")
81
- genai.configure(api_key=self._api_key)
82
- self._model = genai.GenerativeModel(
83
- "gemma-3-27b-it",
84
- generation_config={
85
- "temperature": 0.2,
86
- "top_p": 0.95,
87
- "top_k": 40,
88
- },
89
- )
90
- logger.info("Gemini AI initialized")
91
-
92
- self._update_task = asyncio.create_task(self._periodic_update())
93
-
94
- async def close(self) -> None:
95
- """Clean up resources."""
96
- if self._update_task:
97
- self._update_task.cancel()
98
- try:
99
- await self._update_task
100
- except asyncio.CancelledError:
101
- pass
102
- logger.info("Analysis service closed")
103
-
104
- @property
105
- def is_initialized(self) -> bool:
106
- """Check if the service is initialized."""
107
- return self._model is not None
108
-
109
- async def _periodic_update(self) -> None:
110
- """Background task for periodic analysis updates."""
111
- await asyncio.sleep(5)
112
- await self._run_update("startup")
113
-
114
- interval = self._settings.update_interval_minutes * 60
115
- while True:
116
- await asyncio.sleep(interval)
117
- await self._run_update("scheduled")
118
-
119
- async def _run_update(self, trigger: str) -> None:
120
- """Execute an analysis update."""
121
- try:
122
- logger.info(f"Starting analysis update (trigger: {trigger})")
123
- analysis = await self.generate_analysis(trigger=trigger)
124
- if analysis:
125
- logger.info(f"Analysis complete. Tension: {analysis.tension_level.value}")
126
- else:
127
- logger.warning("No analysis generated")
128
- except Exception as e:
129
- logger.error(f"Error in update: {e}")
130
-
131
- async def _ingest_all_news(self) -> int:
132
- """
133
- Ingest ALL news from all sources into vector store.
134
-
135
- No keyword filtering - we store everything and let vector search
136
- determine relevance.
137
-
138
- Returns:
139
- Number of new items added.
140
- """
141
- days = self._settings.analysis_days_back
142
-
143
- reddit_task = self._reddit.get_all_posts(days)
144
- rss_task = self._rss.get_all_articles(days)
145
- newsapi_task = self._newsapi.get_related_articles(days_back=days)
146
-
147
- results = await asyncio.gather(
148
- reddit_task, rss_task, newsapi_task,
149
- return_exceptions=True,
150
- )
151
-
152
- all_items: list[NewsItem] = []
153
- source_names = ["Reddit", "RSS", "NewsAPI"]
154
-
155
- for i, result in enumerate(results):
156
- if isinstance(result, Exception):
157
- logger.error(f"Error fetching from {source_names[i]}: {result}")
158
- else:
159
- all_items.extend(result)
160
-
161
- logger.info(f"Ingested {len(all_items)} total news items from all sources")
162
-
163
- if not self._vectors.is_initialized:
164
- logger.warning("Vector store not initialized, skipping ingestion")
165
- return 0
166
-
167
- stored = self._vectors.add_items(all_items)
168
- total_count = self._vectors.get_count()
169
- logger.info(f"Stored {stored} new items. Total in vector store: {total_count}")
170
- return stored
171
-
172
- def _retrieve_relevant_items(self, max_items: int = 30) -> list[dict]:
173
- """
174
- Retrieve relevant items from vector store using semantic search.
175
-
176
- This is the core of RAG - we query by topic and get the most
177
- semantically relevant items regardless of when they were published.
178
-
179
- Args:
180
- max_items: Maximum items to retrieve.
181
-
182
- Returns:
183
- List of relevant items with metadata.
184
- """
185
- if not self._vectors.is_initialized:
186
- logger.warning("Vector store not initialized")
187
- return []
188
-
189
- all_results = self._vectors.query_by_topics(
190
- RAG_QUERY_TOPICS,
191
- n_per_topic=max_items // len(RAG_QUERY_TOPICS) + 1,
192
- )
193
-
194
- results = all_results[:max_items]
195
- logger.info(f"Retrieved {len(results)} relevant items via vector search")
196
- return results
197
-
198
- @retry(wait=wait_exponential(min=2, max=60), stop=stop_after_attempt(3))
199
- async def _call_gemini(self, prompt: str) -> Optional[dict[str, Any]]:
200
- """Call Gemini API with retry logic."""
201
- if not self._model:
202
- raise RuntimeError("Gemini not initialized")
203
-
204
- logger.info("Calling Gemini API")
205
- response = await self._model.generate_content_async(prompt)
206
- text = response.text
207
-
208
- try:
209
- return json.loads(text)
210
- except json.JSONDecodeError:
211
- json_match = re.search(r"```(?:json)?\n(.*?)\n```", text, re.DOTALL)
212
- if json_match:
213
- return json.loads(json_match.group(1))
214
-
215
- json_match = re.search(r"\{.*\}", text, re.DOTALL)
216
- if json_match:
217
- return json.loads(json_match.group(0))
218
-
219
- logger.error(f"Failed to parse JSON: {text[:200]}...")
220
- raise ValueError("Could not parse JSON from response")
221
-
222
- def _parse_tension_level(self, value: str) -> TensionLevel:
223
- """Parse tension level from string."""
224
- value = value.upper()
225
- if "CRITICAL" in value:
226
- return TensionLevel.CRITICAL
227
- if "HIGH" in value:
228
- return TensionLevel.HIGH
229
- if "MEDIUM" in value:
230
- return TensionLevel.MEDIUM
231
- return TensionLevel.LOW
232
-
233
- def _parse_tension_trend(self, value: str) -> TensionTrend:
234
- """Parse tension trend from string."""
235
- value = value.upper()
236
- if "INCREASING" in value:
237
- return TensionTrend.INCREASING
238
- if "DECREASING" in value:
239
- return TensionTrend.DECREASING
240
- return TensionTrend.STABLE
241
-
242
- def _parse_analysis_type(self, value: str) -> AnalysisType:
243
- """Parse analysis type from string."""
244
- value = value.upper()
245
- if "MILITARY" in value:
246
- return AnalysisType.MILITARY
247
- if "DIPLOMATIC" in value:
248
- return AnalysisType.DIPLOMATIC
249
- if "INTERNAL" in value:
250
- return AnalysisType.INTERNAL_SECURITY
251
- if "POLITICAL" in value:
252
- return AnalysisType.POLITICAL
253
- return AnalysisType.OTHER
254
-
255
- def _parse_key_developments(self, data: list[dict]) -> list[KeyDevelopment]:
256
- """Parse key developments from response data."""
257
- developments = []
258
- for item in data:
259
- if not isinstance(item, dict):
260
- continue
261
- developments.append(
262
- KeyDevelopment(
263
- title=item.get("title", "Unnamed"),
264
- description=item.get("description", "No description"),
265
- sources=item.get("sources", []),
266
- timestamp=datetime.now(),
267
- )
268
- )
269
- return developments
270
-
271
- def _count_sources(self, items: list[dict]) -> dict[str, int]:
272
- """Count items by source type from retrieved results."""
273
- counts: dict[str, int] = {}
274
- for item in items:
275
- meta = item.get("metadata", {})
276
- key = meta.get("source_type", "unknown")
277
- counts[key] = counts.get(key, 0) + 1
278
- return counts
279
-
280
- async def generate_analysis(self, trigger: str = "scheduled") -> Optional[ConflictAnalysis]:
281
- """
282
- Generate a new conflict analysis using RAG.
283
-
284
- Flow:
285
- 1. Ingest ALL news into vector store
286
- 2. Retrieve most relevant items via semantic search
287
- 3. Send retrieved items to Gemini for analysis
288
-
289
- Args:
290
- trigger: What triggered this analysis.
291
-
292
- Returns:
293
- The generated analysis or None.
294
- """
295
- await self._ingest_all_news()
296
-
297
- retrieved_items = self._retrieve_relevant_items(
298
- max_items=self._settings.max_posts_for_analysis
299
- )
300
-
301
- if len(retrieved_items) < self._settings.min_posts_for_analysis:
302
- logger.warning(f"Insufficient data: {len(retrieved_items)} items")
303
- return None
304
-
305
- prompt = build_rag_prompt(retrieved_items, self._vectors.get_count())
306
-
307
- try:
308
- data = await self._call_gemini(prompt)
309
- except RetryError as e:
310
- logger.error(f"Gemini failed after retries: {e}")
311
- return None
312
-
313
- if not data:
314
- return None
315
-
316
- tension_score = 1
317
- raw_score = data.get("tension_score")
318
- if isinstance(raw_score, (int, float)):
319
- tension_score = max(1, min(10, int(raw_score)))
320
- elif isinstance(raw_score, str) and raw_score.isdigit():
321
- tension_score = max(1, min(10, int(raw_score)))
322
-
323
- key_entities = data.get("key_entities", [])
324
- if isinstance(key_entities, str):
325
- key_entities = [e.strip() for e in key_entities.split(",") if e.strip()]
326
-
327
- reliability_data = data.get("reliability_assessment", {})
328
- regional_data = data.get("regional_implications", {})
329
-
330
- analysis = ConflictAnalysis(
331
- analysis_id=str(uuid.uuid4()),
332
- generated_at=datetime.now(),
333
- latest_status=data.get("latest_status", "No status available"),
334
- situation_summary=data.get("situation_summary", "No summary available"),
335
- key_developments=self._parse_key_developments(data.get("key_developments", [])),
336
- reliability_assessment=ReliabilityAssessment(
337
- source_credibility=reliability_data.get("source_credibility", "Unknown"),
338
- information_gaps=reliability_data.get("information_gaps", "Unknown"),
339
- confidence_rating=reliability_data.get("confidence_rating", "LOW"),
340
- ),
341
- regional_implications=RegionalImplications(
342
- security=regional_data.get("security", "No assessment"),
343
- diplomatic=regional_data.get("diplomatic", "No assessment"),
344
- economic=regional_data.get("economic", "No assessment"),
345
- ),
346
- tension_level=self._parse_tension_level(data.get("tension_level", "LOW")),
347
- tension_rationale=data.get("tension_rationale", "No rationale"),
348
- tension_score=tension_score,
349
- tension_trend=self._parse_tension_trend(data.get("tension_trend", "STABLE")),
350
- analysis_type=self._parse_analysis_type(data.get("analysis_type", "OTHER")),
351
- key_entities=key_entities if isinstance(key_entities, list) else [],
352
- source_count=len(retrieved_items),
353
- source_breakdown=self._count_sources(retrieved_items),
354
- )
355
-
356
- await self._repository.save(analysis)
357
- logger.info(f"Generated RAG-enhanced analysis {analysis.analysis_id}")
358
- return analysis
359
-
360
- async def get_latest(self) -> Optional[ConflictAnalysis]:
361
- """Get the latest analysis from the repository."""
362
- return await self._repository.get_latest()
 
1
+ """AI-powered conflict analysis service with RAG enhancement and quality improvements."""
2
+
3
+ import asyncio
4
+ import uuid
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import google.generativeai as genai
9
+ from loguru import logger
10
+ from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
11
+
12
+ from westernfront.config import Settings
13
+ from westernfront.core.constants import RAG_QUERY_TOPICS, SEARCH_KEYWORDS
14
+ from westernfront.core.exceptions import ServiceNotInitializedError
15
+ from westernfront.core.models import (
16
+ ConflictAnalysis,
17
+ NewsItem,
18
+ RegionalImplications,
19
+ ReliabilityAssessment,
20
+ )
21
+ from westernfront.prompts.analysis import build_rag_prompt
22
+ from westernfront.repositories.analysis import AnalysisRepository
23
+ from westernfront.repositories.vectors import VectorRepository
24
+ from westernfront.services.chain_analysis import ChainAnalysisService
25
+ from westernfront.services.newsapi import NewsApiService
26
+ from westernfront.services.parsing import ResponseParser
27
+ from westernfront.services.reddit import RedditService
28
+ from westernfront.services.retrieval import RetrievalService
29
+ from westernfront.services.rss import RssService
30
+ from westernfront.services.scheduler import AnalysisScheduler
31
+ from westernfront.services.validation import AnalysisValidator
32
+ from westernfront.utils import extract_json_from_response
33
+
34
+
35
+ class AnalysisService:
36
+ """Service for generating AI-powered conflict analysis with RAG and quality improvements."""
37
+
38
+ def __init__(
39
+ self,
40
+ gemini_api_key: str,
41
+ reddit: RedditService,
42
+ rss: RssService,
43
+ newsapi: NewsApiService,
44
+ repository: AnalysisRepository,
45
+ vectors: VectorRepository,
46
+ settings: Settings,
47
+ ) -> None:
48
+ self._api_key = gemini_api_key
49
+ self._reddit = reddit
50
+ self._rss = rss
51
+ self._newsapi = newsapi
52
+ self._repository = repository
53
+ self._vectors = vectors
54
+ self._settings = settings
55
+ self._model: genai.GenerativeModel | None = None
56
+ self._retrieval: RetrievalService | None = None
57
+ self._chain: ChainAnalysisService | None = None
58
+ self._parser = ResponseParser()
59
+ self._validator = AnalysisValidator()
60
+ self._scheduler: AnalysisScheduler | None = None
61
+
62
+ async def initialize(self) -> None:
63
+ """Initialize the Gemini model and start background updates."""
64
+ logger.info("Initializing Gemini AI")
65
+ genai.configure(api_key=self._api_key)
66
+ self._model = genai.GenerativeModel(
67
+ "gemma-3-27b-it",
68
+ generation_config={
69
+ "temperature": 0.2,
70
+ "top_p": 0.95,
71
+ "top_k": 40,
72
+ },
73
+ )
74
+ logger.info("Gemini AI initialized")
75
+
76
+ self._retrieval = RetrievalService(self._vectors)
77
+ self._chain = ChainAnalysisService(self._model)
78
+ logger.info("Quality services initialized (retrieval, chain analysis)")
79
+
80
+ self._scheduler = AnalysisScheduler(self, self._settings.update_interval_minutes)
81
+ self._scheduler.start()
82
+
83
+ async def close(self) -> None:
84
+ """Clean up resources."""
85
+ if self._scheduler:
86
+ await self._scheduler.stop()
87
+ logger.info("Analysis service closed")
88
+
89
+ @property
90
+ def is_initialized(self) -> bool:
91
+ """Check if the service is initialized."""
92
+ return self._model is not None
93
+
94
+ @property
95
+ def keywords(self) -> list[str]:
96
+ """Get the search keywords used for analysis."""
97
+ return SEARCH_KEYWORDS
98
+
99
+ async def _ingest_all_news(self) -> int:
100
+ """Ingest all news from all sources into vector store in parallel."""
101
+ days = self._settings.analysis_days_back
102
+
103
+ results = await asyncio.gather(
104
+ self._reddit.get_all_posts(days),
105
+ self._rss.get_all_articles(days),
106
+ self._newsapi.get_related_articles(days_back=days),
107
+ return_exceptions=True,
108
+ )
109
+
110
+ all_items: list[NewsItem] = []
111
+ source_names = ["Reddit", "RSS", "NewsAPI"]
112
+
113
+ for i, result in enumerate(results):
114
+ if isinstance(result, Exception):
115
+ logger.error(f"Error fetching from {source_names[i]}: {result}")
116
+ else:
117
+ all_items.extend(result)
118
+
119
+ logger.info(f"Ingested {len(all_items)} total news items from all sources")
120
+
121
+ if not self._vectors.is_initialized:
122
+ logger.warning("Vector store not initialized, skipping ingestion")
123
+ return 0
124
+
125
+ stored = await self._vectors.add_items_async(all_items)
126
+ total_count = self._vectors.get_count()
127
+ logger.info(f"Stored {stored} new items. Total in vector store: {total_count}")
128
+ return stored
129
+
130
+ def _retrieve_relevant_items(self, max_items: int = 40) -> list[dict]:
131
+ """Retrieve relevant items with quality weighting."""
132
+ if self._retrieval:
133
+ return self._retrieval.retrieve_with_quality(
134
+ RAG_QUERY_TOPICS,
135
+ max_items=max_items,
136
+ )
137
+
138
+ if not self._vectors.is_initialized:
139
+ logger.warning("Vector store not initialized")
140
+ return []
141
+
142
+ return self._vectors.query_by_topics(
143
+ RAG_QUERY_TOPICS,
144
+ n_per_topic=max_items // len(RAG_QUERY_TOPICS) + 1,
145
+ )[:max_items]
146
+
147
+ @retry(wait=wait_exponential(min=2, max=60), stop=stop_after_attempt(3))
148
+ async def _call_gemini(self, prompt: str) -> dict[str, Any] | None:
149
+ """Call Gemini API with retry logic."""
150
+ if not self._model:
151
+ raise ServiceNotInitializedError("Gemini model not initialized")
152
+
153
+ logger.info("Calling Gemini API")
154
+ response = await self._model.generate_content_async(prompt)
155
+ result = extract_json_from_response(response.text)
156
+ if not result:
157
+ raise ValueError("Could not parse JSON from response")
158
+ return result
159
+
160
+ async def generate_analysis(self, use_chain: bool = True) -> ConflictAnalysis | None:
161
+ """Generate a new conflict analysis using RAG with quality improvements."""
162
+ await self._ingest_all_news()
163
+
164
+ retrieved_items = self._retrieve_relevant_items(
165
+ max_items=self._settings.max_posts_for_analysis
166
+ )
167
+
168
+ if len(retrieved_items) < self._settings.min_posts_for_analysis:
169
+ logger.warning(f"Insufficient data: {len(retrieved_items)} items")
170
+ return None
171
+
172
+ total_count = self._vectors.get_count()
173
+
174
+ try:
175
+ if use_chain and self._chain and len(retrieved_items) >= 10:
176
+ data = await self._chain.run_chain_analysis(
177
+ retrieved_items,
178
+ total_in_memory=total_count,
179
+ use_full_chain=True,
180
+ )
181
+ else:
182
+ prompt = build_rag_prompt(retrieved_items, total_count)
183
+ data = await self._call_gemini(prompt)
184
+ except RetryError as e:
185
+ logger.error(f"AI analysis failed after retries: {e}")
186
+ return None
187
+
188
+ if not data:
189
+ return None
190
+
191
+ source_texts = [item.get("document", "") for item in retrieved_items]
192
+ is_valid, issues = self._validator.validate_analysis(data, source_texts)
193
+ if not is_valid:
194
+ logger.warning(f"Validation issues: {issues}")
195
+
196
+ reliability_data = data.get("reliability_assessment", {})
197
+ regional_data = data.get("regional_implications", {})
198
+
199
+ analysis = ConflictAnalysis(
200
+ analysis_id=str(uuid.uuid4()),
201
+ generated_at=datetime.now(),
202
+ latest_status=data.get("latest_status", "No status available"),
203
+ situation_summary=data.get("situation_summary", "No summary available"),
204
+ key_developments=self._parser.parse_key_developments(
205
+ data.get("key_developments", [])
206
+ ),
207
+ reliability_assessment=ReliabilityAssessment(
208
+ source_credibility=reliability_data.get("source_credibility", "Unknown"),
209
+ information_gaps=reliability_data.get("information_gaps", "Unknown"),
210
+ confidence_rating=reliability_data.get("confidence_rating", "LOW"),
211
+ ),
212
+ regional_implications=RegionalImplications(
213
+ security=regional_data.get("security", "No assessment"),
214
+ diplomatic=regional_data.get("diplomatic", "No assessment"),
215
+ economic=regional_data.get("economic", "No assessment"),
216
+ ),
217
+ tension_level=self._parser.parse_tension_level(
218
+ data.get("tension_level", "LOW")
219
+ ),
220
+ tension_rationale=data.get("tension_rationale", "No rationale"),
221
+ tension_score=self._parser.parse_tension_score(data.get("tension_score")),
222
+ tension_trend=self._parser.parse_tension_trend(
223
+ data.get("tension_trend", "STABLE")
224
+ ),
225
+ analysis_type=self._parser.parse_analysis_type(
226
+ data.get("analysis_type", "OTHER")
227
+ ),
228
+ key_entities=self._parser.parse_key_entities(data.get("key_entities")),
229
+ source_count=len(retrieved_items),
230
+ source_breakdown=self._parser.count_sources(retrieved_items),
231
+ )
232
+
233
+ await self._repository.save(analysis)
234
+ logger.info(f"Generated quality-enhanced analysis {analysis.analysis_id}")
235
+ return analysis
236
+
237
+ async def get_latest(self) -> ConflictAnalysis | None:
238
+ """Get the latest analysis from the repository."""
239
+ return await self._repository.get_latest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/westernfront/services/cache.py CHANGED
@@ -1,73 +1,90 @@
1
- """Thread-safe async cache service."""
2
 
3
  import asyncio
4
- from typing import Any, Optional
5
 
6
  from cachetools import TTLCache
7
 
8
 
9
- class CacheService:
10
- """Async-safe caching with TTL support."""
11
 
12
- def __init__(self, ttl_seconds: int = 3600, max_size: int = 100) -> None:
13
- """
14
- Initialize the cache service.
 
15
 
16
- Args:
17
- ttl_seconds: Time-to-live for cache entries in seconds.
18
- max_size: Maximum number of entries in the cache.
19
- """
20
- self._cache: TTLCache[str, Any] = TTLCache(maxsize=max_size, ttl=ttl_seconds)
21
- self._lock = asyncio.Lock()
 
 
 
 
 
 
 
22
 
23
- async def get(self, key: str) -> Optional[Any]:
24
- """
25
- Get a value from the cache.
 
 
 
26
 
27
- Args:
28
- key: The cache key.
 
 
 
 
 
 
 
29
 
30
- Returns:
31
- The cached value or None if not found.
32
- """
33
- async with self._lock:
 
 
 
 
34
  return self._cache.get(key)
 
 
35
 
36
  async def set(self, key: str, value: Any) -> None:
37
- """
38
- Set a value in the cache.
39
-
40
- Args:
41
- key: The cache key.
42
- value: The value to cache.
43
- """
44
- async with self._lock:
45
  self._cache[key] = value
 
 
46
 
47
  async def delete(self, key: str) -> None:
48
- """
49
- Delete a value from the cache.
50
-
51
- Args:
52
- key: The cache key.
53
- """
54
- async with self._lock:
55
  self._cache.pop(key, None)
 
 
56
 
57
  async def clear(self) -> None:
58
  """Clear all entries from the cache."""
59
- async with self._lock:
 
60
  self._cache.clear()
 
 
61
 
62
  async def has(self, key: str) -> bool:
63
- """
64
- Check if a key exists in the cache.
65
-
66
- Args:
67
- key: The cache key.
68
-
69
- Returns:
70
- True if the key exists, False otherwise.
71
- """
72
- async with self._lock:
73
  return key in self._cache
 
 
 
1
+ """Thread-safe async cache service with read-write lock pattern."""
2
 
3
  import asyncio
4
+ from typing import Any
5
 
6
  from cachetools import TTLCache
7
 
8
 
9
+ class ReadWriteLock:
10
+ """Async read-write lock allowing concurrent reads but exclusive writes."""
11
 
12
+ def __init__(self) -> None:
13
+ self._readers = 0
14
+ self._writer = False
15
+ self._condition = asyncio.Condition()
16
 
17
+ async def acquire_read(self) -> None:
18
+ """Acquire read lock (allows concurrent readers)."""
19
+ async with self._condition:
20
+ while self._writer:
21
+ await self._condition.wait()
22
+ self._readers += 1
23
+
24
+ async def release_read(self) -> None:
25
+ """Release read lock."""
26
+ async with self._condition:
27
+ self._readers -= 1
28
+ if self._readers == 0:
29
+ self._condition.notify_all()
30
 
31
+ async def acquire_write(self) -> None:
32
+ """Acquire exclusive write lock."""
33
+ async with self._condition:
34
+ while self._writer or self._readers > 0:
35
+ await self._condition.wait()
36
+ self._writer = True
37
 
38
+ async def release_write(self) -> None:
39
+ """Release write lock."""
40
+ async with self._condition:
41
+ self._writer = False
42
+ self._condition.notify_all()
43
+
44
+
45
+ class CacheService:
46
+ """Async-safe caching with TTL support and read-write lock for concurrent reads."""
47
 
48
+ def __init__(self, ttl_seconds: int = 3600, max_size: int = 100) -> None:
49
+ self._cache: TTLCache[str, Any] = TTLCache(maxsize=max_size, ttl=ttl_seconds)
50
+ self._lock = ReadWriteLock()
51
+
52
+ async def get(self, key: str) -> Any | None:
53
+ """Get a value from the cache."""
54
+ await self._lock.acquire_read()
55
+ try:
56
  return self._cache.get(key)
57
+ finally:
58
+ await self._lock.release_read()
59
 
60
  async def set(self, key: str, value: Any) -> None:
61
+ """Set a value in the cache."""
62
+ await self._lock.acquire_write()
63
+ try:
 
 
 
 
 
64
  self._cache[key] = value
65
+ finally:
66
+ await self._lock.release_write()
67
 
68
  async def delete(self, key: str) -> None:
69
+ """Delete a value from the cache."""
70
+ await self._lock.acquire_write()
71
+ try:
 
 
 
 
72
  self._cache.pop(key, None)
73
+ finally:
74
+ await self._lock.release_write()
75
 
76
  async def clear(self) -> None:
77
  """Clear all entries from the cache."""
78
+ await self._lock.acquire_write()
79
+ try:
80
  self._cache.clear()
81
+ finally:
82
+ await self._lock.release_write()
83
 
84
  async def has(self, key: str) -> bool:
85
+ """Check if a key exists in the cache."""
86
+ await self._lock.acquire_read()
87
+ try:
 
 
 
 
 
 
 
88
  return key in self._cache
89
+ finally:
90
+ await self._lock.release_read()
src/westernfront/services/chain_analysis.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-pass chain analysis for improved quality."""
2
+
3
+ from typing import Any
4
+
5
+ import google.generativeai as genai
6
+ from loguru import logger
7
+ from tenacity import retry, stop_after_attempt, wait_exponential
8
+
9
+ from westernfront.prompts.analysis import (
10
+ build_extraction_prompt,
11
+ build_rag_prompt,
12
+ build_synthesis_prompt,
13
+ )
14
+ from westernfront.utils import extract_json_from_response
15
+
16
+
17
+ class ChainAnalysisService:
18
+ """Multi-pass analysis using chain-of-thought for better quality."""
19
+
20
+ def __init__(self, model: genai.GenerativeModel) -> None:
21
+ self._model = model
22
+
23
+ @retry(wait=wait_exponential(min=2, max=30), stop=stop_after_attempt(3))
24
+ async def _call_model(self, prompt: str) -> dict[str, Any] | None:
25
+ """Call model with retry and JSON parsing."""
26
+ response = await self._model.generate_content_async(prompt)
27
+ return extract_json_from_response(response.text)
28
+
29
+ async def extract_facts(self, items: list[dict]) -> list[dict]:
30
+ """Stage 1: Extract facts from sources."""
31
+ logger.info("Chain Analysis Stage 1: Extracting facts")
32
+ prompt = build_extraction_prompt(items)
33
+
34
+ result = await self._call_model(prompt)
35
+ if not result:
36
+ logger.warning("Fact extraction failed, continuing with empty facts")
37
+ return []
38
+
39
+ facts = result.get("facts", [])
40
+ logger.info(f"Extracted {len(facts)} facts from sources")
41
+ return facts
42
+
43
+ async def synthesize_facts(
44
+ self,
45
+ facts: list[dict],
46
+ historical_context: str = "",
47
+ ) -> dict:
48
+ """Stage 2: Synthesize facts into preliminary assessment."""
49
+ logger.info("Chain Analysis Stage 2: Synthesizing facts")
50
+ prompt = build_synthesis_prompt(facts, historical_context)
51
+
52
+ result = await self._call_model(prompt)
53
+ if not result:
54
+ logger.warning("Synthesis failed, using defaults")
55
+ return {
56
+ "significant_developments": [],
57
+ "preliminary_tension": "MEDIUM",
58
+ "tension_reasoning": "Unable to synthesize",
59
+ "trends": [],
60
+ "contradictions": [],
61
+ "gaps": ["Synthesis stage failed"],
62
+ }
63
+
64
+ logger.info(f"Synthesis complete: preliminary tension = {result.get('preliminary_tension')}")
65
+ return result
66
+
67
+ async def generate_final_report(
68
+ self,
69
+ items: list[dict],
70
+ synthesis: dict,
71
+ total_in_memory: int = 0,
72
+ ) -> dict[str, Any] | None:
73
+ """Stage 3: Generate final analysis report."""
74
+ logger.info("Chain Analysis Stage 3: Generating final report")
75
+
76
+ prompt = build_rag_prompt(items, total_in_memory)
77
+
78
+ synthesis_context = f"""
79
+ PRELIMINARY ASSESSMENT (from internal analysis):
80
+ - Tension Level Estimate: {synthesis.get('preliminary_tension', 'UNKNOWN')}
81
+ - Reasoning: {synthesis.get('tension_reasoning', 'N/A')}
82
+ - Key Trends: {', '.join(synthesis.get('trends', []))}
83
+ - Information Gaps: {', '.join(synthesis.get('gaps', []))}
84
+
85
+ Consider this preliminary assessment but verify against the source data.
86
+ """
87
+ enhanced_prompt = prompt + "\n\n" + synthesis_context
88
+
89
+ result = await self._call_model(enhanced_prompt)
90
+ if result:
91
+ logger.info("Final report generated successfully")
92
+ return result
93
+
94
+ async def run_chain_analysis(
95
+ self,
96
+ items: list[dict],
97
+ total_in_memory: int = 0,
98
+ use_full_chain: bool = True,
99
+ ) -> dict[str, Any] | None:
100
+ """Run complete chain analysis."""
101
+ if not use_full_chain or len(items) < 10:
102
+ logger.info("Using direct analysis (chain disabled or insufficient data)")
103
+ prompt = build_rag_prompt(items, total_in_memory)
104
+ return await self._call_model(prompt)
105
+
106
+ facts = await self.extract_facts(items)
107
+ synthesis = await self.synthesize_facts(facts)
108
+ return await self.generate_final_report(items, synthesis, total_in_memory)
src/westernfront/services/embeddings.py CHANGED
@@ -1,25 +1,21 @@
1
  """Local embedding service using sentence-transformers MiniLM model."""
2
 
 
3
  import os
4
  from pathlib import Path
5
- from typing import Optional
6
 
7
  from loguru import logger
8
 
 
9
 
10
  MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
11
 
12
 
13
  class EmbeddingService:
14
  """Service for generating text embeddings using local MiniLM model."""
15
 
16
- def __init__(self, model_cache_dir: Optional[str] = None) -> None:
17
- """
18
- Initialize the embedding service.
19
-
20
- Args:
21
- model_cache_dir: Directory to cache the model (defaults to server/models).
22
- """
23
  if model_cache_dir:
24
  self._cache_dir = Path(model_cache_dir)
25
  else:
@@ -32,28 +28,19 @@ class EmbeddingService:
32
  os.environ["TRANSFORMERS_CACHE"] = str(self._cache_dir)
33
 
34
  self._model = None
35
- self._dimension = 384
36
 
37
  def initialize(self) -> bool:
38
- """
39
- Initialize the embedding model.
40
-
41
- Returns:
42
- True if initialization was successful.
43
- """
44
- try:
45
- logger.info(f"Loading embedding model from {self._cache_dir}")
46
- from sentence_transformers import SentenceTransformer
47
-
48
- self._model = SentenceTransformer(
49
- MODEL_NAME,
50
- cache_folder=str(self._cache_dir),
51
- )
52
- logger.info(f"Embedding model loaded: {MODEL_NAME}")
53
- return True
54
- except Exception as e:
55
- logger.error(f"Failed to load embedding model: {e}")
56
- return False
57
 
58
  @property
59
  def is_initialized(self) -> bool:
@@ -65,34 +52,32 @@ class EmbeddingService:
65
  """Get the embedding dimension."""
66
  return self._dimension
67
 
68
- def embed(self, text: str) -> list[float]:
69
- """
70
- Generate embedding for a single text.
71
-
72
- Args:
73
- text: Text to embed.
74
-
75
- Returns:
76
- List of floats representing the embedding.
77
- """
78
  if not self._model:
79
- raise RuntimeError("Embedding model not initialized")
80
-
81
  embedding = self._model.encode(text, convert_to_numpy=True)
82
  return embedding.tolist()
83
 
84
- def embed_batch(self, texts: list[str]) -> list[list[float]]:
85
- """
86
- Generate embeddings for multiple texts.
87
-
88
- Args:
89
- texts: List of texts to embed.
90
-
91
- Returns:
92
- List of embeddings.
93
- """
94
  if not self._model:
95
- raise RuntimeError("Embedding model not initialized")
96
-
97
  embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
98
  return embeddings.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Local embedding service using sentence-transformers MiniLM model."""
2
 
3
+ import asyncio
4
  import os
5
  from pathlib import Path
 
6
 
7
  from loguru import logger
8
 
9
+ from westernfront.core.exceptions import ServiceNotInitializedError
10
 
11
  MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
+ EMBEDDING_DIMENSION = 384
13
 
14
 
15
  class EmbeddingService:
16
  """Service for generating text embeddings using local MiniLM model."""
17
 
18
+ def __init__(self, model_cache_dir: str | None = None) -> None:
 
 
 
 
 
 
19
  if model_cache_dir:
20
  self._cache_dir = Path(model_cache_dir)
21
  else:
 
28
  os.environ["TRANSFORMERS_CACHE"] = str(self._cache_dir)
29
 
30
  self._model = None
31
+ self._dimension = EMBEDDING_DIMENSION
32
 
33
  def initialize(self) -> bool:
34
+ """Initialize the embedding model."""
35
+ logger.info(f"Loading embedding model from {self._cache_dir}")
36
+ from sentence_transformers import SentenceTransformer
37
+
38
+ self._model = SentenceTransformer(
39
+ MODEL_NAME,
40
+ cache_folder=str(self._cache_dir),
41
+ )
42
+ logger.info(f"Embedding model loaded: {MODEL_NAME}")
43
+ return True
 
 
 
 
 
 
 
 
 
44
 
45
  @property
46
  def is_initialized(self) -> bool:
 
52
  """Get the embedding dimension."""
53
  return self._dimension
54
 
55
+ def _embed_sync(self, text: str) -> list[float]:
56
+ """Synchronous embedding generation."""
 
 
 
 
 
 
 
 
57
  if not self._model:
58
+ raise ServiceNotInitializedError("Embedding model not initialized")
 
59
  embedding = self._model.encode(text, convert_to_numpy=True)
60
  return embedding.tolist()
61
 
62
+ def _embed_batch_sync(self, texts: list[str]) -> list[list[float]]:
63
+ """Synchronous batch embedding generation."""
 
 
 
 
 
 
 
 
64
  if not self._model:
65
+ raise ServiceNotInitializedError("Embedding model not initialized")
 
66
  embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
67
  return embeddings.tolist()
68
+
69
+ def embed(self, text: str) -> list[float]:
70
+ """Generate embedding for a single text (sync version)."""
71
+ return self._embed_sync(text)
72
+
73
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
74
+ """Generate embeddings for multiple texts (sync version)."""
75
+ return self._embed_batch_sync(texts)
76
+
77
+ async def embed_async(self, text: str) -> list[float]:
78
+ """Generate embedding for a single text without blocking the event loop."""
79
+ return await asyncio.to_thread(self._embed_sync, text)
80
+
81
+ async def embed_batch_async(self, texts: list[str]) -> list[list[float]]:
82
+ """Generate embeddings for multiple texts without blocking the event loop."""
83
+ return await asyncio.to_thread(self._embed_batch_sync, texts)
src/westernfront/services/http.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared HTTP client service with connection pooling."""
2
+
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
5
+
6
+ import aiohttp
7
+ from loguru import logger
8
+
9
+ from westernfront.core.constants import HTTP_TIMEOUT_SECONDS
10
+
11
+
12
+ class HttpService:
13
+ """Shared HTTP client with connection pooling for all external requests."""
14
+
15
+ def __init__(self, timeout_seconds: int = HTTP_TIMEOUT_SECONDS) -> None:
16
+ self._timeout = aiohttp.ClientTimeout(total=timeout_seconds)
17
+ self._session: aiohttp.ClientSession | None = None
18
+
19
+ async def initialize(self) -> None:
20
+ """Initialize the shared HTTP session with connection pooling."""
21
+ if self._session is None or self._session.closed:
22
+ connector = aiohttp.TCPConnector(
23
+ limit=100,
24
+ limit_per_host=10,
25
+ ttl_dns_cache=300,
26
+ enable_cleanup_closed=True,
27
+ )
28
+ self._session = aiohttp.ClientSession(
29
+ connector=connector,
30
+ timeout=self._timeout,
31
+ )
32
+ logger.info("HTTP service initialized with connection pooling")
33
+
34
+ async def close(self) -> None:
35
+ """Close the HTTP session."""
36
+ if self._session and not self._session.closed:
37
+ await self._session.close()
38
+ self._session = None
39
+ logger.info("HTTP service closed")
40
+
41
+ @property
42
+ def session(self) -> aiohttp.ClientSession:
43
+ """Get the shared HTTP session."""
44
+ if self._session is None or self._session.closed:
45
+ raise RuntimeError("HTTP service not initialized")
46
+ return self._session
47
+
48
+ @asynccontextmanager
49
+ async def get(
50
+ self,
51
+ url: str,
52
+ params: dict | None = None,
53
+ headers: dict | None = None,
54
+ ) -> AsyncGenerator[aiohttp.ClientResponse, None]:
55
+ """Perform a GET request with automatic error handling."""
56
+ async with self.session.get(url, params=params, headers=headers) as response:
57
+ yield response
src/westernfront/services/newsapi.py CHANGED
@@ -1,41 +1,32 @@
1
- """NewsAPI integration service."""
2
 
 
3
  import hashlib
4
- from datetime import datetime, timedelta, timezone
5
- from typing import Optional
6
 
7
  import aiohttp
8
  from loguru import logger
9
 
 
10
  from westernfront.core.enums import SourceType
11
  from westernfront.core.models import NewsItem
12
  from westernfront.services.cache import CacheService
13
-
14
-
15
- NEWSAPI_BASE_URL = "https://newsapi.org/v2"
16
-
17
- INDIA_PAKISTAN_QUERIES = [
18
- "India Pakistan",
19
- "Kashmir conflict",
20
- "India Pakistan border",
21
- "LOC firing",
22
- "Indo-Pak",
23
- ]
24
 
25
 
26
  class NewsApiService:
27
- """Service for fetching news from NewsAPI.org."""
28
-
29
- def __init__(self, api_key: Optional[str], cache: CacheService) -> None:
30
- """
31
- Initialize the NewsAPI service.
32
 
33
- Args:
34
- api_key: NewsAPI API key (optional, service is disabled if None).
35
- cache: Cache service for storing results.
36
- """
 
 
37
  self._api_key = api_key
38
  self._cache = cache
 
 
39
 
40
  @property
41
  def is_enabled(self) -> bool:
@@ -49,14 +40,14 @@ class NewsApiService:
49
  raw = f"newsapi:{url}:{title}"
50
  return hashlib.sha256(raw.encode()).hexdigest()[:16]
51
 
52
- def _parse_date(self, date_str: Optional[str]) -> datetime:
53
  """Parse ISO date string from NewsAPI."""
54
  if not date_str:
55
- return datetime.now(timezone.utc)
56
  try:
57
  return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
58
  except ValueError:
59
- return datetime.now(timezone.utc)
60
 
61
  async def _search_news(self, query: str, days_back: int = 2) -> list[dict]:
62
  """Search NewsAPI for articles matching query."""
@@ -69,22 +60,25 @@ class NewsApiService:
69
  logger.debug(f"Cache hit for NewsAPI query: {query}")
70
  return cached
71
 
72
- from_date = (datetime.now(timezone.utc) - timedelta(days=days_back)).strftime("%Y-%m-%d")
73
-
74
- params = {
75
- "q": query,
76
- "from": from_date,
77
- "language": "en",
78
- "sortBy": "publishedAt",
79
- "pageSize": 50,
80
- }
81
-
82
- headers = {"X-Api-Key": self._api_key}
83
-
84
- try:
85
- async with aiohttp.ClientSession() as session:
86
- url = f"{NEWSAPI_BASE_URL}/everything"
87
- async with session.get(url, params=params, headers=headers) as response:
 
 
 
88
  if response.status == 401:
89
  logger.error("NewsAPI: Invalid API key")
90
  return []
@@ -101,37 +95,34 @@ class NewsApiService:
101
  logger.info(f"NewsAPI: Found {len(articles)} articles for '{query}'")
102
  return articles
103
 
104
- except aiohttp.ClientError as e:
105
- logger.error(f"NewsAPI HTTP error: {e}")
106
- return []
107
 
108
  async def get_related_articles(
109
  self,
110
- keywords: Optional[list[str]] = None,
111
  days_back: int = 2,
112
  ) -> list[NewsItem]:
113
- """
114
- Get articles related to India-Pakistan from NewsAPI.
115
-
116
- Args:
117
- keywords: Optional additional keywords (defaults to built-in queries).
118
- days_back: How many days back to search.
119
-
120
- Returns:
121
- List of news items.
122
- """
123
  if not self.is_enabled:
124
  logger.debug("NewsAPI service is disabled (no API key)")
125
  return []
126
 
127
- queries = keywords if keywords else INDIA_PAKISTAN_QUERIES
 
 
 
 
128
  seen_urls: set[str] = set()
129
- results: list[NewsItem] = []
130
 
131
- for query in queries[:3]:
132
- articles = await self._search_news(query, days_back)
 
 
133
 
134
- for article in articles:
135
  url = article.get("url", "")
136
  if not url or url in seen_urls:
137
  continue
@@ -154,9 +145,9 @@ class NewsApiService:
154
  reliability_score=0.9,
155
  author=article.get("author"),
156
  )
157
- results.append(item)
158
  seen_urls.add(url)
159
 
160
- results.sort(key=lambda i: (-i.published_at.timestamp(), -i.reliability_score))
161
- logger.info(f"Found {len(results)} articles from NewsAPI")
162
- return results
 
1
+ """NewsAPI integration service with parallel query fetching."""
2
 
3
+ import asyncio
4
  import hashlib
5
+ from datetime import UTC, datetime, timedelta
 
6
 
7
  import aiohttp
8
  from loguru import logger
9
 
10
+ from westernfront.core.constants import MAX_CONCURRENT_REQUESTS, NEWSAPI_BASE_URL, NEWSAPI_QUERIES
11
  from westernfront.core.enums import SourceType
12
  from westernfront.core.models import NewsItem
13
  from westernfront.services.cache import CacheService
14
+ from westernfront.services.http import HttpService
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class NewsApiService:
18
+ """Service for fetching news from NewsAPI.org with parallel query execution."""
 
 
 
 
19
 
20
+ def __init__(
21
+ self,
22
+ api_key: str | None,
23
+ cache: CacheService,
24
+ http: HttpService,
25
+ ) -> None:
26
  self._api_key = api_key
27
  self._cache = cache
28
+ self._http = http
29
+ self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
30
 
31
  @property
32
  def is_enabled(self) -> bool:
 
40
  raw = f"newsapi:{url}:{title}"
41
  return hashlib.sha256(raw.encode()).hexdigest()[:16]
42
 
43
+ def _parse_date(self, date_str: str | None) -> datetime:
44
  """Parse ISO date string from NewsAPI."""
45
  if not date_str:
46
+ return datetime.now(UTC)
47
  try:
48
  return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
49
  except ValueError:
50
+ return datetime.now(UTC)
51
 
52
  async def _search_news(self, query: str, days_back: int = 2) -> list[dict]:
53
  """Search NewsAPI for articles matching query."""
 
60
  logger.debug(f"Cache hit for NewsAPI query: {query}")
61
  return cached
62
 
63
+ async with self._semaphore:
64
+ from_date = (datetime.now(UTC) - timedelta(days=days_back)).strftime("%Y-%m-%d")
65
+
66
+ params = {
67
+ "q": query,
68
+ "from": from_date,
69
+ "language": "en",
70
+ "sortBy": "publishedAt",
71
+ "pageSize": 50,
72
+ }
73
+
74
+ headers = {"X-Api-Key": self._api_key}
75
+
76
+ try:
77
+ async with self._http.get(
78
+ f"{NEWSAPI_BASE_URL}/everything",
79
+ params=params,
80
+ headers=headers,
81
+ ) as response:
82
  if response.status == 401:
83
  logger.error("NewsAPI: Invalid API key")
84
  return []
 
95
  logger.info(f"NewsAPI: Found {len(articles)} articles for '{query}'")
96
  return articles
97
 
98
+ except aiohttp.ClientError as e:
99
+ logger.error(f"NewsAPI HTTP error: {e}")
100
+ return []
101
 
102
  async def get_related_articles(
103
  self,
104
+ keywords: list[str] | None = None,
105
  days_back: int = 2,
106
  ) -> list[NewsItem]:
107
+ """Get articles related to India-Pakistan from NewsAPI in parallel."""
 
 
 
 
 
 
 
 
 
108
  if not self.is_enabled:
109
  logger.debug("NewsAPI service is disabled (no API key)")
110
  return []
111
 
112
+ queries = keywords if keywords else NEWSAPI_QUERIES[:3]
113
+
114
+ tasks = [self._search_news(query, days_back) for query in queries]
115
+ results = await asyncio.gather(*tasks, return_exceptions=True)
116
+
117
  seen_urls: set[str] = set()
118
+ all_articles: list[NewsItem] = []
119
 
120
+ for i, result in enumerate(results):
121
+ if isinstance(result, Exception):
122
+ logger.error(f"NewsAPI query '{queries[i]}' failed: {result}")
123
+ continue
124
 
125
+ for article in result:
126
  url = article.get("url", "")
127
  if not url or url in seen_urls:
128
  continue
 
145
  reliability_score=0.9,
146
  author=article.get("author"),
147
  )
148
+ all_articles.append(item)
149
  seen_urls.add(url)
150
 
151
+ all_articles.sort(key=lambda i: (-i.published_at.timestamp(), -i.reliability_score))
152
+ logger.info(f"Found {len(all_articles)} articles from NewsAPI")
153
+ return all_articles
src/westernfront/services/parsing.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Response parsing service for AI-generated analysis data."""
2
+
3
+ from datetime import datetime
4
+
5
+ from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
6
+ from westernfront.core.models import KeyDevelopment
7
+
8
+
9
+ class ResponseParser:
10
+ """Parses and validates AI-generated analysis response data."""
11
+
12
+ def parse_tension_level(self, value: str) -> TensionLevel:
13
+ """Parse tension level from string."""
14
+ value = value.upper()
15
+ if "CRITICAL" in value:
16
+ return TensionLevel.CRITICAL
17
+ if "HIGH" in value:
18
+ return TensionLevel.HIGH
19
+ if "MEDIUM" in value:
20
+ return TensionLevel.MEDIUM
21
+ return TensionLevel.LOW
22
+
23
+ def parse_tension_trend(self, value: str) -> TensionTrend:
24
+ """Parse tension trend from string."""
25
+ value = value.upper()
26
+ if "INCREASING" in value:
27
+ return TensionTrend.INCREASING
28
+ if "DECREASING" in value:
29
+ return TensionTrend.DECREASING
30
+ return TensionTrend.STABLE
31
+
32
+ def parse_analysis_type(self, value: str) -> AnalysisType:
33
+ """Parse analysis type from string."""
34
+ value = value.upper()
35
+ if "MILITARY" in value:
36
+ return AnalysisType.MILITARY
37
+ if "DIPLOMATIC" in value:
38
+ return AnalysisType.DIPLOMATIC
39
+ if "INTERNAL" in value:
40
+ return AnalysisType.INTERNAL_SECURITY
41
+ if "POLITICAL" in value:
42
+ return AnalysisType.POLITICAL
43
+ return AnalysisType.OTHER
44
+
45
+ def parse_key_developments(self, data: list[dict]) -> list[KeyDevelopment]:
46
+ """Parse key developments from response data."""
47
+ developments = []
48
+ for item in data:
49
+ if not isinstance(item, dict):
50
+ continue
51
+ developments.append(
52
+ KeyDevelopment(
53
+ title=item.get("title", "Unnamed"),
54
+ description=item.get("description", "No description"),
55
+ sources=item.get("sources", []),
56
+ timestamp=datetime.now(),
57
+ )
58
+ )
59
+ return developments
60
+
61
+ def parse_tension_score(self, raw_score: int | float | str | None) -> int:
62
+ """Parse and clamp tension score to valid range [1, 10]."""
63
+ if raw_score is None:
64
+ return 1
65
+ if isinstance(raw_score, int | float):
66
+ return max(1, min(10, int(raw_score)))
67
+ if isinstance(raw_score, str) and raw_score.isdigit():
68
+ return max(1, min(10, int(raw_score)))
69
+ return 1
70
+
71
+ def parse_key_entities(self, entities: list[str] | str | None) -> list[str]:
72
+ """Parse key entities from response, handling string or list format."""
73
+ if entities is None:
74
+ return []
75
+ if isinstance(entities, str):
76
+ return [e.strip() for e in entities.split(",") if e.strip()]
77
+ if isinstance(entities, list):
78
+ return entities
79
+ return []
80
+
81
+ def count_sources(self, items: list[dict]) -> dict[str, int]:
82
+ """Count items by source type."""
83
+ counts: dict[str, int] = {}
84
+ for item in items:
85
+ meta = item.get("metadata", {})
86
+ key = meta.get("source_type", "unknown")
87
+ counts[key] = counts.get(key, 0) + 1
88
+ return counts
src/westernfront/services/reddit.py CHANGED
@@ -1,7 +1,7 @@
1
- """Reddit data collection service using AsyncPRAW."""
2
 
3
- from datetime import datetime, timedelta, timezone
4
- from typing import Optional
5
  from urllib.parse import urlparse
6
 
7
  import aiohttp
@@ -10,40 +10,19 @@ import asyncprawcore
10
  from loguru import logger
11
  from tenacity import retry, stop_after_attempt, wait_exponential
12
 
 
 
 
 
 
13
  from westernfront.core.enums import SourceType
 
14
  from westernfront.core.models import NewsItem, SubredditSource
15
  from westernfront.services.cache import CacheService
16
 
17
 
18
- RELIABLE_DOMAINS = frozenset([
19
- "bbc.com", "reuters.com", "apnews.com", "aljazeera.com",
20
- "nytimes.com", "wsj.com", "ft.com", "economist.com",
21
- "thediplomat.com", "foreignpolicy.com", "foreignaffairs.com",
22
- "dawn.com", "timesofindia.indiatimes.com", "ndtv.com", "geo.tv",
23
- ])
24
-
25
- DEFAULT_SUBREDDITS = [
26
- # High-quality geopolitics sources
27
- SubredditSource(name="geopolitics", reliability_score=0.85),
28
- SubredditSource(name="CredibleDefense", reliability_score=0.9),
29
- SubredditSource(name="worldnews", reliability_score=0.8),
30
- SubredditSource(name="neutralnews", reliability_score=0.8),
31
- SubredditSource(name="DefenseNews", reliability_score=0.85),
32
- # South Asia focused
33
- SubredditSource(name="GeopoliticsIndia", reliability_score=0.75),
34
- SubredditSource(name="SouthAsia", reliability_score=0.7),
35
- SubredditSource(name="india", reliability_score=0.7),
36
- SubredditSource(name="pakistan", reliability_score=0.7),
37
- # Regional neighbors
38
- SubredditSource(name="Nepal", reliability_score=0.65),
39
- SubredditSource(name="bangladesh", reliability_score=0.65),
40
- SubredditSource(name="srilanka", reliability_score=0.65),
41
- SubredditSource(name="China", reliability_score=0.6),
42
- ]
43
-
44
-
45
  class RedditService:
46
- """Service for collecting posts from Reddit via AsyncPRAW."""
47
 
48
  def __init__(
49
  self,
@@ -52,30 +31,17 @@ class RedditService:
52
  user_agent: str,
53
  cache: CacheService,
54
  ) -> None:
55
- """
56
- Initialize the Reddit service.
57
-
58
- Args:
59
- client_id: Reddit API client ID.
60
- client_secret: Reddit API client secret.
61
- user_agent: User agent string for API requests.
62
- cache: Cache service for storing results.
63
- """
64
  self._client_id = client_id
65
  self._client_secret = client_secret
66
  self._user_agent = user_agent
67
  self._cache = cache
68
- self._reddit: Optional[asyncpraw.Reddit] = None
69
- self._session: Optional[aiohttp.ClientSession] = None
70
  self._sources = list(DEFAULT_SUBREDDITS)
 
71
 
72
  async def initialize(self) -> bool:
73
- """
74
- Initialize the Reddit API client.
75
-
76
- Returns:
77
- True if initialization was successful.
78
- """
79
  logger.info("Initializing Reddit service")
80
  self._session = aiohttp.ClientSession()
81
  self._reddit = asyncpraw.Reddit(
@@ -115,7 +81,7 @@ class RedditService:
115
  try:
116
  domain = urlparse(url).netloc
117
  return domain.replace("www.", "")
118
- except Exception:
119
  return ""
120
 
121
  def _calculate_reliability(self, post_url: str, score: int, base_score: float) -> float:
@@ -139,9 +105,9 @@ class RedditService:
139
  source: SubredditSource,
140
  limit: int = 100,
141
  ) -> list[NewsItem]:
142
- """Fetch posts from a subreddit with retry logic."""
143
  if not self._reddit:
144
- raise RuntimeError("Reddit service not initialized")
145
 
146
  cache_key = f"reddit_{source.name}_{limit}"
147
  cached = await self._cache.get(cache_key)
@@ -149,69 +115,62 @@ class RedditService:
149
  logger.debug(f"Cache hit for r/{source.name}")
150
  return cached
151
 
152
- logger.info(f"Fetching posts from r/{source.name}")
153
- subreddit = await self._reddit.subreddit(source.name)
154
- posts: list[NewsItem] = []
155
-
156
- async for submission in subreddit.new(limit=limit):
157
- content = f"{submission.title}\n{getattr(submission, 'selftext', '')}"
158
- author = str(submission.author) if submission.author else "[deleted]"
159
-
160
- post = NewsItem(
161
- id=submission.id,
162
- title=submission.title,
163
- content=content,
164
- url=submission.url,
165
- source_name=f"r/{source.name}",
166
- source_type=SourceType.REDDIT,
167
- published_at=datetime.fromtimestamp(submission.created_utc, tz=timezone.utc),
168
- reliability_score=self._calculate_reliability(
169
- submission.url, submission.score, source.reliability_score
170
- ),
171
- author=author,
172
- score=submission.score,
173
- )
174
- posts.append(post)
175
-
176
- await self._cache.set(cache_key, posts)
177
- logger.info(f"Fetched {len(posts)} posts from r/{source.name}")
178
- return posts
179
-
180
- async def get_all_posts(
181
- self,
182
- days_back: int = 2,
183
- ) -> list[NewsItem]:
184
- """
185
- Get ALL posts from all active subreddits without keyword filtering.
186
-
187
- Used for RAG ingestion where vector search determines relevance.
188
 
189
- Args:
190
- days_back: How many days back to search.
191
 
192
- Returns:
193
- List of all news items.
194
- """
195
- active_sources = [s for s in self._sources if s.is_active]
196
- cutoff = datetime.now(timezone.utc) - timedelta(days=days_back)
197
  seen_ids: set[str] = set()
198
- results: list[NewsItem] = []
199
-
200
- for source in active_sources:
201
- try:
202
- posts = await self._fetch_subreddit_posts(source)
203
- for post in posts:
204
- if post.published_at < cutoff:
205
- continue
206
- if post.id in seen_ids:
207
- continue
208
- results.append(post)
209
- seen_ids.add(post.id)
210
-
211
- except asyncprawcore.exceptions.RequestException as e:
212
- logger.error(f"Reddit API error for r/{source.name}: {e}")
213
- except Exception as e:
214
- logger.error(f"Error fetching r/{source.name}: {e}")
215
-
216
- logger.info(f"Collected {len(results)} total Reddit posts for ingestion")
217
- return results
 
 
1
+ """Reddit data collection service using AsyncPRAW with full parallelization."""
2
 
3
+ import asyncio
4
+ from datetime import UTC, datetime, timedelta
5
  from urllib.parse import urlparse
6
 
7
  import aiohttp
 
10
  from loguru import logger
11
  from tenacity import retry, stop_after_attempt, wait_exponential
12
 
13
+ from westernfront.core.constants import (
14
+ DEFAULT_SUBREDDITS,
15
+ MAX_CONCURRENT_REQUESTS,
16
+ RELIABLE_DOMAINS,
17
+ )
18
  from westernfront.core.enums import SourceType
19
+ from westernfront.core.exceptions import ServiceNotInitializedError
20
  from westernfront.core.models import NewsItem, SubredditSource
21
  from westernfront.services.cache import CacheService
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class RedditService:
25
+ """Service for collecting posts from Reddit via AsyncPRAW with parallel fetching."""
26
 
27
  def __init__(
28
  self,
 
31
  user_agent: str,
32
  cache: CacheService,
33
  ) -> None:
 
 
 
 
 
 
 
 
 
34
  self._client_id = client_id
35
  self._client_secret = client_secret
36
  self._user_agent = user_agent
37
  self._cache = cache
38
+ self._reddit: asyncpraw.Reddit | None = None
39
+ self._session: aiohttp.ClientSession | None = None
40
  self._sources = list(DEFAULT_SUBREDDITS)
41
+ self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
42
 
43
  async def initialize(self) -> bool:
44
+ """Initialize the Reddit API client."""
 
 
 
 
 
45
  logger.info("Initializing Reddit service")
46
  self._session = aiohttp.ClientSession()
47
  self._reddit = asyncpraw.Reddit(
 
81
  try:
82
  domain = urlparse(url).netloc
83
  return domain.replace("www.", "")
84
+ except (ValueError, AttributeError):
85
  return ""
86
 
87
  def _calculate_reliability(self, post_url: str, score: int, base_score: float) -> float:
 
105
  source: SubredditSource,
106
  limit: int = 100,
107
  ) -> list[NewsItem]:
108
+ """Fetch posts from a subreddit with retry logic and rate limiting."""
109
  if not self._reddit:
110
+ raise ServiceNotInitializedError("Reddit service not initialized")
111
 
112
  cache_key = f"reddit_{source.name}_{limit}"
113
  cached = await self._cache.get(cache_key)
 
115
  logger.debug(f"Cache hit for r/{source.name}")
116
  return cached
117
 
118
+ async with self._semaphore:
119
+ logger.info(f"Fetching posts from r/{source.name}")
120
+ subreddit = await self._reddit.subreddit(source.name)
121
+ posts: list[NewsItem] = []
122
+
123
+ async for submission in subreddit.new(limit=limit):
124
+ content = f"{submission.title}\n{getattr(submission, 'selftext', '')}"
125
+ author = str(submission.author) if submission.author else "[deleted]"
126
+
127
+ post = NewsItem(
128
+ id=submission.id,
129
+ title=submission.title,
130
+ content=content,
131
+ url=submission.url,
132
+ source_name=f"r/{source.name}",
133
+ source_type=SourceType.REDDIT,
134
+ published_at=datetime.fromtimestamp(submission.created_utc, tz=UTC),
135
+ reliability_score=self._calculate_reliability(
136
+ submission.url, submission.score, source.reliability_score
137
+ ),
138
+ author=author,
139
+ score=submission.score,
140
+ )
141
+ posts.append(post)
142
+
143
+ await self._cache.set(cache_key, posts)
144
+ logger.info(f"Fetched {len(posts)} posts from r/{source.name}")
145
+ return posts
146
+
147
+ async def get_all_posts(self, days_back: int = 2) -> list[NewsItem]:
148
+ """Get all posts from all active subreddits in parallel."""
149
+ active_sources = [s for s in self._sources if s.is_active]
150
+ cutoff = datetime.now(UTC) - timedelta(days=days_back)
 
 
 
151
 
152
+ tasks = [self._fetch_subreddit_posts(source) for source in active_sources]
153
+ results = await asyncio.gather(*tasks, return_exceptions=True)
154
 
 
 
 
 
 
155
  seen_ids: set[str] = set()
156
+ all_posts: list[NewsItem] = []
157
+
158
+ for i, result in enumerate(results):
159
+ if isinstance(result, Exception):
160
+ source_name = active_sources[i].name
161
+ if isinstance(result, asyncprawcore.exceptions.RequestException):
162
+ logger.error(f"Reddit API error for r/{source_name}: {result}")
163
+ else:
164
+ logger.error(f"Error fetching r/{source_name}: {result}")
165
+ continue
166
+
167
+ for post in result:
168
+ if post.published_at < cutoff:
169
+ continue
170
+ if post.id in seen_ids:
171
+ continue
172
+ all_posts.append(post)
173
+ seen_ids.add(post.id)
174
+
175
+ logger.info(f"Collected {len(all_posts)} total Reddit posts for ingestion")
176
+ return all_posts
src/westernfront/services/retrieval.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Enhanced retrieval service with diversity and recency weighting."""
2
+
3
+ from datetime import UTC, datetime
4
+ from typing import TYPE_CHECKING
5
+
6
+ from loguru import logger
7
+
8
+ from westernfront.core.constants import RECENCY_BOOST, SOURCE_DIVERSITY_RULES
9
+
10
+ if TYPE_CHECKING:
11
+ from westernfront.repositories.vectors import VectorRepository
12
+
13
+
14
+ class RetrievalService:
15
+ """Enhanced retrieval with source diversity and temporal weighting."""
16
+
17
+ def __init__(self, vectors: "VectorRepository") -> None:
18
+ self._vectors = vectors
19
+
20
+ def _calculate_recency_score(self, published_at: str) -> float:
21
+ """Calculate recency boost based on publication time."""
22
+ try:
23
+ pub_date = datetime.fromisoformat(published_at.replace("Z", "+00:00"))
24
+ now = datetime.now(UTC)
25
+ hours_old = (now - pub_date).total_seconds() / 3600
26
+
27
+ if hours_old <= 24:
28
+ return RECENCY_BOOST["hours_24"]
29
+ if hours_old <= 48:
30
+ return RECENCY_BOOST["hours_48"]
31
+ if hours_old <= 168:
32
+ return RECENCY_BOOST["days_7"]
33
+ return RECENCY_BOOST["older"]
34
+ except (ValueError, TypeError):
35
+ return 1.0
36
+
37
+ def _apply_recency_boost(self, items: list[dict]) -> list[dict]:
38
+ """Apply recency boost to similarity scores."""
39
+ for item in items:
40
+ meta = item.get("metadata", {})
41
+ published_at = meta.get("published_at", "")
42
+ recency_mult = self._calculate_recency_score(published_at)
43
+ original_score = item.get("similarity_score", 0.5)
44
+ item["boosted_score"] = min(1.0, original_score * recency_mult)
45
+ item["recency_multiplier"] = recency_mult
46
+
47
+ items.sort(key=lambda x: x.get("boosted_score", 0), reverse=True)
48
+ return items
49
+
50
+ def _enforce_diversity(self, items: list[dict], max_items: int) -> list[dict]:
51
+ """Enforce source type diversity in results."""
52
+ by_source: dict[str, list[dict]] = {"reddit": [], "rss": [], "newsapi": []}
53
+
54
+ for item in items:
55
+ meta = item.get("metadata", {})
56
+ source_type = meta.get("source_type", "unknown")
57
+ if source_type in by_source:
58
+ by_source[source_type].append(item)
59
+
60
+ result = []
61
+ for source_type, rules in SOURCE_DIVERSITY_RULES.items():
62
+ min_count = int(max_items * rules["min_pct"])
63
+ available = by_source.get(source_type, [])
64
+ to_add = available[:min_count]
65
+ result.extend(to_add)
66
+ logger.debug(f"Diversity: Added {len(to_add)} from {source_type} (min: {min_count})")
67
+
68
+ seen_ids = {item["id"] for item in result}
69
+ remaining_items = [item for item in items if item["id"] not in seen_ids]
70
+
71
+ space_left = max_items - len(result)
72
+ for item in remaining_items[:space_left]:
73
+ meta = item.get("metadata", {})
74
+ source_type = meta.get("source_type", "unknown")
75
+ max_count = int(max_items * SOURCE_DIVERSITY_RULES.get(source_type, {}).get("max_pct", 1.0))
76
+ current_count = sum(1 for r in result if r.get("metadata", {}).get("source_type") == source_type)
77
+ if current_count < max_count:
78
+ result.append(item)
79
+
80
+ result.sort(key=lambda x: x.get("boosted_score", 0), reverse=True)
81
+ return result[:max_items]
82
+
83
+ def retrieve_with_quality(
84
+ self,
85
+ topics: list[str],
86
+ max_items: int = 40,
87
+ n_per_topic: int = 10,
88
+ ) -> list[dict]:
89
+ """Retrieve items with recency boost and source diversity."""
90
+ if not self._vectors.is_initialized:
91
+ logger.warning("Vector store not initialized")
92
+ return []
93
+
94
+ raw_results = self._vectors.query_by_topics(topics, n_per_topic=n_per_topic)
95
+ logger.info(f"Retrieved {len(raw_results)} raw items from vector store")
96
+
97
+ boosted = self._apply_recency_boost(raw_results)
98
+ diverse = self._enforce_diversity(boosted, max_items)
99
+ logger.info(f"After diversity enforcement: {len(diverse)} items")
100
+
101
+ return diverse
src/westernfront/services/rss.py CHANGED
@@ -1,99 +1,29 @@
1
- """RSS feed collection service."""
2
 
 
3
  import hashlib
4
- from datetime import datetime, timezone
5
- from typing import Optional
6
  from email.utils import parsedate_to_datetime
7
 
8
  import aiohttp
9
  import feedparser
10
  from loguru import logger
11
 
 
12
  from westernfront.core.enums import SourceType
13
  from westernfront.core.models import NewsItem, RssFeed
14
  from westernfront.services.cache import CacheService
15
-
16
-
17
- DEFAULT_RSS_FEEDS = [
18
- # Tier 1: Pakistan
19
- RssFeed(
20
- name="Dawn (Pakistan)",
21
- url="https://www.dawn.com/feeds/home",
22
- reliability_score=0.85,
23
- ),
24
- RssFeed(
25
- name="Geo News",
26
- url="https://www.geo.tv/rss/1/1",
27
- reliability_score=0.8,
28
- ),
29
- RssFeed(
30
- name="Express Tribune",
31
- url="https://tribune.com.pk/feed/home",
32
- reliability_score=0.75,
33
- ),
34
- # India
35
- RssFeed(
36
- name="Times of India",
37
- url="https://timesofindia.indiatimes.com/rssfeeds/296589292.cms",
38
- reliability_score=0.75,
39
- ),
40
- RssFeed(
41
- name="NDTV India",
42
- url="https://feeds.feedburner.com/ndtvnews-india-news",
43
- reliability_score=0.8,
44
- ),
45
- RssFeed(
46
- name="The Hindu",
47
- url="https://www.thehindu.com/news/national/feeder/default.rss",
48
- reliability_score=0.85,
49
- ),
50
- RssFeed(
51
- name="Indian Express",
52
- url="https://indianexpress.com/section/india/feed/",
53
- reliability_score=0.85,
54
- ),
55
- # Tier 2: China/Nepal/Bangladesh
56
- RssFeed(
57
- name="South China Morning Post - Asia",
58
- url="https://www.scmp.com/rss/91/feed",
59
- reliability_score=0.85,
60
- ),
61
- RssFeed(
62
- name="Kathmandu Post",
63
- url="https://kathmandupost.com/rss",
64
- reliability_score=0.75,
65
- ),
66
- RssFeed(
67
- name="Dhaka Tribune",
68
- url="https://www.dhakatribune.com/rss",
69
- reliability_score=0.75,
70
- ),
71
- RssFeed(
72
- name="Daily Star Bangladesh",
73
- url="https://www.thedailystar.net/rss.xml",
74
- reliability_score=0.75,
75
- ),
76
- # Tier 3: Others
77
- RssFeed(
78
- name="Daily Mirror Sri Lanka",
79
- url="http://www.dailymirror.lk/RSS_Feeds/breaking-news",
80
- reliability_score=0.7,
81
- ),
82
- ]
83
 
84
 
85
  class RssService:
86
- """Service for collecting news from RSS feeds."""
87
-
88
- def __init__(self, cache: CacheService) -> None:
89
- """
90
- Initialize the RSS service.
91
 
92
- Args:
93
- cache: Cache service for storing results.
94
- """
95
  self._cache = cache
 
96
  self._feeds = list(DEFAULT_RSS_FEEDS)
 
97
 
98
  @property
99
  def feeds(self) -> list[RssFeed]:
@@ -112,11 +42,11 @@ class RssService:
112
  try:
113
  parsed = parsedate_to_datetime(entry[date_field])
114
  if parsed.tzinfo is None:
115
- return parsed.replace(tzinfo=timezone.utc)
116
  return parsed
117
  except (ValueError, TypeError):
118
  pass
119
- return datetime.now(timezone.utc)
120
 
121
  def _generate_id(self, entry: dict, feed_name: str) -> str:
122
  """Generate a unique ID for an entry."""
@@ -133,81 +63,73 @@ class RssService:
133
  logger.debug(f"Cache hit for RSS: {feed.name}")
134
  return cached
135
 
136
- logger.info(f"Fetching RSS feed: {feed.name}")
137
- items: list[NewsItem] = []
 
138
 
139
- try:
140
- async with aiohttp.ClientSession() as session:
141
- async with session.get(feed.url, timeout=30) as response:
142
  if response.status != 200:
143
  logger.warning(f"RSS feed {feed.name} returned {response.status}")
144
  return []
145
  content = await response.text()
146
 
147
- parsed = feedparser.parse(content)
148
-
149
- for entry in parsed.entries:
150
- title = entry.get("title", "").strip()
151
- description = entry.get("description", "") or entry.get("summary", "")
152
- link = entry.get("link", "")
153
-
154
- if not title or not link:
155
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- item = NewsItem(
158
- id=self._generate_id(entry, feed.name),
159
- title=title,
160
- content=f"{title}\n{description}",
161
- url=link,
162
- source_name=feed.name,
163
- source_type=SourceType.RSS,
164
- published_at=self._parse_date(entry),
165
- reliability_score=feed.reliability_score,
166
- author=entry.get("author"),
167
- )
168
- items.append(item)
169
-
170
- await self._cache.set(cache_key, items)
171
- logger.info(f"Fetched {len(items)} items from {feed.name}")
172
-
173
- except aiohttp.ClientError as e:
174
- logger.error(f"HTTP error fetching {feed.name}: {e}")
175
- except Exception as e:
176
- logger.error(f"Error parsing {feed.name}: {e}")
177
-
178
- return items
179
-
180
- async def get_all_articles(
181
- self,
182
- days_back: int = 2,
183
- ) -> list[NewsItem]:
184
- """
185
- Get ALL articles from all active feeds without keyword filtering.
186
-
187
- Used for RAG ingestion where vector search determines relevance.
188
-
189
- Args:
190
- days_back: How many days back to search.
191
-
192
- Returns:
193
- List of all news items.
194
- """
195
- from datetime import timedelta
196
 
197
- active_feeds = [f for f in self._feeds if f.is_active]
198
- cutoff = datetime.now(timezone.utc) - timedelta(days=days_back)
199
  seen_urls: set[str] = set()
200
- results: list[NewsItem] = []
 
 
 
 
 
201
 
202
- for feed in active_feeds:
203
- items = await self._fetch_feed(feed)
204
- for item in items:
205
  if item.published_at < cutoff:
206
  continue
207
  if item.url in seen_urls:
208
  continue
209
- results.append(item)
210
  seen_urls.add(item.url)
211
 
212
- logger.info(f"Collected {len(results)} total RSS articles for ingestion")
213
- return results
 
1
+ """RSS feed collection service with parallel fetching."""
2
 
3
+ import asyncio
4
  import hashlib
5
+ from datetime import UTC, datetime, timedelta
 
6
  from email.utils import parsedate_to_datetime
7
 
8
  import aiohttp
9
  import feedparser
10
  from loguru import logger
11
 
12
+ from westernfront.core.constants import DEFAULT_RSS_FEEDS, MAX_CONCURRENT_REQUESTS
13
  from westernfront.core.enums import SourceType
14
  from westernfront.core.models import NewsItem, RssFeed
15
  from westernfront.services.cache import CacheService
16
+ from westernfront.services.http import HttpService
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  class RssService:
20
+ """Service for collecting news from RSS feeds with parallel fetching."""
 
 
 
 
21
 
22
+ def __init__(self, cache: CacheService, http: HttpService) -> None:
 
 
23
  self._cache = cache
24
+ self._http = http
25
  self._feeds = list(DEFAULT_RSS_FEEDS)
26
+ self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
27
 
28
  @property
29
  def feeds(self) -> list[RssFeed]:
 
42
  try:
43
  parsed = parsedate_to_datetime(entry[date_field])
44
  if parsed.tzinfo is None:
45
+ return parsed.replace(tzinfo=UTC)
46
  return parsed
47
  except (ValueError, TypeError):
48
  pass
49
+ return datetime.now(UTC)
50
 
51
  def _generate_id(self, entry: dict, feed_name: str) -> str:
52
  """Generate a unique ID for an entry."""
 
63
  logger.debug(f"Cache hit for RSS: {feed.name}")
64
  return cached
65
 
66
+ async with self._semaphore:
67
+ logger.info(f"Fetching RSS feed: {feed.name}")
68
+ items: list[NewsItem] = []
69
 
70
+ try:
71
+ async with self._http.get(feed.url) as response:
 
72
  if response.status != 200:
73
  logger.warning(f"RSS feed {feed.name} returned {response.status}")
74
  return []
75
  content = await response.text()
76
 
77
+ parsed = feedparser.parse(content)
78
+
79
+ for entry in parsed.entries:
80
+ title = entry.get("title", "").strip()
81
+ description = entry.get("description", "") or entry.get("summary", "")
82
+ link = entry.get("link", "")
83
+
84
+ if not title or not link:
85
+ continue
86
+
87
+ item = NewsItem(
88
+ id=self._generate_id(entry, feed.name),
89
+ title=title,
90
+ content=f"{title}\n{description}",
91
+ url=link,
92
+ source_name=feed.name,
93
+ source_type=SourceType.RSS,
94
+ published_at=self._parse_date(entry),
95
+ reliability_score=feed.reliability_score,
96
+ author=entry.get("author"),
97
+ )
98
+ items.append(item)
99
+
100
+ await self._cache.set(cache_key, items)
101
+ logger.info(f"Fetched {len(items)} items from {feed.name}")
102
+
103
+ except aiohttp.ClientError as e:
104
+ logger.error(f"HTTP error fetching {feed.name}: {e}")
105
+ except feedparser.CharacterEncodingOverride as e:
106
+ logger.error(f"Parse error for {feed.name}: {e}")
107
+
108
+ return items
109
+
110
+ async def get_all_articles(self, days_back: int = 2) -> list[NewsItem]:
111
+ """Get all articles from all active feeds in parallel."""
112
+ active_feeds = [f for f in self._feeds if f.is_active]
113
+ cutoff = datetime.now(UTC) - timedelta(days=days_back)
114
 
115
+ tasks = [self._fetch_feed(feed) for feed in active_feeds]
116
+ results = await asyncio.gather(*tasks, return_exceptions=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
 
 
118
  seen_urls: set[str] = set()
119
+ all_articles: list[NewsItem] = []
120
+
121
+ for i, result in enumerate(results):
122
+ if isinstance(result, Exception):
123
+ logger.error(f"Error fetching {active_feeds[i].name}: {result}")
124
+ continue
125
 
126
+ for item in result:
 
 
127
  if item.published_at < cutoff:
128
  continue
129
  if item.url in seen_urls:
130
  continue
131
+ all_articles.append(item)
132
  seen_urls.add(item.url)
133
 
134
+ logger.info(f"Collected {len(all_articles)} total RSS articles for ingestion")
135
+ return all_articles
src/westernfront/services/scheduler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Background task scheduler for periodic analysis updates."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ from typing import TYPE_CHECKING, Protocol
6
+
7
+ from loguru import logger
8
+
9
+ if TYPE_CHECKING:
10
+ from westernfront.core.models import ConflictAnalysis
11
+
12
+
13
+ class AnalysisGenerator(Protocol):
14
+ """Protocol for analysis generation."""
15
+
16
+ async def generate_analysis(self) -> "ConflictAnalysis | None":
17
+ """Generate a new conflict analysis."""
18
+ ...
19
+
20
+
21
+ class AnalysisScheduler:
22
+ """Manages periodic background analysis updates."""
23
+
24
+ def __init__(
25
+ self,
26
+ generator: AnalysisGenerator,
27
+ interval_minutes: int = 60,
28
+ ) -> None:
29
+ self._generator = generator
30
+ self._interval_seconds = interval_minutes * 60
31
+ self._task: asyncio.Task[None] | None = None
32
+
33
+ def start(self) -> None:
34
+ """Start the periodic update task."""
35
+ if self._task is None:
36
+ self._task = asyncio.create_task(self._periodic_update())
37
+ logger.info(f"Scheduler started with {self._interval_seconds // 60}min interval")
38
+
39
+ async def stop(self) -> None:
40
+ """Stop the periodic update task."""
41
+ if self._task:
42
+ self._task.cancel()
43
+ with contextlib.suppress(asyncio.CancelledError):
44
+ await self._task
45
+ self._task = None
46
+ logger.info("Scheduler stopped")
47
+
48
+ @property
49
+ def is_running(self) -> bool:
50
+ """Check if the scheduler is running."""
51
+ return self._task is not None and not self._task.done()
52
+
53
+ async def _periodic_update(self) -> None:
54
+ """Background task for periodic analysis updates."""
55
+ await asyncio.sleep(5)
56
+ await self._run_update()
57
+
58
+ while True:
59
+ await asyncio.sleep(self._interval_seconds)
60
+ await self._run_update()
61
+
62
+ async def _run_update(self) -> None:
63
+ """Execute a single analysis update cycle."""
64
+ logger.info("Starting scheduled analysis update")
65
+ analysis = await self._generator.generate_analysis()
66
+ if analysis:
67
+ logger.info(f"Scheduled analysis complete. Tension: {analysis.tension_level.value}")
68
+ else:
69
+ logger.warning("Scheduled analysis produced no result")
src/westernfront/services/validation.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Analysis validation service for quality assurance."""
2
+
3
+ from datetime import datetime
4
+
5
+ from westernfront.core.constants import TENSION_LEVEL_CRITERIA
6
+ from westernfront.core.enums import TensionLevel
7
+
8
+
9
+ class AnalysisValidator:
10
+ """Validates AI-generated analysis for quality and consistency."""
11
+
12
+ def validate_tension_consistency(
13
+ self,
14
+ level: TensionLevel,
15
+ score: int,
16
+ rationale: str,
17
+ ) -> tuple[bool, list[str]]:
18
+ """Validate tension level matches score and rationale."""
19
+ issues = []
20
+ criteria = TENSION_LEVEL_CRITERIA.get(level.value.upper(), {})
21
+ expected_range = criteria.get("score_range", (1, 10))
22
+
23
+ if not (expected_range[0] <= score <= expected_range[1]):
24
+ issues.append(
25
+ f"Score {score} inconsistent with {level.value} level "
26
+ f"(expected {expected_range[0]}-{expected_range[1]})"
27
+ )
28
+
29
+ level_keywords = {
30
+ TensionLevel.LOW: ["calm", "normal", "routine", "stable", "peaceful"],
31
+ TensionLevel.MEDIUM: ["elevated", "heightened", "tension", "concern", "monitoring"],
32
+ TensionLevel.HIGH: ["serious", "alert", "mobilization", "firing", "escalation"],
33
+ TensionLevel.CRITICAL: ["urgent", "imminent", "emergency", "active", "combat"],
34
+ }
35
+
36
+ keywords = level_keywords.get(level, [])
37
+ rationale_lower = rationale.lower()
38
+ if keywords and not any(kw in rationale_lower for kw in keywords):
39
+ issues.append(
40
+ f"Rationale may not justify {level.value} level "
41
+ f"(expected keywords like: {', '.join(keywords[:3])})"
42
+ )
43
+
44
+ return len(issues) == 0, issues
45
+
46
+ def validate_entities(
47
+ self,
48
+ entities: list[str],
49
+ source_texts: list[str],
50
+ ) -> tuple[bool, list[str]]:
51
+ """Validate that key entities are grounded in source texts."""
52
+ combined_text = " ".join(source_texts).lower()
53
+ ungrounded = []
54
+
55
+ for entity in entities:
56
+ entity_lower = entity.lower()
57
+ entity_words = entity_lower.split()
58
+ found = any(word in combined_text for word in entity_words if len(word) > 3)
59
+ if not found:
60
+ ungrounded.append(entity)
61
+
62
+ is_valid = len(ungrounded) <= 1
63
+ return is_valid, ungrounded
64
+
65
+ def validate_dates(
66
+ self,
67
+ developments: list[dict],
68
+ ) -> tuple[bool, list[str]]:
69
+ """Validate that development timestamps are reasonable."""
70
+ issues = []
71
+ now = datetime.now()
72
+
73
+ for dev in developments:
74
+ title = dev.get("title", "Unknown")
75
+ timestamp = dev.get("timestamp")
76
+ if timestamp and isinstance(timestamp, datetime) and timestamp > now:
77
+ issues.append(f"Future date in development: {title}")
78
+
79
+ return len(issues) == 0, issues
80
+
81
+ def validate_analysis(
82
+ self,
83
+ analysis_data: dict,
84
+ source_texts: list[str],
85
+ ) -> tuple[bool, list[str]]:
86
+ """Perform comprehensive validation on analysis."""
87
+ all_issues = []
88
+
89
+ level_str = analysis_data.get("tension_level", "LOW")
90
+ try:
91
+ level = TensionLevel(level_str)
92
+ except ValueError:
93
+ level = TensionLevel.LOW
94
+ all_issues.append(f"Invalid tension level: {level_str}")
95
+
96
+ score = analysis_data.get("tension_score", 1)
97
+ if not isinstance(score, int):
98
+ try:
99
+ score = int(score)
100
+ except (ValueError, TypeError):
101
+ score = 1
102
+ all_issues.append(f"Invalid tension score: {analysis_data.get('tension_score')}")
103
+
104
+ rationale = analysis_data.get("tension_rationale", "")
105
+ _, tension_issues = self.validate_tension_consistency(level, score, rationale)
106
+ all_issues.extend(tension_issues)
107
+
108
+ entities = analysis_data.get("key_entities", [])
109
+ if isinstance(entities, list):
110
+ _, ungrounded = self.validate_entities(entities, source_texts)
111
+ if ungrounded:
112
+ all_issues.append(f"Possibly ungrounded entities: {', '.join(ungrounded)}")
113
+
114
+ developments = analysis_data.get("key_developments", [])
115
+ if isinstance(developments, list):
116
+ _, date_issues = self.validate_dates(developments)
117
+ all_issues.extend(date_issues)
118
+
119
+ return len(all_issues) == 0, all_issues
src/westernfront/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Utilities package exports."""
2
+
3
+ from westernfront.utils.json_parser import extract_json_from_response
4
+
5
+ __all__ = ["extract_json_from_response"]
src/westernfront/utils/json_parser.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for parsing JSON from LLM responses."""
2
+
3
+ import json
4
+ import re
5
+ from typing import Any
6
+
7
+ from loguru import logger
8
+
9
+
10
+ def extract_json_from_response(text: str) -> dict[str, Any] | None:
11
+ """
12
+ Extract JSON from an LLM response that may contain markdown code blocks or raw JSON.
13
+
14
+ Tries multiple strategies:
15
+ 1. Direct JSON parsing
16
+ 2. Extract from markdown code block
17
+ 3. Find JSON object in text
18
+ """
19
+ # Try direct parsing first
20
+ try:
21
+ return json.loads(text)
22
+ except json.JSONDecodeError:
23
+ pass
24
+
25
+ # Try extracting from markdown code block
26
+ json_match = re.search(r"```(?:json)?\n(.*?)\n```", text, re.DOTALL)
27
+ if json_match:
28
+ try:
29
+ return json.loads(json_match.group(1))
30
+ except json.JSONDecodeError:
31
+ pass
32
+
33
+ # Try finding JSON object in text
34
+ json_match = re.search(r"\{.*\}", text, re.DOTALL)
35
+ if json_match:
36
+ try:
37
+ return json.loads(json_match.group(0))
38
+ except json.JSONDecodeError:
39
+ pass
40
+
41
+ logger.warning(f"Failed to parse JSON from response: {text[:200]}...")
42
+ return None
tests/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (188 Bytes). View file
 
tests/__pycache__/conftest.cpython-312-pytest-8.4.2.pyc ADDED
Binary file (1.16 kB). View file
 
tests/__pycache__/test_services.cpython-312-pytest-8.4.2.pyc ADDED
Binary file (13 kB). View file
 
tests/test_api.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for API routes."""
2
+
3
+ import pytest
4
+ from httpx import ASGITransport, AsyncClient
5
+
6
+ from westernfront.main import app
7
+
8
+
9
+ @pytest.fixture
10
+ async def client():
11
+ """Create test client."""
12
+ transport = ASGITransport(app=app)
13
+ async with AsyncClient(transport=transport, base_url="http://test") as client:
14
+ yield client
15
+
16
+
17
+ class TestPublicEndpoints:
18
+ """Tests for public API endpoints."""
19
+
20
+ async def test_root_endpoint(self, client):
21
+ """Test root endpoint returns API info."""
22
+ response = await client.get("/")
23
+ assert response.status_code == 200
24
+ data = response.json()
25
+ assert data["name"] == "WesternFront API"
26
+ assert "version" in data
27
+ assert data["status"] in ["ready", "initializing"]
28
+
29
+ async def test_health_endpoint(self, client):
30
+ """Test health endpoint returns status."""
31
+ response = await client.get("/health")
32
+ assert response.status_code == 200
33
+ data = response.json()
34
+ assert data["status"] in ["healthy", "initializing"]
35
+ assert "version" in data
36
+ assert "timestamp" in data
37
+ assert "components" in data
38
+
39
+ async def test_health_head_endpoint(self, client):
40
+ """Test health HEAD request."""
41
+ response = await client.head("/health")
42
+ assert response.status_code == 200
43
+
44
+ async def test_tension_levels_endpoint(self, client):
45
+ """Test tension levels endpoint."""
46
+ response = await client.get(
47
+ "/tension-levels",
48
+ headers={"X-API-Key": "test-key"},
49
+ )
50
+ # Will return 401 if no valid key, or 200 with levels
51
+ assert response.status_code in [200, 401]
52
+
53
+
54
+ class TestProtectedEndpoints:
55
+ """Tests for protected API endpoints."""
56
+
57
+ async def test_analysis_requires_auth(self, client):
58
+ """Test analysis endpoint requires API key."""
59
+ response = await client.get("/analysis")
60
+ assert response.status_code == 401
61
+ assert "API key" in response.json()["detail"]
62
+
63
+ async def test_sources_requires_auth(self, client):
64
+ """Test sources endpoint requires API key."""
65
+ response = await client.get("/sources")
66
+ assert response.status_code == 401
67
+
68
+ async def test_keywords_requires_auth(self, client):
69
+ """Test keywords endpoint requires API key."""
70
+ response = await client.get("/keywords")
71
+ assert response.status_code == 401
tests/test_parsing.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ResponseParser service."""
2
+
3
+ import pytest
4
+ from datetime import datetime
5
+
6
+ from westernfront.core.enums import AnalysisType, TensionLevel, TensionTrend
7
+ from westernfront.services.parsing import ResponseParser
8
+
9
+
10
+ class TestResponseParser:
11
+ """Tests for ResponseParser methods."""
12
+
13
+ @pytest.fixture
14
+ def parser(self):
15
+ """Create parser instance."""
16
+ return ResponseParser()
17
+
18
+ def test_parse_tension_level_low(self, parser):
19
+ """Test parsing LOW tension level."""
20
+ assert parser.parse_tension_level("low") == TensionLevel.LOW
21
+ assert parser.parse_tension_level("LOW") == TensionLevel.LOW
22
+ assert parser.parse_tension_level("unknown") == TensionLevel.LOW
23
+
24
+ def test_parse_tension_level_medium(self, parser):
25
+ """Test parsing MEDIUM tension level."""
26
+ assert parser.parse_tension_level("medium") == TensionLevel.MEDIUM
27
+ assert parser.parse_tension_level("MEDIUM") == TensionLevel.MEDIUM
28
+
29
+ def test_parse_tension_level_high(self, parser):
30
+ """Test parsing HIGH tension level."""
31
+ assert parser.parse_tension_level("high") == TensionLevel.HIGH
32
+ assert parser.parse_tension_level("HIGH") == TensionLevel.HIGH
33
+
34
+ def test_parse_tension_level_critical(self, parser):
35
+ """Test parsing CRITICAL tension level."""
36
+ assert parser.parse_tension_level("critical") == TensionLevel.CRITICAL
37
+ assert parser.parse_tension_level("CRITICAL") == TensionLevel.CRITICAL
38
+
39
+ def test_parse_tension_trend(self, parser):
40
+ """Test parsing tension trends."""
41
+ assert parser.parse_tension_trend("increasing") == TensionTrend.INCREASING
42
+ assert parser.parse_tension_trend("decreasing") == TensionTrend.DECREASING
43
+ assert parser.parse_tension_trend("stable") == TensionTrend.STABLE
44
+ assert parser.parse_tension_trend("unknown") == TensionTrend.STABLE
45
+
46
+ def test_parse_analysis_type(self, parser):
47
+ """Test parsing analysis types."""
48
+ assert parser.parse_analysis_type("military") == AnalysisType.MILITARY
49
+ assert parser.parse_analysis_type("diplomatic") == AnalysisType.DIPLOMATIC
50
+ assert parser.parse_analysis_type("internal security") == AnalysisType.INTERNAL_SECURITY
51
+ assert parser.parse_analysis_type("political") == AnalysisType.POLITICAL
52
+ assert parser.parse_analysis_type("unknown") == AnalysisType.OTHER
53
+
54
+ def test_parse_tension_score_valid(self, parser):
55
+ """Test parsing valid tension scores."""
56
+ assert parser.parse_tension_score(5) == 5
57
+ assert parser.parse_tension_score(5.7) == 5
58
+ assert parser.parse_tension_score("7") == 7
59
+
60
+ def test_parse_tension_score_clamped(self, parser):
61
+ """Test tension scores are clamped to valid range."""
62
+ assert parser.parse_tension_score(0) == 1
63
+ assert parser.parse_tension_score(-5) == 1
64
+ assert parser.parse_tension_score(15) == 10
65
+ assert parser.parse_tension_score(None) == 1
66
+
67
+ def test_parse_key_entities_list(self, parser):
68
+ """Test parsing entities from list."""
69
+ result = parser.parse_key_entities(["India", "Pakistan", "Kashmir"])
70
+ assert result == ["India", "Pakistan", "Kashmir"]
71
+
72
+ def test_parse_key_entities_string(self, parser):
73
+ """Test parsing entities from comma-separated string."""
74
+ result = parser.parse_key_entities("India, Pakistan, Kashmir")
75
+ assert result == ["India", "Pakistan", "Kashmir"]
76
+
77
+ def test_parse_key_entities_none(self, parser):
78
+ """Test parsing None returns empty list."""
79
+ assert parser.parse_key_entities(None) == []
80
+
81
+ def test_parse_key_developments_valid(self, parser):
82
+ """Test parsing key developments."""
83
+ data = [
84
+ {
85
+ "title": "Test Event",
86
+ "description": "Test description",
87
+ "sources": ["Military Activity"],
88
+ }
89
+ ]
90
+ result = parser.parse_key_developments(data)
91
+ assert len(result) == 1
92
+ assert result[0].title == "Test Event"
93
+ assert result[0].description == "Test description"
94
+ assert result[0].sources == ["Military Activity"]
95
+
96
+ def test_parse_key_developments_skips_non_dict(self, parser):
97
+ """Test that non-dict items are skipped."""
98
+ data = [{"title": "Valid"}, "invalid", None]
99
+ result = parser.parse_key_developments(data)
100
+ assert len(result) == 1
101
+
102
+ def test_count_sources(self, parser):
103
+ """Test source counting."""
104
+ items = [
105
+ {"metadata": {"source_type": "reddit"}},
106
+ {"metadata": {"source_type": "reddit"}},
107
+ {"metadata": {"source_type": "rss"}},
108
+ {"metadata": {}},
109
+ ]
110
+ result = parser.count_sources(items)
111
+ assert result == {"reddit": 2, "rss": 1, "unknown": 1}