roemmele commited on
Commit
c886682
·
1 Parent(s): 9d0d6e1

Enabled control of generation parameters; created README.md

Browse files
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Creative Help
2
+
3
+ This is the model repo for Creative Help, a legacy app for AI-based writing assistance powered by a RNN language model. It was developed in 2016 as one of the first demonstrations of the use of a language model for helping people write stories. For more information, see the following research papers:
4
+
5
+ [Automated Assistance for Creative Writing with an RNN Language Model.](https://roemmele.github.io/publications/creative-help-demo.pdf) Melissa Roemmele and Andrew Gordon. Demo at IUI 2018.
6
+
7
+ [Linguistic Features of Helpfulness in Automated Support for Creative Writing.](https://roemmele.github.io/publications/creative-help-evaluation.pdf) Melissa Roemmele and Andrew Gordon. Storytelling Workshop at NAACL 2018.
8
+
9
+ ## Installation
10
+
11
+ ```bash
12
+ pip install transformers torch
13
+ python -m spacy download en_core_web_sm
14
+ ```
15
+
16
+ ## Loading the Model
17
+
18
+ This model uses a custom architecture and tokenizer (which has been semi-automatically adapted from the original implementation [here](https://github.com/roemmele/narrative-prediction)). Load it with `trust_remote_code=True`:
19
+
20
+ ```python
21
+ from transformers import AutoConfig, AutoModelForCausalLM
22
+ from rnnlm_model import (
23
+ RNNLMConfig,
24
+ RNNLMForCausalLM,
25
+ RNNLMTokenizer,
26
+ RNNLMTextGenerationPipeline,
27
+ )
28
+
29
+ AutoConfig.register("rnnlm", RNNLMConfig)
30
+ AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ "path/to/model",
34
+ trust_remote_code=True,
35
+ )
36
+ tokenizer = RNNLMTokenizer.from_pretrained("path/to/model")
37
+ pipe = RNNLMTextGenerationPipeline(model=model, tokenizer=tokenizer)
38
+ ```
39
+
40
+ ## Usage Examples
41
+
42
+ Generation uses a base configuration of `max_new_tokens=50`, `do_sample=True`, and `temperature=1.0` unless overridden.
43
+
44
+ ### Basic Generation (Default Parameters)
45
+
46
+ ```python
47
+ output = pipe("The storm came", max_new_tokens=50, do_sample=True, temperature=1.0)
48
+ print(output[0]["generated_text"])
49
+ ```
50
+
51
+ ### Limiting by Sentences (`max_new_sents`)
52
+
53
+ Limit the decoded output to a specific number of sentences:
54
+
55
+ ```python
56
+ # At most 1 sentence
57
+ output = pipe(
58
+ "Sarah closed her laptop and stared out the window.",
59
+ max_new_tokens=50,
60
+ max_new_sents=1,
61
+ )
62
+ print(output[0]["generated_text"])
63
+ ```
64
+
65
+ ## Inference API
66
+
67
+ When using the Hugging Face Inference API or Inference Endpoints, pass parameters in the request body:
68
+
69
+ ```json
70
+ {
71
+ "inputs": "The storm came",
72
+ "parameters": {
73
+ "max_new_tokens": 50,
74
+ "do_sample": true,
75
+ "temperature": 1.0,
76
+ "max_new_sents": 2
77
+ }
78
+ }
79
+ ```
80
+
81
+ ## Test Script
82
+
83
+ ```bash
84
+ python test_model.py --model_path . --seed 0
85
+ ```
rnnlm_model/pipeline_rnnlm.py CHANGED
@@ -5,6 +5,13 @@ from transformers.pipelines.text_generation import TextGenerationPipeline
5
  from transformers.pipelines.text_generation import ReturnType
6
  from transformers import GenerationConfig
7
 
 
 
 
 
 
 
 
8
 
9
  class RNNLMTextGenerationPipeline(TextGenerationPipeline):
10
  """
@@ -14,6 +21,18 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
14
 
15
  When the tokenizer has generalize_ents=True, entities are extracted from the
16
  prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
  assistant_model = None # Class default for transformers compatibility (assisted decoding)
19
  assistant_tokenizer = None
@@ -28,17 +47,28 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
28
  if not hasattr(self, "generation_config") or self.generation_config is None:
29
  self.generation_config = GenerationConfig(
30
  pad_token_id=getattr(self.tokenizer, "pad_token_id", 0),
31
- max_new_tokens=256,
32
  do_sample=True,
33
- temperature=0.7,
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
36
  def postprocess(
37
  self,
38
  model_outputs,
39
  return_type=ReturnType.NEW_TEXT,
40
  clean_up_tokenization_spaces=False,
41
  continue_final_message=None,
 
42
  ):
43
  generated_sequence = model_outputs["generated_sequence"][0]
44
  input_ids = model_outputs["input_ids"]
@@ -94,6 +124,14 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
94
  decode_kw.update(
95
  adapt_ents=True, capitalize_ents=True, ents=[ents])
96
 
 
 
 
 
 
 
 
 
97
  # Decode only the generated token IDs, then append to saved prompt
98
  prompt_len = 0
99
  if input_ids is not None:
 
5
  from transformers.pipelines.text_generation import ReturnType
6
  from transformers import GenerationConfig
7
 
8
+ # Decode parameters from RNNLMTokenizer.decode() that users can control via the pipeline
9
+ DECODE_PARAM_NAMES = frozenset({
10
+ "begin_sentence", "skip_special_tokens", "clean_up_tokenization_spaces",
11
+ "ents", "adapt_ents", "detokenize", "capitalize_ents",
12
+ "max_new_sents", "eos_tokens",
13
+ })
14
+
15
 
16
  class RNNLMTextGenerationPipeline(TextGenerationPipeline):
17
  """
 
21
 
22
  When the tokenizer has generalize_ents=True, entities are extracted from the
23
  prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
24
+
25
+ Decode parameters (from RNNLMTokenizer.decode) can be controlled when calling the
26
+ pipeline. Pass any of these as kwargs to override defaults:
27
+ - begin_sentence (bool): Whether generated text starts a new sentence
28
+ - skip_special_tokens (bool): Skip special tokens in output
29
+ - clean_up_tokenization_spaces (bool): Clean up extra spaces
30
+ - detokenize (bool): Apply detokenization (capitalization, punctuation)
31
+ - adapt_ents (bool): Replace ENT_* tokens with entities from context
32
+ - capitalize_ents (bool): Capitalize adapted entity names
33
+ - max_new_sents (int): Maximum number of sentences to include in decoded output
34
+ - eos_tokens (list): Token IDs treated as end-of-sequence
35
+ - ents (dict or list): Custom entity mapping(s) for adaptation
36
  """
37
  assistant_model = None # Class default for transformers compatibility (assisted decoding)
38
  assistant_tokenizer = None
 
47
  if not hasattr(self, "generation_config") or self.generation_config is None:
48
  self.generation_config = GenerationConfig(
49
  pad_token_id=getattr(self.tokenizer, "pad_token_id", 0),
50
+ max_new_tokens=50,
51
  do_sample=True,
52
+ temperature=1.0,
53
  )
54
 
55
+ def _sanitize_parameters(self, **kwargs):
56
+ """Extract RNNLM decode parameters into postprocess_params so users can control them."""
57
+ # Pull out decode params before passing to parent (they would otherwise go to forward/generate)
58
+ decode_params = {k: kwargs.pop(k) for k in list(
59
+ kwargs.keys()) if k in DECODE_PARAM_NAMES}
60
+ preprocess_params, forward_params, postprocess_params = super(
61
+ )._sanitize_parameters(**kwargs)
62
+ postprocess_params["decode_params"] = decode_params
63
+ return preprocess_params, forward_params, postprocess_params
64
+
65
  def postprocess(
66
  self,
67
  model_outputs,
68
  return_type=ReturnType.NEW_TEXT,
69
  clean_up_tokenization_spaces=False,
70
  continue_final_message=None,
71
+ decode_params=None,
72
  ):
73
  generated_sequence = model_outputs["generated_sequence"][0]
74
  input_ids = model_outputs["input_ids"]
 
124
  decode_kw.update(
125
  adapt_ents=True, capitalize_ents=True, ents=[ents])
126
 
127
+ # Apply user-provided decode params (from pipeline call)
128
+ user_decode = decode_params or {}
129
+ for k, v in user_decode.items():
130
+ if k == "ents":
131
+ decode_kw["ents"] = [v] if isinstance(v, dict) else v
132
+ else:
133
+ decode_kw[k] = v
134
+
135
  # Decode only the generated token IDs, then append to saved prompt
136
  prompt_len = 0
137
  if input_ids is not None:
rnnlm_model/tokenization_rnnlm.py CHANGED
@@ -153,14 +153,17 @@ class RNNLMTokenizer(PreTrainedTokenizer):
153
  adapt_ents=True,
154
  detokenize=True,
155
  capitalize_ents=True,
156
- n_sents_per_seq=1,
157
  eos_tokens=None,
158
  **kwargs,
159
  ):
160
  """Decode token IDs to string. When adapt_ents=True and ents is provided,
161
  replaces generic ENT_* tokens in the output with entities from the input context.
162
  ents should be a list of dicts (one per sequence) mapping entity name to type
163
- (e.g. {"John": "PERSON_0"} from number_ents(get_ents(...)))."""
 
 
 
164
  if isinstance(token_ids[0], (list, tuple)):
165
  seqs = token_ids
166
  else:
@@ -177,7 +180,7 @@ class RNNLMTokenizer(PreTrainedTokenizer):
177
  self._lexicon_lookup,
178
  self.unk_token,
179
  seqs,
180
- n_sents_per_seq=n_sents_per_seq,
181
  eos_tokens=eos_tokens or [],
182
  detokenize=detokenize,
183
  ents=ents or [],
 
153
  adapt_ents=True,
154
  detokenize=True,
155
  capitalize_ents=True,
156
+ max_new_sents=None,
157
  eos_tokens=None,
158
  **kwargs,
159
  ):
160
  """Decode token IDs to string. When adapt_ents=True and ents is provided,
161
  replaces generic ENT_* tokens in the output with entities from the input context.
162
  ents should be a list of dicts (one per sequence) mapping entity name to type
163
+ (e.g. {"John": "PERSON_0"} from number_ents(get_ents(...))).
164
+ When max_new_sents is None, output length is determined by the token sequence
165
+ (i.e. by max_new_tokens from generation); when set, output is truncated to that
166
+ many sentences."""
167
  if isinstance(token_ids[0], (list, tuple)):
168
  seqs = token_ids
169
  else:
 
180
  self._lexicon_lookup,
181
  self.unk_token,
182
  seqs,
183
+ max_new_sents=max_new_sents,
184
  eos_tokens=eos_tokens or [],
185
  detokenize=detokenize,
186
  ents=ents or [],
rnnlm_model/tokenization_utils.py CHANGED
@@ -120,7 +120,7 @@ def replace_ents_in_seq(encoder, seq):
120
  return seq
121
 
122
 
123
- def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, n_sents_per_seq=None, eos_tokens=[],
124
  detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
125
  sub_ent_probs=None, begin_sentence=True):
126
  if not seqs:
@@ -184,8 +184,8 @@ def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, n_sents_per_seq=Non
184
  seq = " ".join(seq)
185
  if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
186
  seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
187
- elif n_sents_per_seq:
188
- seq = filter_gen_seq(encoder, seq, n_sents=n_sents_per_seq)
189
  decoded_seqs.append(seq)
190
  return decoded_seqs
191
 
 
120
  return seq
121
 
122
 
123
+ def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, max_new_sents=None, eos_tokens=[],
124
  detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
125
  sub_ent_probs=None, begin_sentence=True):
126
  if not seqs:
 
184
  seq = " ".join(seq)
185
  if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
186
  seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
187
+ elif max_new_sents:
188
+ seq = filter_gen_seq(encoder, seq, n_sents=max_new_sents)
189
  decoded_seqs.append(seq)
190
  return decoded_seqs
191
 
test_model.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import random
7
+ import sys
8
+
9
+
10
+ def set_seed(seed: int):
11
+ """Set random seeds for reproducibility."""
12
+ random.seed(seed)
13
+ try:
14
+ import numpy as np
15
+ np.random.seed(seed)
16
+ except ImportError:
17
+ pass
18
+ try:
19
+ import torch
20
+ torch.manual_seed(seed)
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed_all(seed)
23
+ except ImportError:
24
+ pass
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--model_path", "-m", default=".", help="Path to converted model")
31
+ parser.add_argument(
32
+ "--prompts", "-p", default="test_prompts.json",
33
+ help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)")
34
+ parser.add_argument(
35
+ "--seed", "-s", type=int, default=0,
36
+ help="Random seed for reproducible generation (default: None, non-deterministic)")
37
+ parser.add_argument(
38
+ "--max_new_tokens", type=int, default=None,
39
+ help="Max tokens to generate (default: 50)")
40
+ parser.add_argument(
41
+ "--max_new_sents", type=int, default=None,
42
+ help="Max sentences in decoded output (default: pipeline default)")
43
+ args = parser.parse_args()
44
+
45
+ if args.seed is not None:
46
+ set_seed(args.seed)
47
+ print(f"Random seed set to {args.seed} for reproducibility")
48
+
49
+ if not os.path.isdir(args.model_path):
50
+ print(f"Error: Model path {args.model_path} does not exist.")
51
+ sys.exit(1)
52
+
53
+ prompts_path = args.prompts
54
+ if prompts_path is None:
55
+ prompts_path = os.path.join(os.path.dirname(
56
+ os.path.abspath(__file__)), "test_prompts.json")
57
+ if not os.path.isfile(prompts_path):
58
+ print(f"Error: Prompts file {prompts_path} does not exist.")
59
+ sys.exit(1)
60
+
61
+ print("Loading model and tokenizer...")
62
+ from transformers import AutoModelForCausalLM
63
+
64
+ # Register custom model and load tokenizer directly (AutoTokenizer doesn't know RNNLMTokenizer)
65
+ model_path = os.path.abspath(args.model_path)
66
+ from rnnlm_model import (
67
+ RNNLMConfig,
68
+ RNNLMForCausalLM,
69
+ RNNLMTokenizer,
70
+ RNNLMTextGenerationPipeline,
71
+ )
72
+ from transformers import AutoConfig
73
+ AutoConfig.register("rnnlm", RNNLMConfig)
74
+ AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
75
+
76
+ model = AutoModelForCausalLM.from_pretrained(
77
+ model_path, trust_remote_code=True)
78
+ tokenizer = RNNLMTokenizer.from_pretrained(model_path)
79
+
80
+ print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...")
81
+ pipe = RNNLMTextGenerationPipeline(
82
+ model=model,
83
+ tokenizer=tokenizer,
84
+ )
85
+
86
+ with open(prompts_path) as f:
87
+ test_prompts = json.load(f)
88
+
89
+ base_kwargs = dict(
90
+ max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50,
91
+ do_sample=True,
92
+ temperature=1.0,
93
+ pad_token_id=tokenizer.pad_token_id,
94
+ )
95
+ if args.max_new_sents is not None:
96
+ base_kwargs["max_new_sents"] = args.max_new_sents
97
+
98
+ def run_tests(kwargs):
99
+ for i, prompt in enumerate(test_prompts):
100
+ print(f"\n [{i + 1}/{len(test_prompts)}]")
101
+ print(f" PROMPT: ``{prompt}``")
102
+ output = pipe(prompt, **kwargs)
103
+ print(f" GENERATED: ``{output[0]['generated_text']}``")
104
+
105
+ # Test 1: Basic generation with default params
106
+ print("\n--- Test 1: Basic generation (default params) ---")
107
+ run_tests(base_kwargs)
108
+
109
+ # Test 2: max_new_tokens=20
110
+ print("\n--- Test 2: max_new_tokens=20 ---")
111
+ short_kwargs = {**base_kwargs, "max_new_tokens": 20}
112
+ run_tests(short_kwargs)
113
+
114
+ # Test 3: max_new_sents=2
115
+ print("\n--- Test 3: max_new_sents=2 ---")
116
+ sents_kwargs = {**base_kwargs, "max_new_sents": 2}
117
+ run_tests(sents_kwargs)
118
+
119
+ # Test 4: max_new_sents=1
120
+ print("\n--- Test 4: max_new_sents=1 ---")
121
+ sents1_kwargs = {**base_kwargs, "max_new_sents": 1}
122
+ run_tests(sents1_kwargs)
123
+
124
+ # Test 5: do_sample=False (greedy decoding)
125
+ print("\n--- Test 5: do_sample=False ---")
126
+ greedy_kwargs = {**base_kwargs, "do_sample": False}
127
+ run_tests(greedy_kwargs)
128
+
129
+ # Test 6: temperature=0.3
130
+ print("\n--- Test 6: temperature=0.3 ---")
131
+ low_temp_kwargs = {**base_kwargs, "temperature": 0.3}
132
+ run_tests(low_temp_kwargs)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
test_prompts.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "She",
3
+ "The old",
4
+ "The storm came",
5
+ "Marcus opened",
6
+ "The door creaked",
7
+ "Sarah closed her laptop and stared out the window. The email from her editor had been clear: the manuscript needed major revisions, and she had two weeks.",
8
+ "The detective studied the crime scene photos spread across his desk. Three victims, three different cities, and one impossible connection that made no sense.",
9
+ "The ancient library had been sealed for centuries, but the earthquake had cracked the stone. Now dust motes danced in the first light it had seen in",
10
+ "When the power went out across the city, nobody panicked at first.\nIt was only when the lights stayed off for the second day that people began to worry",
11
+ "Marcus and Elena walked through the forbidden forest, their torches raised and hearts pounding.\n\nThey had heard rumors of something dark moving among the trees."
12
+ ]