KSvend Claude Happy commited on
Commit
f491e48
·
1 Parent(s): bf99ddb

feat: three-phase batch job worker (submit → poll → harvest)

Browse files

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>

Files changed (2) hide show
  1. app/worker.py +90 -8
  2. tests/test_worker.py +110 -0
app/worker.py CHANGED
@@ -3,6 +3,7 @@ import asyncio
3
  import json
4
  import logging
5
  import os
 
6
  import traceback
7
  from app.database import Database
8
  from app.indicators.base import IndicatorRegistry
@@ -17,6 +18,9 @@ from app.core.email import send_completion_email
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
20
 
21
  def _save_spatial_json(spatial, status_value: str, path: str) -> None:
22
  """Serialize spatial data to JSON for the frontend."""
@@ -51,19 +55,97 @@ async def process_job(job_id: str, db: Database, registry: IndicatorRegistry) ->
51
  return
52
  await db.update_job_status(job_id, JobStatus.PROCESSING)
53
  try:
54
- # Track spatial data per indicator for map generation
55
  spatial_cache = {}
56
 
 
 
 
57
  for indicator_id in job.request.indicator_ids:
58
- await db.update_job_progress(job_id, indicator_id, "processing")
59
  indicator = registry.get(indicator_id)
60
- result = await indicator.process(
61
- job.request.aoi,
62
- job.request.time_range,
63
- season_months=job.request.season_months(),
64
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Capture spatial data before it's lost
67
  spatial = indicator.get_spatial_data()
68
  if spatial is not None:
69
  spatial_cache[indicator_id] = spatial
 
3
  import json
4
  import logging
5
  import os
6
+ import time
7
  import traceback
8
  from app.database import Database
9
  from app.indicators.base import IndicatorRegistry
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
+ BATCH_POLL_INTERVAL = 30 # seconds between status checks
22
+ BATCH_TIMEOUT = 1200 # 20 minutes maximum wait
23
+
24
 
25
  def _save_spatial_json(spatial, status_value: str, path: str) -> None:
26
  """Serialize spatial data to JSON for the frontend."""
 
55
  return
56
  await db.update_job_status(job_id, JobStatus.PROCESSING)
57
  try:
 
58
  spatial_cache = {}
59
 
60
+ # Separate batch vs non-batch indicators
61
+ batch_indicators = {}
62
+ process_indicators = []
63
  for indicator_id in job.request.indicator_ids:
 
64
  indicator = registry.get(indicator_id)
65
+ if indicator.uses_batch:
66
+ batch_indicators[indicator_id] = indicator
67
+ else:
68
+ process_indicators.append((indicator_id, indicator))
69
+
70
+ # -- Phase 1: Submit batch jobs --
71
+ batch_submissions = {}
72
+ fallback_ids = set()
73
+ for indicator_id, indicator in batch_indicators.items():
74
+ await db.update_job_progress(job_id, indicator_id, "submitting")
75
+ try:
76
+ jobs = await indicator.submit_batch(
77
+ job.request.aoi,
78
+ job.request.time_range,
79
+ season_months=job.request.season_months(),
80
+ )
81
+ batch_submissions[indicator_id] = jobs
82
+ await db.update_job_progress(job_id, indicator_id, "processing on CDSE")
83
+ except Exception as exc:
84
+ logger.warning("Batch submit failed for %s, will use fallback: %s", indicator_id, exc)
85
+ fallback_ids.add(indicator_id)
86
+
87
+ # -- Phase 2: Poll until all batch jobs finish --
88
+ poll_start = time.monotonic()
89
+ pending = dict(batch_submissions)
90
+
91
+ while pending:
92
+ # Check current statuses before sleeping
93
+ for indicator_id in list(pending.keys()):
94
+ jobs = pending[indicator_id]
95
+ statuses = [j.status() for j in jobs]
96
+ if all(s == "finished" for s in statuses):
97
+ logger.info("Batch jobs finished for %s", indicator_id)
98
+ del pending[indicator_id]
99
+ elif any(s in ("error", "canceled") for s in statuses):
100
+ logger.warning("Batch job failed for %s: %s", indicator_id, statuses)
101
+ del pending[indicator_id]
102
+
103
+ if not pending:
104
+ break
105
+
106
+ elapsed = time.monotonic() - poll_start
107
+ if elapsed >= BATCH_TIMEOUT:
108
+ logger.warning("Batch poll timeout after %.0fs, remaining: %s", elapsed, list(pending.keys()))
109
+ fallback_ids.update(pending.keys())
110
+ break
111
+
112
+ await asyncio.sleep(BATCH_POLL_INTERVAL)
113
+
114
+ # -- Phase 3: Harvest batch results + process non-batch indicators --
115
+ for indicator_id in job.request.indicator_ids:
116
+ indicator = registry.get(indicator_id)
117
+
118
+ if indicator_id in fallback_ids:
119
+ await db.update_job_progress(job_id, indicator_id, "processing")
120
+ result = await indicator.process(
121
+ job.request.aoi,
122
+ job.request.time_range,
123
+ season_months=job.request.season_months(),
124
+ )
125
+ elif indicator_id in batch_submissions:
126
+ await db.update_job_progress(job_id, indicator_id, "downloading")
127
+ try:
128
+ result = await indicator.harvest(
129
+ job.request.aoi,
130
+ job.request.time_range,
131
+ season_months=job.request.season_months(),
132
+ batch_jobs=batch_submissions[indicator_id],
133
+ )
134
+ except Exception as exc:
135
+ logger.warning("Harvest failed for %s, using fallback: %s", indicator_id, exc)
136
+ result = await indicator.process(
137
+ job.request.aoi,
138
+ job.request.time_range,
139
+ season_months=job.request.season_months(),
140
+ )
141
+ else:
142
+ await db.update_job_progress(job_id, indicator_id, "processing")
143
+ result = await indicator.process(
144
+ job.request.aoi,
145
+ job.request.time_range,
146
+ season_months=job.request.season_months(),
147
+ )
148
 
 
149
  spatial = indicator.get_spatial_data()
150
  if spatial is not None:
151
  spatial_cache[indicator_id] = spatial
tests/test_worker.py CHANGED
@@ -1,5 +1,6 @@
1
  import pytest
2
  from datetime import date
 
3
  from app.worker import process_job
4
  from app.database import Database
5
  from app.models import JobStatus, AOI, TimeRange, JobRequest, IndicatorResult, StatusLevel, TrendDirection, ConfidenceLevel
@@ -66,3 +67,112 @@ async def test_process_job_handles_unknown_indicator(temp_db_path):
66
  job = await db.get_job(job_id)
67
  assert job.status == JobStatus.FAILED
68
  assert "nonexistent" in job.error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pytest
2
  from datetime import date
3
+ from unittest.mock import MagicMock
4
  from app.worker import process_job
5
  from app.database import Database
6
  from app.models import JobStatus, AOI, TimeRange, JobRequest, IndicatorResult, StatusLevel, TrendDirection, ConfidenceLevel
 
67
  job = await db.get_job(job_id)
68
  assert job.status == JobStatus.FAILED
69
  assert "nonexistent" in job.error
70
+
71
+
72
+ class MockBatchIndicator(BaseIndicator):
73
+ """Batch indicator for testing the three-phase worker."""
74
+ id = "ndvi"
75
+ name = "Vegetation (NDVI)"
76
+ category = "D2"
77
+ question = "Is vegetation cover declining?"
78
+ estimated_minutes = 8
79
+ uses_batch = True
80
+
81
+ async def process(self, aoi, time_range, season_months=None):
82
+ return IndicatorResult(
83
+ indicator_id="ndvi", headline="placeholder",
84
+ status=StatusLevel.GREEN, trend=TrendDirection.STABLE,
85
+ confidence=ConfidenceLevel.LOW, map_layer_path="",
86
+ chart_data={"dates": ["2025"], "values": [0.3], "label": "NDVI"},
87
+ data_source="placeholder",
88
+ summary="Fallback.", methodology="Placeholder.", limitations=[],
89
+ )
90
+
91
+ async def submit_batch(self, aoi, time_range, season_months=None):
92
+ mock_job = MagicMock()
93
+ mock_job.job_id = "j-test"
94
+ mock_job.status.return_value = "finished"
95
+ return [mock_job, mock_job, mock_job]
96
+
97
+ async def harvest(self, aoi, time_range, season_months=None, batch_jobs=None):
98
+ return IndicatorResult(
99
+ indicator_id="ndvi", headline="Real NDVI data",
100
+ status=StatusLevel.GREEN, trend=TrendDirection.STABLE,
101
+ confidence=ConfidenceLevel.HIGH, map_layer_path="",
102
+ chart_data={"dates": ["2025-01"], "values": [0.45], "label": "NDVI"},
103
+ data_source="satellite",
104
+ summary="Real.", methodology="Sentinel-2.", limitations=[],
105
+ )
106
+
107
+
108
+ @pytest.mark.asyncio
109
+ async def test_process_job_uses_batch_flow(temp_db_path):
110
+ """Worker uses submit_batch -> poll -> harvest for batch indicators."""
111
+ db = Database(temp_db_path)
112
+ await db.init()
113
+ reg = IndicatorRegistry()
114
+ reg.register(MockBatchIndicator())
115
+ request = JobRequest(
116
+ aoi=AOI(name="Test", bbox=[32.45, 15.65, 32.65, 15.80]),
117
+ time_range=TimeRange(start=date(2025, 3, 1), end=date(2026, 3, 1)),
118
+ indicator_ids=["ndvi"],
119
+ email="test@example.com",
120
+ )
121
+ job_id = await db.create_job(request)
122
+ await process_job(job_id, db, reg)
123
+ job = await db.get_job(job_id)
124
+ assert job.status == JobStatus.COMPLETE
125
+ assert len(job.results) == 1
126
+ assert job.results[0].data_source == "satellite"
127
+ assert job.results[0].headline == "Real NDVI data"
128
+
129
+
130
+ @pytest.mark.asyncio
131
+ async def test_process_job_mixes_batch_and_process(temp_db_path):
132
+ """Worker handles batch and non-batch indicators in the same job."""
133
+ db = Database(temp_db_path)
134
+ await db.init()
135
+ reg = IndicatorRegistry()
136
+ reg.register(MockBatchIndicator())
137
+ reg.register(MockFiresIndicator())
138
+ request = JobRequest(
139
+ aoi=AOI(name="Test", bbox=[32.45, 15.65, 32.65, 15.80]),
140
+ time_range=TimeRange(start=date(2025, 3, 1), end=date(2026, 3, 1)),
141
+ indicator_ids=["ndvi", "fires"],
142
+ email="test@example.com",
143
+ )
144
+ job_id = await db.create_job(request)
145
+ await process_job(job_id, db, reg)
146
+ job = await db.get_job(job_id)
147
+ assert job.status == JobStatus.COMPLETE
148
+ assert len(job.results) == 2
149
+
150
+ ndvi_result = next(r for r in job.results if r.indicator_id == "ndvi")
151
+ fires_result = next(r for r in job.results if r.indicator_id == "fires")
152
+ assert ndvi_result.data_source == "satellite"
153
+ assert fires_result.headline == "3 fire events detected"
154
+
155
+
156
+ @pytest.mark.asyncio
157
+ async def test_process_job_batch_submit_failure_falls_back(temp_db_path):
158
+ """If submit_batch() fails, worker falls back to process()."""
159
+
160
+ class FailingBatchIndicator(MockBatchIndicator):
161
+ async def submit_batch(self, aoi, time_range, season_months=None):
162
+ raise ConnectionError("CDSE unreachable")
163
+
164
+ db = Database(temp_db_path)
165
+ await db.init()
166
+ reg = IndicatorRegistry()
167
+ reg.register(FailingBatchIndicator())
168
+ request = JobRequest(
169
+ aoi=AOI(name="Test", bbox=[32.45, 15.65, 32.65, 15.80]),
170
+ time_range=TimeRange(start=date(2025, 3, 1), end=date(2026, 3, 1)),
171
+ indicator_ids=["ndvi"],
172
+ email="test@example.com",
173
+ )
174
+ job_id = await db.create_job(request)
175
+ await process_job(job_id, db, reg)
176
+ job = await db.get_job(job_id)
177
+ assert job.status == JobStatus.COMPLETE
178
+ assert job.results[0].data_source == "placeholder"