| | import os |
| | import time |
| | import datetime |
| |
|
| | import pandas as pd |
| | import seaborn as sns |
| | import numpy as np |
| | import random |
| |
|
| | import matplotlib.pyplot as plt |
| |
|
| | import torch |
| | from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler |
| |
|
| |
|
| | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2LMHeadModel |
| | from transformers import AdamW, get_linear_schedule_with_warmup |
| |
|
| | import sys |
| |
|
| | import pytz |
| | IST = pytz.timezone('Asia/Kolkata') |
| | print(datetime.datetime.now(IST).strftime("%c")) |
| |
|
| | tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') |
| |
|
| | |
| | configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False) |
| |
|
| | |
| | model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration) |
| |
|
| | |
| | |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | |
| | device = torch.device("cuda") |
| |
|
| | model = model.to(device) |
| |
|
| | print('Model loaded to GPU') |
| | print(datetime.datetime.now(IST).strftime("%c")) |
| |
|
| | output_dir = '/media/data_dump/Ritwik/ggpt/model_save/pytorch_save_files/' |
| |
|
| | print('Loading fine-tuned weights') |
| | model = model.from_pretrained(output_dir).to(device) |
| | tokenizer = tokenizer.from_pretrained(output_dir) |
| |
|
| | print('Model and tokenizer loaded!') |
| | print(datetime.datetime.now(IST).strftime("%c")) |
| |
|
| | model.eval() |
| |
|
| | |
| | prompt_list = ['<|startoftext|> Regarding Kashmir I am very confident to say that'] |
| |
|
| | for prompt in prompt_list: |
| |
|
| | |
| |
|
| | print(prompt) |
| |
|
| | generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) |
| | generated = generated.to(device) |
| |
|
| | print(generated) |
| |
|
| | sample_outputs = model.generate( |
| | generated, |
| | |
| | do_sample=True, |
| | top_k=50, |
| | max_length = 500, |
| | top_p=0.95, |
| | num_return_sequences=3 |
| | ) |
| |
|
| | for i, sample_output in enumerate(sample_outputs): |
| | print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True))) |
| |
|
| | print(datetime.datetime.now(IST).strftime("%c")) |
| | print('\n') |
| |
|
| |
|