wygbb commited on
Commit
03f811b
·
verified ·
1 Parent(s): c29cdbc

Delete modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +0 -607
modeling_baichuan.py DELETED
@@ -1,607 +0,0 @@
1
- # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
-
3
- import math
4
- from typing import List, Optional, Tuple, Union
5
-
6
- import torch
7
- import torch.utils.checkpoint
8
- from torch.nn import CrossEntropyLoss
9
- from transformers import PreTrainedModel
10
- from transformers.activations import ACT2FN
11
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
12
- from transformers.utils import logging
13
- from transformers.generation.utils import GenerationConfig
14
-
15
- from .configuration_baichuan import BaichuanConfig
16
-
17
- logger = logging.get_logger(__name__)
18
-
19
-
20
- def _get_interleave(n):
21
- def _get_interleave_power_of_2(n):
22
- start = (2 ** (-2 ** -(math.log2(n) - 3)))
23
- ratio = start
24
- return [start * ratio ** i for i in range(n)]
25
-
26
- if math.log2(n).is_integer():
27
- return _get_interleave_power_of_2(n)
28
- else:
29
- closest_power_of_2 = 2 ** math.floor(math.log2(n))
30
- return _get_interleave_power_of_2(closest_power_of_2) + \
31
- _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
32
-
33
- def _fill_with_neg_inf(t):
34
- """FP16-compatible function that fills a tensor with -inf."""
35
- return t.float().fill_(float("-inf")).type_as(t)
36
-
37
- def _gen_alibi_mask(n_head, max_pos):
38
- """used in inference only"""
39
- slopes = torch.Tensor(_get_interleave(n_head))
40
- alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
41
- n_head, -1, -1)
42
- alibi = alibi.view(n_head, 1, max_pos)
43
- alibi_mask = torch.triu(
44
- _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
45
- )
46
- alibi_mask = alibi_mask.unsqueeze(0) + alibi
47
- return alibi_mask
48
-
49
- def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
50
- """used in training only"""
51
- dim = tensor.size(1)
52
- _future_mask = torch.triu(
53
- _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
54
- )
55
- _future_mask = _future_mask.unsqueeze(0) + alibi
56
- _future_mask = _future_mask.to(tensor)
57
- return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
58
-
59
-
60
- class RMSNorm(torch.nn.Module):
61
- def __init__(self, hidden_size, epsilon=1e-6):
62
- super().__init__()
63
- self.weight = torch.nn.Parameter(torch.empty(hidden_size))
64
- self.epsilon = epsilon
65
-
66
- def forward(self, hidden_states):
67
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
68
- hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
69
-
70
- # convert into half-precision
71
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
72
- hidden_states = hidden_states.to(self.weight.dtype)
73
-
74
- return self.weight * hidden_states
75
-
76
-
77
- class MLP(torch.nn.Module):
78
- def __init__(
79
- self,
80
- hidden_size: int,
81
- intermediate_size: int,
82
- hidden_act: str,
83
- ):
84
- super().__init__()
85
- self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
86
- self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
87
- self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
88
- self.act_fn = ACT2FN[hidden_act]
89
-
90
- def forward(self, x):
91
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
92
-
93
-
94
- class BaichuanAttention(torch.nn.Module):
95
- def __init__(self, config: BaichuanConfig):
96
- super().__init__()
97
- self.config = config
98
- self.hidden_size = config.hidden_size
99
- self.num_heads = config.num_attention_heads
100
- self.head_dim = self.hidden_size // self.num_heads
101
- self.max_position_embeddings = config.model_max_length
102
-
103
- if (self.head_dim * self.num_heads) != self.hidden_size:
104
- raise ValueError(
105
- f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
106
- )
107
- self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
108
- self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
109
-
110
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
111
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
112
-
113
- def forward(
114
- self,
115
- hidden_states: torch.Tensor,
116
- attention_mask: Optional[torch.Tensor] = None,
117
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
118
- output_attentions: bool = False,
119
- use_cache: bool = False,
120
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
121
-
122
- bsz, q_len, _ = hidden_states.size()
123
-
124
- proj = self.W_pack(hidden_states)
125
- proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
126
- query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
127
- key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
128
- value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
129
-
130
- kv_seq_len = key_states.shape[-2]
131
- if past_key_value is not None:
132
- kv_seq_len += past_key_value[0].shape[-2]
133
-
134
- if past_key_value is not None:
135
- # reuse k, v, self_attention
136
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
137
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
138
-
139
- past_key_value = (key_states, value_states) if use_cache else None
140
-
141
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
142
-
143
- if attention_mask is not None:
144
- if q_len == 1: # inference with cache
145
- if len(attention_mask.size()) == 4:
146
- attention_mask = attention_mask[:, :, -1:, :]
147
- else:
148
- attention_mask = attention_mask[:, -1:, :]
149
- attn_weights = attn_weights + attention_mask
150
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
151
-
152
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
153
-
154
- attn_output = torch.matmul(attn_weights, value_states)
155
-
156
- attn_output = attn_output.transpose(1, 2)
157
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
158
- attn_output = self.o_proj(attn_output)
159
-
160
- if not output_attentions:
161
- attn_weights = None
162
-
163
- return attn_output, attn_weights, past_key_value
164
-
165
-
166
- class BaichuanLayer(torch.nn.Module):
167
- def __init__(self, config: BaichuanConfig):
168
- super().__init__()
169
- self.hidden_size = config.hidden_size
170
- self.self_attn = BaichuanAttention(config=config)
171
- self.mlp = MLP(
172
- hidden_size=self.hidden_size,
173
- intermediate_size=config.intermediate_size,
174
- hidden_act=config.hidden_act,
175
- )
176
- self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
177
- self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
178
-
179
- def forward(
180
- self,
181
- hidden_states: torch.Tensor,
182
- attention_mask: Optional[torch.Tensor] = None,
183
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
184
- output_attentions: Optional[bool] = False,
185
- use_cache: Optional[bool] = False,
186
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
187
-
188
- residual = hidden_states
189
-
190
- hidden_states = self.input_layernorm(hidden_states)
191
-
192
- # Self Attention
193
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
194
- hidden_states=hidden_states,
195
- attention_mask=attention_mask,
196
- past_key_value=past_key_value,
197
- output_attentions=output_attentions,
198
- use_cache=use_cache,
199
- )
200
- hidden_states = residual + hidden_states
201
-
202
- # Fully Connected
203
- residual = hidden_states
204
- hidden_states = self.post_attention_layernorm(hidden_states)
205
- hidden_states = self.mlp(hidden_states)
206
- hidden_states = residual + hidden_states
207
-
208
- outputs = (hidden_states,)
209
-
210
- if use_cache:
211
- outputs += (present_key_value,)
212
-
213
- return outputs
214
-
215
-
216
- class BaichuanPreTrainedModel(PreTrainedModel):
217
- config_class = BaichuanConfig
218
- base_model_prefix = "model"
219
- supports_gradient_checkpointing = True
220
- _no_split_modules = ["BaichuanLayer"]
221
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
222
-
223
- def _init_weights(self, module):
224
- std = self.config.initializer_range
225
- if isinstance(module, torch.nn.Linear):
226
- module.weight.data.normal_(mean=0.0, std=std)
227
- if module.bias is not None:
228
- module.bias.data.zero_()
229
- elif isinstance(module, torch.nn.Embedding):
230
- module.weight.data.normal_(mean=0.0, std=std)
231
- if module.padding_idx is not None:
232
- module.weight.data[module.padding_idx].zero_()
233
-
234
- def _set_gradient_checkpointing(self, module, value=False):
235
- if isinstance(module, BaichuanModel):
236
- module.gradient_checkpointing = value
237
-
238
-
239
- class BaichuanModel(BaichuanPreTrainedModel):
240
- def __init__(self, config: BaichuanConfig):
241
- super().__init__(config)
242
- self.padding_idx = config.pad_token_id
243
- self.vocab_size = config.vocab_size
244
- self.n_head = config.num_attention_heads
245
- self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
246
- self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
247
- self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
248
-
249
- self.gradient_checkpointing = config.gradient_checkpointing
250
- self.post_init()
251
- self.max_cache_pos = config.model_max_length
252
- self.first_run = True
253
- self.alibi_mask = None
254
-
255
- def get_input_embeddings(self):
256
- return self.embed_tokens
257
-
258
- def set_input_embeddings(self, value):
259
- self.embed_tokens = value
260
-
261
- def get_alibi_mask(self, tensor, seq_length_with_past):
262
- if self.training:
263
- slopes = torch.Tensor(_get_interleave(self.n_head))
264
- alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
265
- self.n_head,
266
- -1, -1)
267
- alibi = alibi.view(self.n_head, 1, seq_length_with_past)
268
- mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
269
- else:
270
- if self.first_run:
271
- self.first_run = False
272
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
273
- if seq_length_with_past > self.max_cache_pos:
274
- self.max_cache_pos = seq_length_with_past
275
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
276
- mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
277
- return mask
278
-
279
- def forward(
280
- self,
281
- input_ids: torch.LongTensor = None,
282
- attention_mask: Optional[torch.Tensor] = None,
283
- past_key_values: Optional[List[torch.FloatTensor]] = None,
284
- inputs_embeds: Optional[torch.FloatTensor] = None,
285
- use_cache: Optional[bool] = False,
286
- output_attentions: Optional[bool] = False,
287
- output_hidden_states: Optional[bool] = False,
288
- return_dict: Optional[bool] = True,
289
- ) -> Union[Tuple, BaseModelOutputWithPast]:
290
-
291
- if input_ids is not None and inputs_embeds is not None:
292
- raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
293
- elif input_ids is not None:
294
- batch_size, seq_length = input_ids.shape
295
- elif inputs_embeds is not None:
296
- batch_size, seq_length, _ = inputs_embeds.shape
297
- else:
298
- raise ValueError("You need to provide input_ids or inputs_embeds")
299
-
300
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
301
-
302
- seq_length_with_past = seq_length
303
-
304
- if past_key_values is not None:
305
- past_key_values_length = past_key_values[0][0].shape[2]
306
- seq_length_with_past = seq_length_with_past + past_key_values_length
307
-
308
- if inputs_embeds is None:
309
- inputs_embeds = self.embed_tokens(input_ids)
310
-
311
- if self.training:
312
- if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
313
- self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
314
- alibi_mask = self.alibi_mask
315
- else:
316
- alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
317
-
318
- if attention_mask is not None:
319
- if len(attention_mask.shape) == 2:
320
- expanded_mask = attention_mask.to(alibi_mask.dtype)
321
- expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
322
- ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
323
- else:
324
- expanded_mask = attention_mask
325
- bsz = inputs_embeds.size(0)
326
- src_len, tgt_len = alibi_mask.size()[-2:]
327
- expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
328
- inverted_mask = 1.0 - expanded_mask
329
- inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min)
330
- attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
331
- else:
332
- attention_mask = alibi_mask
333
-
334
- hidden_states = inputs_embeds
335
-
336
- if self.gradient_checkpointing and self.training:
337
- if use_cache:
338
- logger.warning_once(
339
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
340
- )
341
- use_cache = False
342
-
343
- # decoder layers
344
- all_hidden_states = () if output_hidden_states else None
345
- all_self_attns = () if output_attentions else None
346
- next_decoder_cache = () if use_cache else None
347
-
348
- for idx, decoder_layer in enumerate(self.layers):
349
- if output_hidden_states:
350
- all_hidden_states += (hidden_states,)
351
-
352
- past_key_value = past_key_values[idx] if past_key_values is not None else None
353
-
354
- if self.gradient_checkpointing and self.training:
355
-
356
- def create_custom_forward(module):
357
- def custom_forward(*inputs):
358
- # None for past_key_value
359
- return module(*inputs, output_attentions, None)
360
-
361
- return custom_forward
362
-
363
- layer_outputs = torch.utils.checkpoint.checkpoint(
364
- create_custom_forward(decoder_layer),
365
- hidden_states,
366
- attention_mask,
367
- None,
368
- )
369
- else:
370
- layer_outputs = decoder_layer(
371
- hidden_states,
372
- attention_mask=attention_mask,
373
- past_key_value=past_key_value,
374
- output_attentions=output_attentions,
375
- use_cache=use_cache,
376
- )
377
-
378
- hidden_states = layer_outputs[0]
379
-
380
- if use_cache:
381
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
382
-
383
- if output_attentions:
384
- all_self_attns += (layer_outputs[1],)
385
-
386
- hidden_states = self.norm(hidden_states)
387
-
388
- # add hidden states from the last decoder layer
389
- if output_hidden_states:
390
- all_hidden_states += (hidden_states,)
391
-
392
- next_cache = next_decoder_cache if use_cache else None
393
- if not return_dict:
394
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
395
- return BaseModelOutputWithPast(
396
- last_hidden_state=hidden_states,
397
- past_key_values=next_cache,
398
- hidden_states=all_hidden_states,
399
- attentions=all_self_attns,
400
- )
401
-
402
-
403
- class BaichuanForCausalLM(BaichuanPreTrainedModel):
404
- def __init__(self, config):
405
- super().__init__(config)
406
- self.model = BaichuanModel(config)
407
- self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
408
-
409
- # Initialize weights and apply final processing
410
- self.post_init()
411
-
412
- def get_input_embeddings(self):
413
- return self.model.embed_tokens
414
-
415
- def set_input_embeddings(self, value):
416
- self.model.embed_tokens = value
417
-
418
- def get_output_embeddings(self):
419
- return self.lm_head
420
-
421
- def set_output_embeddings(self, new_embeddings):
422
- self.lm_head = new_embeddings
423
-
424
- def set_decoder(self, decoder):
425
- self.model = decoder
426
-
427
- def get_decoder(self):
428
- return self.model
429
-
430
- def forward(
431
- self,
432
- input_ids: torch.LongTensor = None,
433
- attention_mask: Optional[torch.Tensor] = None,
434
- past_key_values: Optional[List[torch.FloatTensor]] = None,
435
- inputs_embeds: Optional[torch.FloatTensor] = None,
436
- labels: Optional[torch.LongTensor] = None,
437
- use_cache: Optional[bool] = None,
438
- output_attentions: Optional[bool] = False,
439
- output_hidden_states: Optional[bool] = False,
440
- return_dict: Optional[bool] = True,
441
- **kwargs
442
- ) -> Union[Tuple, CausalLMOutputWithPast]:
443
-
444
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
445
-
446
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
447
- outputs = self.model(
448
- input_ids=input_ids,
449
- attention_mask=attention_mask,
450
- past_key_values=past_key_values,
451
- inputs_embeds=inputs_embeds,
452
- use_cache=use_cache,
453
- output_attentions=output_attentions,
454
- output_hidden_states=output_hidden_states,
455
- return_dict=return_dict,
456
- )
457
-
458
- hidden_states = outputs[0]
459
- logits = self.lm_head(hidden_states)
460
-
461
- loss = None
462
- if labels is not None:
463
- # Shift so that tokens < n predict n
464
- shift_logits = logits[..., :-1, :].contiguous()
465
- shift_labels = labels[..., 1:].contiguous()
466
- # Flatten the tokens
467
- loss_fct = CrossEntropyLoss()
468
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
469
- shift_labels = shift_labels.view(-1)
470
- # Enable model parallelism
471
- shift_labels = shift_labels.to(shift_logits.device)
472
- loss = loss_fct(shift_logits, shift_labels)
473
-
474
- if not return_dict:
475
- output = (logits,) + outputs[1:]
476
- return (loss,) + output if loss is not None else output
477
-
478
- return CausalLMOutputWithPast(
479
- loss=loss,
480
- logits=logits,
481
- past_key_values=outputs.past_key_values,
482
- hidden_states=outputs.hidden_states,
483
- attentions=outputs.attentions,
484
- )
485
-
486
- def prepare_inputs_for_generation(
487
- self,
488
- input_ids: torch.LongTensor,
489
- past_key_values: Optional[torch.Tensor] = None,
490
- attention_mask: Optional[torch.Tensor] = None,
491
- inputs_embeds: Optional[torch.Tensor] = None,
492
- **kwargs
493
- ):
494
- if past_key_values:
495
- input_ids = input_ids[:, -1:]
496
-
497
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
498
- if inputs_embeds is not None and past_key_values is None:
499
- model_inputs = {"inputs_embeds": inputs_embeds}
500
- else:
501
- model_inputs = {"input_ids": input_ids}
502
-
503
- model_inputs.update(
504
- {
505
- "past_key_values": past_key_values,
506
- "use_cache": kwargs.get("use_cache"),
507
- "attention_mask": attention_mask
508
- }
509
- )
510
- return model_inputs
511
-
512
- @staticmethod
513
- def _reorder_cache(past_key_values, beam_idx):
514
- return tuple(
515
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
516
- for layer_past in past_key_values
517
- )
518
-
519
- def quantize(self, bits: int):
520
- try:
521
- from .quantizer import QLinear
522
- except ImportError:
523
- raise ImportError(
524
- f"Needs QLinear to run quantize."
525
- )
526
-
527
- for layer in self.model.layers:
528
- layer.self_attn.W_pack = QLinear(
529
- bits=bits,
530
- weight=layer.self_attn.W_pack.weight,
531
- bias = None,
532
- )
533
- layer.self_attn.o_proj = QLinear(
534
- bits=bits,
535
- weight=layer.self_attn.o_proj.weight,
536
- bias = None,
537
- )
538
- layer.mlp.gate_proj = QLinear(
539
- bits=bits,
540
- weight=layer.mlp.gate_proj.weight,
541
- bias = None,
542
- )
543
- layer.mlp.down_proj = QLinear(
544
- bits=bits,
545
- weight=layer.mlp.down_proj.weight,
546
- bias = None,
547
- )
548
- layer.mlp.up_proj = QLinear(
549
- bits=bits,
550
- weight=layer.mlp.up_proj.weight,
551
- bias = None,
552
- )
553
- return self
554
-
555
- def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
556
- max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
557
- max_input_tokens = self.config.model_max_length - max_new_tokens
558
- max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
559
- total_input, round_input = [], []
560
- for i, message in enumerate(messages[::-1]):
561
- content_tokens = tokenizer.encode(message['content'])
562
- if message['role'] == 'user':
563
- round_input = [self.generation_config.user_token_id] + content_tokens + round_input
564
- if total_input and len(total_input) + len(round_input) > max_input_tokens:
565
- break
566
- else:
567
- total_input = round_input + total_input
568
- if len(total_input) >= max_input_tokens:
569
- break
570
- else:
571
- round_input = []
572
- elif message['role'] == 'assistant':
573
- round_input = [
574
- self.generation_config.assistant_token_id
575
- ] + content_tokens + [
576
- self.generation_config.eos_token_id
577
- ] + round_input
578
- else:
579
- raise ValueError(f"message role not supported yet: {message['role']}")
580
- total_input = total_input[-max_input_tokens:] # truncate left
581
- total_input.append(self.generation_config.assistant_token_id)
582
- total_input = torch.LongTensor([total_input]).to(self.device)
583
- return total_input
584
-
585
- @torch.no_grad()
586
- def chat(self, tokenizer, messages: List[dict], stream=False,
587
- generation_config: Optional[GenerationConfig]=None):
588
- generation_config = generation_config or self.generation_config
589
- input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
590
- if stream:
591
- from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
592
- self.__class__.generate = NewGenerationMixin.generate
593
- self.__class__.sample_stream = NewGenerationMixin.sample_stream
594
- stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
595
-
596
- def stream_generator():
597
- outputs = []
598
- for token in self.generate(input_ids, generation_config=stream_config):
599
- outputs.append(token.item())
600
- yield tokenizer.decode(outputs, skip_special_tokens=True)
601
-
602
- return stream_generator()
603
- else:
604
- self.__class__.generate = PreTrainedModel.generate # disable stream
605
- outputs = self.generate(input_ids, generation_config=generation_config)
606
- response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
607
- return response