Vivek commited on
Commit
61622dc
·
1 Parent(s): 137f64a

test file

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. src/test_hellaswag.py +79 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
src/test_hellaswag.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict, unfreeze
8
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
9
+
10
+ from transformers import GPTNeoConfig
11
+ from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel
12
+ from transformers import GPT2Tokenizer
13
+
14
+ from datasets import load_dataset
15
+
16
+ num_choices=4
17
+ dataset = load_dataset("hellaswag")
18
+
19
+ def preprocess(example):
20
+ example['first_sentence']=[example['ctx_a']]*4
21
+ example['second_sentence']=[example['ctx_b']+' '+example['endings'][i] for i in range(4)]
22
+ return example
23
+
24
+ test_dataset=dataset['test'].map(preprocess)
25
+
26
+ len_test_dataset=100
27
+
28
+ test_dataset=test_dataset.select(range(len_test_dataset))
29
+
30
+ tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
31
+
32
+ remove_col=test_dataset.column_names
33
+
34
+ def tokenize(examples):
35
+ tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
36
+ return tokenized_examples
37
+
38
+ test_dataset=test_dataset.map(tokenize)
39
+
40
+ test_dataset=test_dataset.remove_columns(remove_col)
41
+
42
+ def glue_test_data_loader(rng,dataset,batch_size):
43
+ steps_per_epoch=len_test_dataset//batch_size
44
+ perms=jax.random.permutation(rng,len_test_dataset)
45
+ perms=perms[:steps_per_epoch*batch_size]
46
+ perms=perms.reshape((steps_per_epoch,batch_size))
47
+ for perm in perms:
48
+ batch=dataset[perm]
49
+ #print(jnp.array(batch['label']))
50
+ batch={k:jnp.array(v) for k,v in batch.items()}
51
+ batch=shard(batch)
52
+ yield batch
53
+
54
+ seed=0
55
+ rng=jax.random.PRNGKey(seed)
56
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
57
+
58
+ input_id=jnp.array(test_dataset['input_ids'])
59
+ att_mask=jnp.array(test_dataset['attention_mask'])
60
+
61
+ total_batch_size=32
62
+
63
+ from model_file import FlaxGPTNeoForMultipleChoice
64
+
65
+ model = FlaxGPTNeoForMultipleChoice.from_pretrained('EleutherAI/gpt-neo-1.3B',input_shape=(1,num_choices,1))
66
+
67
+ restored_output=[]
68
+ rng, input_rng = jax.random.split(rng)
69
+ for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)):
70
+ outputs=model(batch['input_ids'],batch['attention_mask'])
71
+ outputs=outputs['logits'].reshape(total_batch_size,-1)
72
+ final_output=jnp.argmax(outputs,axis=-1)
73
+ restored_output.append(final_output)
74
+
75
+ finall=pd.DataFrame({'predictions':restored_output})
76
+ finall.to_csv('../predictions.csv')
77
+
78
+
79
+