Spaces:
Running
Running
File size: 1,789 Bytes
31086ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import inspect
import os
from abc import ABC, abstractmethod
from typing import Iterable, Union
import pandas as pd
import ray
from graphgen.utils import CURRENT_LOGGER_VAR, set_logger
class BaseOperator(ABC):
def __init__(self, working_dir: str = "cache", op_name: str = None):
log_dir = os.path.join(working_dir, "logs")
self.op_name = op_name or self.__class__.__name__
try:
ctx = ray.get_runtime_context()
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
worker_id_short = worker_id[-6:] if worker_id else "driver"
except Exception as e:
print(
"Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:",
e,
)
worker_id_short = "local"
# e.g. cache/logs/ChunkService_a1b2c3.log
log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log")
self.logger = set_logger(
log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True
)
self.logger.info(
"[%s] Operator initialized on Worker %s", self.op_name, worker_id_short
)
def __call__(
self, batch: pd.DataFrame
) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
try:
result = self.process(batch)
if inspect.isgenerator(result):
yield from result
else:
yield result
finally:
CURRENT_LOGGER_VAR.reset(logger_token)
@abstractmethod
def process(self, batch):
raise NotImplementedError("Subclasses must implement the process method.")
def get_logger(self):
return self.logger
|