Commit
·
ea3bbbe
1
Parent(s):
15a8eb4
Upload BaiChuanForCausalLM
Browse files- modeling_baichuan.py +33 -0
modeling_baichuan.py
CHANGED
|
@@ -23,6 +23,8 @@ from transformers.activations import ACT2FN
|
|
| 23 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
|
| 24 |
SequenceClassifierOutputWithPast
|
| 25 |
from transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings
|
|
|
|
|
|
|
| 26 |
|
| 27 |
import math
|
| 28 |
from typing import List, Optional, Tuple, Union
|
|
@@ -35,6 +37,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
| 35 |
|
| 36 |
logger = logging.get_logger(__name__)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 39 |
def _make_causal_mask(
|
| 40 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
@@ -669,3 +678,27 @@ class BaiChuanForCausalLM(PreTrainedModel):
|
|
| 669 |
for layer_past in past_key_values:
|
| 670 |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 671 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
|
| 24 |
SequenceClassifierOutputWithPast
|
| 25 |
from transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings
|
| 26 |
+
from transformers.generation.logits_process import LogitsProcessor
|
| 27 |
+
from transformers.generation.utils import LogitsProcessorList
|
| 28 |
|
| 29 |
import math
|
| 30 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 37 |
|
| 38 |
logger = logging.get_logger(__name__)
|
| 39 |
|
| 40 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
| 41 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 42 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
| 43 |
+
scores.zero_()
|
| 44 |
+
scores[..., 5] = 5e4
|
| 45 |
+
return scores
|
| 46 |
+
|
| 47 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 48 |
def _make_causal_mask(
|
| 49 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
|
| 678 |
for layer_past in past_key_values:
|
| 679 |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 680 |
return reordered_past
|
| 681 |
+
|
| 682 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 683 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
| 684 |
+
if history is None:
|
| 685 |
+
history = []
|
| 686 |
+
if logits_processor is None:
|
| 687 |
+
logits_processor = LogitsProcessorList()
|
| 688 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 689 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 690 |
+
"temperature": temperature, "logits_processor": logits_processor, "use_cache": True, **kwargs}
|
| 691 |
+
prompt_template = '###Human: {instruction}###Assistant: {output}'
|
| 692 |
+
if not history:
|
| 693 |
+
prompt = prompt_template.format(instruction = query, output='')
|
| 694 |
+
else:
|
| 695 |
+
prompt = ""
|
| 696 |
+
for i, (old_query, response) in enumerate(history):
|
| 697 |
+
prompt += prompt_template.format(instruction = old_query, output=response)
|
| 698 |
+
prompt += prompt_template.format(instruction = query, output='')
|
| 699 |
+
inputs = tokenizer(prompt, return_tensors='pt')
|
| 700 |
+
inputs = inputs.to(self.device)
|
| 701 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
| 702 |
+
response = tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
| 703 |
+
history = history + [(query, response)]
|
| 704 |
+
return response, history
|