dummy / modeling_dummy.py
knight-lee's picture
Create modeling_dummy.py
762d17f verified
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
class DummyConfig(PretrainedConfig):
model_type = "dummy"
def __init__(
self,
vocab_size=32000,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=1,
max_position_embeddings=2048,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
class DummyForCausalLM(PreTrainedModel):
config_class = DummyConfig
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# κ³ μ • μ‘λ‹΅μš© 토큰
self.fixed_response = "이것은 더미 λͺ¨λΈμ˜ κ³ μ • μ‘λ‹΅μž…λ‹ˆλ‹€. vLLM μ„œλΉ™ ν…ŒμŠ€νŠΈμš©μœΌλ‘œ λ§Œλ“€μ–΄μ‘ŒμŠ΅λ‹ˆλ‹€."
# κ°€μ€‘μΉ˜ μ΄ˆκΈ°ν™”
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids=None, attention_mask=None, **kwargs):
batch_size = input_ids.shape[0] if input_ids is not None else 1
seq_len = input_ids.shape[1] if input_ids is not None else 1
# 맀우 κ°„λ‹¨ν•œ μž„λ² λ”©
dummy_hidden = torch.zeros((batch_size, seq_len, self.config.hidden_size),
dtype=torch.float32,
device=input_ids.device if input_ids is not None else "cpu")
# μž„μ˜μ˜ λ‘œμ§“ κ°’ 생성
logits = self.lm_head(dummy_hidden)
# 항상 κ³ μ •λœ μ‘λ‹΅μœΌλ‘œ μ˜ˆμΈ‘λ˜λ„λ‘
return {"logits": logits}
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
return {
"input_ids": input_ids,
"past_key_values": past_key_values
}
def _reorder_cache(self, past_key_values, beam_idx):
return past_key_values