Spaces:
Paused
Paused
| # | |
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import base64 | |
| import datetime | |
| import io | |
| import json | |
| import os | |
| import pickle | |
| import socket | |
| import time | |
| import uuid | |
| import requests | |
| from enum import Enum, IntEnum | |
| import importlib | |
| from Cryptodome.PublicKey import RSA | |
| from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 | |
| from filelock import FileLock | |
| from . import file_utils | |
| SERVICE_CONF = "service_conf.yaml" | |
| def conf_realpath(conf_name): | |
| conf_path = f"conf/{conf_name}" | |
| return os.path.join(file_utils.get_project_base_directory(), conf_path) | |
| def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: | |
| local_config = {} | |
| local_path = conf_realpath(f'local.{conf_name}') | |
| if default is None: | |
| default = os.environ.get(key.upper()) | |
| if os.path.exists(local_path): | |
| local_config = file_utils.load_yaml_conf(local_path) | |
| if not isinstance(local_config, dict): | |
| raise ValueError(f'Invalid config file: "{local_path}".') | |
| if key is not None and key in local_config: | |
| return local_config[key] | |
| config_path = conf_realpath(conf_name) | |
| config = file_utils.load_yaml_conf(config_path) | |
| if not isinstance(config, dict): | |
| raise ValueError(f'Invalid config file: "{config_path}".') | |
| config.update(local_config) | |
| return config.get(key, default) if key is not None else config | |
| use_deserialize_safe_module = get_base_config( | |
| 'use_deserialize_safe_module', False) | |
| class CoordinationCommunicationProtocol(object): | |
| HTTP = "http" | |
| GRPC = "grpc" | |
| class BaseType: | |
| def to_dict(self): | |
| return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) | |
| def to_dict_with_type(self): | |
| def _dict(obj): | |
| module = None | |
| if issubclass(obj.__class__, BaseType): | |
| data = {} | |
| for attr, v in obj.__dict__.items(): | |
| k = attr.lstrip("_") | |
| data[k] = _dict(v) | |
| module = obj.__module__ | |
| elif isinstance(obj, (list, tuple)): | |
| data = [] | |
| for i, vv in enumerate(obj): | |
| data.append(_dict(vv)) | |
| elif isinstance(obj, dict): | |
| data = {} | |
| for _k, vv in obj.items(): | |
| data[_k] = _dict(vv) | |
| else: | |
| data = obj | |
| return {"type": obj.__class__.__name__, | |
| "data": data, "module": module} | |
| return _dict(self) | |
| class CustomJSONEncoder(json.JSONEncoder): | |
| def __init__(self, **kwargs): | |
| self._with_type = kwargs.pop("with_type", False) | |
| super().__init__(**kwargs) | |
| def default(self, obj): | |
| if isinstance(obj, datetime.datetime): | |
| return obj.strftime('%Y-%m-%d %H:%M:%S') | |
| elif isinstance(obj, datetime.date): | |
| return obj.strftime('%Y-%m-%d') | |
| elif isinstance(obj, datetime.timedelta): | |
| return str(obj) | |
| elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): | |
| return obj.value | |
| elif isinstance(obj, set): | |
| return list(obj) | |
| elif issubclass(type(obj), BaseType): | |
| if not self._with_type: | |
| return obj.to_dict() | |
| else: | |
| return obj.to_dict_with_type() | |
| elif isinstance(obj, type): | |
| return obj.__name__ | |
| else: | |
| return json.JSONEncoder.default(self, obj) | |
| def rag_uuid(): | |
| return uuid.uuid1().hex | |
| def string_to_bytes(string): | |
| return string if isinstance( | |
| string, bytes) else string.encode(encoding="utf-8") | |
| def bytes_to_string(byte): | |
| return byte.decode(encoding="utf-8") | |
| def json_dumps(src, byte=False, indent=None, with_type=False): | |
| dest = json.dumps( | |
| src, | |
| indent=indent, | |
| cls=CustomJSONEncoder, | |
| with_type=with_type) | |
| if byte: | |
| dest = string_to_bytes(dest) | |
| return dest | |
| def json_loads(src, object_hook=None, object_pairs_hook=None): | |
| if isinstance(src, bytes): | |
| src = bytes_to_string(src) | |
| return json.loads(src, object_hook=object_hook, | |
| object_pairs_hook=object_pairs_hook) | |
| def current_timestamp(): | |
| return int(time.time() * 1000) | |
| def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"): | |
| if not timestamp: | |
| timestamp = time.time() | |
| timestamp = int(timestamp) / 1000 | |
| time_array = time.localtime(timestamp) | |
| str_date = time.strftime(format_string, time_array) | |
| return str_date | |
| def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): | |
| time_array = time.strptime(time_str, format_string) | |
| time_stamp = int(time.mktime(time_array) * 1000) | |
| return time_stamp | |
| def serialize_b64(src, to_str=False): | |
| dest = base64.b64encode(pickle.dumps(src)) | |
| if not to_str: | |
| return dest | |
| else: | |
| return bytes_to_string(dest) | |
| def deserialize_b64(src): | |
| src = base64.b64decode( | |
| string_to_bytes(src) if isinstance( | |
| src, str) else src) | |
| if use_deserialize_safe_module: | |
| return restricted_loads(src) | |
| return pickle.loads(src) | |
| safe_module = { | |
| 'numpy', | |
| 'rag_flow' | |
| } | |
| class RestrictedUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| import importlib | |
| if module.split('.')[0] in safe_module: | |
| _module = importlib.import_module(module) | |
| return getattr(_module, name) | |
| # Forbid everything else. | |
| raise pickle.UnpicklingError("global '%s.%s' is forbidden" % | |
| (module, name)) | |
| def restricted_loads(src): | |
| """Helper function analogous to pickle.loads().""" | |
| return RestrictedUnpickler(io.BytesIO(src)).load() | |
| def get_lan_ip(): | |
| if os.name != "nt": | |
| import fcntl | |
| import struct | |
| def get_interface_ip(ifname): | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| return socket.inet_ntoa( | |
| fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) | |
| ip = socket.gethostbyname(socket.getfqdn()) | |
| if ip.startswith("127.") and os.name != "nt": | |
| interfaces = [ | |
| "bond1", | |
| "eth0", | |
| "eth1", | |
| "eth2", | |
| "wlan0", | |
| "wlan1", | |
| "wifi0", | |
| "ath0", | |
| "ath1", | |
| "ppp0", | |
| ] | |
| for ifname in interfaces: | |
| try: | |
| ip = get_interface_ip(ifname) | |
| break | |
| except IOError as e: | |
| pass | |
| return ip or '' | |
| def from_dict_hook(in_dict: dict): | |
| if "type" in in_dict and "data" in in_dict: | |
| if in_dict["module"] is None: | |
| return in_dict["data"] | |
| else: | |
| return getattr(importlib.import_module( | |
| in_dict["module"]), in_dict["type"])(**in_dict["data"]) | |
| else: | |
| return in_dict | |
| def decrypt_database_password(password): | |
| encrypt_password = get_base_config("encrypt_password", False) | |
| encrypt_module = get_base_config("encrypt_module", False) | |
| private_key = get_base_config("private_key", None) | |
| if not password or not encrypt_password: | |
| return password | |
| if not private_key: | |
| raise ValueError("No private key") | |
| module_fun = encrypt_module.split("#") | |
| pwdecrypt_fun = getattr( | |
| importlib.import_module( | |
| module_fun[0]), | |
| module_fun[1]) | |
| return pwdecrypt_fun(private_key, password) | |
| def decrypt_database_config( | |
| database=None, passwd_key="password", name="database"): | |
| if not database: | |
| database = get_base_config(name, {}) | |
| database[passwd_key] = decrypt_database_password(database[passwd_key]) | |
| return database | |
| def update_config(key, value, conf_name=SERVICE_CONF): | |
| conf_path = conf_realpath(conf_name=conf_name) | |
| if not os.path.isabs(conf_path): | |
| conf_path = os.path.join( | |
| file_utils.get_project_base_directory(), conf_path) | |
| with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): | |
| config = file_utils.load_yaml_conf(conf_path=conf_path) or {} | |
| config[key] = value | |
| file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) | |
| def get_uuid(): | |
| return uuid.uuid1().hex | |
| def datetime_format(date_time: datetime.datetime) -> datetime.datetime: | |
| return datetime.datetime(date_time.year, date_time.month, date_time.day, | |
| date_time.hour, date_time.minute, date_time.second) | |
| def get_format_time() -> datetime.datetime: | |
| return datetime_format(datetime.datetime.now()) | |
| def str2date(date_time: str): | |
| return datetime.datetime.strptime(date_time, '%Y-%m-%d') | |
| def elapsed2time(elapsed): | |
| seconds = elapsed / 1000 | |
| minuter, second = divmod(seconds, 60) | |
| hour, minuter = divmod(minuter, 60) | |
| return '%02d:%02d:%02d' % (hour, minuter, second) | |
| def decrypt(line): | |
| file_path = os.path.join( | |
| file_utils.get_project_base_directory(), | |
| "conf", | |
| "private.pem") | |
| rsa_key = RSA.importKey(open(file_path).read(), "Welcome") | |
| cipher = Cipher_pkcs1_v1_5.new(rsa_key) | |
| return cipher.decrypt(base64.b64decode( | |
| line), "Fail to decrypt password!").decode('utf-8') | |
| def download_img(url): | |
| if not url: | |
| return "" | |
| response = requests.get(url) | |
| return "data:" + \ | |
| response.headers.get('Content-Type', 'image/jpg') + ";" + \ | |
| "base64," + base64.b64encode(response.content).decode("utf-8") | |