| """Populate `text_encoder/` with the flan-t5-large encoder for upload. |
| |
| MiniT2I conditions on the encoder of `google/flan-t5-large`. This script saves |
| just the encoder weights + tokenizer into this folder so it can be uploaded to |
| `<user>/text_encoder` and loaded with `T5EncoderModel.from_pretrained(repo)` |
| (pure transformers, no diffusers). Weights are intentionally not committed to the |
| source repo; run this once before uploading. |
| |
| python minit2i_hf/text_encoder/prepare_text_encoder.py |
| """ |
| import argparse |
| from pathlib import Path |
|
|
| from transformers import AutoTokenizer, T5EncoderModel |
|
|
| HERE = Path(__file__).resolve().parent |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--source", default="google/flan-t5-large", help="Source HF model id") |
| args = ap.parse_args() |
|
|
| print(f"Downloading {args.source} encoder + tokenizer ...") |
| tokenizer = AutoTokenizer.from_pretrained(args.source) |
| encoder = T5EncoderModel.from_pretrained(args.source) |
|
|
| tokenizer.save_pretrained(HERE) |
| encoder.save_pretrained(HERE) |
| print(f"Saved encoder + tokenizer -> {HERE}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|