File size: 1,789 Bytes
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import inspect
import os
from abc import ABC, abstractmethod
from typing import Iterable, Union

import pandas as pd
import ray

from graphgen.utils import CURRENT_LOGGER_VAR, set_logger


class BaseOperator(ABC):
    def __init__(self, working_dir: str = "cache", op_name: str = None):
        log_dir = os.path.join(working_dir, "logs")
        self.op_name = op_name or self.__class__.__name__

        try:
            ctx = ray.get_runtime_context()
            worker_id = ctx.get_actor_id() or ctx.get_worker_id()
            worker_id_short = worker_id[-6:] if worker_id else "driver"
        except Exception as e:
            print(
                "Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:",
                e,
            )
            worker_id_short = "local"

        # e.g. cache/logs/ChunkService_a1b2c3.log
        log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log")

        self.logger = set_logger(
            log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True
        )

        self.logger.info(
            "[%s] Operator initialized on Worker %s", self.op_name, worker_id_short
        )

    def __call__(
        self, batch: pd.DataFrame
    ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
        logger_token = CURRENT_LOGGER_VAR.set(self.logger)
        try:
            result = self.process(batch)
            if inspect.isgenerator(result):
                yield from result
            else:
                yield result
        finally:
            CURRENT_LOGGER_VAR.reset(logger_token)

    @abstractmethod
    def process(self, batch):
        raise NotImplementedError("Subclasses must implement the process method.")

    def get_logger(self):
        return self.logger