Spaces:
Paused
Paused
| # Licensed to the Apache Software Foundation (ASF) under one | |
| # or more contributor license agreements. See the NOTICE file | |
| # distributed with this work for additional information | |
| # regarding copyright ownership. The ASF licenses this file | |
| # to you under the Apache License, Version 2.0 (the | |
| # "License"); you may not use this file except in compliance | |
| # with the License. You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an | |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| # KIND, either express or implied. See the License for the | |
| # specific language governing permissions and limitations | |
| # under the License. | |
| """Example of invisibly propagating a request ID with middleware.""" | |
| import argparse | |
| import sys | |
| import threading | |
| import uuid | |
| import pyarrow as pa | |
| import pyarrow.flight as flight | |
| class TraceContext: | |
| _locals = threading.local() | |
| _locals.trace_id = None | |
| def current_trace_id(cls): | |
| if not getattr(cls._locals, "trace_id", None): | |
| cls.set_trace_id(uuid.uuid4().hex) | |
| return cls._locals.trace_id | |
| def set_trace_id(cls, trace_id): | |
| cls._locals.trace_id = trace_id | |
| TRACE_HEADER = "x-tracing-id" | |
| class TracingServerMiddleware(flight.ServerMiddleware): | |
| def __init__(self, trace_id): | |
| self.trace_id = trace_id | |
| def sending_headers(self): | |
| return { | |
| TRACE_HEADER: self.trace_id, | |
| } | |
| class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory): | |
| def start_call(self, info, headers): | |
| print("Starting new call:", info) | |
| if TRACE_HEADER in headers: | |
| trace_id = headers[TRACE_HEADER][0] | |
| print("Found trace header with value:", trace_id) | |
| TraceContext.set_trace_id(trace_id) | |
| return TracingServerMiddleware(TraceContext.current_trace_id()) | |
| class TracingClientMiddleware(flight.ClientMiddleware): | |
| def sending_headers(self): | |
| print("Sending trace ID:", TraceContext.current_trace_id()) | |
| return { | |
| "x-tracing-id": TraceContext.current_trace_id(), | |
| } | |
| def received_headers(self, headers): | |
| if TRACE_HEADER in headers: | |
| trace_id = headers[TRACE_HEADER][0] | |
| print("Found trace header with value:", trace_id) | |
| # Don't overwrite our trace ID | |
| class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory): | |
| def start_call(self, info): | |
| print("Starting new call:", info) | |
| return TracingClientMiddleware() | |
| class FlightServer(flight.FlightServerBase): | |
| def __init__(self, delegate, **kwargs): | |
| super().__init__(**kwargs) | |
| if delegate: | |
| self.delegate = flight.connect( | |
| delegate, | |
| middleware=(TracingClientMiddlewareFactory(),)) | |
| else: | |
| self.delegate = None | |
| def list_actions(self, context): | |
| return [ | |
| ("get-trace-id", "Get the trace context ID."), | |
| ] | |
| def do_action(self, context, action): | |
| trace_middleware = context.get_middleware("trace") | |
| if trace_middleware: | |
| TraceContext.set_trace_id(trace_middleware.trace_id) | |
| if action.type == "get-trace-id": | |
| if self.delegate: | |
| for result in self.delegate.do_action(action): | |
| yield result | |
| else: | |
| trace_id = TraceContext.current_trace_id().encode("utf-8") | |
| print("Returning trace ID:", trace_id) | |
| buf = pa.py_buffer(trace_id) | |
| yield pa.flight.Result(buf) | |
| else: | |
| raise KeyError(f"Unknown action {action.type!r}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| subparsers = parser.add_subparsers(dest="command") | |
| client = subparsers.add_parser("client", help="Run the client.") | |
| client.add_argument("server") | |
| client.add_argument("--request-id", default=None) | |
| server = subparsers.add_parser("server", help="Run the server.") | |
| server.add_argument( | |
| "--listen", | |
| required=True, | |
| help="The location to listen on (example: grpc://localhost:5050)", | |
| ) | |
| server.add_argument( | |
| "--delegate", | |
| required=False, | |
| default=None, | |
| help=("A location to delegate to. That is, this server will " | |
| "simply call the given server for the response. Demonstrates " | |
| "propagation of the trace ID between servers."), | |
| ) | |
| args = parser.parse_args() | |
| if not getattr(args, "command"): | |
| parser.print_help() | |
| return 1 | |
| if args.command == "server": | |
| server = FlightServer( | |
| args.delegate, | |
| location=args.listen, | |
| middleware={"trace": TracingServerMiddlewareFactory()}) | |
| server.serve() | |
| elif args.command == "client": | |
| client = flight.connect( | |
| args.server, | |
| middleware=(TracingClientMiddlewareFactory(),)) | |
| if args.request_id: | |
| TraceContext.set_trace_id(args.request_id) | |
| else: | |
| TraceContext.set_trace_id("client-chosen-id") | |
| for result in client.do_action(flight.Action("get-trace-id", b"")): | |
| print(result.body.to_pybytes()) | |
| if __name__ == "__main__": | |
| sys.exit(main() or 0) | |