| import jax |
| print(jax.local_device_count()) |
| import jax.numpy as jnp |
|
|
| import flax |
| import flax.linen as nn |
| from flax.core.frozen_dict import FrozenDict, unfreeze |
| from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key |
|
|
| from transformers import GPTNeoConfig |
| from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel |
| from transformers import GPT2Tokenizer |
|
|
| from datasets import load_dataset |
| import pandas as pd |
|
|
| num_choices=2 |
| dataset = load_dataset('winogrande', 'winogrande_xl') |
|
|
| def preprocess(example): |
| example['first_sentence']=[example['sentence']]*num_choices |
| example['second_sentence']=[example[f'option{i}'] for i in [1,2]] |
| return example |
|
|
| test_dataset=dataset['test'].map(preprocess) |
|
|
| len_test_dataset=100 |
|
|
| test_dataset=test_dataset.select(range(len_test_dataset)) |
|
|
| tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>') |
|
|
| remove_col=test_dataset.column_names |
|
|
| def tokenize(examples): |
| tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax') |
| return tokenized_examples |
|
|
| test_dataset=test_dataset.map(tokenize) |
|
|
| test_dataset=test_dataset.remove_columns(remove_col) |
| list1=[] |
|
|
| def glue_test_data_loader(rng,dataset,batch_size): |
| steps_per_epoch=len_test_dataset//batch_size |
| perms=jax.random.permutation(rng,len_test_dataset) |
| perms=perms[:steps_per_epoch*batch_size] |
| perms=perms.reshape((steps_per_epoch,batch_size)) |
| for perm in perms: |
| list1.append(perm) |
| batch=dataset[perm] |
| |
| batch={k:jnp.array(v) for k,v in batch.items()} |
| |
| yield batch |
|
|
| seed=0 |
| rng=jax.random.PRNGKey(seed) |
| dropout_rngs=jax.random.split(rng,jax.local_device_count()) |
|
|
| input_id=jnp.array(test_dataset['input_ids']) |
| att_mask=jnp.array(test_dataset['attention_mask']) |
|
|
| total_batch_size=16 |
|
|
| from model_file import FlaxGPTNeoForMultipleChoice |
|
|
| model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_winogrande',input_shape=(1,num_choices,1)) |
|
|
| restored_output=[] |
| rng, input_rng = jax.random.split(rng) |
| for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)): |
| outputs=model(batch['input_ids'],batch['attention_mask']) |
| final_output=jnp.argmax(outputs,axis=-1) |
| restored_output.append(final_output) |
|
|
| finall=pd.DataFrame({'predictions':restored_output,'permutation':list1}) |
| finall.to_csv('./winogrande_predictions.csv') |