| | from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer |
| |
|
| |
|
| | hidden_size = 8 |
| | num_hidden_layers = 2 |
| | num_attention_heads = 2 |
| | intermediate_size = 16 |
| |
|
| | n_embd = 8 |
| | n_layer = 2 |
| | n_head = 2 |
| | n_inner = 16 |
| |
|
| | encoder_config = ViTConfig( |
| | hidden_size=hidden_size, |
| | num_hidden_layers=num_hidden_layers, |
| | num_attention_heads=num_attention_heads, |
| | intermediate_size=intermediate_size, |
| | ) |
| | decoder_config = GPT2Config( |
| | n_embd=n_embd, |
| | n_layer=n_layer, |
| | n_head=n_head, |
| | n_inner=n_inner, |
| | ) |
| | encoder = FlaxViTModel(encoder_config) |
| | decoder = FlaxGPT2Model(decoder_config) |
| | encoder.save_pretrained("./encoder-decoder/encoder") |
| | decoder.save_pretrained("./encoder-decoder/decoder") |
| |
|
| | enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( |
| | "./encoder-decoder/encoder", |
| | "./encoder-decoder/decoder", |
| | ) |
| | enocder_decoder.save_pretrained("./encoder-decoder") |
| | enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder") |
| |
|
| |
|
| | config = enocder_decoder.config |
| |
|
| | decoder_start_token_id = getattr(config, "decoder_start_token_id", None) |
| | if not decoder_start_token_id and getattr(config, "decoder", None): |
| | decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None) |
| | bos_token_id = getattr(config, "bos_token_id", None) |
| | if not bos_token_id and getattr(config, "decoder", None): |
| | bos_token_id = getattr(config.decoder, "bos_token_id", None) |
| | eos_token_id = getattr(config, "eos_token_id", None) |
| | if not eos_token_id and getattr(config, "decoder", None): |
| | eos_token_id = getattr(config.decoder, "eos_token_id", None) |
| | pad_token_id = getattr(config, "pad_token_id", None) |
| | if not pad_token_id and getattr(config, "decoder", None): |
| | pad_token_id = getattr(config.decoder, "pad_token_id", None) |
| |
|
| | if decoder_start_token_id is None: |
| | decoder_start_token_id = bos_token_id |
| | if pad_token_id is None: |
| | pad_token_id = eos_token_id |
| | |
| | config.decoder_start_token_id = decoder_start_token_id |
| | config.bos_token_id = bos_token_id |
| | config.eos_token_id = eos_token_id |
| | config.pad_token_id = pad_token_id |
| |
|
| | if getattr(config, "decoder", None): |
| | config.decoder.decoder_start_token_id = decoder_start_token_id |
| | config.decoder.bos_token_id = bos_token_id |
| | config.decoder.eos_token_id = eos_token_id |
| | config.decoder.pad_token_id = pad_token_id |
| |
|
| | fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") |
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| | tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) |
| |
|
| | fe.save_pretrained("./encoder-decoder/encoder") |
| | tokenizer.save_pretrained("./encoder-decoder/decoder") |
| |
|
| | targets = ['i love dog', 'you cat is very cute'] |
| |
|
| | |
| | with tokenizer.as_target_tokenizer(): |
| | labels = tokenizer( |
| | targets, max_length=8, padding="max_length", truncation=True, return_tensors="np" |
| | ) |
| | |
| | print(labels) |
| |
|