Spaces:
Runtime error
Runtime error
add: pipelines nodes can now spawn one to many jobs via yield
Browse files- pipeline.py +12 -12
- tests/test_pipeline.py +14 -11
pipeline.py
CHANGED
|
@@ -27,24 +27,24 @@ class Node:
|
|
| 27 |
job: Job = await self.input_queue.get()
|
| 28 |
self._jobs_dequeued += 1
|
| 29 |
if self.sequential_node == False:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
else:
|
| 37 |
# ensure that jobs are processed in order
|
| 38 |
self.buffer[job.id] = job
|
| 39 |
while self.next_i in self.buffer:
|
| 40 |
job = self.buffer.pop(self.next_i)
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self.next_i += 1
|
| 43 |
-
if self.output_queue is not None:
|
| 44 |
-
await self.output_queue.put(job)
|
| 45 |
-
if self.job_sync is not None:
|
| 46 |
-
self.job_sync.append(job)
|
| 47 |
-
self._jobs_processed += 1
|
| 48 |
|
| 49 |
async def process_job(self, job: Job):
|
| 50 |
raise NotImplementedError()
|
|
|
|
| 27 |
job: Job = await self.input_queue.get()
|
| 28 |
self._jobs_dequeued += 1
|
| 29 |
if self.sequential_node == False:
|
| 30 |
+
async for job in self.process_job(job):
|
| 31 |
+
if self.output_queue is not None:
|
| 32 |
+
await self.output_queue.put(job)
|
| 33 |
+
if self.job_sync is not None:
|
| 34 |
+
self.job_sync.append(job)
|
| 35 |
+
self._jobs_processed += 1
|
| 36 |
else:
|
| 37 |
# ensure that jobs are processed in order
|
| 38 |
self.buffer[job.id] = job
|
| 39 |
while self.next_i in self.buffer:
|
| 40 |
job = self.buffer.pop(self.next_i)
|
| 41 |
+
async for job in self.process_job(job):
|
| 42 |
+
if self.output_queue is not None:
|
| 43 |
+
await self.output_queue.put(job)
|
| 44 |
+
if self.job_sync is not None:
|
| 45 |
+
self.job_sync.append(job)
|
| 46 |
+
self._jobs_processed += 1
|
| 47 |
self.next_i += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
async def process_job(self, job: Job):
|
| 50 |
raise NotImplementedError()
|
tests/test_pipeline.py
CHANGED
|
@@ -12,6 +12,7 @@ from pipeline import Pipeline, Node, Job
|
|
| 12 |
class Node1(Node):
|
| 13 |
async def process_job(self, job: Job):
|
| 14 |
job.data += f' (processed by node 1, worker {self.worker_id})'
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class Node2(Node):
|
|
@@ -19,12 +20,14 @@ class Node2(Node):
|
|
| 19 |
sleep_duration = 0.08 + 0.04 * random.random()
|
| 20 |
await asyncio.sleep(sleep_duration)
|
| 21 |
job.data += f' (processed by node 2, worker {self.worker_id})'
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class Node3(Node):
|
| 25 |
async def process_job(self, job: Job):
|
| 26 |
job.data += f' (processed by node 3, worker {self.worker_id})'
|
| 27 |
print(f'{job.id} - {job.data}')
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class TestPipeline(unittest.TestCase):
|
|
@@ -63,17 +66,17 @@ class TestPipeline(unittest.TestCase):
|
|
| 63 |
asyncio.run(self._test_pipeline_edge_cases())
|
| 64 |
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
|
| 78 |
|
| 79 |
if __name__ == '__main__':
|
|
|
|
| 12 |
class Node1(Node):
|
| 13 |
async def process_job(self, job: Job):
|
| 14 |
job.data += f' (processed by node 1, worker {self.worker_id})'
|
| 15 |
+
yield job
|
| 16 |
|
| 17 |
|
| 18 |
class Node2(Node):
|
|
|
|
| 20 |
sleep_duration = 0.08 + 0.04 * random.random()
|
| 21 |
await asyncio.sleep(sleep_duration)
|
| 22 |
job.data += f' (processed by node 2, worker {self.worker_id})'
|
| 23 |
+
yield job
|
| 24 |
|
| 25 |
|
| 26 |
class Node3(Node):
|
| 27 |
async def process_job(self, job: Job):
|
| 28 |
job.data += f' (processed by node 3, worker {self.worker_id})'
|
| 29 |
print(f'{job.id} - {job.data}')
|
| 30 |
+
yield job
|
| 31 |
|
| 32 |
|
| 33 |
class TestPipeline(unittest.TestCase):
|
|
|
|
| 66 |
asyncio.run(self._test_pipeline_edge_cases())
|
| 67 |
|
| 68 |
|
| 69 |
+
def test_pipeline_keeps_order(self):
|
| 70 |
+
self.pipeline = Pipeline()
|
| 71 |
+
self.job_sync = []
|
| 72 |
+
num_jobs = 100
|
| 73 |
+
start_time = time.time()
|
| 74 |
+
asyncio.run(self._test_pipeline(num_jobs))
|
| 75 |
+
end_time = time.time()
|
| 76 |
+
print(f"Pipeline processed in {end_time - start_time} seconds.")
|
| 77 |
+
self.assertEqual(len(self.job_sync), num_jobs)
|
| 78 |
+
for i, job in enumerate(self.job_sync):
|
| 79 |
+
self.assertEqual(i, job.id)
|
| 80 |
|
| 81 |
|
| 82 |
if __name__ == '__main__':
|