Enabled control of generation parameters; created README.md
Browse files- README.md +85 -0
- rnnlm_model/pipeline_rnnlm.py +40 -2
- rnnlm_model/tokenization_rnnlm.py +6 -3
- rnnlm_model/tokenization_utils.py +3 -3
- test_model.py +136 -0
- test_prompts.json +12 -0
README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Creative Help
|
| 2 |
+
|
| 3 |
+
This is the model repo for Creative Help, a legacy app for AI-based writing assistance powered by a RNN language model. It was developed in 2016 as one of the first demonstrations of the use of a language model for helping people write stories. For more information, see the following research papers:
|
| 4 |
+
|
| 5 |
+
[Automated Assistance for Creative Writing with an RNN Language Model.](https://roemmele.github.io/publications/creative-help-demo.pdf) Melissa Roemmele and Andrew Gordon. Demo at IUI 2018.
|
| 6 |
+
|
| 7 |
+
[Linguistic Features of Helpfulness in Automated Support for Creative Writing.](https://roemmele.github.io/publications/creative-help-evaluation.pdf) Melissa Roemmele and Andrew Gordon. Storytelling Workshop at NAACL 2018.
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
pip install transformers torch
|
| 13 |
+
python -m spacy download en_core_web_sm
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Loading the Model
|
| 17 |
+
|
| 18 |
+
This model uses a custom architecture and tokenizer (which has been semi-automatically adapted from the original implementation [here](https://github.com/roemmele/narrative-prediction)). Load it with `trust_remote_code=True`:
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 22 |
+
from rnnlm_model import (
|
| 23 |
+
RNNLMConfig,
|
| 24 |
+
RNNLMForCausalLM,
|
| 25 |
+
RNNLMTokenizer,
|
| 26 |
+
RNNLMTextGenerationPipeline,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
AutoConfig.register("rnnlm", RNNLMConfig)
|
| 30 |
+
AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
|
| 31 |
+
|
| 32 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 33 |
+
"path/to/model",
|
| 34 |
+
trust_remote_code=True,
|
| 35 |
+
)
|
| 36 |
+
tokenizer = RNNLMTokenizer.from_pretrained("path/to/model")
|
| 37 |
+
pipe = RNNLMTextGenerationPipeline(model=model, tokenizer=tokenizer)
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Usage Examples
|
| 41 |
+
|
| 42 |
+
Generation uses a base configuration of `max_new_tokens=50`, `do_sample=True`, and `temperature=1.0` unless overridden.
|
| 43 |
+
|
| 44 |
+
### Basic Generation (Default Parameters)
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
output = pipe("The storm came", max_new_tokens=50, do_sample=True, temperature=1.0)
|
| 48 |
+
print(output[0]["generated_text"])
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Limiting by Sentences (`max_new_sents`)
|
| 52 |
+
|
| 53 |
+
Limit the decoded output to a specific number of sentences:
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
# At most 1 sentence
|
| 57 |
+
output = pipe(
|
| 58 |
+
"Sarah closed her laptop and stared out the window.",
|
| 59 |
+
max_new_tokens=50,
|
| 60 |
+
max_new_sents=1,
|
| 61 |
+
)
|
| 62 |
+
print(output[0]["generated_text"])
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Inference API
|
| 66 |
+
|
| 67 |
+
When using the Hugging Face Inference API or Inference Endpoints, pass parameters in the request body:
|
| 68 |
+
|
| 69 |
+
```json
|
| 70 |
+
{
|
| 71 |
+
"inputs": "The storm came",
|
| 72 |
+
"parameters": {
|
| 73 |
+
"max_new_tokens": 50,
|
| 74 |
+
"do_sample": true,
|
| 75 |
+
"temperature": 1.0,
|
| 76 |
+
"max_new_sents": 2
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Test Script
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python test_model.py --model_path . --seed 0
|
| 85 |
+
```
|
rnnlm_model/pipeline_rnnlm.py
CHANGED
|
@@ -5,6 +5,13 @@ from transformers.pipelines.text_generation import TextGenerationPipeline
|
|
| 5 |
from transformers.pipelines.text_generation import ReturnType
|
| 6 |
from transformers import GenerationConfig
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class RNNLMTextGenerationPipeline(TextGenerationPipeline):
|
| 10 |
"""
|
|
@@ -14,6 +21,18 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
|
|
| 14 |
|
| 15 |
When the tokenizer has generalize_ents=True, entities are extracted from the
|
| 16 |
prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
assistant_model = None # Class default for transformers compatibility (assisted decoding)
|
| 19 |
assistant_tokenizer = None
|
|
@@ -28,17 +47,28 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
|
|
| 28 |
if not hasattr(self, "generation_config") or self.generation_config is None:
|
| 29 |
self.generation_config = GenerationConfig(
|
| 30 |
pad_token_id=getattr(self.tokenizer, "pad_token_id", 0),
|
| 31 |
-
max_new_tokens=
|
| 32 |
do_sample=True,
|
| 33 |
-
temperature=
|
| 34 |
)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def postprocess(
|
| 37 |
self,
|
| 38 |
model_outputs,
|
| 39 |
return_type=ReturnType.NEW_TEXT,
|
| 40 |
clean_up_tokenization_spaces=False,
|
| 41 |
continue_final_message=None,
|
|
|
|
| 42 |
):
|
| 43 |
generated_sequence = model_outputs["generated_sequence"][0]
|
| 44 |
input_ids = model_outputs["input_ids"]
|
|
@@ -94,6 +124,14 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
|
|
| 94 |
decode_kw.update(
|
| 95 |
adapt_ents=True, capitalize_ents=True, ents=[ents])
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# Decode only the generated token IDs, then append to saved prompt
|
| 98 |
prompt_len = 0
|
| 99 |
if input_ids is not None:
|
|
|
|
| 5 |
from transformers.pipelines.text_generation import ReturnType
|
| 6 |
from transformers import GenerationConfig
|
| 7 |
|
| 8 |
+
# Decode parameters from RNNLMTokenizer.decode() that users can control via the pipeline
|
| 9 |
+
DECODE_PARAM_NAMES = frozenset({
|
| 10 |
+
"begin_sentence", "skip_special_tokens", "clean_up_tokenization_spaces",
|
| 11 |
+
"ents", "adapt_ents", "detokenize", "capitalize_ents",
|
| 12 |
+
"max_new_sents", "eos_tokens",
|
| 13 |
+
})
|
| 14 |
+
|
| 15 |
|
| 16 |
class RNNLMTextGenerationPipeline(TextGenerationPipeline):
|
| 17 |
"""
|
|
|
|
| 21 |
|
| 22 |
When the tokenizer has generalize_ents=True, entities are extracted from the
|
| 23 |
prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
|
| 24 |
+
|
| 25 |
+
Decode parameters (from RNNLMTokenizer.decode) can be controlled when calling the
|
| 26 |
+
pipeline. Pass any of these as kwargs to override defaults:
|
| 27 |
+
- begin_sentence (bool): Whether generated text starts a new sentence
|
| 28 |
+
- skip_special_tokens (bool): Skip special tokens in output
|
| 29 |
+
- clean_up_tokenization_spaces (bool): Clean up extra spaces
|
| 30 |
+
- detokenize (bool): Apply detokenization (capitalization, punctuation)
|
| 31 |
+
- adapt_ents (bool): Replace ENT_* tokens with entities from context
|
| 32 |
+
- capitalize_ents (bool): Capitalize adapted entity names
|
| 33 |
+
- max_new_sents (int): Maximum number of sentences to include in decoded output
|
| 34 |
+
- eos_tokens (list): Token IDs treated as end-of-sequence
|
| 35 |
+
- ents (dict or list): Custom entity mapping(s) for adaptation
|
| 36 |
"""
|
| 37 |
assistant_model = None # Class default for transformers compatibility (assisted decoding)
|
| 38 |
assistant_tokenizer = None
|
|
|
|
| 47 |
if not hasattr(self, "generation_config") or self.generation_config is None:
|
| 48 |
self.generation_config = GenerationConfig(
|
| 49 |
pad_token_id=getattr(self.tokenizer, "pad_token_id", 0),
|
| 50 |
+
max_new_tokens=50,
|
| 51 |
do_sample=True,
|
| 52 |
+
temperature=1.0,
|
| 53 |
)
|
| 54 |
|
| 55 |
+
def _sanitize_parameters(self, **kwargs):
|
| 56 |
+
"""Extract RNNLM decode parameters into postprocess_params so users can control them."""
|
| 57 |
+
# Pull out decode params before passing to parent (they would otherwise go to forward/generate)
|
| 58 |
+
decode_params = {k: kwargs.pop(k) for k in list(
|
| 59 |
+
kwargs.keys()) if k in DECODE_PARAM_NAMES}
|
| 60 |
+
preprocess_params, forward_params, postprocess_params = super(
|
| 61 |
+
)._sanitize_parameters(**kwargs)
|
| 62 |
+
postprocess_params["decode_params"] = decode_params
|
| 63 |
+
return preprocess_params, forward_params, postprocess_params
|
| 64 |
+
|
| 65 |
def postprocess(
|
| 66 |
self,
|
| 67 |
model_outputs,
|
| 68 |
return_type=ReturnType.NEW_TEXT,
|
| 69 |
clean_up_tokenization_spaces=False,
|
| 70 |
continue_final_message=None,
|
| 71 |
+
decode_params=None,
|
| 72 |
):
|
| 73 |
generated_sequence = model_outputs["generated_sequence"][0]
|
| 74 |
input_ids = model_outputs["input_ids"]
|
|
|
|
| 124 |
decode_kw.update(
|
| 125 |
adapt_ents=True, capitalize_ents=True, ents=[ents])
|
| 126 |
|
| 127 |
+
# Apply user-provided decode params (from pipeline call)
|
| 128 |
+
user_decode = decode_params or {}
|
| 129 |
+
for k, v in user_decode.items():
|
| 130 |
+
if k == "ents":
|
| 131 |
+
decode_kw["ents"] = [v] if isinstance(v, dict) else v
|
| 132 |
+
else:
|
| 133 |
+
decode_kw[k] = v
|
| 134 |
+
|
| 135 |
# Decode only the generated token IDs, then append to saved prompt
|
| 136 |
prompt_len = 0
|
| 137 |
if input_ids is not None:
|
rnnlm_model/tokenization_rnnlm.py
CHANGED
|
@@ -153,14 +153,17 @@ class RNNLMTokenizer(PreTrainedTokenizer):
|
|
| 153 |
adapt_ents=True,
|
| 154 |
detokenize=True,
|
| 155 |
capitalize_ents=True,
|
| 156 |
-
|
| 157 |
eos_tokens=None,
|
| 158 |
**kwargs,
|
| 159 |
):
|
| 160 |
"""Decode token IDs to string. When adapt_ents=True and ents is provided,
|
| 161 |
replaces generic ENT_* tokens in the output with entities from the input context.
|
| 162 |
ents should be a list of dicts (one per sequence) mapping entity name to type
|
| 163 |
-
(e.g. {"John": "PERSON_0"} from number_ents(get_ents(...))).
|
|
|
|
|
|
|
|
|
|
| 164 |
if isinstance(token_ids[0], (list, tuple)):
|
| 165 |
seqs = token_ids
|
| 166 |
else:
|
|
@@ -177,7 +180,7 @@ class RNNLMTokenizer(PreTrainedTokenizer):
|
|
| 177 |
self._lexicon_lookup,
|
| 178 |
self.unk_token,
|
| 179 |
seqs,
|
| 180 |
-
|
| 181 |
eos_tokens=eos_tokens or [],
|
| 182 |
detokenize=detokenize,
|
| 183 |
ents=ents or [],
|
|
|
|
| 153 |
adapt_ents=True,
|
| 154 |
detokenize=True,
|
| 155 |
capitalize_ents=True,
|
| 156 |
+
max_new_sents=None,
|
| 157 |
eos_tokens=None,
|
| 158 |
**kwargs,
|
| 159 |
):
|
| 160 |
"""Decode token IDs to string. When adapt_ents=True and ents is provided,
|
| 161 |
replaces generic ENT_* tokens in the output with entities from the input context.
|
| 162 |
ents should be a list of dicts (one per sequence) mapping entity name to type
|
| 163 |
+
(e.g. {"John": "PERSON_0"} from number_ents(get_ents(...))).
|
| 164 |
+
When max_new_sents is None, output length is determined by the token sequence
|
| 165 |
+
(i.e. by max_new_tokens from generation); when set, output is truncated to that
|
| 166 |
+
many sentences."""
|
| 167 |
if isinstance(token_ids[0], (list, tuple)):
|
| 168 |
seqs = token_ids
|
| 169 |
else:
|
|
|
|
| 180 |
self._lexicon_lookup,
|
| 181 |
self.unk_token,
|
| 182 |
seqs,
|
| 183 |
+
max_new_sents=max_new_sents,
|
| 184 |
eos_tokens=eos_tokens or [],
|
| 185 |
detokenize=detokenize,
|
| 186 |
ents=ents or [],
|
rnnlm_model/tokenization_utils.py
CHANGED
|
@@ -120,7 +120,7 @@ def replace_ents_in_seq(encoder, seq):
|
|
| 120 |
return seq
|
| 121 |
|
| 122 |
|
| 123 |
-
def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs,
|
| 124 |
detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
|
| 125 |
sub_ent_probs=None, begin_sentence=True):
|
| 126 |
if not seqs:
|
|
@@ -184,8 +184,8 @@ def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, n_sents_per_seq=Non
|
|
| 184 |
seq = " ".join(seq)
|
| 185 |
if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
|
| 186 |
seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
|
| 187 |
-
elif
|
| 188 |
-
seq = filter_gen_seq(encoder, seq, n_sents=
|
| 189 |
decoded_seqs.append(seq)
|
| 190 |
return decoded_seqs
|
| 191 |
|
|
|
|
| 120 |
return seq
|
| 121 |
|
| 122 |
|
| 123 |
+
def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, max_new_sents=None, eos_tokens=[],
|
| 124 |
detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
|
| 125 |
sub_ent_probs=None, begin_sentence=True):
|
| 126 |
if not seqs:
|
|
|
|
| 184 |
seq = " ".join(seq)
|
| 185 |
if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
|
| 186 |
seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
|
| 187 |
+
elif max_new_sents:
|
| 188 |
+
seq = filter_gen_seq(encoder, seq, n_sents=max_new_sents)
|
| 189 |
decoded_seqs.append(seq)
|
| 190 |
return decoded_seqs
|
| 191 |
|
test_model.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def set_seed(seed: int):
|
| 11 |
+
"""Set random seeds for reproducibility."""
|
| 12 |
+
random.seed(seed)
|
| 13 |
+
try:
|
| 14 |
+
import numpy as np
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
except ImportError:
|
| 17 |
+
pass
|
| 18 |
+
try:
|
| 19 |
+
import torch
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
torch.cuda.manual_seed_all(seed)
|
| 23 |
+
except ImportError:
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--model_path", "-m", default=".", help="Path to converted model")
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--prompts", "-p", default="test_prompts.json",
|
| 33 |
+
help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)")
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--seed", "-s", type=int, default=0,
|
| 36 |
+
help="Random seed for reproducible generation (default: None, non-deterministic)")
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--max_new_tokens", type=int, default=None,
|
| 39 |
+
help="Max tokens to generate (default: 50)")
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--max_new_sents", type=int, default=None,
|
| 42 |
+
help="Max sentences in decoded output (default: pipeline default)")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
if args.seed is not None:
|
| 46 |
+
set_seed(args.seed)
|
| 47 |
+
print(f"Random seed set to {args.seed} for reproducibility")
|
| 48 |
+
|
| 49 |
+
if not os.path.isdir(args.model_path):
|
| 50 |
+
print(f"Error: Model path {args.model_path} does not exist.")
|
| 51 |
+
sys.exit(1)
|
| 52 |
+
|
| 53 |
+
prompts_path = args.prompts
|
| 54 |
+
if prompts_path is None:
|
| 55 |
+
prompts_path = os.path.join(os.path.dirname(
|
| 56 |
+
os.path.abspath(__file__)), "test_prompts.json")
|
| 57 |
+
if not os.path.isfile(prompts_path):
|
| 58 |
+
print(f"Error: Prompts file {prompts_path} does not exist.")
|
| 59 |
+
sys.exit(1)
|
| 60 |
+
|
| 61 |
+
print("Loading model and tokenizer...")
|
| 62 |
+
from transformers import AutoModelForCausalLM
|
| 63 |
+
|
| 64 |
+
# Register custom model and load tokenizer directly (AutoTokenizer doesn't know RNNLMTokenizer)
|
| 65 |
+
model_path = os.path.abspath(args.model_path)
|
| 66 |
+
from rnnlm_model import (
|
| 67 |
+
RNNLMConfig,
|
| 68 |
+
RNNLMForCausalLM,
|
| 69 |
+
RNNLMTokenizer,
|
| 70 |
+
RNNLMTextGenerationPipeline,
|
| 71 |
+
)
|
| 72 |
+
from transformers import AutoConfig
|
| 73 |
+
AutoConfig.register("rnnlm", RNNLMConfig)
|
| 74 |
+
AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
|
| 75 |
+
|
| 76 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 77 |
+
model_path, trust_remote_code=True)
|
| 78 |
+
tokenizer = RNNLMTokenizer.from_pretrained(model_path)
|
| 79 |
+
|
| 80 |
+
print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...")
|
| 81 |
+
pipe = RNNLMTextGenerationPipeline(
|
| 82 |
+
model=model,
|
| 83 |
+
tokenizer=tokenizer,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
with open(prompts_path) as f:
|
| 87 |
+
test_prompts = json.load(f)
|
| 88 |
+
|
| 89 |
+
base_kwargs = dict(
|
| 90 |
+
max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50,
|
| 91 |
+
do_sample=True,
|
| 92 |
+
temperature=1.0,
|
| 93 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 94 |
+
)
|
| 95 |
+
if args.max_new_sents is not None:
|
| 96 |
+
base_kwargs["max_new_sents"] = args.max_new_sents
|
| 97 |
+
|
| 98 |
+
def run_tests(kwargs):
|
| 99 |
+
for i, prompt in enumerate(test_prompts):
|
| 100 |
+
print(f"\n [{i + 1}/{len(test_prompts)}]")
|
| 101 |
+
print(f" PROMPT: ``{prompt}``")
|
| 102 |
+
output = pipe(prompt, **kwargs)
|
| 103 |
+
print(f" GENERATED: ``{output[0]['generated_text']}``")
|
| 104 |
+
|
| 105 |
+
# Test 1: Basic generation with default params
|
| 106 |
+
print("\n--- Test 1: Basic generation (default params) ---")
|
| 107 |
+
run_tests(base_kwargs)
|
| 108 |
+
|
| 109 |
+
# Test 2: max_new_tokens=20
|
| 110 |
+
print("\n--- Test 2: max_new_tokens=20 ---")
|
| 111 |
+
short_kwargs = {**base_kwargs, "max_new_tokens": 20}
|
| 112 |
+
run_tests(short_kwargs)
|
| 113 |
+
|
| 114 |
+
# Test 3: max_new_sents=2
|
| 115 |
+
print("\n--- Test 3: max_new_sents=2 ---")
|
| 116 |
+
sents_kwargs = {**base_kwargs, "max_new_sents": 2}
|
| 117 |
+
run_tests(sents_kwargs)
|
| 118 |
+
|
| 119 |
+
# Test 4: max_new_sents=1
|
| 120 |
+
print("\n--- Test 4: max_new_sents=1 ---")
|
| 121 |
+
sents1_kwargs = {**base_kwargs, "max_new_sents": 1}
|
| 122 |
+
run_tests(sents1_kwargs)
|
| 123 |
+
|
| 124 |
+
# Test 5: do_sample=False (greedy decoding)
|
| 125 |
+
print("\n--- Test 5: do_sample=False ---")
|
| 126 |
+
greedy_kwargs = {**base_kwargs, "do_sample": False}
|
| 127 |
+
run_tests(greedy_kwargs)
|
| 128 |
+
|
| 129 |
+
# Test 6: temperature=0.3
|
| 130 |
+
print("\n--- Test 6: temperature=0.3 ---")
|
| 131 |
+
low_temp_kwargs = {**base_kwargs, "temperature": 0.3}
|
| 132 |
+
run_tests(low_temp_kwargs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
main()
|
test_prompts.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"She",
|
| 3 |
+
"The old",
|
| 4 |
+
"The storm came",
|
| 5 |
+
"Marcus opened",
|
| 6 |
+
"The door creaked",
|
| 7 |
+
"Sarah closed her laptop and stared out the window. The email from her editor had been clear: the manuscript needed major revisions, and she had two weeks.",
|
| 8 |
+
"The detective studied the crime scene photos spread across his desk. Three victims, three different cities, and one impossible connection that made no sense.",
|
| 9 |
+
"The ancient library had been sealed for centuries, but the earthquake had cracked the stone. Now dust motes danced in the first light it had seen in",
|
| 10 |
+
"When the power went out across the city, nobody panicked at first.\nIt was only when the lights stayed off for the second day that people began to worry",
|
| 11 |
+
"Marcus and Elena walked through the forbidden forest, their torches raised and hearts pounding.\n\nThey had heard rumors of something dark moving among the trees."
|
| 12 |
+
]
|