knight-lee commited on
Commit
762d17f
ยท
verified ยท
1 Parent(s): 1dbadc9

Create modeling_dummy.py

Browse files
Files changed (1) hide show
  1. modeling_dummy.py +81 -0
modeling_dummy.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+
5
+ class DummyConfig(PretrainedConfig):
6
+ model_type = "dummy"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=32,
12
+ intermediate_size=64,
13
+ num_hidden_layers=1,
14
+ num_attention_heads=1,
15
+ max_position_embeddings=2048,
16
+ pad_token_id=0,
17
+ bos_token_id=1,
18
+ eos_token_id=2,
19
+ **kwargs
20
+ ):
21
+ super().__init__(
22
+ pad_token_id=pad_token_id,
23
+ bos_token_id=bos_token_id,
24
+ eos_token_id=eos_token_id,
25
+ **kwargs
26
+ )
27
+ self.vocab_size = vocab_size
28
+ self.hidden_size = hidden_size
29
+ self.intermediate_size = intermediate_size
30
+ self.num_hidden_layers = num_hidden_layers
31
+ self.num_attention_heads = num_attention_heads
32
+ self.max_position_embeddings = max_position_embeddings
33
+
34
+ class DummyForCausalLM(PreTrainedModel):
35
+ config_class = DummyConfig
36
+ _keys_to_ignore_on_load_missing = ["lm_head.weight"]
37
+
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+ self.config = config
41
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
42
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
43
+
44
+ # ๊ณ ์ • ์‘๋‹ต์šฉ ํ† ํฐ
45
+ self.fixed_response = "์ด๊ฒƒ์€ ๋”๋ฏธ ๋ชจ๋ธ์˜ ๊ณ ์ • ์‘๋‹ต์ž…๋‹ˆ๋‹ค. vLLM ์„œ๋น™ ํ…Œ์ŠคํŠธ์šฉ์œผ๋กœ ๋งŒ๋“ค์–ด์กŒ์Šต๋‹ˆ๋‹ค."
46
+
47
+ # ๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™”
48
+ self.post_init()
49
+
50
+ def get_input_embeddings(self):
51
+ return self.embed
52
+
53
+ def set_input_embeddings(self, value):
54
+ self.embed = value
55
+
56
+ def get_output_embeddings(self):
57
+ return self.lm_head
58
+
59
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
60
+ batch_size = input_ids.shape[0] if input_ids is not None else 1
61
+ seq_len = input_ids.shape[1] if input_ids is not None else 1
62
+
63
+ # ๋งค์šฐ ๊ฐ„๋‹จํ•œ ์ž„๋ฒ ๋”ฉ
64
+ dummy_hidden = torch.zeros((batch_size, seq_len, self.config.hidden_size),
65
+ dtype=torch.float32,
66
+ device=input_ids.device if input_ids is not None else "cpu")
67
+
68
+ # ์ž„์˜์˜ ๋กœ์ง“ ๊ฐ’ ์ƒ์„ฑ
69
+ logits = self.lm_head(dummy_hidden)
70
+
71
+ # ํ•ญ์ƒ ๊ณ ์ •๋œ ์‘๋‹ต์œผ๋กœ ์˜ˆ์ธก๋˜๋„๋ก
72
+ return {"logits": logits}
73
+
74
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
75
+ return {
76
+ "input_ids": input_ids,
77
+ "past_key_values": past_key_values
78
+ }
79
+
80
+ def _reorder_cache(self, past_key_values, beam_idx):
81
+ return past_key_values