roemmele commited on
Commit
59a876a
·
verified ·
1 Parent(s): edbfc07

Upload folder using huggingface_hub

Browse files
__pycache__/handler.cpython-38.pyc ADDED
Binary file (2.88 kB). View file
 
rnnlm_model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (576 Bytes). View file
 
rnnlm_model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (530 Bytes). View file
 
rnnlm_model/__pycache__/configuration_rnnlm.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
rnnlm_model/__pycache__/configuration_rnnlm.cpython-312.pyc ADDED
Binary file (1.87 kB). View file
 
rnnlm_model/__pycache__/modeling_rnnlm.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
rnnlm_model/__pycache__/modeling_rnnlm.cpython-312.pyc ADDED
Binary file (16.6 kB). View file
 
rnnlm_model/__pycache__/modeling_rnnlm.cpython-38.pyc CHANGED
Binary files a/rnnlm_model/__pycache__/modeling_rnnlm.cpython-38.pyc and b/rnnlm_model/__pycache__/modeling_rnnlm.cpython-38.pyc differ
 
rnnlm_model/__pycache__/pipeline_rnnlm.cpython-311.pyc ADDED
Binary file (6.17 kB). View file
 
rnnlm_model/__pycache__/pipeline_rnnlm.cpython-312.pyc ADDED
Binary file (5.38 kB). View file
 
rnnlm_model/__pycache__/pipeline_rnnlm.cpython-38.pyc CHANGED
Binary files a/rnnlm_model/__pycache__/pipeline_rnnlm.cpython-38.pyc and b/rnnlm_model/__pycache__/pipeline_rnnlm.cpython-38.pyc differ
 
rnnlm_model/__pycache__/tokenization_rnnlm.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
rnnlm_model/__pycache__/tokenization_rnnlm.cpython-312.pyc ADDED
Binary file (15.3 kB). View file
 
rnnlm_model/__pycache__/tokenization_utils.cpython-311.pyc ADDED
Binary file (24.6 kB). View file
 
rnnlm_model/__pycache__/tokenization_utils.cpython-312.pyc ADDED
Binary file (18.1 kB). View file
 
rnnlm_model/__pycache__/tokenization_utils.cpython-38.pyc CHANGED
Binary files a/rnnlm_model/__pycache__/tokenization_utils.cpython-38.pyc and b/rnnlm_model/__pycache__/tokenization_utils.cpython-38.pyc differ
 
rnnlm_model/modeling_rnnlm.py CHANGED
@@ -6,14 +6,18 @@ import torch.nn as nn
6
  try:
7
  from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.generation import LogitsProcessor, LogitsProcessorList
10
  except ImportError:
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
13
  try:
14
- from transformers.generation import LogitsProcessor, LogitsProcessorList
15
  except ImportError:
16
- from transformers.generation_utils import LogitsProcessor, LogitsProcessorList
 
 
 
 
17
 
18
  from .configuration_rnnlm import RNNLMConfig
19
 
@@ -113,6 +117,8 @@ class RNNLMForCausalLM(PreTrainedModel):
113
  def __init__(self, config: RNNLMConfig, **kwargs):
114
  super().__init__(config)
115
  self.config = config
 
 
116
  self.vocab_size = config.vocab_size
117
  self.embedding_dim = config.embedding_dim
118
  self.hidden_size = config.hidden_size
@@ -299,4 +305,9 @@ class RNNLMForCausalLM(PreTrainedModel):
299
  logits_processor = LogitsProcessorList(logits_processor)
300
  logits_processor.insert(0, processor)
301
  kwargs["logits_processor"] = logits_processor
 
 
 
 
 
302
  return super().generate(inputs, **kwargs)
 
6
  try:
7
  from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList
10
  except ImportError:
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
13
  try:
14
+ from transformers.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList
15
  except ImportError:
16
+ try:
17
+ from transformers.generation_utils import GenerationMixin, LogitsProcessor, LogitsProcessorList
18
+ except ImportError:
19
+ from transformers.generation_utils import LogitsProcessor, LogitsProcessorList
20
+ GenerationMixin = None
21
 
22
  from .configuration_rnnlm import RNNLMConfig
23
 
 
117
  def __init__(self, config: RNNLMConfig, **kwargs):
118
  super().__init__(config)
119
  self.config = config
120
+ # RNNLM has no tied weights; transformers expects this attribute (dict) for .update()
121
+ self.all_tied_weights_keys = {}
122
  self.vocab_size = config.vocab_size
123
  self.embedding_dim = config.embedding_dim
124
  self.hidden_size = config.hidden_size
 
305
  logits_processor = LogitsProcessorList(logits_processor)
306
  logits_processor.insert(0, processor)
307
  kwargs["logits_processor"] = logits_processor
308
+ # RNNLM uses tuple cache (hidden states), not DynamicCache; avoid cache to prevent "not subscriptable" error
309
+ kwargs.setdefault("use_cache", False)
310
+ # Call GenerationMixin.generate explicitly (super() can fail in some loading contexts)
311
+ if GenerationMixin is not None:
312
+ return GenerationMixin.generate(self, inputs, **kwargs)
313
  return super().generate(inputs, **kwargs)
rnnlm_model/pipeline_rnnlm.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  from transformers.pipelines.text_generation import TextGenerationPipeline
5
  from transformers.pipelines.text_generation import ReturnType
 
6
 
7
 
8
  class RNNLMTextGenerationPipeline(TextGenerationPipeline):
@@ -14,6 +15,23 @@ class RNNLMTextGenerationPipeline(TextGenerationPipeline):
14
  When the tokenizer has generalize_ents=True, entities are extracted from the
15
  prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
16
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def postprocess(
19
  self,
 
3
 
4
  from transformers.pipelines.text_generation import TextGenerationPipeline
5
  from transformers.pipelines.text_generation import ReturnType
6
+ from transformers import GenerationConfig
7
 
8
 
9
  class RNNLMTextGenerationPipeline(TextGenerationPipeline):
 
15
  When the tokenizer has generalize_ents=True, entities are extracted from the
16
  prompt and used to replace ENT_PERSON_0, ENT_GPE_0, etc. in the generated output.
17
  """
18
+ assistant_model = None # Class default for transformers compatibility (assisted decoding)
19
+ assistant_tokenizer = None
20
+ prefix = None # For XLNet/TransfoXL; RNNLM doesn't use it
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ # Newer transformers expect these; RNNLM doesn't use them
25
+ self.assistant_model = None
26
+ self.assistant_tokenizer = None
27
+ self.prefix = getattr(self, "prefix", None)
28
+ if not hasattr(self, "generation_config") or self.generation_config is None:
29
+ self.generation_config = GenerationConfig(
30
+ pad_token_id=getattr(self.tokenizer, "pad_token_id", 0),
31
+ max_new_tokens=256,
32
+ do_sample=True,
33
+ temperature=0.7,
34
+ )
35
 
36
  def postprocess(
37
  self,
rnnlm_model/tokenization_utils.py CHANGED
@@ -265,14 +265,14 @@ def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True):
265
  # capitalize first-person "I" pronoun
266
  detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent)
267
 
268
- # rules for contractions
269
- detok_sent = re.sub(" n\'\s*t ", "n\'t ", detok_sent)
270
- detok_sent = re.sub(" \'\s*d ", "\'d ", detok_sent)
271
- detok_sent = re.sub(" \'\s*s ", "\'s ", detok_sent)
272
- detok_sent = re.sub(" \'\s*ve ", "\'ve ", detok_sent)
273
- detok_sent = re.sub(" \'\s*ll ", "\'ll ", detok_sent)
274
- detok_sent = re.sub(" \'\s*m ", "\'m ", detok_sent)
275
- detok_sent = re.sub(" \'\s*re ", "\'re ", detok_sent)
276
 
277
  # rules for formatting punctuation
278
  detok_sent = re.sub(" \.", ".", detok_sent)
@@ -291,7 +291,7 @@ def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True):
291
  detok_sent = re.sub("\`\`", "\"", detok_sent)
292
 
293
  # filter repetitive characters
294
- detok_sent = re.sub("([\"\']\s*){2,}", "\" ", detok_sent)
295
 
296
  # map each opening puncutation mark to closing mark
297
  punc_pairs = {"\'": "\'", "\'": "\'",
 
265
  # capitalize first-person "I" pronoun
266
  detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent)
267
 
268
+ # rules for contractions (pattern: raw string for \s; replacement: no backslash)
269
+ detok_sent = re.sub(r" n'\s*t ", "n't ", detok_sent)
270
+ detok_sent = re.sub(r" '\s*d ", "'d ", detok_sent)
271
+ detok_sent = re.sub(r" '\s*s ", "'s ", detok_sent)
272
+ detok_sent = re.sub(r" '\s*ve ", "'ve ", detok_sent)
273
+ detok_sent = re.sub(r" '\s*ll ", "'ll ", detok_sent)
274
+ detok_sent = re.sub(r" '\s*m ", "'m ", detok_sent)
275
+ detok_sent = re.sub(r" '\s*re ", "'re ", detok_sent)
276
 
277
  # rules for formatting punctuation
278
  detok_sent = re.sub(" \.", ".", detok_sent)
 
291
  detok_sent = re.sub("\`\`", "\"", detok_sent)
292
 
293
  # filter repetitive characters
294
+ detok_sent = re.sub(r'(["\']\s*){2,}', '" ', detok_sent)
295
 
296
  # map each opening puncutation mark to closing mark
297
  punc_pairs = {"\'": "\'", "\'": "\'",