| | import argparse |
| | from transformers import GPT2LMHeadModel, GPT2Tokenizer |
| |
|
| |
|
| | def test_model(prompt, model_path='/Users/raghul.v/Desktop/research/pii_extraction_test/results', model_name='distilgpt2'): |
| | tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
| | model = GPT2LMHeadModel.from_pretrained(model_path) |
| |
|
| | |
| | |
| |
|
| | |
| | generated = tokenizer.encode(prompt, return_tensors="pt") |
| |
|
| | |
| | sample_outputs = model.generate( |
| | generated, |
| | do_sample=True, |
| | max_length=50, |
| | top_k=50, |
| | top_p=0.95, |
| | num_return_sequences=3 |
| | ) |
| |
|
| | for idx, sample_output in enumerate(sample_outputs): |
| | decoded_output = tokenizer.decode(sample_output, skip_special_tokens=True) |
| | print(f"Generated Text {idx}: {decoded_output}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser(description="Enter the prompt for the model.") |
| | parser.add_argument('--prompt', type=str, required=True, help='Prompt for the model') |
| | args = parser.parse_args() |
| |
|
| | |
| | test_model(args.prompt, model_path='results', model_name='distilgpt2') |