|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
class DummyConfig(PretrainedConfig): |
|
|
model_type = "dummy" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=32000, |
|
|
hidden_size=32, |
|
|
intermediate_size=64, |
|
|
num_hidden_layers=1, |
|
|
num_attention_heads=1, |
|
|
max_position_embeddings=2048, |
|
|
pad_token_id=0, |
|
|
bos_token_id=1, |
|
|
eos_token_id=2, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
|
|
class DummyForCausalLM(PreTrainedModel): |
|
|
config_class = DummyConfig |
|
|
_keys_to_ignore_on_load_missing = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.embed = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.fixed_response = "μ΄κ²μ λλ―Έ λͺ¨λΈμ κ³ μ μλ΅μ
λλ€. vLLM μλΉ ν
μ€νΈμ©μΌλ‘ λ§λ€μ΄μ‘μ΅λλ€." |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
|
batch_size = input_ids.shape[0] if input_ids is not None else 1 |
|
|
seq_len = input_ids.shape[1] if input_ids is not None else 1 |
|
|
|
|
|
|
|
|
dummy_hidden = torch.zeros((batch_size, seq_len, self.config.hidden_size), |
|
|
dtype=torch.float32, |
|
|
device=input_ids.device if input_ids is not None else "cpu") |
|
|
|
|
|
|
|
|
logits = self.lm_head(dummy_hidden) |
|
|
|
|
|
|
|
|
return {"logits": logits} |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"past_key_values": past_key_values |
|
|
} |
|
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
|
return past_key_values |