import dataclasses import logging import types import typing from typing import Any, TypeVar, cast, get_args, get_origin, get_type_hints from surrealdb import ( BlockingHttpSurrealConnection, BlockingWsSurrealConnection, Value, ) from ..definitions import Object RecordType = TypeVar("RecordType") logger = logging.getLogger(__name__) def _coerce_value(value: Any, target_type: Any) -> Any: # pyright: ignore[reportExplicitAny, reportAny] """Recursively coerce SurrealDB-returned values (dict/list) into typed values. Intended primarily for nested dataclass graphs (e.g. dict -> dataclass, list[dataclass], etc.). """ if target_type is Any or target_type is object: return value # pyright: ignore[reportAny] if value is None: return None origin = get_origin(target_type) # pyright: ignore[reportAny] args = get_args(target_type) # Optional[T] / Union[...] if target_type is None or target_type is type(None): return None if origin in (typing.Union, types.UnionType): # pyright: ignore[reportDeprecated] # Best-effort: try each option and return the first successful conversion. for opt in args: # pyright: ignore[reportAny] if opt is type(None): # noqa: E721 continue try: return _coerce_value(value, opt) # pyright: ignore[reportAny] except Exception: continue return value # pyright: ignore[reportAny] # dataclass types if isinstance(target_type, type) and dataclasses.is_dataclass(target_type): if isinstance(value, target_type): return value if isinstance(value, dict): return _coerce_dataclass(value, target_type) # pyright: ignore[reportAny, reportUnknownArgumentType] return value # pyright: ignore[reportAny] # Containers if origin in (list, tuple, set): inner_t = args[0] if len(args) >= 1 else Any if isinstance(value, list): coerced_list = [_coerce_value(v, inner_t) for v in value] # pyright: ignore[reportUnknownVariableType] if origin is tuple: return tuple(coerced_list) if origin is set: return set(coerced_list) return coerced_list return value # pyright: ignore[reportAny] if origin is dict: key_t, val_t = args if len(args) == 2 else (Any, Any) if isinstance(value, dict): return { _coerce_value(k, key_t): _coerce_value(v, val_t) for k, v in value.items() # pyright: ignore[reportUnknownVariableType] } return value # pyright: ignore[reportAny] # Primitive / passthrough if isinstance(target_type, type) and isinstance(value, target_type): return value # pyright: ignore[reportAny] return value # pyright: ignore[reportAny] def _coerce_dataclass(data: dict[str, Any], cls: type[Any]) -> Any: # pyright: ignore[reportExplicitAny, reportAny] """Build dataclass `cls` from dict `data`, recursively coercing nested fields.""" type_hints = get_type_hints(cls) kwargs: dict[str, Any] = {} # pyright: ignore[reportExplicitAny] for field in dataclasses.fields(cls): name = field.name if name not in data: continue kwargs[name] = _coerce_value(data[name], type_hints.get(name, Any)) return cls(**kwargs) # pyright: ignore[reportAny] def parse_time(time: str) -> float: r""" Examples: - "123.456µs" => 0.123456 - "1.939083ms" => 1.939083 - "1ms" => 1 - "1.2345s" => 1234.5 """ import re regex = re.compile(r"(\d+\.?\d*)s") match = regex.match(time) if match: return float(match.group(1)) * 1000 regex = re.compile(r"(\d+\.?\d*)ms") match = regex.match(time) if match: return float(match.group(1)) regex = re.compile(r"(\d+\.?\d*)µs") match = regex.match(time) if match: return float(match.group(1)) / 1000 raise ValueError(f"Invalid time format: {time}") def _query_aux( client: BlockingWsSurrealConnection | BlockingHttpSurrealConnection, query: str, vars: Object, ) -> Value: try: response = client.query(query, cast(dict[str, Value], vars)) logger.debug(f"Query: {query} with {vars}, Response: {response}") except Exception as e: logger.error(f"Query execution error: {query} with {vars}, Error: {e}") raise e return response def query( client: BlockingWsSurrealConnection | BlockingHttpSurrealConnection, query: str, vars: Object, record_type: type[RecordType], ) -> list[RecordType]: response = _query_aux(client, query, vars) if isinstance(response, list): if dataclasses.is_dataclass(record_type) and hasattr( record_type, "from_dict" ): cast_fn = getattr(record_type, "from_dict") # pyright: ignore[reportAny] casted: list[RecordType] = [cast_fn.__call__(x) for x in response] # pyright: ignore[reportAny] assert all(isinstance(x, record_type) for x in casted) return casted if dataclasses.is_dataclass(record_type): casted = [_coerce_value(x, record_type) for x in response] assert all(isinstance(x, record_type) for x in casted) # return cast(list[RecordType], casted) return casted else: return [record_type(**x) for x in response] else: raise TypeError(f"Unexpected response type: {type(response)}") def query_one( client: BlockingWsSurrealConnection | BlockingHttpSurrealConnection, query: str, vars: Object, record_type: type[RecordType], ) -> RecordType | None: response = _query_aux(client, query, vars) if response is None: return None elif not isinstance(response, list): if dataclasses.is_dataclass(record_type) and hasattr( record_type, "from_dict" ): casted = getattr(record_type, "from_dict").__call__(response) # pyright: ignore[reportAny] assert isinstance(casted, record_type) return casted if dataclasses.is_dataclass(record_type) and isinstance(response, dict): casted = _coerce_value(response, record_type) # pyright: ignore[reportAny] assert isinstance(casted, record_type) return casted elif isinstance(response, dict): try: return record_type(**response) except Exception as e: print(f"Error creating record: {e}. Response: {response}") raise e raise TypeError(f"Unexpected response type: {type(response)}")