File size: 27,369 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
import asyncio
import atexit
import datetime
import secrets
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from multiprocessing.connection import Client, Listener
from pathlib import Path
from queue import Queue
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import torch
from janus import Queue as AsyncQueue

from ._utils import mpi_rank, mpi_world_size
from .bindings import executor as tllm
from .hlapi.mpi_session import (MpiPoolSession, MpiSession,
                                external_mpi_comm_available, find_free_port,
                                need_spawn_mpi_workers)
from .hlapi.utils import (ContextManager, SamplingParams, exception_handler,
                          print_traceback_on_error)


def has_event_loop() -> bool:
    try:
        asyncio.get_running_loop()
    except RuntimeError:
        return False
    return True


class GenerationRequest:

    def __init__(
        self,
        prompt_token_ids: Union[torch.Tensor, np.ndarray, list],
        sampling_params: SamplingParams,
        streaming: bool = False,
    ):
        if isinstance(prompt_token_ids, list):
            self.prompt_token_ids = prompt_token_ids
        elif isinstance(prompt_token_ids, (torch.Tensor, np.ndarray)):
            self.prompt_token_ids = prompt_token_ids.tolist()
        else:
            raise TypeError(
                f"prompt_token_ids ({prompt_token_ids}) should be an instance of torch.Tensor, np.ndarray or list"
            )

        self.sampling_params = sampling_params
        self.streaming = streaming
        self.id = -1

    def set_id(self, id):
        self.id = id
        return self

    def as_executor_request(self) -> tllm.Request:
        request_kwargs = {
            "input_token_ids":
            self.prompt_token_ids,
            "max_new_tokens":
            self.sampling_params.max_new_tokens,
            "streaming":
            self.streaming,
            "sampling_config":
            self.sampling_params._get_sampling_config(),
            "end_id":
            self.sampling_params.end_id,
            "pad_id":
            self.sampling_params.pad_id,
            "output_config":
            self.sampling_params._get_output_config(),
            # The following options in the Executor API are not yet exposed by the HLAPI:
            # https://jirasw.nvidia.com/browse/TRTLLM-489
            "bad_words":
            self.sampling_params._get_bad_words(),
            "stop_words":
            self.sampling_params._get_stop_words(),
            "embedding_bias":
            self.sampling_params.embedding_bias,
            "external_draft_tokens_config":
            self.sampling_params.external_draft_tokens_config,
            "prompt_tuning_config":
            self.sampling_params.prompt_tuning_config,
            "lora_config":
            self.sampling_params.lora_config,
            "logits_post_processor_name":
            self.sampling_params.logits_post_processor_name,
        }
        request = tllm.Request(**request_kwargs)
        return request


@dataclass(slots=True)
class CompletionOutput:
    """The output data of one completion output of a request.

    Args:
        index (int): The index of the output in the request.
        text (str): The generated output text.
        token_ids (List[int]): The token ids of the generated output text.
        cumulative_logprob (float): The cumulative log probability of the generated output text.
        logprobs (List[float]): The log probabilities of the top probability words at each position if the logprobs are requested.
        generation_logits (torch.Tensor): The logits on the generated output token ids.
    """
    index: int
    text: str = ""
    token_ids: List[int] = field(default_factory=list)
    cumulative_logprob: Optional[float] = None
    logprobs: List[float] = field(default_factory=list)
    generation_logits: Optional[torch.Tensor] = field(default=None, repr=False)
    _last_text: str = field(default="", init=False, repr=False)

    @property
    def length(self):
        return len(self.token_ids)

    @property
    def text_diff(self) -> str:
        diff = self.text[len(self._last_text):]
        self._last_text = self.text
        return diff


class GenerationResult:

    def __init__(self, generation_request: GenerationRequest) -> None:
        self._done = False
        self._cancelled = False
        self._generation_request = generation_request

        if has_event_loop():
            aqueue = AsyncQueue()
            self.queue = aqueue.sync_q
            self.aqueue = aqueue.async_q
        else:
            self.queue = Queue()
            self.aqueue = None

        self.outputs: List[CompletionOutput] = [
            CompletionOutput(i) for i in range(self.beam_width)
        ]
        self.context_logits: Optional[torch.Tensor] = None

    @property
    def request_id(self) -> int:
        return self._generation_request.id

    @property
    def prompt_token_ids(self) -> List[int]:
        return self._generation_request.prompt_token_ids

    @property
    def finished(self) -> bool:
        return self._done

    @property
    def streaming(self):
        return self._generation_request.streaming

    @property
    def beam_width(self):
        return self._generation_request.sampling_params.beam_width

    def handle_generation_msg(self, tensors: tuple, error: str):
        if error:
            raise RuntimeError(error)

        output_token_ids, context_logits, generation_logits, log_probs, cum_log_probs = tensors

        for i, beam_ids in enumerate(output_token_ids):
            self.outputs[i].token_ids.extend(beam_ids)
            if cum_log_probs is not None:
                self.outputs[i].cumulative_logprob = cum_log_probs[i]
            if log_probs is not None:
                self.outputs[i].logprobs = log_probs[i]
                assert len(self.outputs[i].logprobs) == self.outputs[i].length
            if generation_logits is not None:
                self.outputs[i].generation_logits = generation_logits[
                    i, :self.outputs[i].length]

        if self.finished and not self._generation_request.sampling_params.include_stop_str_in_output:
            for beam_output in self.outputs:
                for stop_ids in self._generation_request.sampling_params._get_stop_words(
                ):
                    if beam_output.token_ids[-len(stop_ids):] == stop_ids:
                        beam_output.token_ids = beam_output.token_ids[:-len(
                            stop_ids)]
                        break

        if context_logits is not None:
            self.context_logits = context_logits

    def result_step(self, timeout: Optional[float] = None):
        _, tensors, self._done, error = self.queue.get(timeout=timeout)
        self.handle_generation_msg(tensors, error)

    async def aresult_step(self):
        assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available."
        _, tensors, self._done, error = await self.aqueue.get()
        self.handle_generation_msg(tensors, error)

    def result(self, timeout: Optional[float] = None) -> "GenerationResult":
        while not self._done:
            self.result_step(timeout)
        return self

    async def aresult(self) -> "GenerationResult":
        while not self._done:
            await self.aresult_step()
        return self

    def __await__(self):
        return self.aresult().__await__()

    def __iter__(self):
        return self

    def __next__(self):
        if self._done:
            raise StopIteration

        self.result_step()
        return self

    def __aiter__(self):
        return self

    async def __anext__(self):
        if self._done:
            raise StopAsyncIteration

        await self.aresult_step()
        return self

    def running(self) -> bool:
        return not self._done

    def cancelled(self) -> bool:
        return self._cancelled

    def cancel(self):
        raise NotImplementedError

    def done(self) -> bool:
        return self._done

    def exception(self, timeout: Optional[float] = None):
        try:
            self.result(timeout)
        except RuntimeError as e:
            return e

    def _repr_fields(self):
        return ['request_id', 'prompt_token_ids', 'outputs', 'finished']

    def __repr__(self) -> str:
        repr = []
        for field in self._repr_fields():
            value = getattr(self, field)
            if isinstance(value, str):
                repr.append(f"{field}={value!r}")
            else:
                repr.append(f"{field}={value}")
        repr = ", ".join(repr)
        repr = f"{self.__class__.__name__}({repr})"
        return repr


class GenerationExecutor(ABC):
    TERMINATE_REQUEST_ID = 0

    def __init__(self):
        self.id_counter = GenerationExecutor.TERMINATE_REQUEST_ID + 1
        self._stats = None
        self.stats_queue = None

        exception_handler.register(self)
        atexit.register(self.shutdown)

    def generate_id(self) -> int:
        gen_id = self.id_counter

        # underlying C type is uint64
        uint64_max = 2**64 - 1
        self.id_counter = (self.id_counter + 1) % uint64_max

        if self.id_counter == GenerationExecutor.TERMINATE_REQUEST_ID:
            self.id_counter += 1

        return gen_id

    @abstractmethod
    def submit(self, request: GenerationRequest) -> GenerationResult:
        pass

    def generate_async(
        self,
        prompt_token_ids: List[int],
        sampling_params: SamplingParams,
        streaming: bool = False,
    ) -> GenerationResult:
        """Generate output for the given prompt token ids in the asynchronous mode.
        Asynchronous generation accepts single prompt only.
        """
        assert isinstance(prompt_token_ids[0], int)
        assert isinstance(sampling_params, SamplingParams)
        result = self.submit(
            GenerationRequest(prompt_token_ids,
                              sampling_params=sampling_params,
                              streaming=streaming))
        return result

    def generate(
        self, prompt_token_ids: Union[List[int], List[List[int]]],
        sampling_params: Union[SamplingParams, List[SamplingParams]]
    ) -> Union[GenerationResult, List[GenerationResult]]:
        """Generate output for the given prompt token ids in the synchronous mode.
        Synchronous generation accepts either single prompt or batched prompts.
        """
        unbatched = isinstance(prompt_token_ids[0], int)

        if unbatched:
            prompt_token_ids = [prompt_token_ids]

        futures = []
        for i, p in enumerate(prompt_token_ids):
            if isinstance(sampling_params, list):
                sp = sampling_params[i]
            else:
                sp = sampling_params
            future = self.generate_async(p, sampling_params=sp, streaming=False)
            futures.append(future)

        for future in futures:
            future.result()

        if unbatched:
            futures = futures[0]

        return futures

    @abstractmethod
    def shutdown(self):
        pass

    def create_stats_queue(self):
        # Stats queue is created during first submission to ensure event loop exists if it is needed.
        if not self._stats:
            if has_event_loop():
                self._stats = AsyncQueue()
                self.stats_queue = self._stats.sync_q
                self.stats_aqueue = self._stats.async_q
            else:
                self._stats = Queue()
                self.stats_queue = self._stats
                self.stats_aqueue = None

    def get_stats(self):
        return self.stats_queue.get()

    async def aget_stats(self):
        assert self.stats_aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available."
        return await self.stats_aqueue.get()

    @staticmethod
    def create(
        engine_dir: Path,
        executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig(1),
        model_world_size: int = 1,
        world_size: int = 0,
        mpi_session: Optional[MpiSession] = None,
        reuse_mpi_comm: bool = False,
    ) -> Union["ExecutorBindingsProxy", "ExecutorBindingsWorker"]:

        if world_size == 0:
            world_size = mpi_world_size()

        if world_size > 1 and world_size < model_world_size:
            raise RuntimeError(
                "Cannot instantiate Generator for engine built "
                f"for {model_world_size} ranks, while currently running "
                f"on {world_size} ranks.")

        worker_kwargs = {
            "engine_dir": engine_dir,
            "executor_config": executor_config,
        }

        # The case where the Python main process is launched by mpirun
        mpirun_launch = external_mpi_comm_available(model_world_size)
        # The case where the Python main process utilizes mpi4py to spawn MPI workers
        spawn_workers = need_spawn_mpi_workers(model_world_size)
        if spawn_workers or (mpirun_launch and reuse_mpi_comm):
            if reuse_mpi_comm:
                assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session"
            return ExecutorBindingsProxy(worker_kwargs,
                                         model_world_size=model_world_size,
                                         mpi_session=mpi_session)

        return ExecutorBindingsWorker(**worker_kwargs)


class ExecutorBindingsWorker(GenerationExecutor):

    class WorkerExit(GeneratorExit):
        pass

    def __init__(
        self,
        engine_dir: Path,
        executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig(1),
    ) -> None:
        super().__init__()

        self.engine = None
        self._results: Dict[int, GenerationResult] = {}
        self._pending: set = set()
        self.result_queue = None
        self.rank = mpi_rank()

        self.engine = tllm.Executor(engine_dir,
                                    tllm.ModelType.DECODER_ONLY,
                                    executor_config=executor_config)
        self.awaiter_stop_event = threading.Event()
        self.awaiter_thread = threading.Thread(target=self.awaiter_loop,
                                               daemon=True)
        self.stats_thread = threading.Thread(target=self.stats_loop,
                                             daemon=True)

    def create_stats_queue(self):
        # Stats queue is created during first submission to ensure event loop exists if it is needed.
        if not self._stats:
            if has_event_loop():
                self._stats = AsyncQueue()
                self.stats_queue = self._stats.sync_q
                self.stats_aqueue = self._stats.async_q
            else:
                self._stats = Queue()
                self.stats_queue = self._stats
                self.stats_aqueue = None

    def set_result_queue(self, queue):
        """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
        self.result_queue = queue

    def set_stats_queue(self, queue):
        """In multi-gpu mode, stats_queue will be set here to communicate between the proxy and the worker 0 process."""
        self._stats = queue
        self.stats_queue = self._stats
        self.stats_aqueue = None

    def return_queue(self, req_id: int):
        """ If a centralized result queue is registered (used for communication with the proxy)
            send the message there.
            Otherwise, push the result directly in the GenerationResult queue.
        """

        if self.result_queue is not None:
            return self.result_queue
        return self._results[req_id].queue

    def start_awaiter_thread(self):
        if self.engine.can_enqueue_requests(
        ) and not self.awaiter_thread.is_alive():
            self.awaiter_thread.start()

    def start_stats_thread(self):
        if self.engine.can_enqueue_requests(
        ) and not self.stats_thread.is_alive():
            self.stats_thread.start()

    def awaiter_loop(self):
        """ Gets responses from executor and places in the return queue."""
        while not self.awaiter_stop_event.is_set():
            # Get responses and place in queue.
            for response in self.engine.await_responses(
                    timeout=datetime.timedelta(milliseconds=100)):
                req_id = response.request_id
                if response.has_error():
                    self.return_queue(req_id).put(
                        (req_id, None, None, response.error_msg))
                else:
                    tensors = (
                        response.result.output_token_ids,
                        response.result.context_logits,
                        response.result.generation_logits,
                        response.result.log_probs,
                        response.result.cum_log_probs,
                    )
                    self.return_queue(req_id).put(
                        (response.request_id, tensors, response.result.is_final,
                         None))
                    if response.result.is_final:
                        self._pending.remove(req_id)

    def stats_loop(self):
        while not self.awaiter_stop_event.is_set():
            time.sleep(0.1)
            # Get stats and place in queue.
            for stats in self.engine.get_latest_iteration_stats():
                while hasattr(self.stats_queue,
                              "full") and self.stats_queue.full():
                    self.stats_queue.get()
                self.stats_queue.put(stats.to_json_str())

    def start(self):
        self.create_stats_queue()
        self.start_awaiter_thread()
        self.start_stats_thread()

    def submit(self, request: GenerationRequest) -> GenerationResult:
        """
            Low-level API to the executor. Return a "future" GenerationResult which can be waited.
        """
        self.start()

        if self.rank != 0:
            raise NotImplementedError("Only rank 0 can submit requests.")
        req_id = self.engine.enqueue_request(request.as_executor_request())
        request.set_id(req_id)

        result = GenerationResult(request)
        self._results[req_id] = result
        self._pending.add(req_id)
        return result

    def shutdown(self):
        if self.engine is not None:
            self.awaiter_stop_event.set()
            if self.engine.can_enqueue_requests():
                if self.awaiter_thread.is_alive():
                    self.awaiter_thread.join()
                if self.stats_thread.is_alive():
                    self.stats_thread.join()
            self.engine.shutdown()
            self.engine = None

    def block_subordinates(self):
        if self.rank != 0:
            self.shutdown()
            raise self.WorkerExit(
                "block_subordinates() should be used in a `with ExecutorBindingsWorker() as ...:` block"
            )

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> bool:
        self.shutdown()
        return exc_type is None or exc_type == ExecutorBindingsWorker.WorkerExit

    def __del__(self):
        self.shutdown()

    def wait_first_completed(
        self, futures: List[GenerationResult]
    ) -> Generator[GenerationResult, None, None]:
        wait_set = set(f.request_id for f in futures)

        # clear already-finished requests
        for f in futures:
            if f._done:
                wait_set.remove(f.request_id)
                yield f

        # wait remaining active requests
        while len(wait_set) > 0:
            req_id = wait_set.pop()

            if req_id not in self._pending:
                yield self._results[req_id]
            else:
                wait_set.add(req_id)


class Fifo:

    def __init__(self, address: Tuple[str, int, bytes], *, is_server: bool):
        self.address, self.authkey = (address[0], address[1]), address[2]
        self.is_server = is_server
        self.conn = None
        if is_server:
            self.listener = Listener(self.address,
                                     'AF_INET',
                                     authkey=self.authkey)

    def setup(self):
        if self.is_server:
            self.conn = self.listener.accept()
        else:
            self.conn = Client(self.address, authkey=self.authkey)

    def put(self, obj: Any):
        if self.conn is None:
            self.setup()
        self.conn.send(obj)

    def get(self) -> Any:
        if self.conn is None:
            self.setup()
        return self.conn.recv()


class ExecutorBindingsProxy(GenerationExecutor):

    def __init__(
        self,
        workers_kwargs,
        model_world_size: int = 1,
        mpi_session: Optional[MpiSession] = None,
    ) -> None:
        super().__init__()

        self.workers_started = False

        request_queue_addr = ("127.0.0.1", find_free_port(),
                              secrets.token_bytes(512))
        self.request_queue = Fifo(request_queue_addr, is_server=True)

        # Return request id back to dispatcher
        request_id_queue_addr = ("127.0.0.1", find_free_port(),
                                 secrets.token_bytes(512))
        self.request_id_queue = Fifo(request_id_queue_addr, is_server=True)

        result_queue_addr = ("127.0.0.1", find_free_port(),
                             secrets.token_bytes(512))
        self.result_queue = Fifo(result_queue_addr, is_server=True)

        stats_queue_addr = ("127.0.0.1", find_free_port(),
                            secrets.token_bytes(512))
        self.mp_stats_queue = Fifo(stats_queue_addr, is_server=True)

        self._results: Dict[int, GenerationResult] = {}
        self._request_id_dispatcher_queue = Queue()

        if mpi_session is None:
            self.mpi_session = MpiPoolSession(n_workers=model_world_size)
        else:
            self.mpi_session = mpi_session
        self.model_world_size = model_world_size

        self.workers_kwargs = workers_kwargs
        self.workers_kwargs.update({
            "request_queue_addr": request_queue_addr,
            "request_id_queue_addr": request_id_queue_addr,
            "result_queue_addr": result_queue_addr,
            "stats_queue_addr": stats_queue_addr,
        })
        self.workers_init_ok = False
        self.dispatcher = threading.Thread(target=self.dispatcher_thread,
                                           daemon=True)
        self.stats_thread = threading.Thread(target=self.stats_main,
                                             daemon=True)

    @print_traceback_on_error
    @staticmethod
    def workers_main(
        engine_dir: Path,
        request_queue_addr: Tuple[str, int, bytes],
        request_id_queue_addr: Tuple[str, int, bytes],
        result_queue_addr: Tuple[str, int, bytes],
        stats_queue_addr: Tuple[str, int, bytes],
        executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig(1)
    ) -> None:
        result_queue = None

        if mpi_rank() == 0:
            request_queue = Fifo(request_queue_addr, is_server=False)
            request_id_queue = Fifo(request_id_queue_addr, is_server=False)
            result_queue = Fifo(result_queue_addr, is_server=False)
            mp_stats_queue = Fifo(stats_queue_addr, is_server=False)

        # Only the failure on rank0 can be captured here. All the non-rank0 process will hang once the executor runtime
        # is successfully initialized, that is controlled within cpp runtime.
        # To capture the failure on all the ranks, more work should be done in the cpp runtime.
        # TODO[chunweiy]: fix the non-rank0 process failure
        init_ok = True
        try:
            executor = ExecutorBindingsWorker(engine_dir, executor_config)
        except Exception as e:
            init_ok = False
            raise e
        finally:
            if mpi_rank() == 0:
                result_queue.put(init_ok)

        with ContextManager(executor) as executor:
            if mpi_rank() == 0:
                executor.set_result_queue(result_queue)
                executor.set_stats_queue(mp_stats_queue)
                while (req := request_queue.get()) is not None:
                    result = executor.submit(req)
                    request_id_queue.put(result.request_id)

                result_queue.put(None)
                mp_stats_queue.put(None)
            else:
                executor.block_subordinates()

    def dispatcher_thread(self):
        """ Collect centralized results from result queue and dispatch them in the
            correct GenerationResult queues. """

        while (res := self.result_queue.get()) is not None:
            req_id, *_ = res
            # Wait for this result ready in self._results
            while req_id not in self._results:
                self._request_id_dispatcher_queue.get()
            self._results[req_id].queue.put(res)
            while not self._request_id_dispatcher_queue.empty():
                self._request_id_dispatcher_queue.get()

    def stats_main(self):
        while (stats := self.mp_stats_queue.get()) is not None:
            time.sleep(0.1)
            while self.stats_queue.full():
                self.stats_queue.get()
            self.stats_queue.put(stats)

    def start(self):
        self.mpi_futures = self.mpi_session.submit(
            ExecutorBindingsProxy.workers_main, **self.workers_kwargs)
        self.workers_started = True
        self.workers_init_ok = self.result_queue.get()
        if not self.workers_init_ok:
            raise RuntimeError("worker initialization failed")
        self.dispatcher.start()
        self.create_stats_queue()
        self.stats_thread.start()

    def shutdown(self):
        if not self.workers_started:
            return
        if self.workers_init_ok:
            self.request_queue.put(None)
        for f in self.mpi_futures:
            f.result()
        if self.dispatcher.is_alive():
            self.result_queue.put(None)
            self.dispatcher.join()
        if self.stats_thread.is_alive():
            self.mp_stats_queue.put(None)
            self.stats_thread.join()
        self.workers_started = False

    def submit(self, request: GenerationRequest) -> GenerationResult:
        """
            Low-level API to the executor. Return a "future" GenerationResult which can be waited.
            Forwards the request to the workers through the request queue.
        """
        if not self.workers_started:
            self.start()

        self.request_queue.put(request)

        # Await req id.
        req_id = self.request_id_queue.get()
        request.set_id(req_id)

        result = GenerationResult(request)
        self._results[req_id] = result
        self._request_id_dispatcher_queue.put(req_id)

        return result

    def __del__(self):
        self.shutdown()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.shutdown()
        return False