| import argparse | |
| from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration | |
| def main(args): | |
| pt_model = T5ForConditionalGeneration.from_pretrained(args.model_dir, from_flax=True) | |
| pt_model.save_pretrained(args.model_dir) | |
| tf_model = TFT5ForConditionalGeneration.from_pretrained(args.model_dir, from_pt=True) | |
| tf_model.save_pretrained(args.model_dir) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model_dir', type=str, default='.') |