| | import asyncio |
| | import concurrent.futures |
| | import contextvars |
| | import functools |
| | import inspect |
| | import logging |
| | import os |
| | import textwrap |
| | import threading |
| | from enum import Enum |
| | from typing import Optional, Type, get_origin, get_args |
| |
|
| |
|
| | class TypeTracker: |
| | """Tracks types discovered during stub generation for automatic import generation.""" |
| |
|
| | def __init__(self): |
| | self.discovered_types = {} |
| | self.builtin_types = { |
| | "Any", |
| | "Dict", |
| | "List", |
| | "Optional", |
| | "Tuple", |
| | "Union", |
| | "Set", |
| | "Sequence", |
| | "cast", |
| | "NamedTuple", |
| | "str", |
| | "int", |
| | "float", |
| | "bool", |
| | "None", |
| | "bytes", |
| | "object", |
| | "type", |
| | "dict", |
| | "list", |
| | "tuple", |
| | "set", |
| | } |
| | self.already_imported = ( |
| | set() |
| | ) |
| |
|
| | def track_type(self, annotation): |
| | """Track a type annotation and record its module/import info.""" |
| | if annotation is None or annotation is type(None): |
| | return |
| |
|
| | |
| | type_name = getattr(annotation, "__name__", None) |
| | if type_name and ( |
| | type_name in self.builtin_types or type_name in self.already_imported |
| | ): |
| | return |
| |
|
| | |
| | module = getattr(annotation, "__module__", None) |
| | qualname = getattr(annotation, "__qualname__", type_name or "") |
| |
|
| | |
| | if module == "typing": |
| | return |
| |
|
| | |
| | if module == "types" and type_name in ("UnionType", "GenericAlias"): |
| | return |
| |
|
| | if module and module not in ["builtins", "__main__"]: |
| | |
| | if type_name: |
| | self.discovered_types[type_name] = (module, qualname) |
| |
|
| | def get_imports(self, main_module_name: str) -> list[str]: |
| | """Generate import statements for all discovered types.""" |
| | imports = [] |
| | imports_by_module = {} |
| |
|
| | for type_name, (module, qualname) in sorted(self.discovered_types.items()): |
| | |
| | if main_module_name and module == main_module_name: |
| | continue |
| |
|
| | if module not in imports_by_module: |
| | imports_by_module[module] = [] |
| | if type_name not in imports_by_module[module]: |
| | imports_by_module[module].append(type_name) |
| |
|
| | |
| | for module, types in sorted(imports_by_module.items()): |
| | if len(types) == 1: |
| | imports.append(f"from {module} import {types[0]}") |
| | else: |
| | imports.append(f"from {module} import {', '.join(sorted(set(types)))}") |
| |
|
| | return imports |
| |
|
| |
|
| | class AsyncToSyncConverter: |
| | """ |
| | Provides utilities to convert async classes to sync classes with proper type hints. |
| | """ |
| |
|
| | _thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None |
| | _thread_pool_lock = threading.Lock() |
| | _thread_pool_initialized = False |
| |
|
| | @classmethod |
| | def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: |
| | """Get or create the shared thread pool with proper thread-safe initialization.""" |
| | |
| | if cls._thread_pool_initialized: |
| | assert cls._thread_pool is not None, "Thread pool should be initialized" |
| | return cls._thread_pool |
| |
|
| | |
| | with cls._thread_pool_lock: |
| | if not cls._thread_pool_initialized: |
| | cls._thread_pool = concurrent.futures.ThreadPoolExecutor( |
| | max_workers=max_workers, thread_name_prefix="async_to_sync_" |
| | ) |
| | cls._thread_pool_initialized = True |
| |
|
| | |
| | assert cls._thread_pool is not None |
| | return cls._thread_pool |
| |
|
| | @classmethod |
| | def run_async_in_thread(cls, coro_func, *args, **kwargs): |
| | """ |
| | Run an async function in a separate thread from the thread pool. |
| | Blocks until the async function completes. |
| | Properly propagates contextvars between threads and manages event loops. |
| | """ |
| | |
| | context = contextvars.copy_context() |
| |
|
| | |
| | result_container: dict = {"result": None, "exception": None} |
| |
|
| | |
| | def run_in_thread(): |
| | |
| | loop = asyncio.new_event_loop() |
| | asyncio.set_event_loop(loop) |
| |
|
| | try: |
| | |
| | async def run_with_context(): |
| | |
| | return await coro_func(*args, **kwargs) |
| |
|
| | |
| | |
| | result = context.run(loop.run_until_complete, run_with_context()) |
| | result_container["result"] = result |
| | except Exception as e: |
| | |
| | result_container["exception"] = e |
| | finally: |
| | |
| | try: |
| | |
| | pending = asyncio.all_tasks(loop) |
| | for task in pending: |
| | task.cancel() |
| |
|
| | |
| | if pending: |
| | loop.run_until_complete( |
| | asyncio.gather(*pending, return_exceptions=True) |
| | ) |
| | except Exception: |
| | pass |
| |
|
| | |
| | loop.close() |
| |
|
| | |
| | asyncio.set_event_loop(None) |
| |
|
| | |
| | thread_pool = cls.get_thread_pool() |
| | future = thread_pool.submit(run_in_thread) |
| | future.result() |
| |
|
| | |
| | if result_container["exception"] is not None: |
| | raise result_container["exception"] |
| |
|
| | return result_container["result"] |
| |
|
| | @classmethod |
| | def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: |
| | """ |
| | Creates a new class with synchronous versions of all async methods. |
| | |
| | Args: |
| | async_class: The async class to convert |
| | thread_pool_size: Size of thread pool to use |
| | |
| | Returns: |
| | A new class with sync versions of all async methods |
| | """ |
| | sync_class_name = "ComfyAPISyncStub" |
| | cls.get_thread_pool(thread_pool_size) |
| |
|
| | |
| | sync_class_dict = { |
| | "__doc__": async_class.__doc__, |
| | "__module__": async_class.__module__, |
| | "__qualname__": sync_class_name, |
| | "__orig_class__": async_class, |
| | } |
| |
|
| | |
| | def __init__(self, *args, **kwargs): |
| | self._async_instance = async_class(*args, **kwargs) |
| |
|
| | |
| | |
| | all_annotations = {} |
| | for base_class in reversed(inspect.getmro(async_class)): |
| | if hasattr(base_class, "__annotations__"): |
| | all_annotations.update(base_class.__annotations__) |
| |
|
| | |
| | for attr_name, attr_type in all_annotations.items(): |
| | if hasattr(self._async_instance, attr_name): |
| | |
| | attr = getattr(self._async_instance, attr_name) |
| | |
| | if hasattr(attr, "__class__"): |
| | from comfy_api.internal.singleton import ProxiedSingleton |
| |
|
| | if isinstance(attr, ProxiedSingleton): |
| | |
| | try: |
| | sync_attr_class = cls.create_sync_class(attr.__class__) |
| | |
| | sync_attr = object.__new__(sync_attr_class) |
| | sync_attr._async_instance = attr |
| | setattr(self, attr_name, sync_attr) |
| | except Exception: |
| | |
| | setattr(self, attr_name, attr) |
| | else: |
| | |
| | setattr(self, attr_name, attr) |
| | else: |
| | |
| | |
| | if isinstance(attr_type, type): |
| | |
| | if hasattr(async_class, attr_type.__name__): |
| | inner_class = getattr(async_class, attr_type.__name__) |
| | from comfy_api.internal.singleton import ProxiedSingleton |
| |
|
| | |
| | try: |
| | |
| | if issubclass(inner_class, ProxiedSingleton): |
| | async_instance = inner_class.get_instance() |
| | else: |
| | async_instance = inner_class() |
| |
|
| | |
| | sync_attr_class = cls.create_sync_class(inner_class) |
| | sync_attr = object.__new__(sync_attr_class) |
| | sync_attr._async_instance = async_instance |
| | setattr(self, attr_name, sync_attr) |
| | |
| | setattr(self._async_instance, attr_name, async_instance) |
| | except Exception as e: |
| | logging.warning( |
| | f"Failed to create instance for {attr_name}: {e}" |
| | ) |
| |
|
| | |
| | for name, attr in inspect.getmembers(self._async_instance): |
| | if name.startswith("_") or hasattr(self, name): |
| | continue |
| |
|
| | |
| | |
| | if isinstance(attr, object) and not isinstance( |
| | attr, (str, int, float, bool, list, dict, tuple) |
| | ): |
| | from comfy_api.internal.singleton import ProxiedSingleton |
| |
|
| | if isinstance(attr, ProxiedSingleton): |
| | |
| | try: |
| | sync_attr_class = cls.create_sync_class(attr.__class__) |
| | |
| | sync_attr = object.__new__(sync_attr_class) |
| | sync_attr._async_instance = attr |
| | setattr(self, name, sync_attr) |
| | except Exception: |
| | |
| | setattr(self, name, attr) |
| |
|
| | sync_class_dict["__init__"] = __init__ |
| |
|
| | |
| | for name, method in inspect.getmembers( |
| | async_class, predicate=inspect.isfunction |
| | ): |
| | if name.startswith("_"): |
| | continue |
| |
|
| | |
| | if inspect.iscoroutinefunction(method): |
| | |
| | @functools.wraps(method) |
| | def sync_method(self, *args, _method_name=name, **kwargs): |
| | async_method = getattr(self._async_instance, _method_name) |
| | return AsyncToSyncConverter.run_async_in_thread( |
| | async_method, *args, **kwargs |
| | ) |
| |
|
| | |
| | sync_class_dict[name] = sync_method |
| | else: |
| | |
| | @functools.wraps(method) |
| | def proxy_method(self, *args, _method_name=name, **kwargs): |
| | method = getattr(self._async_instance, _method_name) |
| | return method(*args, **kwargs) |
| |
|
| | |
| | sync_class_dict[name] = proxy_method |
| |
|
| | |
| | for name, prop in inspect.getmembers( |
| | async_class, lambda x: isinstance(x, property) |
| | ): |
| |
|
| | def make_property(name, prop_obj): |
| | def getter(self): |
| | value = getattr(self._async_instance, name) |
| | if inspect.iscoroutinefunction(value): |
| |
|
| | def sync_fn(*args, **kwargs): |
| | return AsyncToSyncConverter.run_async_in_thread( |
| | value, *args, **kwargs |
| | ) |
| |
|
| | return sync_fn |
| | return value |
| |
|
| | def setter(self, value): |
| | setattr(self._async_instance, name, value) |
| |
|
| | return property(getter, setter if prop_obj.fset else None) |
| |
|
| | sync_class_dict[name] = make_property(name, prop) |
| |
|
| | |
| | sync_class = type(sync_class_name, (object,), sync_class_dict) |
| |
|
| | return sync_class |
| |
|
| | @classmethod |
| | def _format_type_annotation( |
| | cls, annotation, type_tracker: Optional[TypeTracker] = None |
| | ) -> str: |
| | """Convert a type annotation to its string representation for stub files.""" |
| | if ( |
| | annotation is inspect.Parameter.empty |
| | or annotation is inspect.Signature.empty |
| | ): |
| | return "Any" |
| |
|
| | |
| | if annotation is type(None): |
| | return "None" |
| |
|
| | |
| | if type_tracker: |
| | type_tracker.track_type(annotation) |
| |
|
| | |
| | try: |
| | origin = get_origin(annotation) |
| | args = get_args(annotation) |
| |
|
| | if origin is not None: |
| | |
| | if type_tracker: |
| | type_tracker.track_type(origin) |
| |
|
| | |
| | origin_name = getattr(origin, "__name__", str(origin)) |
| | if "." in origin_name: |
| | origin_name = origin_name.split(".")[-1] |
| |
|
| | |
| | |
| | if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType": |
| | origin_name = "Union" |
| |
|
| | |
| | if args: |
| | formatted_args = [] |
| | for arg in args: |
| | |
| | if type_tracker: |
| | type_tracker.track_type(arg) |
| | formatted_args.append(cls._format_type_annotation(arg, type_tracker)) |
| | return f"{origin_name}[{', '.join(formatted_args)}]" |
| | else: |
| | return origin_name |
| | except (AttributeError, TypeError): |
| | |
| | pass |
| |
|
| | |
| | if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): |
| | origin = annotation.__origin__ |
| | origin_name = ( |
| | origin.__name__ |
| | if hasattr(origin, "__name__") |
| | else str(origin).split("'")[1] |
| | ) |
| |
|
| | |
| | args = [] |
| | for arg in annotation.__args__: |
| | args.append(cls._format_type_annotation(arg, type_tracker)) |
| |
|
| | return f"{origin_name}[{', '.join(args)}]" |
| |
|
| | |
| | if hasattr(annotation, "__name__"): |
| | return annotation.__name__ |
| |
|
| | |
| | if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): |
| | |
| | return annotation.__qualname__ |
| |
|
| | |
| | type_str = str(annotation) |
| |
|
| | |
| | if type_str.startswith("<class '") and type_str.endswith("'>"): |
| | type_str = type_str[8:-2] |
| |
|
| | |
| | for prefix in ["typing.", "builtins.", "types."]: |
| | if type_str.startswith(prefix): |
| | type_str = type_str[len(prefix) :] |
| |
|
| | |
| | if type_str in ("_empty", "inspect._empty"): |
| | return "None" |
| |
|
| | |
| | if type_str == "NoneType": |
| | return "None" |
| |
|
| | return type_str |
| |
|
| | @classmethod |
| | def _extract_coroutine_return_type(cls, annotation): |
| | """Extract the actual return type from a Coroutine annotation.""" |
| | if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: |
| | |
| | return annotation.__args__[2] |
| | return annotation |
| |
|
| | @classmethod |
| | def _format_parameter_default(cls, default_value) -> str: |
| | """Format a parameter's default value for stub files.""" |
| | if default_value is inspect.Parameter.empty: |
| | return "" |
| | elif default_value is None: |
| | return " = None" |
| | elif isinstance(default_value, bool): |
| | return f" = {default_value}" |
| | elif default_value == {}: |
| | return " = {}" |
| | elif default_value == []: |
| | return " = []" |
| | else: |
| | return f" = {default_value}" |
| |
|
| | @classmethod |
| | def _format_method_parameters( |
| | cls, |
| | sig: inspect.Signature, |
| | skip_self: bool = True, |
| | type_hints: Optional[dict] = None, |
| | type_tracker: Optional[TypeTracker] = None, |
| | ) -> str: |
| | """Format method parameters for stub files.""" |
| | params = [] |
| | if type_hints is None: |
| | type_hints = {} |
| |
|
| | for i, (param_name, param) in enumerate(sig.parameters.items()): |
| | if i == 0 and param_name == "self" and skip_self: |
| | params.append("self") |
| | else: |
| | |
| | annotation = type_hints.get(param_name, param.annotation) |
| | type_str = cls._format_type_annotation(annotation, type_tracker) |
| |
|
| | |
| | default_str = cls._format_parameter_default(param.default) |
| |
|
| | |
| | if annotation is inspect.Parameter.empty: |
| | params.append(f"{param_name}: Any{default_str}") |
| | else: |
| | params.append(f"{param_name}: {type_str}{default_str}") |
| |
|
| | return ", ".join(params) |
| |
|
| | @classmethod |
| | def _generate_method_signature( |
| | cls, |
| | method_name: str, |
| | method, |
| | is_async: bool = False, |
| | type_tracker: Optional[TypeTracker] = None, |
| | ) -> str: |
| | """Generate a complete method signature for stub files.""" |
| | sig = inspect.signature(method) |
| |
|
| | |
| | try: |
| | from typing import get_type_hints |
| | type_hints = get_type_hints(method) |
| | except Exception: |
| | |
| | type_hints = {} |
| |
|
| | |
| | return_annotation = type_hints.get('return', sig.return_annotation) |
| | if is_async and inspect.iscoroutinefunction(method): |
| | return_annotation = cls._extract_coroutine_return_type(return_annotation) |
| |
|
| | |
| | params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) |
| |
|
| | |
| | return_type = cls._format_type_annotation(return_annotation, type_tracker) |
| | if return_annotation is inspect.Signature.empty: |
| | return_type = "None" |
| |
|
| | return f"def {method_name}({params_str}) -> {return_type}: ..." |
| |
|
| | @classmethod |
| | def _generate_imports( |
| | cls, async_class: Type, type_tracker: TypeTracker |
| | ) -> list[str]: |
| | """Generate import statements for the stub file.""" |
| | imports = [] |
| |
|
| | |
| | imports.append( |
| | "from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" |
| | ) |
| |
|
| | |
| | if async_class.__module__ != "builtins": |
| | module = inspect.getmodule(async_class) |
| | additional_types = [] |
| |
|
| | if module: |
| | |
| | module_all = getattr(module, "__all__", None) |
| |
|
| | for name, obj in sorted(inspect.getmembers(module)): |
| | if isinstance(obj, type): |
| | |
| | |
| | if module_all is not None and name not in module_all: |
| | |
| | if name not in type_tracker.discovered_types: |
| | continue |
| |
|
| | |
| | if issubclass(obj, tuple) and hasattr(obj, "_fields"): |
| | additional_types.append(name) |
| | |
| | type_tracker.already_imported.add(name) |
| | |
| | elif issubclass(obj, Enum) and name != "Enum": |
| | additional_types.append(name) |
| | |
| | type_tracker.already_imported.add(name) |
| |
|
| | if additional_types: |
| | type_imports = ", ".join([async_class.__name__] + additional_types) |
| | imports.append(f"from {async_class.__module__} import {type_imports}") |
| | else: |
| | imports.append( |
| | f"from {async_class.__module__} import {async_class.__name__}" |
| | ) |
| |
|
| | |
| | |
| | imports.extend( |
| | type_tracker.get_imports(main_module_name=async_class.__module__) |
| | ) |
| |
|
| | |
| | if hasattr(inspect.getmodule(async_class), "__name__"): |
| | module_name = inspect.getmodule(async_class).__name__ |
| | if "." in module_name: |
| | base_module = module_name.split(".")[0] |
| | |
| | if not any(imp.startswith(f"from {base_module}") for imp in imports): |
| | imports.append(f"import {base_module}") |
| |
|
| | return imports |
| |
|
| | @classmethod |
| | def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: |
| | """Extract class attributes that are classes themselves.""" |
| | class_attributes = [] |
| |
|
| | |
| | for name, attr in sorted(inspect.getmembers(async_class)): |
| | if isinstance(attr, type) and not name.startswith("_"): |
| | class_attributes.append((name, attr)) |
| | elif ( |
| | hasattr(async_class, "__annotations__") |
| | and name in async_class.__annotations__ |
| | ): |
| | annotation = async_class.__annotations__[name] |
| | if isinstance(annotation, type): |
| | class_attributes.append((name, annotation)) |
| |
|
| | return class_attributes |
| |
|
| | @classmethod |
| | def _generate_inner_class_stub( |
| | cls, |
| | name: str, |
| | attr: Type, |
| | indent: str = " ", |
| | type_tracker: Optional[TypeTracker] = None, |
| | ) -> list[str]: |
| | """Generate stub for an inner class.""" |
| | stub_lines = [] |
| | stub_lines.append(f"{indent}class {name}Sync:") |
| |
|
| | |
| | if hasattr(attr, "__doc__") and attr.__doc__: |
| | stub_lines.extend( |
| | cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") |
| | ) |
| |
|
| | |
| | if hasattr(attr, "__init__"): |
| | try: |
| | init_method = getattr(attr, "__init__") |
| | init_sig = inspect.signature(init_method) |
| |
|
| | |
| | try: |
| | from typing import get_type_hints |
| | init_hints = get_type_hints(init_method) |
| | except Exception: |
| | init_hints = {} |
| |
|
| | |
| | params_str = cls._format_method_parameters( |
| | init_sig, type_hints=init_hints, type_tracker=type_tracker |
| | ) |
| | |
| | if hasattr(init_method, "__doc__") and init_method.__doc__: |
| | stub_lines.extend( |
| | cls._format_docstring_for_stub( |
| | init_method.__doc__, f"{indent} " |
| | ) |
| | ) |
| | stub_lines.append( |
| | f"{indent} def __init__({params_str}) -> None: ..." |
| | ) |
| | except (ValueError, TypeError): |
| | stub_lines.append( |
| | f"{indent} def __init__(self, *args, **kwargs) -> None: ..." |
| | ) |
| |
|
| | |
| | has_methods = False |
| | for method_name, method in sorted( |
| | inspect.getmembers(attr, predicate=inspect.isfunction) |
| | ): |
| | if method_name.startswith("_"): |
| | continue |
| |
|
| | has_methods = True |
| | try: |
| | |
| | if method.__doc__: |
| | stub_lines.extend( |
| | cls._format_docstring_for_stub(method.__doc__, f"{indent} ") |
| | ) |
| |
|
| | method_sig = cls._generate_method_signature( |
| | method_name, method, is_async=True, type_tracker=type_tracker |
| | ) |
| | stub_lines.append(f"{indent} {method_sig}") |
| | except (ValueError, TypeError): |
| | stub_lines.append( |
| | f"{indent} def {method_name}(self, *args, **kwargs): ..." |
| | ) |
| |
|
| | if not has_methods: |
| | stub_lines.append(f"{indent} pass") |
| |
|
| | return stub_lines |
| |
|
| | @classmethod |
| | def _format_docstring_for_stub( |
| | cls, docstring: str, indent: str = " " |
| | ) -> list[str]: |
| | """Format a docstring for inclusion in a stub file with proper indentation.""" |
| | if not docstring: |
| | return [] |
| |
|
| | |
| | dedented = textwrap.dedent(docstring).strip() |
| |
|
| | |
| | lines = dedented.split("\n") |
| |
|
| | |
| | result = [] |
| | result.append(f'{indent}"""') |
| |
|
| | for line in lines: |
| | if line.strip(): |
| | result.append(f"{indent}{line}") |
| | else: |
| | result.append("") |
| |
|
| | result.append(f'{indent}"""') |
| | return result |
| |
|
| | @classmethod |
| | def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: |
| | """Post-process stub content to fix any remaining issues.""" |
| | processed = [] |
| |
|
| | for line in stub_content: |
| | |
| | if line.startswith(("from ", "import ")): |
| | processed.append(line) |
| | continue |
| |
|
| | |
| | if ( |
| | line.strip().startswith("def ") |
| | and line.strip().endswith(": ...") |
| | and ") -> " not in line |
| | ): |
| | |
| | line = line.replace(": ...", " -> None: ...") |
| |
|
| | processed.append(line) |
| |
|
| | return processed |
| |
|
| | @classmethod |
| | def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: |
| | """ |
| | Generate a .pyi stub file for the sync class to help IDEs with type checking. |
| | """ |
| | try: |
| | |
| | if async_class.__module__ == "__main__": |
| | return |
| |
|
| | module = inspect.getmodule(async_class) |
| | if not module: |
| | return |
| |
|
| | module_path = module.__file__ |
| | if not module_path: |
| | return |
| |
|
| | |
| | module_dir = os.path.dirname(module_path) |
| | stub_dir = os.path.join(module_dir, "generated") |
| |
|
| | |
| | os.makedirs(stub_dir, exist_ok=True) |
| |
|
| | module_name = os.path.basename(module_path) |
| | if module_name.endswith(".py"): |
| | module_name = module_name[:-3] |
| |
|
| | sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") |
| |
|
| | |
| | type_tracker = TypeTracker() |
| |
|
| | stub_content = [] |
| |
|
| | |
| | |
| | imports_placeholder_index = len(stub_content) |
| | stub_content.append("") |
| |
|
| | |
| | stub_content.append(f"class {sync_class.__name__}:") |
| |
|
| | |
| | if async_class.__doc__: |
| | stub_content.extend( |
| | cls._format_docstring_for_stub(async_class.__doc__, " ") |
| | ) |
| |
|
| | |
| | try: |
| | init_method = async_class.__init__ |
| | init_signature = inspect.signature(init_method) |
| |
|
| | |
| | try: |
| | from typing import get_type_hints |
| | init_hints = get_type_hints(init_method) |
| | except Exception: |
| | init_hints = {} |
| |
|
| | |
| | params_str = cls._format_method_parameters( |
| | init_signature, type_hints=init_hints, type_tracker=type_tracker |
| | ) |
| | |
| | if hasattr(init_method, "__doc__") and init_method.__doc__: |
| | stub_content.extend( |
| | cls._format_docstring_for_stub(init_method.__doc__, " ") |
| | ) |
| | stub_content.append(f" def __init__({params_str}) -> None: ...") |
| | except (ValueError, TypeError): |
| | stub_content.append( |
| | " def __init__(self, *args, **kwargs) -> None: ..." |
| | ) |
| |
|
| | stub_content.append("") |
| |
|
| | |
| | class_attributes = cls._get_class_attributes(async_class) |
| |
|
| | |
| | for name, attr in class_attributes: |
| | inner_class_stub = cls._generate_inner_class_stub( |
| | name, attr, type_tracker=type_tracker |
| | ) |
| | stub_content.extend(inner_class_stub) |
| | stub_content.append("") |
| |
|
| | |
| | processed_methods = set() |
| | for name, method in sorted( |
| | inspect.getmembers(async_class, predicate=inspect.isfunction) |
| | ): |
| | if name.startswith("_") or name in processed_methods: |
| | continue |
| |
|
| | processed_methods.add(name) |
| |
|
| | try: |
| | method_sig = cls._generate_method_signature( |
| | name, method, is_async=True, type_tracker=type_tracker |
| | ) |
| |
|
| | |
| | if method.__doc__: |
| | stub_content.extend( |
| | cls._format_docstring_for_stub(method.__doc__, " ") |
| | ) |
| |
|
| | stub_content.append(f" {method_sig}") |
| |
|
| | stub_content.append("") |
| |
|
| | except (ValueError, TypeError): |
| | |
| | stub_content.append(f" def {name}(self, *args, **kwargs): ...") |
| | stub_content.append("") |
| |
|
| | |
| | for name, prop in sorted( |
| | inspect.getmembers(async_class, lambda x: isinstance(x, property)) |
| | ): |
| | stub_content.append(" @property") |
| | stub_content.append(f" def {name}(self) -> Any: ...") |
| | if prop.fset: |
| | stub_content.append(f" @{name}.setter") |
| | stub_content.append( |
| | f" def {name}(self, value: Any) -> None: ..." |
| | ) |
| | stub_content.append("") |
| |
|
| | |
| | |
| | attribute_mappings = {} |
| |
|
| | |
| | |
| | all_annotations = {} |
| | for base_class in reversed(inspect.getmro(async_class)): |
| | if hasattr(base_class, "__annotations__"): |
| | all_annotations.update(base_class.__annotations__) |
| |
|
| | for attr_name, attr_type in sorted(all_annotations.items()): |
| | for class_name, class_type in class_attributes: |
| | |
| | if ( |
| | attr_type == class_type |
| | or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) |
| | or (isinstance(attr_type, str) and attr_type == class_name) |
| | ): |
| | attribute_mappings[class_name] = attr_name |
| |
|
| | |
| |
|
| | |
| | for class_name, class_type in class_attributes: |
| | |
| | attr_name = attribute_mappings.get(class_name, class_name) |
| | |
| | |
| | stub_content.append(f" {attr_name}: {class_name}Sync") |
| |
|
| | stub_content.append("") |
| |
|
| | |
| | imports = cls._generate_imports(async_class, type_tracker) |
| |
|
| | |
| | seen = set() |
| | unique_imports = [] |
| | for imp in imports: |
| | if imp not in seen: |
| | seen.add(imp) |
| | unique_imports.append(imp) |
| | else: |
| | logging.warning(f"Duplicate import detected: {imp}") |
| |
|
| | |
| | stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( |
| | unique_imports |
| | ) |
| |
|
| | |
| | stub_content = cls._post_process_stub_content(stub_content) |
| |
|
| | |
| | with open(sync_stub_path, "w") as f: |
| | f.write("\n".join(stub_content)) |
| |
|
| | logging.info(f"Generated stub file: {sync_stub_path}") |
| |
|
| | except Exception as e: |
| | |
| | logging.error( |
| | f"Error generating stub file for {sync_class.__name__}: {str(e)}" |
| | ) |
| | import traceback |
| |
|
| | logging.error(traceback.format_exc()) |
| |
|
| |
|
| | def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: |
| | """ |
| | Creates a sync version of an async class |
| | |
| | Args: |
| | async_class: The async class to convert |
| | thread_pool_size: Size of thread pool to use |
| | |
| | Returns: |
| | A new class with sync versions of all async methods |
| | """ |
| | return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size) |
| |
|