Spaces:
Runtime error
Runtime error
Dit-document-layout-analysis
/
unilm
/decoding
/IAD
/fairseq
/examples
/simultaneous_translation
/eval
/evaluate.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from agents import build_agent | |
| from client import SimulSTEvaluationService, SimulSTLocalEvaluationService | |
| from fairseq.registry import REGISTRIES | |
| DEFAULT_HOSTNAME = "localhost" | |
| DEFAULT_PORT = 12321 | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname" | |
| ) | |
| parser.add_argument( | |
| "--port", type=int, default=DEFAULT_PORT, help="server port number" | |
| ) | |
| parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type") | |
| parser.add_argument("--scorer-type", default="text", help="Scorer type") | |
| parser.add_argument( | |
| "--start-idx", | |
| type=int, | |
| default=0, | |
| help="Start index of the sentence to evaluate", | |
| ) | |
| parser.add_argument( | |
| "--end-idx", | |
| type=int, | |
| default=float("inf"), | |
| help="End index of the sentence to evaluate", | |
| ) | |
| parser.add_argument( | |
| "--scores", action="store_true", help="Request scores from server" | |
| ) | |
| parser.add_argument("--reset-server", action="store_true", help="Reset the server") | |
| parser.add_argument( | |
| "--num-threads", type=int, default=10, help="Number of threads used by agent" | |
| ) | |
| parser.add_argument( | |
| "--local", action="store_true", default=False, help="Local evaluation" | |
| ) | |
| args, _ = parser.parse_known_args() | |
| for registry_name, REGISTRY in REGISTRIES.items(): | |
| choice = getattr(args, registry_name, None) | |
| if choice is not None: | |
| cls = REGISTRY["registry"][choice] | |
| if hasattr(cls, "add_args"): | |
| cls.add_args(parser) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = get_args() | |
| if args.local: | |
| session = SimulSTLocalEvaluationService(args) | |
| else: | |
| session = SimulSTEvaluationService(args.hostname, args.port) | |
| if args.reset_server: | |
| session.new_session() | |
| if args.agent_type is not None: | |
| agent = build_agent(args) | |
| agent.decode(session, args.start_idx, args.end_idx, args.num_threads) | |
| if args.scores: | |
| session.get_scores() | |
| print(session.get_scores()) | |