eshan13 commited on
Commit
e343fdd
·
verified ·
1 Parent(s): 1d8e6d3

Upload 5 files

Browse files

Uploading streamlit app , model files & dataset

Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +95 -0
  3. load_data.py +84 -0
  4. pred.py +345 -0
  5. seq2seq_checkpoint.pt +3 -0
  6. test_data.csv +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test_data.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np # linear algebra
2
+ import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
3
+ import torch
4
+ from transformers import GPT2Tokenizer
5
+ from pathlib import Path
6
+ import streamlit as st
7
+ from typing import List, Dict, Any, Callable
8
+ from pred import *
9
+ from load_data import *
10
+
11
+ def main():
12
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True)
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device)
18
+ decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device)
19
+ model = Seq2Seq(encoder, decoder).to(device)
20
+
21
+ checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device)
22
+ model.load_state_dict(checkpoint['model_state_dict'])
23
+ st.title("Footy Commentary Generator")
24
+ # Sidebar for configuration
25
+ st.sidebar.header("Configuration")
26
+ # Tab selection
27
+ tab_selection = st.sidebar.radio(
28
+ "Select Input Method:",
29
+ ["Random Sample from Test Set", "Custom Input"]
30
+ )
31
+ # Decoding configuration section
32
+ st.sidebar.header("Decoding Configuration")
33
+ st.session_state.decoding_mode = st.sidebar.selectbox(
34
+ "Decoding Mode",
35
+ ["greedy", "sample", "top-k", "diverse-beam-search", "min-bayes-risk"]
36
+ )
37
+ # Parameters based on decoding mode
38
+ st.session_state.decoding_params = {}
39
+ st.session_state.decoding_params['max_len'] = st.sidebar.slider('Max length', 1, 500, 50)
40
+ st.session_state.decoding_params['temperature'] = st.sidebar.slider('Temperature', 0.0, 1.0, 0.1)
41
+ if st.session_state.decoding_mode == "top-k":
42
+ st.session_state.decoding_params["k"] = st.sidebar.slider("k value", 1, 100, 5)
43
+ elif st.session_state.decoding_mode == "diverse-beam-search":
44
+ st.session_state.decoding_params["beam_width"] = st.sidebar.slider("beam width", 1, 10, 1)
45
+ st.session_state.decoding_params["diversity_penalty"] = st.sidebar.slider("diversity penalty", 0.0, 1.0, 0.1)
46
+ elif st.session_state.decoding_mode == "min-bayes-risk":
47
+ st.session_state.decoding_params["num_candidates"] = st.sidebar.slider("Number of candidates", 1, 30, 4)
48
+
49
+ if tab_selection == "Random Sample from Test Set":
50
+ st.header("Generate from Test Dataset")
51
+
52
+ col1, col2 = st.columns([3, 1])
53
+
54
+ with col1:
55
+ # Number of samples in the test dataset
56
+ st.write(f"Test dataset contains 5000 samples")
57
+
58
+ with col2:
59
+ # Button to generate a random sample
60
+ if st.button("Generate Random Sample"):
61
+ random_idx = np.random.randint(1, 5000)
62
+ st.session_state.random_idx = random_idx
63
+ st.session_state.ip, st.session_state.ip_mask, st.session_state.tg, st.session_state.tg_mask = get_sample(random_idx)
64
+
65
+ # Display the selected sample
66
+ if hasattr(st.session_state, 'random_idx'):
67
+ st.subheader(f"Sample #{st.session_state.random_idx}")
68
+ st.session_state.x = tokenizer.decode(st.session_state.ip.tolist()[0])
69
+ st.session_state.y = tokenizer.decode(st.session_state.tg.tolist())
70
+ # Display sample details in a table
71
+ df = pd.DataFrame.from_dict({'X': [st.session_state.x], 'y': [st.session_state.y]})
72
+ st.dataframe(df.T.reset_index(), width=800)
73
+
74
+ # Generate output
75
+ if st.button("Generate Sequence"):
76
+ with st.spinner("Generating sequence..."):
77
+ print(f'Ip: {st.session_state.ip} | Mask: {st.session_state.ip_mask} \n Mode: {st.session_state.decoding_mode} | Params: {st.session_state.decoding_params}')
78
+ st.session_state.tok_output = genOp(
79
+ encoder, decoder, device,
80
+ st.session_state.ip, # Convert to string for the placeholder function
81
+ st.session_state.ip_mask,
82
+ mode=st.session_state.decoding_mode,
83
+ **st.session_state.decoding_params
84
+ )
85
+ print(f'\n\n\nOutput: {st.session_state.tok_output} \n')
86
+ st.session_state.output = tokenizer.decode(st.session_state.tok_output)
87
+
88
+ # Display output
89
+ if hasattr(st.session_state, 'output'):
90
+ st.subheader("Generated Sequence")
91
+ st.write(st.session_state.output)
92
+
93
+ if __name__ == "__main__":
94
+ main()
95
+ 1
load_data.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ # import numpy as np
4
+ # import torch
5
+ from transformers import GPT2Tokenizer
6
+
7
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ # train_size = 50000
9
+ # test_size = 2500
10
+ # val_size = 2500
11
+
12
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True)
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+
15
+ # df = pd.read_csv('./prepro_data.csv')
16
+
17
+ # train_df = df[:train_size]
18
+ # test_df = df[train_size:train_size + test_size + val_size]
19
+
20
+ # test_df.to_csv('test_data.csv')
21
+ # print('Test df saved...')
22
+
23
+ test_df = pd.read_csv('./test_data.csv')
24
+ test_df = test_df.reset_index(drop=True)
25
+ # print(test_df.index)
26
+
27
+ class TextDataset(torch.utils.data.Dataset):
28
+ def __init__(self, X, y):
29
+ self.X = X
30
+ self.y = y
31
+
32
+ def __len__(self):
33
+ return len(self.X)
34
+
35
+ # def __getitem__(self, idx):
36
+ # return self.X[idx]['input_ids'], self.X[idx]['attention_mask'] , self.y[idx]['input_ids'], self.y[idx]['attention_mask']
37
+
38
+ def __getitem__(self, idx):
39
+ return self.X[idx], self.y[idx]
40
+
41
+ def collate_fn(batch):
42
+ X = [i[0] for i in batch]
43
+ y = [i[1] for i in batch]
44
+
45
+ lenX = []
46
+ maxlen = max([len(tokenizer.tokenize(i)) for i in X])
47
+ maylen = max([len(tokenizer.tokenize(i)) for i in y])
48
+
49
+ # print(f'maxlen: {maxlen} | maylen: {maylen}')
50
+
51
+ inputs = [tokenizer(i, max_length=maxlen, padding='max_length', truncation=True, return_tensors='pt', return_attention_mask=True) for i in X]
52
+ targets = [tokenizer(i, max_length=maylen, padding='max_length', truncation=True, return_tensors='pt', return_attention_mask=True) for i in y]
53
+
54
+ input_ids, input_mask = [], []
55
+ for i in inputs:
56
+ input_ids.append(i['input_ids'])
57
+ input_mask.append(i['attention_mask'])
58
+ target_ids, target_mask = [], []
59
+ for i in targets:
60
+ target_ids.append(i['input_ids'])
61
+ target_mask.append(i['attention_mask'])
62
+
63
+ return (torch.vstack(input_ids), torch.vstack(input_mask), torch.vstack(target_ids), torch.vstack(target_mask))
64
+
65
+ val_ds = TextDataset(test_df['X'].values, test_df['y'].values)
66
+ valloader = torch.utils.data.DataLoader(val_ds, batch_size=5000, shuffle=False, collate_fn=collate_fn)
67
+
68
+ # print(test_df.head())
69
+
70
+ def get_sample(i, device='cpu'):
71
+ # X,y = test_df['X'][idx], test_df['y'][idx]
72
+ # tok_X = tokenizer(X, return_tensors='pt', return_attention_mask=True)
73
+ # tok_y = tokenizer(y, return_tensors='pt', return_attention_mask=True)
74
+ # return X,y, tok_X, tok_y
75
+ # return X,y
76
+ val_batch = next(iter(valloader))
77
+ return val_batch[0][i].unsqueeze(dim=0).to(device), val_batch[1][i].unsqueeze(dim=0).type(torch.float32).to(device), val_batch[2][i].to(device), val_batch[3][i].to(device)
78
+
79
+
80
+ # X, y, tok_X, tok_y = get_sample(1)
81
+ # print(f'X: {X} \n y: {y}')
82
+ # print(type(tok_X))
83
+ # print(tok_X)
84
+ # print(tok_X.shape, tok_y.shape)
pred.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np # linear algebra
2
+ import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
3
+ import torch
4
+ import string
5
+ import pandas as pd
6
+ import numpy as np
7
+ from torch import nn
8
+ from sklearn.model_selection import train_test_split
9
+ # from gensim.models import Word2Vec
10
+ from torch.nn.utils.rnn import pack_padded_sequence
11
+ from pathlib import Path
12
+ import argparse
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Trainer, TrainingArguments, AdamW, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
14
+ from transformers import GPTNeoForCausalLM, GPT2Tokenizer ,GPTNeoConfig
15
+ from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel,BertTokenizer
16
+ from transformers import GPT2TokenizerFast
17
+ # from peft import LoraModel, LoraConfig
18
+ from pathlib import Path
19
+ import datetime
20
+ from tqdm import tqdm
21
+ import random
22
+ from tqdm import tqdm
23
+ from torch.cuda.amp import autocast, GradScaler
24
+ import gc
25
+ import matplotlib.pyplot as plt
26
+
27
+ class Encoder(torch.nn.Module): #8,18,24 -> 8,40,24 (8x720 and 432x960)
28
+ def __init__(self,h=128,n=8, e=64, a=4, o=1280):
29
+ super(Encoder, self).__init__()
30
+ self.embed = nn.Embedding(50257,e)
31
+ # self.ip = nn.Sequential(
32
+ # nn.Linear(e,e//2),
33
+ # nn.ReLU(),
34
+ # nn.Linear(e//2,e)
35
+ # )
36
+ self.lstm = nn.LSTM(input_size=e,hidden_size=h,num_layers=n, batch_first=True, bidirectional=True)
37
+ self.sa = nn.MultiheadAttention(h*2, a, dropout=0.1, batch_first=True)
38
+ self.op = nn.Sequential(
39
+ nn.Linear(2*h, h//2),
40
+ nn.ReLU(),
41
+ nn.Linear(h//2 , o),
42
+ )
43
+ # self.__init_weights()
44
+
45
+ def forward(self, X):
46
+ emb = self.embed(X) #bs,seq ,e
47
+ # emb = self.ip(emb)
48
+ enc, (hidden, cell) = self.lstm(emb) #bs, seq, h #1,bs,h
49
+ query = enc #nn.MA expects ; seq, bs, h
50
+ atOp , atW = self.sa(query, query, query)
51
+ #convert back to bs,seq, h
52
+ # print(f'AtOp: {atOp.shape} | enc: {enc.shape}')
53
+ logits = self.op(atOp + enc)
54
+ # logits = self.op(enc)
55
+ return logits , hidden , cell
56
+
57
+ # def __init_weights(self):
58
+ # for module in [self.ip, self.op]:
59
+ # if isinstance(module, torch.nn.Linear):
60
+ # torch.nn.init.normal_(module.weight,mean = 0.0 , std=0.02)
61
+ # if module.bias is not None:
62
+ # torch.nn.init.zeros_(module.bias)
63
+
64
+
65
+ class Decoder(torch.nn.Module):
66
+ def __init__(self,h=128,n=8, e=64, a=4, o=50257):
67
+ super(Decoder, self).__init__()
68
+ self.embed = nn.Embedding(50257,e)
69
+ # self.ip = nn.Sequential(
70
+ # nn.Linear(e,e),
71
+ # nn.ReLU(),
72
+ # nn.Linear(e,e)
73
+ # )
74
+ self.lstm = nn.LSTM(input_size=e,hidden_size=h,num_layers=n, batch_first=True, bidirectional=True)
75
+ self.sa = nn.MultiheadAttention(h, a, dropout=0.1, batch_first=True)
76
+ self.op = nn.Sequential(
77
+ nn.Linear(2*h + e, h//2),
78
+ nn.ReLU(),
79
+ nn.Linear(h//2 , o),
80
+ )
81
+ # self.__init_weights()
82
+
83
+ def forward(self, ip, ho, co, enc, mask):
84
+ emb = self.embed(ip) #bs, seq_i, e
85
+ # emb = self.ip(emb)
86
+ dec, (ho, co) = self.lstm(emb, (ho, co)) #bs, seq_i, h #1,bs,h
87
+ query = emb #bs, seq_i, e
88
+ key = enc #bs, seq_e, o
89
+ value = enc #bs, seq_e, o
90
+ # print(f'Q:{query.shape} | K:{key.shape} | V:{value.shape}')
91
+ atOp , atW = self.sa(query, key, value, key_padding_mask=mask) #bs, seq_i, e
92
+ # print(f'Dec: {dec.shape} | atOp : {atOp.shape}')
93
+ op = torch.cat([dec.squeeze(dim=1), atOp.squeeze(dim=1)], dim=1) #bs, seq_i, 2*h + bs, seq_i, e -> bs, 2*h + r
94
+ # op = torch.cat([ho[-1], co[-1], atOp.reshape(atOp.size(0), -1)], dim=-1)
95
+ logits = self.op(op) #bs, o
96
+ return logits, ho ,co
97
+
98
+ # def __init_weights(self):
99
+ # for module in [self.ip, self.op]:
100
+ # if isinstance(module, torch.nn.Linear):
101
+ # torch.nn.init.normal_(module.weight,mean = 0.0 , std=0.02)
102
+ # if module.bias is not None:
103
+ # torch.nn.init.zeros_(module.bias)
104
+
105
+ def init_state(self, batch_size):
106
+ return (torch.zeros(2*self.n,batch_size, self.h).to(device),torch.zeros(2*self.n,batch_size, self.h).to(device))
107
+
108
+ class Seq2Seq(nn.Module):
109
+ def __init__(self, encoder, decoder):
110
+ super(Seq2Seq, self).__init__()
111
+ self.encoder = encoder
112
+ self.decoder = decoder
113
+
114
+ def forward(self, seq_ip, ip_mask, seq_tg):
115
+ enc, hidden, cell = self.encoder(seq_ip)
116
+ outputs = []
117
+ len_tg = seq_tg.shape[1]
118
+ dec_ip = seq_tg[:,0].unsqueeze(dim=-1)
119
+ # print('Target length: ')
120
+ for t in range(1, len_tg): # Teacher Forcing
121
+ op , hidden, cell = self.decoder(dec_ip, hidden, cell, enc, ip_mask)
122
+ outputs.append(op)
123
+ dec_ip = seq_tg[:,t].unsqueeze(dim=-1)
124
+ torch.stack(outputs, dim=1)
125
+ return outputs
126
+
127
+ def diverse_beam_search(decoder, encoder_output, ip_mask, hidden, cell, device, beam_width=5, diversity_penalty=0.7, max_len=100):
128
+ dec_ip = torch.tensor([50256]).type(torch.int64).to(device) # Start token
129
+ beams = [(0.0, [dec_ip.item()], hidden.clone(), cell.clone())] # (score, sequence, hidden, cell)
130
+ count = 0
131
+ for _ in range(max_len):
132
+ all_candidates = []
133
+ for score, seq, h, c in beams:
134
+ if seq[-1] == 50256 and count > 0: # EOS reached
135
+ all_candidates.append((score, seq, h, c))
136
+ continue
137
+ dec_out, h_new, c_new = decoder(
138
+ torch.tensor([seq[-1]]).unsqueeze(0).to(device), h, c, encoder_output, ip_mask
139
+ )
140
+ log_probs = torch.nn.functional.log_softmax(dec_out, dim=-1) # Shape: [1, vocab_size]
141
+ top_k_log_probs, top_k_tokens = torch.topk(log_probs, beam_width, dim=-1)
142
+
143
+ for i in range(beam_width):
144
+ new_score = score + top_k_log_probs[0, i].item() - (diversity_penalty * i) # Diversity penalty
145
+ new_seq = seq + [top_k_tokens[0, i].item()]
146
+ all_candidates.append((new_score, new_seq, h_new.clone(), c_new.clone()))
147
+ count = 1
148
+ # Select top beam_width candidates
149
+ beams = sorted(all_candidates, key=lambda x: x[0], reverse=True)[:beam_width]
150
+ if all(seq[-1] == 50256 for _, seq, _, _ in beams): # All beams ended
151
+ break
152
+
153
+ return beams[0][1] # Return highest-scoring sequence
154
+
155
+ def mbr_decoding(decoder, encoder_output, ip_mask, hidden, cell, device, num_candidates=10, max_len=100):
156
+ # Generate candidate sequences using top-k sampling
157
+ candidates = []
158
+ for _ in range(num_candidates):
159
+ dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
160
+ seq = [dec_ip.item()]
161
+ h, c = hidden.clone(), cell.clone()
162
+ for _ in range(max_len):
163
+ dec_out, h, c = decoder(dec_ip.unsqueeze(0), h, c, encoder_output, ip_mask)
164
+ dec_ip = top_k_sampling(dec_out, k=5).unsqueeze(dim=0) # Use top-k for diversity
165
+ seq.append(dec_ip.item())
166
+ if dec_ip.item() == 50256:
167
+ break
168
+ candidates.append(seq)
169
+
170
+ # Score candidates by similarity (e.g., average overlap with others)
171
+ best_seq, best_score = None, float('-inf')
172
+ for i, cand in enumerate(candidates):
173
+ score = sum(sum(1 for t1, t2 in zip(cand, other) if t1 == t2)
174
+ for other in candidates if other != cand) / (len(candidates) - 1)
175
+ if score > best_score:
176
+ best_score, best_seq = score, cand
177
+ return best_seq
178
+
179
+ def top_k_sampling(logits, k=10, temperature=1.0):
180
+ logits = logits / temperature # Temperature scaling for diversity
181
+ probs = torch.nn.functional.softmax(logits, dim=-1)
182
+ top_k_probs, top_k_indices = torch.topk(probs, k, dim=-1)
183
+ sampled_idx = torch.multinomial(top_k_probs, num_samples=1)
184
+ return top_k_indices[0, sampled_idx.item()]
185
+
186
+ def genOp(encoder, decoder, device, ip, ip_mask, mode='greedy', temperature=1.0, k=13, beam_width=5, diversity_penalty=0.7, num_candidates=10, max_len=100):
187
+ encoder.eval()
188
+ decoder.eval()
189
+ # model.eval()
190
+ print(f'\n\n\n GENOP FX CALL \n\n\n')
191
+ with torch.no_grad():
192
+ enc, hidden, cell = encoder(ip)
193
+ print(f'Hidden : {hidden.shape} | Cell : {cell.shape}')
194
+ if mode == 'greedy':
195
+ outputs = []
196
+ dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
197
+ count = 0
198
+ while True:
199
+ dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask)
200
+ dec_ip = torch.argmax(dec, dim=-1)
201
+ outputs.append(dec_ip.item())
202
+ count += 1
203
+ if count > max_len:
204
+ break
205
+ if dec_ip.item() == 50256:
206
+ print('Self terminated !!!')
207
+ break
208
+ return outputs
209
+ elif mode=='sample':
210
+ outputs = []
211
+ dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
212
+ count = 0
213
+ while True:
214
+ dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask)
215
+ # print(dec)
216
+ dec = dec/temperature
217
+ dec = torch.nn.functional.softmax(dec, dim=-1)
218
+ dec_ip = torch.multinomial(input=dec, num_samples=1, replacement=True).squeeze(0)
219
+ outputs.append(dec_ip.item())
220
+ count += 1
221
+ if count > max_len:
222
+ break
223
+ if dec_ip.item() == 50256:
224
+ print('Self terminated !!!')
225
+ break
226
+ return outputs
227
+ elif mode=='top_k':
228
+ outputs = []
229
+ dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
230
+ count = 0
231
+ while True:
232
+ dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask)
233
+ dec = torch.nn.functional.softmax(dec, dim=-1)
234
+ top_k_probs , top_k_indices = torch.topk(dec, k, dim=-1)
235
+ dec_ip = torch.multinomial(input=top_k_probs, num_samples=1, replacement=True).squeeze(0)
236
+ dec_ip = top_k_indices[0, dec_ip.item()].unsqueeze(dim=0)
237
+ outputs.append(dec_ip.item())
238
+ count += 1
239
+ if count > max_len:
240
+ break
241
+ if dec_ip.item() == 50256:
242
+ print('Self terminated !!!')
243
+ break
244
+ return outputs
245
+
246
+ elif mode=='diverse-beam-search':
247
+ outputs = diverse_beam_search(decoder, enc, ip_mask, hidden, cell, device, beam_width=beam_width, diversity_penalty=diversity_penalty)
248
+ # print(f'GenOP stack trace: {outputs}')
249
+ return outputs
250
+
251
+ elif mode=='min-bayes-risk':
252
+ outputs = mbr_decoding(decoder, enc, ip_mask, hidden, cell, device, num_candidates=num_candidates, max_len=max_len)
253
+ return outputs
254
+
255
+ # ip = torch.tensor([[50256, 11195, 318, 13837, 11, 8272, 318, 2688, 4345, 1578,
256
+ # 11, 4475, 318, 3909, 11, 3035, 767, 11, 1941, 318,
257
+ # 4793, 11, 2435, 357, 315, 66, 8, 318, 1478, 25,
258
+ # 405, 11, 1078, 437, 590, 318, 3126, 11, 2931, 23,
259
+ # 11, 4080, 318, 24880, 10499, 11, 3576, 11, 4492, 11,
260
+ # 19316, 318, 4793, 12, 12726, 37985, 9952, 4041, 11, 6057,
261
+ # 62, 13376, 318, 19446, 11, 30408, 448, 318, 10352, 11,
262
+ # 11195, 62, 26675, 318, 657, 11, 8272, 62, 26675, 318,
263
+ # 352, 11, 11195, 62, 79, 49809, 47, 310, 318, 5598,
264
+ # 7441, 8272, 62, 79, 49809, 47, 310, 318, 4570, 7441,
265
+ # 11195, 62, 20910, 22093, 318, 1542, 357, 1314, 828, 8272,
266
+ # 62, 20910, 22093, 318, 718, 357, 20, 828, 11195, 62,
267
+ # 69, 42033, 6935, 2175, 318, 838, 13, 15, 11, 8272,
268
+ # 62, 69, 42033, 6935, 2175, 318, 1315, 13, 15, 11,
269
+ # 11195, 62, 36022, 34, 1371, 318, 657, 13, 15, 11,
270
+ # 8272, 62, 36022, 34, 1371, 318, 352, 13, 15, 11,
271
+ # 11195, 62, 445, 34, 1371, 318, 657, 13, 15, 11,
272
+ # 8272, 62, 445, 34, 1371, 318, 657, 13, 15, 11,
273
+ # 11195, 62, 8210, 1460, 318, 657, 13, 15, 11, 8272,
274
+ # 62, 8210, 1460, 318, 604, 13, 15, 11, 11195, 62,
275
+ # 26502, 41389, 364, 318, 1478, 13, 15, 11, 8272, 62,
276
+ # 26502, 41389, 364, 318, 352, 13, 15, 11, 11195, 62,
277
+ # 82, 3080, 318, 642, 13, 15, 11, 8272, 62, 82,
278
+ # 3080, 318, 1596, 13, 15, 11, 11195, 62, 1161, 318,
279
+ # 16185, 11, 8272, 62, 1161, 318, 16185, 11, 24623, 318,
280
+ # 3594, 9952, 4041, 11, 16060, 62, 15592, 318, 449, 641,
281
+ # 29921, 9038, 11, 17121, 7096, 292, 11, 42, 14057, 9852,
282
+ # 2634, 11, 10161, 18713, 12119, 280, 2634, 11, 35389, 26689,
283
+ # 75, 1012, 488, 88, 11, 30847, 11979, 406, 73, 2150,
284
+ # 3900, 11, 13787, 292, 10018, 17479, 11, 40747, 32371, 23720,
285
+ # 11, 15309, 38142, 81, 367, 293, 65, 11, 34, 3798,
286
+ # 376, 24247, 65, 2301, 292, 11, 10161, 18713, 1215, 1765,
287
+ # 323, 273, 11, 5124, 2731, 978, 6199, 544, 11, 49680,
288
+ # 68, 311, 2194, 418, 11, 41, 21356, 48590, 18226, 12523,
289
+ # 11, 4826, 280, 6031, 3930, 11, 31579, 44871, 12104, 324,
290
+ # 13235, 11, 32, 1014, 62, 15592, 318, 5199, 3469, 11,
291
+ # 22946, 292, 3169, 359, 11, 20191, 44677, 11, 13217, 261,
292
+ # 44312, 11, 14731, 14006, 11, 24338, 9740, 9860, 11, 25372,
293
+ # 20017, 9557, 11, 45, 47709, 797, 78, 12, 34, 11020,
294
+ # 11, 9704, 20833, 11, 33, 11369, 38343, 5799, 11, 26886,
295
+ # 418, 1665, 33425, 11, 32027, 21298, 11, 31306, 6559, 19574,
296
+ # 1040, 11, 30365, 13058, 273, 11, 25596, 271, 3248, 64,
297
+ # 10788, 68, 11, 42, 538, 64, 11, 7575, 318, 4153,
298
+ # 6]])
299
+ # ip_mask = torch.tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
300
+ # True, True, True, True, True, True, True, True, True, True, True, True,
301
+ # True, True, True, True, True, True, True, True, True, True, True, True,
302
+ # True, True, True, True, True, True, True, True, True, True, True, True,
303
+ # True, True, True, True, True, True, True, True, True, True, True, True,
304
+ # True, True, True, True, True, True, True, True, True, True, True, True,
305
+ # True, True, True, True, True, True, True, True, True, True, True, True,
306
+ # True, True, True, True, True, True, True, True, True, True, True, True,
307
+ # True, True, True, True, True, True, True, True, True, True, True, True,
308
+ # True, True, True, True, True, True, True, True, True, True, True, True,
309
+ # True, True, True, True, True, True, True, True, True, True, True, True,
310
+ # True, True, True, True, True, True, True, True, True, True, True, True,
311
+ # True, True, True, True, True, True, True, True, True, True, True, True,
312
+ # True, True, True, True, True, True, True, True, True, True, True, True,
313
+ # True, True, True, True, True, True, True, True, True, True, True, True,
314
+ # True, True, True, True, True, True, True, True, True, True, True, True,
315
+ # True, True, True, True, True, True, True, True, True, True, True, True,
316
+ # True, True, True, True, True, True, True, True, True, True, True, True,
317
+ # True, True, True, True, True, True, True, True, True, True, True, True,
318
+ # True, True, True, True, True, True, True, True, True, True, True, True,
319
+ # True, True, True, True, True, True, True, True, True, True, True, True,
320
+ # True, True, True, True, True, True, True, True, True, True, True, True,
321
+ # True, True, True, True, True, True, True, True, True, True, True, True,
322
+ # True, True, True, True, True, True, True, True, True, True, True, True,
323
+ # True, True, True, True, True, True, True, True, True, True, True, True,
324
+ # True, True, True, True, True, True, True, True, True, True, True, True,
325
+ # True, True, True, True, True, True, True, True, True, True, True, True,
326
+ # True, True, True, True, True, True, True, True, True, True, True, True,
327
+ # True, True, True, True, True, True, True, True, True, True, True, True,
328
+ # True, True, True, True, True, True, True, True, True, True, True, True,
329
+ # True, True, True, True, True, True, True, True, True, True, True, True,
330
+ # True, True, True, True, True, True, True, True, True, True, True, True,
331
+ # True, True, True, True, True, True, True, True, True, True, True, True,
332
+ # True, True, True, True, True, True, True, True, True, True, True, True,
333
+ # True, True, True, True, True, True, True, True, True, True, True, True,
334
+ # True, True, True, True, True, True, True, True, True, True, True]])
335
+
336
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
337
+
338
+ # encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device)
339
+ # decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device)
340
+ # model = Seq2Seq(encoder, decoder).to(device)
341
+
342
+ # # checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device)
343
+
344
+ # # model.load_state_dict(checkpoint['model_state_dict'])
345
+ # print(genOp(model.encoder, model.decoder, device, ip, ip_mask, mode='greedy', temperature=1.0, k=13, beam_width=5, diversity_penalty=0.7, num_candidates=10, max_len=100))
seq2seq_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f17ff3f8e345cc56e5bb5dcdb329832e43d68f10043bcca999b831b14ac7926
3
+ size 102274136
test_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:838b1509bd748deb39f9cb52bd4c1d7e733e4195a29d5e6da19ed4cf641c97cc
3
+ size 13032378