Upload 5 files
Browse filesUploading streamlit app , model files & dataset
- .gitattributes +1 -0
- app.py +95 -0
- load_data.py +84 -0
- pred.py +345 -0
- seq2seq_checkpoint.pt +3 -0
- test_data.csv +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
test_data.csv filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np # linear algebra
|
| 2 |
+
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import GPT2Tokenizer
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from typing import List, Dict, Any, Callable
|
| 8 |
+
from pred import *
|
| 9 |
+
from load_data import *
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True)
|
| 13 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 14 |
+
|
| 15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 16 |
+
|
| 17 |
+
encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device)
|
| 18 |
+
decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device)
|
| 19 |
+
model = Seq2Seq(encoder, decoder).to(device)
|
| 20 |
+
|
| 21 |
+
checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device)
|
| 22 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 23 |
+
st.title("Footy Commentary Generator")
|
| 24 |
+
# Sidebar for configuration
|
| 25 |
+
st.sidebar.header("Configuration")
|
| 26 |
+
# Tab selection
|
| 27 |
+
tab_selection = st.sidebar.radio(
|
| 28 |
+
"Select Input Method:",
|
| 29 |
+
["Random Sample from Test Set", "Custom Input"]
|
| 30 |
+
)
|
| 31 |
+
# Decoding configuration section
|
| 32 |
+
st.sidebar.header("Decoding Configuration")
|
| 33 |
+
st.session_state.decoding_mode = st.sidebar.selectbox(
|
| 34 |
+
"Decoding Mode",
|
| 35 |
+
["greedy", "sample", "top-k", "diverse-beam-search", "min-bayes-risk"]
|
| 36 |
+
)
|
| 37 |
+
# Parameters based on decoding mode
|
| 38 |
+
st.session_state.decoding_params = {}
|
| 39 |
+
st.session_state.decoding_params['max_len'] = st.sidebar.slider('Max length', 1, 500, 50)
|
| 40 |
+
st.session_state.decoding_params['temperature'] = st.sidebar.slider('Temperature', 0.0, 1.0, 0.1)
|
| 41 |
+
if st.session_state.decoding_mode == "top-k":
|
| 42 |
+
st.session_state.decoding_params["k"] = st.sidebar.slider("k value", 1, 100, 5)
|
| 43 |
+
elif st.session_state.decoding_mode == "diverse-beam-search":
|
| 44 |
+
st.session_state.decoding_params["beam_width"] = st.sidebar.slider("beam width", 1, 10, 1)
|
| 45 |
+
st.session_state.decoding_params["diversity_penalty"] = st.sidebar.slider("diversity penalty", 0.0, 1.0, 0.1)
|
| 46 |
+
elif st.session_state.decoding_mode == "min-bayes-risk":
|
| 47 |
+
st.session_state.decoding_params["num_candidates"] = st.sidebar.slider("Number of candidates", 1, 30, 4)
|
| 48 |
+
|
| 49 |
+
if tab_selection == "Random Sample from Test Set":
|
| 50 |
+
st.header("Generate from Test Dataset")
|
| 51 |
+
|
| 52 |
+
col1, col2 = st.columns([3, 1])
|
| 53 |
+
|
| 54 |
+
with col1:
|
| 55 |
+
# Number of samples in the test dataset
|
| 56 |
+
st.write(f"Test dataset contains 5000 samples")
|
| 57 |
+
|
| 58 |
+
with col2:
|
| 59 |
+
# Button to generate a random sample
|
| 60 |
+
if st.button("Generate Random Sample"):
|
| 61 |
+
random_idx = np.random.randint(1, 5000)
|
| 62 |
+
st.session_state.random_idx = random_idx
|
| 63 |
+
st.session_state.ip, st.session_state.ip_mask, st.session_state.tg, st.session_state.tg_mask = get_sample(random_idx)
|
| 64 |
+
|
| 65 |
+
# Display the selected sample
|
| 66 |
+
if hasattr(st.session_state, 'random_idx'):
|
| 67 |
+
st.subheader(f"Sample #{st.session_state.random_idx}")
|
| 68 |
+
st.session_state.x = tokenizer.decode(st.session_state.ip.tolist()[0])
|
| 69 |
+
st.session_state.y = tokenizer.decode(st.session_state.tg.tolist())
|
| 70 |
+
# Display sample details in a table
|
| 71 |
+
df = pd.DataFrame.from_dict({'X': [st.session_state.x], 'y': [st.session_state.y]})
|
| 72 |
+
st.dataframe(df.T.reset_index(), width=800)
|
| 73 |
+
|
| 74 |
+
# Generate output
|
| 75 |
+
if st.button("Generate Sequence"):
|
| 76 |
+
with st.spinner("Generating sequence..."):
|
| 77 |
+
print(f'Ip: {st.session_state.ip} | Mask: {st.session_state.ip_mask} \n Mode: {st.session_state.decoding_mode} | Params: {st.session_state.decoding_params}')
|
| 78 |
+
st.session_state.tok_output = genOp(
|
| 79 |
+
encoder, decoder, device,
|
| 80 |
+
st.session_state.ip, # Convert to string for the placeholder function
|
| 81 |
+
st.session_state.ip_mask,
|
| 82 |
+
mode=st.session_state.decoding_mode,
|
| 83 |
+
**st.session_state.decoding_params
|
| 84 |
+
)
|
| 85 |
+
print(f'\n\n\nOutput: {st.session_state.tok_output} \n')
|
| 86 |
+
st.session_state.output = tokenizer.decode(st.session_state.tok_output)
|
| 87 |
+
|
| 88 |
+
# Display output
|
| 89 |
+
if hasattr(st.session_state, 'output'):
|
| 90 |
+
st.subheader("Generated Sequence")
|
| 91 |
+
st.write(st.session_state.output)
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
| 95 |
+
1
|
load_data.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import torch
|
| 3 |
+
# import numpy as np
|
| 4 |
+
# import torch
|
| 5 |
+
from transformers import GPT2Tokenizer
|
| 6 |
+
|
| 7 |
+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 8 |
+
# train_size = 50000
|
| 9 |
+
# test_size = 2500
|
| 10 |
+
# val_size = 2500
|
| 11 |
+
|
| 12 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True)
|
| 13 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 14 |
+
|
| 15 |
+
# df = pd.read_csv('./prepro_data.csv')
|
| 16 |
+
|
| 17 |
+
# train_df = df[:train_size]
|
| 18 |
+
# test_df = df[train_size:train_size + test_size + val_size]
|
| 19 |
+
|
| 20 |
+
# test_df.to_csv('test_data.csv')
|
| 21 |
+
# print('Test df saved...')
|
| 22 |
+
|
| 23 |
+
test_df = pd.read_csv('./test_data.csv')
|
| 24 |
+
test_df = test_df.reset_index(drop=True)
|
| 25 |
+
# print(test_df.index)
|
| 26 |
+
|
| 27 |
+
class TextDataset(torch.utils.data.Dataset):
|
| 28 |
+
def __init__(self, X, y):
|
| 29 |
+
self.X = X
|
| 30 |
+
self.y = y
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.X)
|
| 34 |
+
|
| 35 |
+
# def __getitem__(self, idx):
|
| 36 |
+
# return self.X[idx]['input_ids'], self.X[idx]['attention_mask'] , self.y[idx]['input_ids'], self.y[idx]['attention_mask']
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
return self.X[idx], self.y[idx]
|
| 40 |
+
|
| 41 |
+
def collate_fn(batch):
|
| 42 |
+
X = [i[0] for i in batch]
|
| 43 |
+
y = [i[1] for i in batch]
|
| 44 |
+
|
| 45 |
+
lenX = []
|
| 46 |
+
maxlen = max([len(tokenizer.tokenize(i)) for i in X])
|
| 47 |
+
maylen = max([len(tokenizer.tokenize(i)) for i in y])
|
| 48 |
+
|
| 49 |
+
# print(f'maxlen: {maxlen} | maylen: {maylen}')
|
| 50 |
+
|
| 51 |
+
inputs = [tokenizer(i, max_length=maxlen, padding='max_length', truncation=True, return_tensors='pt', return_attention_mask=True) for i in X]
|
| 52 |
+
targets = [tokenizer(i, max_length=maylen, padding='max_length', truncation=True, return_tensors='pt', return_attention_mask=True) for i in y]
|
| 53 |
+
|
| 54 |
+
input_ids, input_mask = [], []
|
| 55 |
+
for i in inputs:
|
| 56 |
+
input_ids.append(i['input_ids'])
|
| 57 |
+
input_mask.append(i['attention_mask'])
|
| 58 |
+
target_ids, target_mask = [], []
|
| 59 |
+
for i in targets:
|
| 60 |
+
target_ids.append(i['input_ids'])
|
| 61 |
+
target_mask.append(i['attention_mask'])
|
| 62 |
+
|
| 63 |
+
return (torch.vstack(input_ids), torch.vstack(input_mask), torch.vstack(target_ids), torch.vstack(target_mask))
|
| 64 |
+
|
| 65 |
+
val_ds = TextDataset(test_df['X'].values, test_df['y'].values)
|
| 66 |
+
valloader = torch.utils.data.DataLoader(val_ds, batch_size=5000, shuffle=False, collate_fn=collate_fn)
|
| 67 |
+
|
| 68 |
+
# print(test_df.head())
|
| 69 |
+
|
| 70 |
+
def get_sample(i, device='cpu'):
|
| 71 |
+
# X,y = test_df['X'][idx], test_df['y'][idx]
|
| 72 |
+
# tok_X = tokenizer(X, return_tensors='pt', return_attention_mask=True)
|
| 73 |
+
# tok_y = tokenizer(y, return_tensors='pt', return_attention_mask=True)
|
| 74 |
+
# return X,y, tok_X, tok_y
|
| 75 |
+
# return X,y
|
| 76 |
+
val_batch = next(iter(valloader))
|
| 77 |
+
return val_batch[0][i].unsqueeze(dim=0).to(device), val_batch[1][i].unsqueeze(dim=0).type(torch.float32).to(device), val_batch[2][i].to(device), val_batch[3][i].to(device)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# X, y, tok_X, tok_y = get_sample(1)
|
| 81 |
+
# print(f'X: {X} \n y: {y}')
|
| 82 |
+
# print(type(tok_X))
|
| 83 |
+
# print(tok_X)
|
| 84 |
+
# print(tok_X.shape, tok_y.shape)
|
pred.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np # linear algebra
|
| 2 |
+
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
|
| 3 |
+
import torch
|
| 4 |
+
import string
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch import nn
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
# from gensim.models import Word2Vec
|
| 10 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import argparse
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Trainer, TrainingArguments, AdamW, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
|
| 14 |
+
from transformers import GPTNeoForCausalLM, GPT2Tokenizer ,GPTNeoConfig
|
| 15 |
+
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel,BertTokenizer
|
| 16 |
+
from transformers import GPT2TokenizerFast
|
| 17 |
+
# from peft import LoraModel, LoraConfig
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import datetime
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import random
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 24 |
+
import gc
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
|
| 27 |
+
class Encoder(torch.nn.Module): #8,18,24 -> 8,40,24 (8x720 and 432x960)
|
| 28 |
+
def __init__(self,h=128,n=8, e=64, a=4, o=1280):
|
| 29 |
+
super(Encoder, self).__init__()
|
| 30 |
+
self.embed = nn.Embedding(50257,e)
|
| 31 |
+
# self.ip = nn.Sequential(
|
| 32 |
+
# nn.Linear(e,e//2),
|
| 33 |
+
# nn.ReLU(),
|
| 34 |
+
# nn.Linear(e//2,e)
|
| 35 |
+
# )
|
| 36 |
+
self.lstm = nn.LSTM(input_size=e,hidden_size=h,num_layers=n, batch_first=True, bidirectional=True)
|
| 37 |
+
self.sa = nn.MultiheadAttention(h*2, a, dropout=0.1, batch_first=True)
|
| 38 |
+
self.op = nn.Sequential(
|
| 39 |
+
nn.Linear(2*h, h//2),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Linear(h//2 , o),
|
| 42 |
+
)
|
| 43 |
+
# self.__init_weights()
|
| 44 |
+
|
| 45 |
+
def forward(self, X):
|
| 46 |
+
emb = self.embed(X) #bs,seq ,e
|
| 47 |
+
# emb = self.ip(emb)
|
| 48 |
+
enc, (hidden, cell) = self.lstm(emb) #bs, seq, h #1,bs,h
|
| 49 |
+
query = enc #nn.MA expects ; seq, bs, h
|
| 50 |
+
atOp , atW = self.sa(query, query, query)
|
| 51 |
+
#convert back to bs,seq, h
|
| 52 |
+
# print(f'AtOp: {atOp.shape} | enc: {enc.shape}')
|
| 53 |
+
logits = self.op(atOp + enc)
|
| 54 |
+
# logits = self.op(enc)
|
| 55 |
+
return logits , hidden , cell
|
| 56 |
+
|
| 57 |
+
# def __init_weights(self):
|
| 58 |
+
# for module in [self.ip, self.op]:
|
| 59 |
+
# if isinstance(module, torch.nn.Linear):
|
| 60 |
+
# torch.nn.init.normal_(module.weight,mean = 0.0 , std=0.02)
|
| 61 |
+
# if module.bias is not None:
|
| 62 |
+
# torch.nn.init.zeros_(module.bias)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Decoder(torch.nn.Module):
|
| 66 |
+
def __init__(self,h=128,n=8, e=64, a=4, o=50257):
|
| 67 |
+
super(Decoder, self).__init__()
|
| 68 |
+
self.embed = nn.Embedding(50257,e)
|
| 69 |
+
# self.ip = nn.Sequential(
|
| 70 |
+
# nn.Linear(e,e),
|
| 71 |
+
# nn.ReLU(),
|
| 72 |
+
# nn.Linear(e,e)
|
| 73 |
+
# )
|
| 74 |
+
self.lstm = nn.LSTM(input_size=e,hidden_size=h,num_layers=n, batch_first=True, bidirectional=True)
|
| 75 |
+
self.sa = nn.MultiheadAttention(h, a, dropout=0.1, batch_first=True)
|
| 76 |
+
self.op = nn.Sequential(
|
| 77 |
+
nn.Linear(2*h + e, h//2),
|
| 78 |
+
nn.ReLU(),
|
| 79 |
+
nn.Linear(h//2 , o),
|
| 80 |
+
)
|
| 81 |
+
# self.__init_weights()
|
| 82 |
+
|
| 83 |
+
def forward(self, ip, ho, co, enc, mask):
|
| 84 |
+
emb = self.embed(ip) #bs, seq_i, e
|
| 85 |
+
# emb = self.ip(emb)
|
| 86 |
+
dec, (ho, co) = self.lstm(emb, (ho, co)) #bs, seq_i, h #1,bs,h
|
| 87 |
+
query = emb #bs, seq_i, e
|
| 88 |
+
key = enc #bs, seq_e, o
|
| 89 |
+
value = enc #bs, seq_e, o
|
| 90 |
+
# print(f'Q:{query.shape} | K:{key.shape} | V:{value.shape}')
|
| 91 |
+
atOp , atW = self.sa(query, key, value, key_padding_mask=mask) #bs, seq_i, e
|
| 92 |
+
# print(f'Dec: {dec.shape} | atOp : {atOp.shape}')
|
| 93 |
+
op = torch.cat([dec.squeeze(dim=1), atOp.squeeze(dim=1)], dim=1) #bs, seq_i, 2*h + bs, seq_i, e -> bs, 2*h + r
|
| 94 |
+
# op = torch.cat([ho[-1], co[-1], atOp.reshape(atOp.size(0), -1)], dim=-1)
|
| 95 |
+
logits = self.op(op) #bs, o
|
| 96 |
+
return logits, ho ,co
|
| 97 |
+
|
| 98 |
+
# def __init_weights(self):
|
| 99 |
+
# for module in [self.ip, self.op]:
|
| 100 |
+
# if isinstance(module, torch.nn.Linear):
|
| 101 |
+
# torch.nn.init.normal_(module.weight,mean = 0.0 , std=0.02)
|
| 102 |
+
# if module.bias is not None:
|
| 103 |
+
# torch.nn.init.zeros_(module.bias)
|
| 104 |
+
|
| 105 |
+
def init_state(self, batch_size):
|
| 106 |
+
return (torch.zeros(2*self.n,batch_size, self.h).to(device),torch.zeros(2*self.n,batch_size, self.h).to(device))
|
| 107 |
+
|
| 108 |
+
class Seq2Seq(nn.Module):
|
| 109 |
+
def __init__(self, encoder, decoder):
|
| 110 |
+
super(Seq2Seq, self).__init__()
|
| 111 |
+
self.encoder = encoder
|
| 112 |
+
self.decoder = decoder
|
| 113 |
+
|
| 114 |
+
def forward(self, seq_ip, ip_mask, seq_tg):
|
| 115 |
+
enc, hidden, cell = self.encoder(seq_ip)
|
| 116 |
+
outputs = []
|
| 117 |
+
len_tg = seq_tg.shape[1]
|
| 118 |
+
dec_ip = seq_tg[:,0].unsqueeze(dim=-1)
|
| 119 |
+
# print('Target length: ')
|
| 120 |
+
for t in range(1, len_tg): # Teacher Forcing
|
| 121 |
+
op , hidden, cell = self.decoder(dec_ip, hidden, cell, enc, ip_mask)
|
| 122 |
+
outputs.append(op)
|
| 123 |
+
dec_ip = seq_tg[:,t].unsqueeze(dim=-1)
|
| 124 |
+
torch.stack(outputs, dim=1)
|
| 125 |
+
return outputs
|
| 126 |
+
|
| 127 |
+
def diverse_beam_search(decoder, encoder_output, ip_mask, hidden, cell, device, beam_width=5, diversity_penalty=0.7, max_len=100):
|
| 128 |
+
dec_ip = torch.tensor([50256]).type(torch.int64).to(device) # Start token
|
| 129 |
+
beams = [(0.0, [dec_ip.item()], hidden.clone(), cell.clone())] # (score, sequence, hidden, cell)
|
| 130 |
+
count = 0
|
| 131 |
+
for _ in range(max_len):
|
| 132 |
+
all_candidates = []
|
| 133 |
+
for score, seq, h, c in beams:
|
| 134 |
+
if seq[-1] == 50256 and count > 0: # EOS reached
|
| 135 |
+
all_candidates.append((score, seq, h, c))
|
| 136 |
+
continue
|
| 137 |
+
dec_out, h_new, c_new = decoder(
|
| 138 |
+
torch.tensor([seq[-1]]).unsqueeze(0).to(device), h, c, encoder_output, ip_mask
|
| 139 |
+
)
|
| 140 |
+
log_probs = torch.nn.functional.log_softmax(dec_out, dim=-1) # Shape: [1, vocab_size]
|
| 141 |
+
top_k_log_probs, top_k_tokens = torch.topk(log_probs, beam_width, dim=-1)
|
| 142 |
+
|
| 143 |
+
for i in range(beam_width):
|
| 144 |
+
new_score = score + top_k_log_probs[0, i].item() - (diversity_penalty * i) # Diversity penalty
|
| 145 |
+
new_seq = seq + [top_k_tokens[0, i].item()]
|
| 146 |
+
all_candidates.append((new_score, new_seq, h_new.clone(), c_new.clone()))
|
| 147 |
+
count = 1
|
| 148 |
+
# Select top beam_width candidates
|
| 149 |
+
beams = sorted(all_candidates, key=lambda x: x[0], reverse=True)[:beam_width]
|
| 150 |
+
if all(seq[-1] == 50256 for _, seq, _, _ in beams): # All beams ended
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
return beams[0][1] # Return highest-scoring sequence
|
| 154 |
+
|
| 155 |
+
def mbr_decoding(decoder, encoder_output, ip_mask, hidden, cell, device, num_candidates=10, max_len=100):
|
| 156 |
+
# Generate candidate sequences using top-k sampling
|
| 157 |
+
candidates = []
|
| 158 |
+
for _ in range(num_candidates):
|
| 159 |
+
dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
|
| 160 |
+
seq = [dec_ip.item()]
|
| 161 |
+
h, c = hidden.clone(), cell.clone()
|
| 162 |
+
for _ in range(max_len):
|
| 163 |
+
dec_out, h, c = decoder(dec_ip.unsqueeze(0), h, c, encoder_output, ip_mask)
|
| 164 |
+
dec_ip = top_k_sampling(dec_out, k=5).unsqueeze(dim=0) # Use top-k for diversity
|
| 165 |
+
seq.append(dec_ip.item())
|
| 166 |
+
if dec_ip.item() == 50256:
|
| 167 |
+
break
|
| 168 |
+
candidates.append(seq)
|
| 169 |
+
|
| 170 |
+
# Score candidates by similarity (e.g., average overlap with others)
|
| 171 |
+
best_seq, best_score = None, float('-inf')
|
| 172 |
+
for i, cand in enumerate(candidates):
|
| 173 |
+
score = sum(sum(1 for t1, t2 in zip(cand, other) if t1 == t2)
|
| 174 |
+
for other in candidates if other != cand) / (len(candidates) - 1)
|
| 175 |
+
if score > best_score:
|
| 176 |
+
best_score, best_seq = score, cand
|
| 177 |
+
return best_seq
|
| 178 |
+
|
| 179 |
+
def top_k_sampling(logits, k=10, temperature=1.0):
|
| 180 |
+
logits = logits / temperature # Temperature scaling for diversity
|
| 181 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 182 |
+
top_k_probs, top_k_indices = torch.topk(probs, k, dim=-1)
|
| 183 |
+
sampled_idx = torch.multinomial(top_k_probs, num_samples=1)
|
| 184 |
+
return top_k_indices[0, sampled_idx.item()]
|
| 185 |
+
|
| 186 |
+
def genOp(encoder, decoder, device, ip, ip_mask, mode='greedy', temperature=1.0, k=13, beam_width=5, diversity_penalty=0.7, num_candidates=10, max_len=100):
|
| 187 |
+
encoder.eval()
|
| 188 |
+
decoder.eval()
|
| 189 |
+
# model.eval()
|
| 190 |
+
print(f'\n\n\n GENOP FX CALL \n\n\n')
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
enc, hidden, cell = encoder(ip)
|
| 193 |
+
print(f'Hidden : {hidden.shape} | Cell : {cell.shape}')
|
| 194 |
+
if mode == 'greedy':
|
| 195 |
+
outputs = []
|
| 196 |
+
dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
|
| 197 |
+
count = 0
|
| 198 |
+
while True:
|
| 199 |
+
dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask)
|
| 200 |
+
dec_ip = torch.argmax(dec, dim=-1)
|
| 201 |
+
outputs.append(dec_ip.item())
|
| 202 |
+
count += 1
|
| 203 |
+
if count > max_len:
|
| 204 |
+
break
|
| 205 |
+
if dec_ip.item() == 50256:
|
| 206 |
+
print('Self terminated !!!')
|
| 207 |
+
break
|
| 208 |
+
return outputs
|
| 209 |
+
elif mode=='sample':
|
| 210 |
+
outputs = []
|
| 211 |
+
dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
|
| 212 |
+
count = 0
|
| 213 |
+
while True:
|
| 214 |
+
dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask)
|
| 215 |
+
# print(dec)
|
| 216 |
+
dec = dec/temperature
|
| 217 |
+
dec = torch.nn.functional.softmax(dec, dim=-1)
|
| 218 |
+
dec_ip = torch.multinomial(input=dec, num_samples=1, replacement=True).squeeze(0)
|
| 219 |
+
outputs.append(dec_ip.item())
|
| 220 |
+
count += 1
|
| 221 |
+
if count > max_len:
|
| 222 |
+
break
|
| 223 |
+
if dec_ip.item() == 50256:
|
| 224 |
+
print('Self terminated !!!')
|
| 225 |
+
break
|
| 226 |
+
return outputs
|
| 227 |
+
elif mode=='top_k':
|
| 228 |
+
outputs = []
|
| 229 |
+
dec_ip = torch.tensor([50256]).type(torch.int64).to(device)
|
| 230 |
+
count = 0
|
| 231 |
+
while True:
|
| 232 |
+
dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask)
|
| 233 |
+
dec = torch.nn.functional.softmax(dec, dim=-1)
|
| 234 |
+
top_k_probs , top_k_indices = torch.topk(dec, k, dim=-1)
|
| 235 |
+
dec_ip = torch.multinomial(input=top_k_probs, num_samples=1, replacement=True).squeeze(0)
|
| 236 |
+
dec_ip = top_k_indices[0, dec_ip.item()].unsqueeze(dim=0)
|
| 237 |
+
outputs.append(dec_ip.item())
|
| 238 |
+
count += 1
|
| 239 |
+
if count > max_len:
|
| 240 |
+
break
|
| 241 |
+
if dec_ip.item() == 50256:
|
| 242 |
+
print('Self terminated !!!')
|
| 243 |
+
break
|
| 244 |
+
return outputs
|
| 245 |
+
|
| 246 |
+
elif mode=='diverse-beam-search':
|
| 247 |
+
outputs = diverse_beam_search(decoder, enc, ip_mask, hidden, cell, device, beam_width=beam_width, diversity_penalty=diversity_penalty)
|
| 248 |
+
# print(f'GenOP stack trace: {outputs}')
|
| 249 |
+
return outputs
|
| 250 |
+
|
| 251 |
+
elif mode=='min-bayes-risk':
|
| 252 |
+
outputs = mbr_decoding(decoder, enc, ip_mask, hidden, cell, device, num_candidates=num_candidates, max_len=max_len)
|
| 253 |
+
return outputs
|
| 254 |
+
|
| 255 |
+
# ip = torch.tensor([[50256, 11195, 318, 13837, 11, 8272, 318, 2688, 4345, 1578,
|
| 256 |
+
# 11, 4475, 318, 3909, 11, 3035, 767, 11, 1941, 318,
|
| 257 |
+
# 4793, 11, 2435, 357, 315, 66, 8, 318, 1478, 25,
|
| 258 |
+
# 405, 11, 1078, 437, 590, 318, 3126, 11, 2931, 23,
|
| 259 |
+
# 11, 4080, 318, 24880, 10499, 11, 3576, 11, 4492, 11,
|
| 260 |
+
# 19316, 318, 4793, 12, 12726, 37985, 9952, 4041, 11, 6057,
|
| 261 |
+
# 62, 13376, 318, 19446, 11, 30408, 448, 318, 10352, 11,
|
| 262 |
+
# 11195, 62, 26675, 318, 657, 11, 8272, 62, 26675, 318,
|
| 263 |
+
# 352, 11, 11195, 62, 79, 49809, 47, 310, 318, 5598,
|
| 264 |
+
# 7441, 8272, 62, 79, 49809, 47, 310, 318, 4570, 7441,
|
| 265 |
+
# 11195, 62, 20910, 22093, 318, 1542, 357, 1314, 828, 8272,
|
| 266 |
+
# 62, 20910, 22093, 318, 718, 357, 20, 828, 11195, 62,
|
| 267 |
+
# 69, 42033, 6935, 2175, 318, 838, 13, 15, 11, 8272,
|
| 268 |
+
# 62, 69, 42033, 6935, 2175, 318, 1315, 13, 15, 11,
|
| 269 |
+
# 11195, 62, 36022, 34, 1371, 318, 657, 13, 15, 11,
|
| 270 |
+
# 8272, 62, 36022, 34, 1371, 318, 352, 13, 15, 11,
|
| 271 |
+
# 11195, 62, 445, 34, 1371, 318, 657, 13, 15, 11,
|
| 272 |
+
# 8272, 62, 445, 34, 1371, 318, 657, 13, 15, 11,
|
| 273 |
+
# 11195, 62, 8210, 1460, 318, 657, 13, 15, 11, 8272,
|
| 274 |
+
# 62, 8210, 1460, 318, 604, 13, 15, 11, 11195, 62,
|
| 275 |
+
# 26502, 41389, 364, 318, 1478, 13, 15, 11, 8272, 62,
|
| 276 |
+
# 26502, 41389, 364, 318, 352, 13, 15, 11, 11195, 62,
|
| 277 |
+
# 82, 3080, 318, 642, 13, 15, 11, 8272, 62, 82,
|
| 278 |
+
# 3080, 318, 1596, 13, 15, 11, 11195, 62, 1161, 318,
|
| 279 |
+
# 16185, 11, 8272, 62, 1161, 318, 16185, 11, 24623, 318,
|
| 280 |
+
# 3594, 9952, 4041, 11, 16060, 62, 15592, 318, 449, 641,
|
| 281 |
+
# 29921, 9038, 11, 17121, 7096, 292, 11, 42, 14057, 9852,
|
| 282 |
+
# 2634, 11, 10161, 18713, 12119, 280, 2634, 11, 35389, 26689,
|
| 283 |
+
# 75, 1012, 488, 88, 11, 30847, 11979, 406, 73, 2150,
|
| 284 |
+
# 3900, 11, 13787, 292, 10018, 17479, 11, 40747, 32371, 23720,
|
| 285 |
+
# 11, 15309, 38142, 81, 367, 293, 65, 11, 34, 3798,
|
| 286 |
+
# 376, 24247, 65, 2301, 292, 11, 10161, 18713, 1215, 1765,
|
| 287 |
+
# 323, 273, 11, 5124, 2731, 978, 6199, 544, 11, 49680,
|
| 288 |
+
# 68, 311, 2194, 418, 11, 41, 21356, 48590, 18226, 12523,
|
| 289 |
+
# 11, 4826, 280, 6031, 3930, 11, 31579, 44871, 12104, 324,
|
| 290 |
+
# 13235, 11, 32, 1014, 62, 15592, 318, 5199, 3469, 11,
|
| 291 |
+
# 22946, 292, 3169, 359, 11, 20191, 44677, 11, 13217, 261,
|
| 292 |
+
# 44312, 11, 14731, 14006, 11, 24338, 9740, 9860, 11, 25372,
|
| 293 |
+
# 20017, 9557, 11, 45, 47709, 797, 78, 12, 34, 11020,
|
| 294 |
+
# 11, 9704, 20833, 11, 33, 11369, 38343, 5799, 11, 26886,
|
| 295 |
+
# 418, 1665, 33425, 11, 32027, 21298, 11, 31306, 6559, 19574,
|
| 296 |
+
# 1040, 11, 30365, 13058, 273, 11, 25596, 271, 3248, 64,
|
| 297 |
+
# 10788, 68, 11, 42, 538, 64, 11, 7575, 318, 4153,
|
| 298 |
+
# 6]])
|
| 299 |
+
# ip_mask = torch.tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
|
| 300 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 301 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 302 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 303 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 304 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 305 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 306 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 307 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 308 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 309 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 310 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 311 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 312 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 313 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 314 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 315 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 316 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 317 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 318 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 319 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 320 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 321 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 322 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 323 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 324 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 325 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 326 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 327 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 328 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 329 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 330 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 331 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 332 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 333 |
+
# True, True, True, True, True, True, True, True, True, True, True, True,
|
| 334 |
+
# True, True, True, True, True, True, True, True, True, True, True]])
|
| 335 |
+
|
| 336 |
+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 337 |
+
|
| 338 |
+
# encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device)
|
| 339 |
+
# decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device)
|
| 340 |
+
# model = Seq2Seq(encoder, decoder).to(device)
|
| 341 |
+
|
| 342 |
+
# # checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device)
|
| 343 |
+
|
| 344 |
+
# # model.load_state_dict(checkpoint['model_state_dict'])
|
| 345 |
+
# print(genOp(model.encoder, model.decoder, device, ip, ip_mask, mode='greedy', temperature=1.0, k=13, beam_width=5, diversity_penalty=0.7, num_candidates=10, max_len=100))
|
seq2seq_checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f17ff3f8e345cc56e5bb5dcdb329832e43d68f10043bcca999b831b14ac7926
|
| 3 |
+
size 102274136
|
test_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:838b1509bd748deb39f9cb52bd4c1d7e733e4195a29d5e6da19ed4cf641c97cc
|
| 3 |
+
size 13032378
|