Spaces:
Paused
Paused
| # Licensed to the Apache Software Foundation (ASF) under one | |
| # or more contributor license agreements. See the NOTICE file | |
| # distributed with this work for additional information | |
| # regarding copyright ownership. The ASF licenses this file | |
| # to you under the Apache License, Version 2.0 (the | |
| # "License"); you may not use this file except in compliance | |
| # with the License. You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an | |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| # KIND, either express or implied. See the License for the | |
| # specific language governing permissions and limitations | |
| # under the License. | |
| """An example Flight CLI client.""" | |
| import argparse | |
| import sys | |
| import pyarrow | |
| import pyarrow.flight | |
| import pyarrow.csv as csv | |
| def list_flights(args, client, connection_args={}): | |
| print('Flights\n=======') | |
| for flight in client.list_flights(): | |
| descriptor = flight.descriptor | |
| if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH: | |
| print("Path:", descriptor.path) | |
| elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD: | |
| print("Command:", descriptor.command) | |
| else: | |
| print("Unknown descriptor type") | |
| print("Total records:", end=" ") | |
| if flight.total_records >= 0: | |
| print(flight.total_records) | |
| else: | |
| print("Unknown") | |
| print("Total bytes:", end=" ") | |
| if flight.total_bytes >= 0: | |
| print(flight.total_bytes) | |
| else: | |
| print("Unknown") | |
| print("Number of endpoints:", len(flight.endpoints)) | |
| print("Schema:") | |
| print(flight.schema) | |
| print('---') | |
| print('\nActions\n=======') | |
| for action in client.list_actions(): | |
| print("Type:", action.type) | |
| print("Description:", action.description) | |
| print('---') | |
| def do_action(args, client, connection_args={}): | |
| try: | |
| buf = pyarrow.allocate_buffer(0) | |
| action = pyarrow.flight.Action(args.action_type, buf) | |
| print('Running action', args.action_type) | |
| for result in client.do_action(action): | |
| print("Got result", result.body.to_pybytes()) | |
| except pyarrow.lib.ArrowIOError as e: | |
| print("Error calling action:", e) | |
| def push_data(args, client, connection_args={}): | |
| print('File Name:', args.file) | |
| my_table = csv.read_csv(args.file) | |
| print('Table rows=', str(len(my_table))) | |
| df = my_table.to_pandas() | |
| print(df.head()) | |
| writer, _ = client.do_put( | |
| pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema) | |
| writer.write_table(my_table) | |
| writer.close() | |
| def get_flight(args, client, connection_args={}): | |
| if args.path: | |
| descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path) | |
| else: | |
| descriptor = pyarrow.flight.FlightDescriptor.for_command(args.command) | |
| info = client.get_flight_info(descriptor) | |
| for endpoint in info.endpoints: | |
| print('Ticket:', endpoint.ticket) | |
| for location in endpoint.locations: | |
| print(location) | |
| get_client = pyarrow.flight.FlightClient(location, | |
| **connection_args) | |
| reader = get_client.do_get(endpoint.ticket) | |
| df = reader.read_pandas() | |
| print(df) | |
| def _add_common_arguments(parser): | |
| parser.add_argument('--tls', action='store_true', | |
| help='Enable transport-level security') | |
| parser.add_argument('--tls-roots', default=None, | |
| help='Path to trusted TLS certificate(s)') | |
| parser.add_argument("--mtls", nargs=2, default=None, | |
| metavar=('CERTFILE', 'KEYFILE'), | |
| help="Enable transport-level security") | |
| parser.add_argument('host', type=str, | |
| help="Address or hostname to connect to") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| subcommands = parser.add_subparsers() | |
| cmd_list = subcommands.add_parser('list') | |
| cmd_list.set_defaults(action='list') | |
| _add_common_arguments(cmd_list) | |
| cmd_list.add_argument('-l', '--list', action='store_true', | |
| help="Print more details.") | |
| cmd_do = subcommands.add_parser('do') | |
| cmd_do.set_defaults(action='do') | |
| _add_common_arguments(cmd_do) | |
| cmd_do.add_argument('action_type', type=str, | |
| help="The action type to run.") | |
| cmd_put = subcommands.add_parser('put') | |
| cmd_put.set_defaults(action='put') | |
| _add_common_arguments(cmd_put) | |
| cmd_put.add_argument('file', type=str, | |
| help="CSV file to upload.") | |
| cmd_get = subcommands.add_parser('get') | |
| cmd_get.set_defaults(action='get') | |
| _add_common_arguments(cmd_get) | |
| cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True) | |
| cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append', | |
| help="The path for the descriptor.") | |
| cmd_get_descriptor.add_argument('-c', '--command', type=str, | |
| help="The command for the descriptor.") | |
| args = parser.parse_args() | |
| if not hasattr(args, 'action'): | |
| parser.print_help() | |
| sys.exit(1) | |
| commands = { | |
| 'list': list_flights, | |
| 'do': do_action, | |
| 'get': get_flight, | |
| 'put': push_data, | |
| } | |
| host, port = args.host.split(':') | |
| port = int(port) | |
| scheme = "grpc+tcp" | |
| connection_args = {} | |
| if args.tls: | |
| scheme = "grpc+tls" | |
| if args.tls_roots: | |
| with open(args.tls_roots, "rb") as root_certs: | |
| connection_args["tls_root_certs"] = root_certs.read() | |
| if args.mtls: | |
| with open(args.mtls[0], "rb") as cert_file: | |
| tls_cert_chain = cert_file.read() | |
| with open(args.mtls[1], "rb") as key_file: | |
| tls_private_key = key_file.read() | |
| connection_args["cert_chain"] = tls_cert_chain | |
| connection_args["private_key"] = tls_private_key | |
| client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}", | |
| **connection_args) | |
| while True: | |
| try: | |
| action = pyarrow.flight.Action("healthcheck", b"") | |
| options = pyarrow.flight.FlightCallOptions(timeout=1) | |
| list(client.do_action(action, options=options)) | |
| break | |
| except pyarrow.ArrowIOError as e: | |
| if "Deadline" in str(e): | |
| print("Server is not ready, waiting...") | |
| commands[args.action](args, client, connection_args) | |
| if __name__ == '__main__': | |
| main() | |