roemmele commited on
Commit
edbfc07
·
verified ·
1 Parent(s): a0070bf

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RNNLMForCausalLM"
4
+ ],
5
+ "embedding_dim": 300,
6
+ "hidden_size": 500,
7
+ "model_type": "rnnlm",
8
+ "n_feature_nodes": 100,
9
+ "n_pos_embedding_nodes": 25,
10
+ "n_pos_nodes": 100,
11
+ "n_pos_tags": 59,
12
+ "num_hidden_layers": 2,
13
+ "pad_token_id": 0,
14
+ "tie_word_embeddings": false,
15
+ "torch_dtype": "float32",
16
+ "transformers_version": "4.46.3",
17
+ "unk_token_id": 1,
18
+ "use_cache": true,
19
+ "use_features": false,
20
+ "use_pos": false,
21
+ "vocab_size": 64986
22
+ }
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.46.3"
5
+ }
handler.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ Custom Inference Handler for RNNLM (creative-help) on Hugging Face Inference Endpoints.
4
+
5
+ Implements EndpointHandler as described in:
6
+ https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler
7
+
8
+ The handler loads the RNNLM model with entity adaptation support and serves
9
+ text generation requests via the Inference API.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ from typing import Any, Dict, List, Union
15
+
16
+
17
+ class EndpointHandler:
18
+ """
19
+ Custom handler for RNNLM text generation on Hugging Face Inference Endpoints.
20
+ Loads the model, tokenizer, and pipeline at init; serves generation requests in __call__.
21
+ """
22
+
23
+ def __init__(self, path: str = ""):
24
+ """
25
+ Initialize the handler. Called when the Endpoint starts.
26
+ :param path: Path to the model repository (model weights, config, tokenizer).
27
+ """
28
+ self.path = path or "."
29
+ self.path = os.path.abspath(self.path)
30
+
31
+ # Add model repo to path so we can import rnnlm_model
32
+ if self.path not in sys.path:
33
+ sys.path.insert(0, self.path)
34
+
35
+ # Register custom model architecture with Transformers
36
+ from transformers import AutoConfig, AutoModelForCausalLM
37
+ from rnnlm_model import (
38
+ RNNLMConfig,
39
+ RNNLMForCausalLM,
40
+ RNNLMTokenizer,
41
+ RNNLMTextGenerationPipeline,
42
+ )
43
+
44
+ AutoConfig.register("rnnlm", RNNLMConfig)
45
+ AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
46
+
47
+ # Load model and tokenizer
48
+ self.model = AutoModelForCausalLM.from_pretrained(
49
+ self.path,
50
+ trust_remote_code=True,
51
+ )
52
+ self.tokenizer = RNNLMTokenizer.from_pretrained(self.path)
53
+
54
+ # Create text generation pipeline with entity adaptation
55
+ self.pipeline = RNNLMTextGenerationPipeline(
56
+ model=self.model,
57
+ tokenizer=self.tokenizer,
58
+ )
59
+
60
+ def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, str]], Dict[str, Any]]:
61
+ """
62
+ Handle inference requests. Called on every API request.
63
+ :param data: Request payload with "inputs" (prompt string or list) and optional "parameters".
64
+ :return: List of dicts with "generated_text" key(s), or single dict for compatibility.
65
+ """
66
+ inputs = data.pop("inputs", None)
67
+ if inputs is None:
68
+ return {"error": "Missing 'inputs' in request body"}
69
+
70
+ parameters = data.pop("parameters", data) or {}
71
+ if not isinstance(parameters, dict):
72
+ parameters = {}
73
+
74
+ # Default generation parameters
75
+ gen_kwargs = {
76
+ "max_new_tokens": parameters.get("max_new_tokens", 50),
77
+ "do_sample": parameters.get("do_sample", True),
78
+ "temperature": parameters.get("temperature", 1.0),
79
+ "pad_token_id": self.tokenizer.pad_token_id,
80
+ }
81
+ # Allow override of other params (top_p, top_k, repetition_penalty, etc.)
82
+ for k, v in parameters.items():
83
+ if k not in gen_kwargs:
84
+ gen_kwargs[k] = v
85
+
86
+ # Run generation
87
+ try:
88
+ result = self.pipeline(inputs, **gen_kwargs)
89
+ except Exception as e:
90
+ return {"error": str(e)}
91
+
92
+ # Ensure we return a list of dicts (API expects list for batch)
93
+ if isinstance(result, list):
94
+ return result
95
+ return [result] if isinstance(result, dict) else [{"generated_text": str(result)}]
lexicon_lookup.json ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8ba1140559355d5160d133f9b243db038758bf3922520e2e5aab6b08fe55f07
3
+ size 219043380
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom dependencies for RNNLM (creative-help) Inference Endpoint
2
+ # Base stack (torch, transformers) is provided by the Inference Endpoints container
3
+
4
+ # RNNLM tokenizer uses spaCy for tokenization and entity extraction
5
+ spacy>=3.0
6
+ # English spaCy model - required for RNNLMTokenizer (entity recognition, tokenization)
7
+ # Install from GitHub release (pip cannot install spacy models via python -m spacy download in container)
8
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.0/en_core_web_sm-3.7.0-py3-none-any.whl
9
+
10
+ # NumPy (used by tokenization_utils)
11
+ numpy
rnnlm_model/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """RNNLM model for HuggingFace Transformers."""
3
+
4
+ from .configuration_rnnlm import RNNLMConfig
5
+ from .modeling_rnnlm import RNNLMForCausalLM
6
+ from .tokenization_rnnlm import RNNLMTokenizer
7
+ from .pipeline_rnnlm import RNNLMTextGenerationPipeline
8
+
9
+ __all__ = [
10
+ "RNNLMConfig",
11
+ "RNNLMForCausalLM",
12
+ "RNNLMTokenizer",
13
+ "RNNLMTextGenerationPipeline",
14
+ ]
rnnlm_model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (508 Bytes). View file
 
rnnlm_model/__pycache__/configuration_rnnlm.cpython-38.pyc ADDED
Binary file (1.46 kB). View file
 
rnnlm_model/__pycache__/modeling_rnnlm.cpython-38.pyc ADDED
Binary file (9.04 kB). View file
 
rnnlm_model/__pycache__/pipeline_rnnlm.cpython-38.pyc ADDED
Binary file (2.81 kB). View file
 
rnnlm_model/__pycache__/tokenization_rnnlm.cpython-38.pyc ADDED
Binary file (9.78 kB). View file
 
rnnlm_model/__pycache__/tokenization_utils.cpython-38.pyc ADDED
Binary file (11.8 kB). View file
 
rnnlm_model/configuration_rnnlm.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """RNN Language Model configuration for HuggingFace Transformers."""
3
+
4
+ try:
5
+ from transformers import PreTrainedConfig
6
+ except ImportError:
7
+ try:
8
+ from transformers.configuration_utils import PreTrainedConfig
9
+ except ImportError:
10
+ from transformers.configuration_utils import PretrainedConfig as PreTrainedConfig
11
+
12
+
13
+ class RNNLMConfig(PreTrainedConfig):
14
+ """Configuration class for RNNLM (RNN Language Model)."""
15
+
16
+ model_type = "rnnlm"
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_size=50000,
21
+ embedding_dim=300,
22
+ hidden_size=250,
23
+ num_hidden_layers=1,
24
+ pad_token_id=0,
25
+ unk_token_id=1,
26
+ bos_token_id=None,
27
+ eos_token_id=None,
28
+ use_pos=False,
29
+ use_features=False,
30
+ n_pos_tags=59,
31
+ n_pos_embedding_nodes=25,
32
+ n_pos_nodes=100,
33
+ n_feature_nodes=100,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
37
+ self.vocab_size = vocab_size
38
+ self.embedding_dim = embedding_dim
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.unk_token_id = unk_token_id
42
+ self.bos_token_id = bos_token_id
43
+ self.eos_token_id = eos_token_id
44
+ self.use_pos = use_pos
45
+ self.use_features = use_features
46
+ self.n_pos_tags = n_pos_tags
47
+ self.n_pos_embedding_nodes = n_pos_embedding_nodes
48
+ self.n_pos_nodes = n_pos_nodes
49
+ self.n_feature_nodes = n_feature_nodes
50
+ self.use_cache = True # Required for generation
51
+ self.tie_word_embeddings = False # RNNLM uses separate embed and output layers
rnnlm_model/modeling_rnnlm.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """RNN Language Model for HuggingFace Transformers - PyTorch implementation."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ try:
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.generation import LogitsProcessor, LogitsProcessorList
10
+ except ImportError:
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ try:
14
+ from transformers.generation import LogitsProcessor, LogitsProcessorList
15
+ except ImportError:
16
+ from transformers.generation_utils import LogitsProcessor, LogitsProcessorList
17
+
18
+ from .configuration_rnnlm import RNNLMConfig
19
+
20
+
21
+ class PreventUnkLogitsProcessor(LogitsProcessor):
22
+ """
23
+ Redistribute probability from pad (0) and unk (1) to other tokens before sampling.
24
+ Matches the original Keras model's prevent_unk behavior.
25
+ """
26
+
27
+ def __init__(self, pad_token_id: int = 0, unk_token_id: int = 1):
28
+ self.pad_token_id = pad_token_id
29
+ self.unk_token_id = unk_token_id
30
+
31
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
32
+ # Set pad and unk logits to very small value so they're never sampled
33
+ scores = scores.clone()
34
+ scores[:, self.pad_token_id] = -1e8
35
+ scores[:, self.unk_token_id] = -1e8
36
+ return scores
37
+
38
+
39
+ class GRUKerasCompat(nn.Module):
40
+ """
41
+ GRU matching Keras reset_after=False (GRU v1).
42
+ Keras: h_new = tanh(W_h·x + W_hn·(r⊙h))
43
+ PyTorch default: h_new = tanh(W_h·x + r⊙(W_hn·h))
44
+ We implement the Keras formulation for correct conversion.
45
+ Uses same weight layout as nn.GRU: [r, z, n] gate order.
46
+ """
47
+
48
+ def __init__(self, input_size: int, hidden_size: int, batch_first: bool = True):
49
+ super().__init__()
50
+ self.input_size = input_size
51
+ self.hidden_size = hidden_size
52
+ self.batch_first = batch_first
53
+ self.weight_ih = nn.Parameter(torch.empty(3 * hidden_size, input_size))
54
+ self.weight_hh = nn.Parameter(torch.empty(3 * hidden_size, hidden_size))
55
+ self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size))
56
+ self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size))
57
+ self.reset_parameters()
58
+
59
+ def reset_parameters(self):
60
+ nn.init.xavier_uniform_(self.weight_ih)
61
+ nn.init.xavier_uniform_(self.weight_hh)
62
+ nn.init.zeros_(self.bias_ih)
63
+ nn.init.zeros_(self.bias_hh)
64
+
65
+ def forward(self, x: torch.Tensor, h_0: torch.Tensor = None):
66
+ if self.batch_first:
67
+ x = x # (batch, seq, input)
68
+ else:
69
+ x = x.transpose(0, 1)
70
+ batch, seq_len, _ = x.shape
71
+ if h_0 is None:
72
+ h = x.new_zeros(batch, self.hidden_size)
73
+ else:
74
+ h = h_0.squeeze(0) # (batch, hidden)
75
+
76
+ outputs = []
77
+ for t in range(seq_len):
78
+ x_t = x[:, t, :] # (batch, input)
79
+ # Gates: weight layout [r, z, n], each (hidden, input) or (hidden, hidden)
80
+ r_ih = x_t @ self.weight_ih[:self.hidden_size].t() + self.bias_ih[:self.hidden_size]
81
+ z_ih = x_t @ self.weight_ih[self.hidden_size:2*self.hidden_size].t() + self.bias_ih[self.hidden_size:2*self.hidden_size]
82
+ n_ih = x_t @ self.weight_ih[2*self.hidden_size:].t() + self.bias_ih[2*self.hidden_size:]
83
+
84
+ r_hh = h @ self.weight_hh[:self.hidden_size].t() + self.bias_hh[:self.hidden_size]
85
+ z_hh = h @ self.weight_hh[self.hidden_size:2*self.hidden_size].t() + self.bias_hh[self.hidden_size:2*self.hidden_size]
86
+ n_hh = (h * torch.sigmoid(r_ih + r_hh)) @ self.weight_hh[2*self.hidden_size:].t() + self.bias_hh[2*self.hidden_size:]
87
+
88
+ r = torch.sigmoid(r_ih + r_hh)
89
+ z = torch.sigmoid(z_ih + z_hh)
90
+ n = torch.tanh(n_ih + n_hh)
91
+ h = (1 - z) * n + z * h
92
+ outputs.append(h)
93
+
94
+ output = torch.stack(outputs, dim=1) # (batch, seq, hidden)
95
+ if not self.batch_first:
96
+ output = output.transpose(0, 1)
97
+ return output, h.unsqueeze(0)
98
+
99
+
100
+ class RNNLMForCausalLM(PreTrainedModel):
101
+ """
102
+ RNN-based Causal Language Model for text generation.
103
+ Compatible with HuggingFace TextGenerationPipeline.
104
+ Supports base model (no POS, no features). POS and features require
105
+ additional preprocessing at generation time.
106
+ """
107
+
108
+ config_class = RNNLMConfig
109
+ base_model_prefix = "rnnlm"
110
+ supports_gradient_checkpointing = False
111
+ _no_split_modules = []
112
+
113
+ def __init__(self, config: RNNLMConfig, **kwargs):
114
+ super().__init__(config)
115
+ self.config = config
116
+ self.vocab_size = config.vocab_size
117
+ self.embedding_dim = config.embedding_dim
118
+ self.hidden_size = config.hidden_size
119
+ self.num_hidden_layers = config.num_hidden_layers
120
+ self.use_pos = getattr(config, "use_pos", False)
121
+ self.use_features = getattr(config, "use_features", False)
122
+
123
+ # Embedding layer (vocab_size + 1 for padding at index 0)
124
+ self.embedding = nn.Embedding(
125
+ config.vocab_size + 1,
126
+ config.embedding_dim,
127
+ padding_idx=0,
128
+ )
129
+
130
+ # GRU layers (Keras reset_after=False compatible)
131
+ self.gru_layers = nn.ModuleList()
132
+ for i in range(config.num_hidden_layers):
133
+ input_size = config.embedding_dim if i == 0 else config.hidden_size
134
+ self.gru_layers.append(
135
+ GRUKerasCompat(
136
+ input_size=input_size,
137
+ hidden_size=config.hidden_size,
138
+ batch_first=True,
139
+ )
140
+ )
141
+
142
+ # Output size after GRU
143
+ lm_input_size = config.hidden_size
144
+
145
+ # Optional POS branch (for loading converted models - generation needs external POS)
146
+ if self.use_pos:
147
+ self.pos_embedding = nn.Embedding(
148
+ config.n_pos_tags + 1,
149
+ config.n_pos_embedding_nodes,
150
+ padding_idx=0,
151
+ )
152
+ self.pos_gru = nn.GRU(
153
+ input_size=config.n_pos_embedding_nodes,
154
+ hidden_size=config.n_pos_nodes,
155
+ num_layers=1,
156
+ batch_first=True,
157
+ )
158
+ lm_input_size = lm_input_size + config.n_pos_nodes
159
+ else:
160
+ self.pos_embedding = None
161
+ self.pos_gru = None
162
+
163
+ # Optional feature branch
164
+ if self.use_features:
165
+ self.feature_dense = nn.Sequential(
166
+ nn.Linear(config.vocab_size + 1, config.n_feature_nodes),
167
+ nn.Sigmoid(),
168
+ )
169
+ lm_input_size = lm_input_size + config.n_feature_nodes
170
+ else:
171
+ self.feature_dense = None
172
+
173
+ # Output projection
174
+ self.lm_head = nn.Linear(lm_input_size, config.vocab_size + 1)
175
+
176
+ # Initialize weights
177
+ self.apply(self._init_weights)
178
+
179
+ def _init_weights(self, module):
180
+ if isinstance(module, nn.Linear):
181
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
182
+ if module.bias is not None:
183
+ torch.nn.init.zeros_(module.bias)
184
+ elif isinstance(module, nn.Embedding):
185
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
186
+ if module.padding_idx is not None:
187
+ module.weight.data[module.padding_idx].zero_()
188
+
189
+ def get_input_embeddings(self):
190
+ return self.embedding
191
+
192
+ def set_input_embeddings(self, value):
193
+ self.embedding = value
194
+
195
+ def get_output_embeddings(self):
196
+ return self.lm_head
197
+
198
+ def set_output_embeddings(self, new_embeddings):
199
+ self.lm_head = new_embeddings
200
+
201
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
202
+ """
203
+ For RNN: past_key_values stores the hidden state tuple (h_n for each GRU layer).
204
+ During generation we only need the last token and the cached hidden state.
205
+ """
206
+ if past_key_values is not None:
207
+ input_ids = input_ids[:, -1:]
208
+ return {"input_ids": input_ids, "past_key_values": past_key_values}
209
+
210
+ def forward(
211
+ self,
212
+ input_ids=None,
213
+ attention_mask=None,
214
+ past_key_values=None,
215
+ position_ids=None,
216
+ pos_ids=None,
217
+ feature_vecs=None,
218
+ labels=None,
219
+ use_cache=None,
220
+ output_attentions=None,
221
+ output_hidden_states=None,
222
+ return_dict=None,
223
+ ):
224
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
225
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
226
+
227
+ # Get embeddings
228
+ inputs_embeds = self.embedding(input_ids)
229
+
230
+ # Run through GRU layers
231
+ hidden_states = inputs_embeds
232
+ new_past_key_values = () if use_cache else None
233
+
234
+ for i, gru_layer in enumerate(self.gru_layers):
235
+ if past_key_values is not None and len(past_key_values) > i:
236
+ h_0 = past_key_values[i]
237
+ hidden_states, h_n = gru_layer(hidden_states, h_0)
238
+ else:
239
+ hidden_states, h_n = gru_layer(hidden_states)
240
+
241
+ if use_cache:
242
+ new_past_key_values = new_past_key_values + (h_n,)
243
+
244
+ # Optional: concatenate POS hidden states (requires pos_ids at each step)
245
+ if self.use_pos and pos_ids is not None:
246
+ pos_embeds = self.pos_embedding(pos_ids)
247
+ _, pos_h_n = self.pos_gru(pos_embeds)
248
+ pos_hidden = pos_h_n.squeeze(0).unsqueeze(
249
+ 1).expand(-1, hidden_states.size(1), -1)
250
+ hidden_states = torch.cat([hidden_states, pos_hidden], dim=-1)
251
+
252
+ # Optional: concatenate feature vectors
253
+ if self.use_features and feature_vecs is not None:
254
+ features = self.feature_dense(feature_vecs)
255
+ features = features.unsqueeze(
256
+ 1).expand(-1, hidden_states.size(1), -1)
257
+ hidden_states = torch.cat([hidden_states, features], dim=-1)
258
+
259
+ # Project to vocabulary
260
+ logits = self.lm_head(hidden_states)
261
+
262
+ loss = None
263
+ if labels is not None:
264
+ shift_logits = logits[..., :-1, :].contiguous()
265
+ shift_labels = labels[..., 1:].contiguous()
266
+ loss_fct = nn.CrossEntropyLoss()
267
+ loss = loss_fct(
268
+ shift_logits.view(-1, shift_logits.size(-1)),
269
+ shift_labels.view(-1),
270
+ )
271
+
272
+ if not return_dict:
273
+ output = (logits,) + (new_past_key_values,
274
+ ) if use_cache else (logits,)
275
+ return ((loss,) + output) if loss is not None else output
276
+
277
+ return CausalLMOutputWithPast(
278
+ loss=loss,
279
+ logits=logits,
280
+ past_key_values=new_past_key_values,
281
+ hidden_states=None,
282
+ attentions=None,
283
+ )
284
+
285
+ @staticmethod
286
+ def _reorder_cache(past_key_values, beam_idx):
287
+ """Reorder past_key_values for beam search."""
288
+ return tuple(layer_past.index_select(0, beam_idx) for layer_past in past_key_values)
289
+
290
+ def generate(self, inputs=None, **kwargs):
291
+ """Override to add prevent_unk (pad/unk suppression) during generation."""
292
+ pad_id = getattr(self.config, "pad_token_id", 0)
293
+ unk_id = getattr(self.config, "unk_token_id", 1)
294
+ processor = PreventUnkLogitsProcessor(pad_token_id=pad_id, unk_token_id=unk_id)
295
+ logits_processor = kwargs.pop("logits_processor", None)
296
+ if logits_processor is None:
297
+ logits_processor = LogitsProcessorList()
298
+ elif not isinstance(logits_processor, LogitsProcessorList):
299
+ logits_processor = LogitsProcessorList(logits_processor)
300
+ logits_processor.insert(0, processor)
301
+ kwargs["logits_processor"] = logits_processor
302
+ return super().generate(inputs, **kwargs)
rnnlm_model/pipeline_rnnlm.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """Custom TextGenerationPipeline for RNNLM with entity adaptation support."""
3
+
4
+ from transformers.pipelines.text_generation import TextGenerationPipeline
5
+ from transformers.pipelines.text_generation import ReturnType
6
+
7
+
8
+ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
9
+ """
10
+ TextGenerationPipeline that applies RNNLM-specific post-processing:
11
+ - Detokenization (capitalization, punctuation formatting)
12
+ - Entity adaptation: replaces generic ENT_* tokens with real entities from the prompt
13
+
14
+ When the tokenizer has generalize_ents=True, entities are extracted from the
15
+ prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
16
+ """
17
+
18
+ def postprocess(
19
+ self,
20
+ model_outputs,
21
+ return_type=ReturnType.NEW_TEXT,
22
+ clean_up_tokenization_spaces=False,
23
+ continue_final_message=None,
24
+ ):
25
+ generated_sequence = model_outputs["generated_sequence"][0]
26
+ input_ids = model_outputs["input_ids"]
27
+ prompt_text = model_outputs["prompt_text"]
28
+
29
+ # Convert to list (handle both PyTorch and TensorFlow)
30
+ if hasattr(generated_sequence, "cpu"):
31
+ generated_sequence = generated_sequence.cpu().tolist()
32
+ elif hasattr(generated_sequence, "numpy"):
33
+ generated_sequence = generated_sequence.numpy().tolist()
34
+ else:
35
+ generated_sequence = list(generated_sequence)
36
+
37
+ # Flatten if (num_return_sequences, seq_len) -> iterate over sequences
38
+ if generated_sequence and isinstance(generated_sequence[0], (list, tuple)):
39
+ sequences = generated_sequence
40
+ else:
41
+ sequences = [generated_sequence]
42
+
43
+ # Get prompt text(s) - can be str or list for batch
44
+ if isinstance(prompt_text, (list, tuple)):
45
+ prompts = list(prompt_text)
46
+ else:
47
+ prompts = [prompt_text] * len(sequences)
48
+
49
+ records = []
50
+ for seq_idx, sequence in enumerate(sequences):
51
+ if return_type == ReturnType.TENSORS:
52
+ record = {"generated_token_ids": sequence}
53
+ elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
54
+ # Use RNNLM-specific decode when tokenizer supports it (detokenize + entity adaptation)
55
+ # Entities are re-extracted from the original prompt here (prompt_text from model_outputs)
56
+ # and used to replace ENT_* tokens in the decoded output - no need to save from preprocess
57
+ tokenizer = self.tokenizer
58
+ prompt = prompts[seq_idx] if seq_idx < len(
59
+ prompts) else (prompts[0] if prompts else "")
60
+ use_ents = getattr(tokenizer, "_generalize_ents", False) and isinstance(
61
+ prompt, str) and prompt.strip()
62
+ ents = tokenizer.get_ents_for_context(
63
+ prompt) if use_ents else None
64
+
65
+ # Generated text starts a new sentence if prompt ends with end-of-sentence punctuation
66
+ prompt_rstrip = prompt.rstrip() if isinstance(prompt, str) else ""
67
+ begin_sentence = prompt_rstrip.endswith((".", "!", "?"))
68
+
69
+ decode_kw = dict(
70
+ skip_special_tokens=True,
71
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
72
+ detokenize=True,
73
+ begin_sentence=begin_sentence,
74
+ )
75
+ if use_ents and ents:
76
+ decode_kw.update(
77
+ adapt_ents=True, capitalize_ents=True, ents=[ents])
78
+
79
+ # Decode only the generated token IDs, then append to saved prompt
80
+ prompt_len = 0
81
+ if input_ids is not None:
82
+ try:
83
+ if hasattr(input_ids, "shape") and len(input_ids.shape) >= 2:
84
+ pid = input_ids[seq_idx] if seq_idx < input_ids.shape[0] else input_ids[0]
85
+ elif hasattr(input_ids, "__len__") and seq_idx < len(input_ids):
86
+ pid = input_ids[seq_idx]
87
+ else:
88
+ pid = input_ids
89
+ if hasattr(pid, "cpu"):
90
+ pid = pid.cpu().tolist()
91
+ elif hasattr(pid, "tolist"):
92
+ pid = pid.tolist()
93
+ else:
94
+ pid = list(pid) if pid is not None else []
95
+ prompt_len = len(pid) if pid else 0
96
+ except (IndexError, TypeError):
97
+ pass
98
+
99
+ if prompt_len > 0:
100
+ generated_ids = sequence[prompt_len:]
101
+ decoded_generated = tokenizer.decode(
102
+ generated_ids, **decode_kw) if generated_ids else ""
103
+ if return_type == ReturnType.FULL_TEXT:
104
+ text = prompt.rstrip() + (decoded_generated if decoded_generated else "")
105
+ else:
106
+ text = decoded_generated
107
+ else:
108
+ text = tokenizer.decode(sequence, **decode_kw)
109
+
110
+ record = {"generated_text": text}
111
+ records.append(record)
112
+
113
+ return records
rnnlm_model/tokenization_rnnlm.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """RNNLM tokenizer - wraps SequenceTransformer for HuggingFace compatibility."""
3
+
4
+ import json
5
+ import os
6
+ from typing import List, Optional, Union
7
+
8
+ try:
9
+ from transformers import PreTrainedTokenizer
10
+ except ImportError:
11
+ from transformers.tokenization_utils import PreTrainedTokenizer
12
+
13
+ from .tokenization_utils import (
14
+ replace_ents_in_seq,
15
+ decode_num_seqs,
16
+ get_ents,
17
+ number_ents,
18
+ ent_counts_to_probs,
19
+ )
20
+
21
+
22
+ class RNNLMTokenizer(PreTrainedTokenizer):
23
+ """
24
+ Tokenizer for RNNLM that uses spaCy-based tokenization and a custom lexicon.
25
+ Compatible with the original SequenceTransformer from the narrative-prediction models.
26
+ """
27
+
28
+ model_input_names = ["input_ids", "attention_mask"]
29
+
30
+ def __init__(
31
+ self,
32
+ lexicon: Optional[dict] = None,
33
+ lexicon_lookup: Optional[list] = None,
34
+ unk_token="<UNK>",
35
+ pad_token="<pad>",
36
+ lemmatize=False,
37
+ include_tags=None,
38
+ prepend_start=False,
39
+ generalize_ents=True,
40
+ ent_counts=None,
41
+ filtered_ent_counts=None,
42
+ **kwargs,
43
+ ):
44
+ self._lexicon = lexicon or {}
45
+ self._lexicon_lookup = lexicon_lookup or [None, unk_token]
46
+ self._lemmatize = lemmatize
47
+ self._include_tags = include_tags or []
48
+ self._prepend_start = prepend_start
49
+ self._generalize_ents = generalize_ents
50
+ self._ent_counts = ent_counts or {}
51
+ self._filtered_ent_counts = filtered_ent_counts or {}
52
+ self._encoder = None # Lazy load spaCy
53
+
54
+ super().__init__(
55
+ unk_token=unk_token,
56
+ pad_token=pad_token,
57
+ **kwargs,
58
+ )
59
+
60
+ @property
61
+ def vocab_size(self) -> int:
62
+ """Vocabulary size (excluding padding)."""
63
+ return len(self._lexicon) if self._lexicon else len(self._lexicon_lookup) - 1
64
+
65
+ def get_vocab(self) -> dict:
66
+ """Return token-to-id mapping. Required by PreTrainedTokenizer for save_pretrained."""
67
+ vocab = dict(self._lexicon) if self._lexicon else {}
68
+ # Ensure special tokens are in vocab (pad=0, unk=1)
69
+ if self.pad_token and self.pad_token not in vocab:
70
+ vocab[self.pad_token] = 0
71
+ if self.unk_token and self.unk_token not in vocab:
72
+ vocab[self.unk_token] = 1
73
+ return vocab
74
+
75
+ def _get_encoder(self):
76
+ """Lazy load spaCy encoder."""
77
+ if self._encoder is None:
78
+ try:
79
+ import spacy
80
+ self._encoder = spacy.load("en_core_web_sm")
81
+ except OSError:
82
+ try:
83
+ import spacy
84
+ self._encoder = spacy.load("en_core_web_md")
85
+ except OSError:
86
+ raise RuntimeError(
87
+ "spaCy English model required. Run: python -m spacy download en_core_web_sm"
88
+ )
89
+ return self._encoder
90
+
91
+ def _tokenize(self, text: str) -> List[str]:
92
+ """Tokenize text using spaCy (matching SequenceTransformer.tokenize).
93
+ When generalize_ents is True, extracts entities and replaces them with generic
94
+ ENT_TYPE_N tokens before tokenization."""
95
+ encoder = self._get_encoder()
96
+ if self._generalize_ents:
97
+ # Replace named entities with generic tokens (e.g. ENT_PERSON_0)
98
+ text = replace_ents_in_seq(encoder, text)
99
+ doc = encoder(text)
100
+
101
+ # Match tokenize() from models/transformer.py
102
+ seq = []
103
+ for word in doc:
104
+ wtext = getattr(word, 'text', getattr(
105
+ word, 'string', str(word))).strip()
106
+ if self._include_tags and "_" not in wtext and word.tag_ not in self._include_tags:
107
+ continue
108
+ if self._lemmatize:
109
+ tok = word.lemma_ if not wtext.startswith("ENT_") else wtext
110
+ else:
111
+ tok = wtext.lower() if not wtext.startswith("ENT_") else wtext
112
+ if tok:
113
+ seq.append(tok)
114
+
115
+ if self._prepend_start:
116
+ seq.insert(0, "<START>")
117
+ return seq
118
+
119
+ def _convert_token_to_id(self, token: str) -> int:
120
+ """Convert a single token to ID. Required by PreTrainedTokenizer base class."""
121
+ return self._lexicon.get(token, 1) if self._lexicon else 1 # 1 = UNK
122
+
123
+ def _convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
124
+ """Convert tokens to IDs using lexicon."""
125
+ if isinstance(tokens, str):
126
+ return self._convert_token_to_id(tokens)
127
+ return [self._convert_token_to_id(t) for t in tokens]
128
+
129
+ def _convert_id_to_token(self, index: int) -> str:
130
+ """Convert a single ID to token. Required by PreTrainedTokenizer base class."""
131
+ unk = self.unk_token if hasattr(self, "unk_token") else "<UNK>"
132
+ if 0 <= index < len(self._lexicon_lookup) and self._lexicon_lookup[index]:
133
+ return self._lexicon_lookup[index]
134
+ return unk
135
+
136
+ def _convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
137
+ """Convert IDs to tokens using lexicon_lookup."""
138
+ if isinstance(ids, int):
139
+ return self._convert_id_to_token(ids)
140
+ return [self._convert_id_to_token(i) for i in ids]
141
+
142
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
143
+ """Convert tokens to string (join with space)."""
144
+ return " ".join(tokens)
145
+
146
+ def decode(
147
+ self,
148
+ token_ids,
149
+ begin_sentence=True,
150
+ skip_special_tokens=True,
151
+ clean_up_tokenization_spaces=False,
152
+ ents=None,
153
+ adapt_ents=True,
154
+ detokenize=True,
155
+ capitalize_ents=True,
156
+ n_sents_per_seq=1,
157
+ eos_tokens=None,
158
+ **kwargs,
159
+ ):
160
+ """Decode token IDs to string. When adapt_ents=True and ents is provided,
161
+ replaces generic ENT_* tokens in the output with entities from the input context.
162
+ ents should be a list of dicts (one per sequence) mapping entity name to type
163
+ (e.g. {"John": "PERSON_0"} from number_ents(get_ents(...)))."""
164
+ if isinstance(token_ids[0], (list, tuple)):
165
+ seqs = token_ids
166
+ else:
167
+ seqs = [token_ids]
168
+ # ents must be list of dicts (one per sequence)
169
+ if ents is not None:
170
+ ents = [ents] if isinstance(ents, dict) else (
171
+ ents if isinstance(ents, list) else [])
172
+ encoder = self._get_encoder()
173
+ sub_ent_probs = ent_counts_to_probs(
174
+ self._filtered_ent_counts) if self._filtered_ent_counts else {}
175
+ decoded = decode_num_seqs(
176
+ encoder,
177
+ self._lexicon_lookup,
178
+ self.unk_token,
179
+ seqs,
180
+ n_sents_per_seq=n_sents_per_seq,
181
+ eos_tokens=eos_tokens or [],
182
+ detokenize=detokenize,
183
+ ents=ents or [],
184
+ capitalize_ents=capitalize_ents,
185
+ adapt_ents=adapt_ents,
186
+ sub_ent_probs=sub_ent_probs,
187
+ begin_sentence=begin_sentence,
188
+ )
189
+ result = decoded[0] if len(decoded) == 1 and not isinstance(
190
+ token_ids[0], (list, tuple)) else decoded
191
+ if clean_up_tokenization_spaces and isinstance(result, str):
192
+ result = result.rstrip() # preserve leading space from detokenize_tok_seq
193
+ return result
194
+
195
+ def get_ents_for_context(self, text: str):
196
+ """Extract and number entities from context text for use with decode(..., adapt_ents=True).
197
+ Returns a dict mapping entity name to type (e.g. {"John": "PERSON_0"}) for a single sequence."""
198
+ encoder = self._get_encoder()
199
+ ents, ent_counts = get_ents(encoder, text)
200
+ return number_ents(encoder, ents, ent_counts)
201
+
202
+ def build_inputs_with_special_tokens(
203
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
204
+ ) -> List[int]:
205
+ """No special tokens for RNNLM - return as is."""
206
+ if token_ids_1 is None:
207
+ return token_ids_0
208
+ return token_ids_0 + token_ids_1
209
+
210
+ def get_special_tokens_mask(
211
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
212
+ ) -> List[int]:
213
+ """Return mask of 0s (no special tokens in RNNLM)."""
214
+ return [0] * len(token_ids_0 + (token_ids_1 or []))
215
+
216
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
217
+ """Save lexicon and lexicon_lookup to files."""
218
+ if not os.path.isdir(save_directory):
219
+ os.makedirs(save_directory)
220
+
221
+ prefix = filename_prefix or ""
222
+ vocab_file = os.path.join(save_directory, f"{prefix}vocab.json")
223
+ lookup_file = os.path.join(
224
+ save_directory, f"{prefix}lexicon_lookup.json")
225
+
226
+ with open(vocab_file, "w", encoding="utf-8") as f:
227
+ json.dump(self._lexicon, f, ensure_ascii=False, indent=2)
228
+
229
+ with open(lookup_file, "w", encoding="utf-8") as f:
230
+ json.dump(self._lexicon_lookup, f, ensure_ascii=False, indent=2)
231
+
232
+ return (vocab_file, lookup_file)
233
+
234
+ @classmethod
235
+ def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
236
+ """Load tokenizer - supports both HF format and paths with vocab.json + lexicon_lookup.json."""
237
+ save_directory = pretrained_model_name_or_path
238
+ if os.path.isdir(save_directory):
239
+ vocab_file = os.path.join(save_directory, "vocab.json")
240
+ lookup_file = os.path.join(save_directory, "lexicon_lookup.json")
241
+ if os.path.exists(vocab_file) and os.path.exists(lookup_file):
242
+ with open(vocab_file, "r", encoding="utf-8") as f:
243
+ lexicon = json.load(f)
244
+ with open(lookup_file, "r", encoding="utf-8") as f:
245
+ lexicon_lookup = json.load(f)
246
+ tokenizer_config_file = os.path.join(
247
+ save_directory, "tokenizer_config.json")
248
+ lemmatize = False
249
+ include_tags = []
250
+ prepend_start = False
251
+ generalize_ents = False
252
+ ent_counts = {}
253
+ filtered_ent_counts = {}
254
+ if os.path.exists(tokenizer_config_file):
255
+ with open(tokenizer_config_file, "r", encoding="utf-8") as f:
256
+ tc = json.load(f)
257
+ lemmatize = tc.get("lemmatize", False)
258
+ include_tags = tc.get("include_tags", [])
259
+ prepend_start = tc.get("prepend_start", False)
260
+ generalize_ents = tc.get("generalize_ents", False)
261
+ ent_counts = tc.get("ent_counts", {})
262
+ filtered_ent_counts = tc.get("filtered_ent_counts", {})
263
+ return cls(
264
+ lexicon=lexicon,
265
+ lexicon_lookup=lexicon_lookup,
266
+ lemmatize=lemmatize,
267
+ include_tags=include_tags,
268
+ prepend_start=prepend_start,
269
+ generalize_ents=generalize_ents,
270
+ ent_counts=ent_counts,
271
+ filtered_ent_counts=filtered_ent_counts,
272
+ **kwargs,
273
+ )
274
+ return super().from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
275
+
276
+ def save_pretrained(self, save_directory: str, **kwargs):
277
+ """Save tokenizer - also save tokenizer config with custom attributes."""
278
+ super().save_pretrained(save_directory, **kwargs)
279
+ # Save extra config for our tokenizer
280
+ config_path = os.path.join(save_directory, "tokenizer_config.json")
281
+ if os.path.exists(config_path):
282
+ with open(config_path, "r", encoding="utf-8") as f:
283
+ config = json.load(f)
284
+ else:
285
+ config = {}
286
+ config["lemmatize"] = self._lemmatize
287
+ config["include_tags"] = self._include_tags
288
+ config["prepend_start"] = self._prepend_start
289
+ config["generalize_ents"] = self._generalize_ents
290
+ config["ent_counts"] = self._ent_counts
291
+ config["filtered_ent_counts"] = self._filtered_ent_counts
292
+ with open(config_path, "w", encoding="utf-8") as f:
293
+ json.dump(config, f, indent=2)
rnnlm_model/tokenization_utils.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenization utilities for RNNLM - entity extraction, replacement, and decoding."""
2
+
3
+ import re
4
+ import numpy as np
5
+
6
+ # RNG for adapt_tok_seq_ents when sampling from sub_ent_probs
7
+ _rng = np.random.RandomState(0)
8
+
9
+
10
+ def segment(encoder, seq):
11
+ doc = encoder(seq)
12
+ return [getattr(sent, 'text', getattr(sent, 'string', str(sent))).strip() for sent in doc.sents]
13
+
14
+
15
+ def tokenize(encoder, seq, lowercase=True, recognize_ents=False,
16
+ lemmatize=False, include_tags=[], include_pos=[], prepend_start=False):
17
+ seq = encoder(seq)
18
+ if recognize_ents: # merge named entities into single tokens
19
+ ent_start_idxs = {ent.start: ent for ent in seq.ents
20
+ if getattr(ent, 'text', getattr(ent, 'string', '')).strip()}
21
+ # combine each ent into a single token; this is pretty hard to read, but it works
22
+ seq = [ent_start_idxs[word_idx] if word_idx in ent_start_idxs else word
23
+ for word_idx, word in enumerate(seq)
24
+ if (not word.ent_type_ or word_idx in ent_start_idxs)]
25
+
26
+ def _wtext(w):
27
+ return getattr(w, 'text', getattr(w, 'string', str(w))).strip()
28
+
29
+ # Don't apply POS filtering to phrases (words with underscores)
30
+ if include_tags: # fine-grained POS tags
31
+ seq = [word for word in seq
32
+ if ("_" in _wtext(word) or word.tag_ in include_tags)]
33
+ if include_pos: # coarse-grained POS tags
34
+ seq = [word for word in seq
35
+ if ("_" in _wtext(word) or word.pos_ in include_pos)]
36
+ if lemmatize:
37
+ seq = [word.lemma_ if not _wtext(word).startswith('ENT_')
38
+ else _wtext(word) for word in seq]
39
+ # don't lowercase if token is an entity (entities will be of type span instead of token; or will be prefixed with 'ENT_' if already transformed to types)
40
+ elif lowercase:
41
+ seq = [_wtext(word).lower() if not _wtext(word).startswith('ENT_')
42
+ else _wtext(word) for word in seq]
43
+ else:
44
+ seq = [_wtext(word) for word in seq]
45
+ # some words may be empty strings, so filter
46
+ seq = [word for word in seq if word]
47
+ if prepend_start:
48
+ seq.insert(0, u"<START>")
49
+ return seq
50
+
51
+
52
+ def ent_counts_to_probs(ent_counts):
53
+ """Convert entity counts to probabilities for sampling when adapting entities."""
54
+ return {ent_type: {ent: count * 1.0 / sum(counts.values())
55
+ for ent, count in counts.items()}
56
+ for ent_type, counts in ent_counts.items()}
57
+
58
+
59
+ def get_ents(encoder, seq, include_ent_types=('PERSON', 'NORP', 'ORG', 'GPE')):
60
+ '''return dict of all entities in seq mapped to their entity types, optionally labeled with gender for PERSON entities'''
61
+
62
+ ents = {}
63
+ ent_counts = {}
64
+ for ent in encoder(seq).ents:
65
+ ent_type = ent.label_
66
+ if ent_type in include_ent_types:
67
+ ent = getattr(ent, 'text', getattr(
68
+ ent, 'string', str(ent))).strip()
69
+ if ent: # not sure why, but whitespace can be detected as an ent, so need to check for this
70
+ ents[ent] = [ent_type]
71
+ if ent in ent_counts:
72
+ ent_counts[ent] += 1
73
+ else:
74
+ ent_counts[ent] = 1
75
+ ents[ent] = "_".join(ents[ent])
76
+ return ents, ent_counts
77
+
78
+
79
+ def number_ents(encoder, ents, ent_counts):
80
+ '''return dict of all entities in seq mapped to their entity types,
81
+ with numerical suffixes to distinguish entities of the same type'''
82
+ ent_counts = sorted([(count, ent, ents[ent])
83
+ for ent, count in ent_counts.items()])[::-1]
84
+ ent_type_counts = {}
85
+ num_ents = {}
86
+ for count, ent, ent_type in ent_counts:
87
+ tok_ent = tokenize(encoder, ent, lowercase=False)
88
+ coref_ent = [num_ent for num_ent in num_ents
89
+ if (tokenize(encoder, num_ent, lowercase=False)[0] == tok_ent[0]
90
+ or tokenize(encoder, num_ent, lowercase=False)[-1] == tok_ent[-1])
91
+ # treat ents with same first or last word as co-referring
92
+ and ents[num_ent] == ent_type]
93
+ if coref_ent:
94
+ num_ents[ent] = num_ents[coref_ent[0]]
95
+ else:
96
+ ent_type = ent_type.split("_")
97
+ if ent_type[0] in ent_type_counts:
98
+ ent_type_counts[ent_type[0]] += 1
99
+ else:
100
+ ent_type_counts[ent_type[0]] = 1
101
+ num_ents[ent] = ent_type
102
+ # insert number id after entity type (and before tag, if it exists)
103
+ num_ents[ent].insert(1, str(ent_type_counts[ent_type[0]] - 1))
104
+ num_ents[ent] = "_".join(num_ents[ent])
105
+ return num_ents
106
+
107
+
108
+ def replace_ents_in_seq(encoder, seq):
109
+ '''extract entities from seq and replace them with their entity types'''
110
+ ents, ent_counts = get_ents(encoder, seq)
111
+ ents = number_ents(encoder, ents, ent_counts)
112
+ seq = tokenize(encoder, seq, lowercase=False, recognize_ents=True)
113
+ # word can be Token or Span; get text for lookup
114
+
115
+ def _text(w):
116
+ return (getattr(w, 'text', None) or getattr(w, 'string', None) or str(w)).strip()
117
+ seq = ['ENT_' + ents[_text(word)] if _text(word)
118
+ in ents else _text(word) for word in seq]
119
+ seq = " ".join(seq)
120
+ return seq
121
+
122
+
123
+ def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, n_sents_per_seq=None, eos_tokens=[],
124
+ detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
125
+ sub_ent_probs=None, begin_sentence=True):
126
+ if not seqs:
127
+ return []
128
+ if type(seqs[0]) not in (list, np.ndarray, tuple):
129
+ seqs = [seqs]
130
+ decoded_seqs = []
131
+ # transform numerical seq back into string (seq elements are token IDs)
132
+ for seq_idx, seq in enumerate(seqs):
133
+ # Flatten to list of Python ints (handles 2D tensors from model.generate, e.g. (1, seq_len))
134
+ if hasattr(seq, 'cpu'):
135
+ seq = seq.cpu()
136
+ if hasattr(seq, 'tolist'):
137
+ seq = seq.tolist()
138
+ elif seq and hasattr(seq[0], 'tolist'):
139
+ # list(tensor) gives list of row tensors - convert each to list
140
+ seq = [row.tolist() for row in seq]
141
+ else:
142
+ seq = list(seq)
143
+ # If 2D (batch, seq_len), take each row; else single sequence
144
+ if seq and isinstance(seq[0], list):
145
+ rows = seq
146
+ else:
147
+ rows = [seq]
148
+
149
+ def _to_int(x):
150
+ if isinstance(x, (list, tuple)):
151
+ return [_to_int(v) for v in x]
152
+ return int(x.item()) if hasattr(x, 'item') else int(x)
153
+
154
+ for row_idx, row in enumerate(rows):
155
+ tok_seq = []
156
+ flat_row = _to_int(row) if isinstance(
157
+ row, (list, tuple)) else [_to_int(row)]
158
+ if isinstance(flat_row[0], list):
159
+ flat_row = [v for sub in flat_row for v in (
160
+ sub if isinstance(sub, list) else [sub])]
161
+ for w in flat_row:
162
+ i = w if isinstance(w, int) else int(w)
163
+ tok_seq.append(
164
+ lexicon_lookup[i] if (0 <= i < len(lexicon_lookup) and lexicon_lookup[i])
165
+ else unk_word
166
+ )
167
+ seq = tok_seq
168
+ if adapt_ents: # replace ENT_* with entities from ents, or sub_ent_probs/UNK as fallback
169
+ ent_idx = min(seq_idx + row_idx, len(ents) - 1) if ents else 0
170
+ seq_ents = ents[ent_idx] if ents else {}
171
+ seq = adapt_tok_seq_ents(
172
+ seq, ents=seq_ents, sub_ent_probs=sub_ent_probs or {})
173
+ if detokenize: # apply rules for transforming token list into formatted sequence
174
+ if ents and capitalize_ents:
175
+ ent_idx = min(seq_idx + row_idx,
176
+ len(ents) - 1) if ents else 0
177
+ seq = detokenize_tok_seq(
178
+ encoder, seq, ents=ents[ent_idx], begin_sentence=begin_sentence)
179
+ else:
180
+ seq = detokenize_tok_seq(
181
+ encoder, seq, ents=[], begin_sentence=begin_sentence)
182
+ else:
183
+ # otherwise just join tokens with whitespace between each
184
+ seq = " ".join(seq)
185
+ if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
186
+ seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
187
+ elif n_sents_per_seq:
188
+ seq = filter_gen_seq(encoder, seq, n_sents=n_sents_per_seq)
189
+ decoded_seqs.append(seq)
190
+ return decoded_seqs
191
+
192
+
193
+ def adapt_tok_seq_ents(seq, ents={}, sub_ent_probs={}):
194
+
195
+ # reverse ents so that types map to names
196
+ ents = {ent_type: ent for ent, ent_type in ents.items()}
197
+ adapted_seq_ents = {"_".join(token.split("_")[1:]): None
198
+ for token in seq if token.startswith('ENT_')}
199
+
200
+ if not adapted_seq_ents:
201
+ return seq
202
+
203
+ for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
204
+ if seq_ent_type in ents:
205
+ adapted_seq_ents[seq_ent_type] = ents[seq_ent_type]
206
+ del ents[seq_ent_type]
207
+
208
+ if ents:
209
+ for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
210
+ for ent_type, ent in ents.items():
211
+ if seq_ent_type.split("_")[0] in ent_type.split("_")[0]:
212
+ adapted_seq_ents[seq_ent_type] = ents[ent_type]
213
+ del ents[ent_type]
214
+ break
215
+
216
+ for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
217
+ if seq_ent_type.split("_")[0] in sub_ent_probs:
218
+ sub_ents, sub_probs = zip(
219
+ *sub_ent_probs[seq_ent_type.split("_")[0]].items())
220
+ rand_ent_idx = _rng.choice(len(sub_ents), p=np.array(sub_probs))
221
+ adapted_seq_ents[seq_ent_type] = sub_ents[rand_ent_idx]
222
+
223
+ # Use ANY available entity (any type) when no type-specific match found
224
+ all_entities = list(ents.values())
225
+ for base_type, type_ents in sub_ent_probs.items():
226
+ all_entities.extend(type_ents.keys())
227
+ for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
228
+ if all_entities:
229
+ adapted_seq_ents[seq_ent_type] = _rng.choice(all_entities)
230
+ else:
231
+ adapted_seq_ents[seq_ent_type] = "ENT_" + seq_ent_type
232
+
233
+ seq = [adapted_seq_ents["_".join(token.split("_")[1:])] if "_".join(
234
+ token.split("_")[1:]) in adapted_seq_ents else token for token in seq]
235
+ return seq
236
+
237
+
238
+ def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True):
239
+ '''use simple rules for transforming list of tokens back into string
240
+ ents is optional list of words (named entities) that should be capitalized'''
241
+ seq = [sent.split() for sent
242
+ in segment(encoder, " ".join(seq))] # split sequence into sentences
243
+ detok_seq = []
244
+ for sent_idx, sent in enumerate(seq):
245
+
246
+ assert (type(sent) in (list, tuple))
247
+
248
+ if ents:
249
+ token_idx = 0
250
+ # capitalize all tokens that appear in cap_ents
251
+ while token_idx < len(sent):
252
+ for ent in ents:
253
+ ent = ent.split()
254
+ if sent[token_idx:token_idx + len(ent)] == [token.lower() for token in ent]:
255
+ # import pdb;pdb.set_trace()
256
+ sent[token_idx:token_idx + len(ent)] = list(ent)
257
+ token_idx += len(ent) - 1
258
+ break
259
+ token_idx += 1
260
+
261
+ detok_sent = " ".join(sent)
262
+
263
+ detok_sent = re.sub("\'", "'", detok_sent)
264
+
265
+ # capitalize first-person "I" pronoun
266
+ detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent)
267
+
268
+ # rules for contractions
269
+ detok_sent = re.sub(" n\'\s*t ", "n\'t ", detok_sent)
270
+ detok_sent = re.sub(" \'\s*d ", "\'d ", detok_sent)
271
+ detok_sent = re.sub(" \'\s*s ", "\'s ", detok_sent)
272
+ detok_sent = re.sub(" \'\s*ve ", "\'ve ", detok_sent)
273
+ detok_sent = re.sub(" \'\s*ll ", "\'ll ", detok_sent)
274
+ detok_sent = re.sub(" \'\s*m ", "\'m ", detok_sent)
275
+ detok_sent = re.sub(" \'\s*re ", "\'re ", detok_sent)
276
+
277
+ # rules for formatting punctuation
278
+ detok_sent = re.sub(" \.", ".", detok_sent)
279
+ detok_sent = re.sub(" \!", "!", detok_sent)
280
+ detok_sent = re.sub(" \?", "?", detok_sent)
281
+ detok_sent = re.sub(" ,", ",", detok_sent)
282
+ detok_sent = re.sub(" \- ", "-", detok_sent)
283
+ detok_sent = re.sub(" :", ":", detok_sent)
284
+ detok_sent = re.sub(" ;", ";", detok_sent)
285
+ detok_sent = re.sub("\$ ", "$", detok_sent)
286
+ detok_sent = re.sub("\' \'", "\'\'", detok_sent)
287
+ detok_sent = re.sub("\` \`", "\`\`", detok_sent)
288
+
289
+ # replace repeated single quotes with double quotation mark.
290
+ detok_sent = re.sub("\'\'", "\"", detok_sent)
291
+ detok_sent = re.sub("\`\`", "\"", detok_sent)
292
+
293
+ # filter repetitive characters
294
+ detok_sent = re.sub("([\"\']\s*){2,}", "\" ", detok_sent)
295
+
296
+ # map each opening puncutation mark to closing mark
297
+ punc_pairs = {"\'": "\'", "\'": "\'",
298
+ "`": "\'", "\"": "\"", "(": ")", "[": "]"}
299
+ open_punc = []
300
+ char_idx = 0
301
+ while char_idx < len(detok_sent): # check for quotes and parenthesis
302
+ char = detok_sent[char_idx]
303
+ # end quote/parenthesis
304
+ if open_punc and char == punc_pairs[open_punc[-1]]:
305
+ if char_idx > 0 and detok_sent[char_idx - 1] == " ":
306
+ detok_sent = detok_sent[:char_idx -
307
+ 1] + detok_sent[char_idx:]
308
+ open_punc.pop()
309
+ elif char in punc_pairs:
310
+ if char_idx < len(detok_sent) - 1 and detok_sent[char_idx + 1] == " ":
311
+ open_punc.append(char)
312
+ detok_sent = detok_sent[:char_idx +
313
+ 1] + detok_sent[char_idx + 2:]
314
+ if char_idx < len(detok_sent) and detok_sent[char_idx] == char:
315
+ char_idx += 1
316
+
317
+ detok_sent = detok_sent.strip()
318
+ # capitalize first alphabetic character if begin_sentence is True
319
+ if begin_sentence:
320
+ for char_idx, char in enumerate(detok_sent):
321
+ if char.isalpha():
322
+ detok_sent = detok_sent[:char_idx +
323
+ 1].upper() + detok_sent[char_idx + 1:]
324
+ break
325
+ detok_seq.append(detok_sent)
326
+
327
+ detok_seq = " ".join(detok_seq)
328
+ contraction_patterns = ("'s", "'re", "'ve", "'d", "'ll", "'m", "n't")
329
+ punctuation_patterns = (".", "!", "?", ",", "-", ":", ";", ")", "]")
330
+ # Only prepend space if detok_seq doesn't start with these
331
+ starts_with_pattern = detok_seq.startswith(
332
+ contraction_patterns) or detok_seq.startswith(punctuation_patterns)
333
+ if not starts_with_pattern and detok_seq:
334
+ detok_seq = " " + detok_seq
335
+ return detok_seq
336
+
337
+
338
+ def filter_gen_seq(encoder, seq, n_sents=1, eos_tokens=[]):
339
+ '''given a generated sequence, filter so that only the first n_sents are included in final generated sequence'''
340
+ leading_space = seq.startswith(" ") if seq else False
341
+ if eos_tokens: # if end-of-sentence tokens given, cut off sequence at first occurrence of one of these tokens; otherwise use segmenter to infer sentence boundaries
342
+ doc = encoder(seq)
343
+ for idx, word in enumerate(doc):
344
+ wtext = getattr(word, 'text', getattr(
345
+ word, 'string', str(word))).strip()
346
+ if wtext in eos_tokens:
347
+ span = doc[:idx + 1]
348
+ seq = getattr(span, 'text', getattr(
349
+ span, 'string', str(span))).strip()
350
+ break
351
+ else:
352
+ seq = getattr(doc, 'text', getattr(doc, 'string', str(doc)))
353
+ else:
354
+ seq = " ".join(segment(encoder, seq)[:n_sents])
355
+ if leading_space and seq:
356
+ seq = " " + seq.lstrip()
357
+ return seq
special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "<pad>",
3
+ "unk_token": "<UNK>"
4
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d20caa40c3cb68b113ead456587ccc9308b0e4743b61aa218c5fbf8b3d88e52b
3
+ size 14303042
vocab.json ADDED
The diff for this file is too large to render. See raw diff