HayatoHongoEveryonesAI commited on
Commit
81845c2
·
1 Parent(s): 7657160

initial commit

Browse files
Files changed (4) hide show
  1. app.py +113 -4
  2. inference.py +75 -0
  3. model.py +413 -0
  4. requirements.txt +4 -0
app.py CHANGED
@@ -1,7 +1,116 @@
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import gradio as gr
3
+ import spaces
4
+ import torch
5
+ import tiktoken
6
+ from huggingface_hub import hf_hub_download
7
+ from collections import OrderedDict
8
 
9
+ from model import GPT, ModelConfig
10
+ from inference import generate_stream
11
 
12
+ # -------------------------
13
+ # CPU 上でモデルロード(ZeroGPU重要)
14
+ # -------------------------
15
+ # Hugging Face からダウンロード
16
+ model_path = hf_hub_download(
17
+ repo_id="HayatoHongo/everyoneschat-checkpoints",
18
+ filename="model.pt"
19
+ )
20
+
21
+ # state_dict をロード
22
+ state_dict = torch.load(model_path, map_location="cpu")
23
+
24
+ cfg = checkpoint["config"]
25
+ config = ModelConfig(
26
+ embedding_dim=cfg["embedding_dim"],
27
+ hidden_dim=cfg["hidden_dim"],
28
+ num_attention_heads=cfg["num_attention_heads"],
29
+ layer_count=cfg["layer_count"],
30
+ max_sequence_length=cfg["max_sequence_length"],
31
+ rope_theta=cfg["rope_theta"],
32
+ vocab_size=cfg["vocab_size"],
33
+ )
34
+
35
+ # モデル生成 & load
36
+ model = GPT(config)
37
+ model.load_state_dict(state_dict)
38
+ model.eval()
39
+
40
+ tokenizer = tiktoken.get_encoding("gpt2")
41
+ EOS_ID = 50256 # GPT-2 EOS
42
+
43
+ # -------------------------
44
+ # GPU を使う関数だけ ZeroGPU で囲む
45
+ # -------------------------
46
+ @spaces.GPU
47
+ def chat_fn(
48
+ message,
49
+ history,
50
+ temperature,
51
+ top_p,
52
+ top_k,
53
+ ):
54
+ device = "cuda"
55
+ model_gpu = model.to(device)
56
+
57
+ # シングルターンなので毎回 cache を完全リセット
58
+ for block in model_gpu.blocks:
59
+ block.multihead_attention.reset_cache()
60
+
61
+ # ---- ここが超シンプルな prompt 整形 ----
62
+ prompt = (
63
+ "<user>\n"
64
+ f"{message}\n"
65
+ "<assistant>\n"
66
+ )
67
+
68
+ input_ids = torch.tensor(
69
+ [tokenizer.encode(prompt, allowed_special="all")],
70
+ device=device
71
+ )
72
+
73
+ output = ""
74
+
75
+ with torch.no_grad(), torch.autocast(
76
+ device_type="cuda",
77
+ dtype=torch.bfloat16,
78
+ ):
79
+ for tid in generate_stream(
80
+ model_gpu,
81
+ input_ids,
82
+ max_new_tokens=256,
83
+ temperature=temperature,
84
+ top_p=top_p if top_p > 0 else None,
85
+ top_k=top_k if top_k > 0 else None,
86
+ ):
87
+ if tid == EOS_ID:
88
+ break
89
+ output += tokenizer.decode([tid])
90
+
91
+ model_gpu.to("cpu")
92
+ torch.cuda.empty_cache()
93
+
94
+ return output
95
+
96
+
97
+ # -------------------------
98
+ # UI 定義
99
+ # -------------------------
100
+ demo = gr.ChatInterface(
101
+ chat_fn,
102
+ title="EveryonesGPT Pretrained (No Instruction-tuning). Single-turn English-only demo.",
103
+ description=(
104
+ "**Try prompts like:**\n"
105
+ "- What is the capital city of Japan?\n"
106
+ "- What is the element symbol of silver?\n"
107
+ "- Explain AI in simple terms"
108
+ ),
109
+ additional_inputs=[
110
+ gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature"),
111
+ gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"),
112
+ gr.Slider(0, 200, value=0, step=1, label="Top-k"),
113
+ ],
114
+ )
115
+
116
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ def generate_stream(
6
+ model,
7
+ input_ids,
8
+ max_new_tokens,
9
+ temperature,
10
+ top_p=None,
11
+ top_k=None,
12
+ ):
13
+ """
14
+ ストリーミング生成(batch size = 1 固定)
15
+ - GPT.generate と同じロジック
16
+ - KV cache 使用
17
+ - top-k / top-p 対応
18
+ """
19
+ model.eval()
20
+ next_token = None
21
+
22
+ with torch.no_grad():
23
+ for i in range(max_new_tokens):
24
+
25
+ # ===== forward =====
26
+ if i == 0:
27
+ logits, _ = model(input_ids, None, use_cache=True)
28
+ else:
29
+ logits, _ = model(next_token, None, use_cache=True)
30
+
31
+ # last token logits
32
+ last_logits = logits[:, -1, :] / temperature # [1, vocab]
33
+
34
+ # ===== top-k =====
35
+ if top_k is not None:
36
+ top_k = min(top_k, last_logits.size(-1))
37
+ values, _ = torch.topk(last_logits, top_k)
38
+ min_value = values[:, -1].unsqueeze(-1)
39
+ last_logits = torch.where(
40
+ last_logits < min_value,
41
+ torch.full_like(last_logits, float("-inf")),
42
+ last_logits,
43
+ )
44
+
45
+ # ===== top-p (nucleus) =====
46
+ if top_p is not None:
47
+ sorted_logits, sorted_indices = torch.sort(
48
+ last_logits, descending=True
49
+ )
50
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
51
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
52
+
53
+ sorted_mask = cumulative_probs > top_p
54
+ # ★ ここが重要:clone() を入れる
55
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
56
+ sorted_mask[..., 0] = False
57
+
58
+ sorted_logits = torch.where(
59
+ sorted_mask,
60
+ torch.full_like(sorted_logits, float("-inf")),
61
+ sorted_logits,
62
+ )
63
+
64
+ last_logits = torch.zeros_like(last_logits).scatter(
65
+ -1, sorted_indices, sorted_logits
66
+ )
67
+
68
+ # ===== sample =====
69
+ probs = F.softmax(last_logits, dim=-1)
70
+ next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
71
+
72
+ yield int(next_token.item())
73
+
74
+ # 次ステップ用に連結
75
+ input_ids = torch.cat([input_ids, next_token], dim=1)
model.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class ModelConfig:
9
+ embedding_dim: int
10
+ hidden_dim: int
11
+ num_attention_heads: int
12
+ layer_count: int
13
+ max_sequence_length: int
14
+ rope_theta: float
15
+ vocab_size: int
16
+
17
+ # ---- 以下 TokenEmbedding / RotaryEmbedding / MHA / FFN / Block / GPT ----
18
+ # (あなたが提示したコードをそのまま貼る)
19
+ # added top-p and top-k filtering in generate function
20
+ # set vocab_size in config.py
21
+ # MHA with KV cache + RoPE + PyTorch SDPA.
22
+ # This traditional implementation is easier to understand, and still efficient in practice.
23
+ # GQA and MLA is a great way for long-text inference with reduced KV cache size,
24
+ # but both comes with slight loss increase and no efficiency merits during training phase.
25
+ # KV cache does not help training speed. Codebase will be simpler without it.
26
+ # KV cache supports multi-turn continuation by RoPE with position offset.
27
+ # No Dropout. Dataset is large enough and regularization is not necessary.
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ class TokenEmbedding(nn.Module):
34
+ def __init__(self, config):
35
+ super().__init__()
36
+ self.token_embedding_table = nn.Embedding(config.vocab_size, config.embedding_dim)
37
+ # keep embedding in default dtype (autocast will handle bf16 when enabled)
38
+
39
+ def forward(self, input_indices):
40
+ return self.token_embedding_table(input_indices)
41
+
42
+
43
+ class RotaryEmbedding(nn.Module):
44
+ def __init__(self, dim, max_seq_len=2048, rope_theta=1e6):
45
+ super().__init__()
46
+
47
+ inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2) / dim))
48
+ position_index = torch.arange(max_seq_len)
49
+ frequency_matrix = torch.einsum('i,j->ij', position_index, inv_freq)
50
+
51
+ cosine = torch.cos(frequency_matrix)[None, None, :, :]
52
+ sine = torch.sin(frequency_matrix)[None, None, :, :]
53
+
54
+ self.register_buffer("cos_cached", cosine, persistent=False)
55
+ self.register_buffer("sin_cached", sine, persistent=False)
56
+
57
+ def apply_rotary_emb(self, x, position_offset=0):
58
+ sequence_length = x.size(2)
59
+
60
+ cosine = self.cos_cached[:, :, position_offset:position_offset + sequence_length, :]
61
+ sine = self.sin_cached[:, :, position_offset:position_offset + sequence_length, :]
62
+
63
+ x_even = x[..., 0::2]
64
+ x_odd = x[..., 1::2]
65
+
66
+ rotated_even = x_even * cosine - x_odd * sine
67
+ rotated_odd = x_odd * cosine + x_even * sine
68
+
69
+ rotated = torch.empty_like(x)
70
+ rotated[..., 0::2] = rotated_even
71
+ rotated[..., 1::2] = rotated_odd
72
+
73
+ return rotated
74
+
75
+ class MultiHeadAttention(nn.Module):
76
+ def __init__(self, config):
77
+ super().__init__()
78
+ self.config = config
79
+ self.num_heads = config.num_attention_heads
80
+ self.embed_dim = config.embedding_dim
81
+ self.head_dim = self.embed_dim // self.num_heads
82
+
83
+ # QKV projection
84
+ self.query_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
85
+ self.key_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
86
+ self.value_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
87
+
88
+ # Rotary Positional Embedding (RoPE)
89
+ self.rotary_emb = RotaryEmbedding(
90
+ dim=self.head_dim,
91
+ max_seq_len=config.max_sequence_length,
92
+ rope_theta=config.rope_theta
93
+ )
94
+
95
+ self.output_projection = nn.Linear(self.embed_dim, self.embed_dim)
96
+
97
+ self.register_buffer(
98
+ "causal_mask",
99
+ torch.tril(torch.ones(
100
+ config.max_sequence_length,
101
+ config.max_sequence_length,
102
+ dtype=torch.bool
103
+ )),
104
+ persistent=False
105
+ )
106
+
107
+ # KV cache
108
+ self.register_buffer("cache_k", None, persistent=False)
109
+ self.register_buffer("cache_v", None, persistent=False)
110
+ self.current_pos = 0
111
+
112
+ # --------------------------------------------------
113
+ # router
114
+ # --------------------------------------------------
115
+ def forward(self, x, use_cache=False):
116
+ input_len = x.size(1)
117
+ if use_cache is False:
118
+ return self.forward_no_cache(x)
119
+ elif use_cache is True and input_len > 1:
120
+ return self.forward_prefill(x)
121
+ elif use_cache is True and input_len == 1: # Hi scenario also starts with T==1
122
+ return self.forward_cached_decoding(x)
123
+ else:
124
+ raise RuntimeError("Unexpected condition in MultiHeadAttention forward")
125
+
126
+ # --------------------------------------------------
127
+ # (1) no cache : training
128
+ # --------------------------------------------------
129
+ def forward_no_cache(self, x):
130
+ B, T, C = x.shape
131
+
132
+ Q = self.query_fc(x)
133
+ K = self.key_fc(x)
134
+ V = self.value_fc(x)
135
+
136
+ Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
137
+ K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
138
+ V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
139
+
140
+ # RoPE : offset = 0
141
+ Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=0)
142
+ K = self.rotary_emb.apply_rotary_emb(K, position_offset=0)
143
+
144
+ out = F.scaled_dot_product_attention(
145
+ Q, K, V,
146
+ attn_mask=None,
147
+ is_causal=True
148
+ )
149
+
150
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
151
+ out = self.output_projection(out)
152
+ return out
153
+
154
+ # --------------------------------------------------
155
+ # (2) prefill : initialize KV cache
156
+ # --------------------------------------------------
157
+ def forward_prefill(self, x):
158
+ B, T, C = x.shape
159
+
160
+ Q = self.query_fc(x)
161
+ K = self.key_fc(x)
162
+ V = self.value_fc(x)
163
+
164
+ Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
165
+ K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
166
+ V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
167
+
168
+ # init cache
169
+ if self.cache_k is None:
170
+ self.cache_k = torch.zeros(
171
+ B, self.num_heads, self.config.max_sequence_length, self.head_dim,
172
+ device=x.device, dtype=K.dtype
173
+ )
174
+ self.cache_v = torch.zeros(
175
+ B, self.num_heads, self.config.max_sequence_length, self.head_dim,
176
+ device=x.device, dtype=V.dtype
177
+ )
178
+ self.current_pos = 0
179
+
180
+ # RoPE : offset = current_pos (supports multi-turn continuation)
181
+ Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=self.current_pos)
182
+ K = self.rotary_emb.apply_rotary_emb(K, position_offset=self.current_pos)
183
+
184
+ # prevent overflow
185
+ if self.current_pos + T > self.config.max_sequence_length:
186
+ raise RuntimeError("KV cache exceeded max_sequence_length")
187
+
188
+ self.cache_k[:, :, self.current_pos:self.current_pos + T, :] = K
189
+ self.cache_v[:, :, self.current_pos:self.current_pos + T, :] = V
190
+
191
+ K = self.cache_k[:, :, :self.current_pos + T, :]
192
+ V = self.cache_v[:, :, :self.current_pos + T, :]
193
+
194
+ attn_mask = self.causal_mask[
195
+ self.current_pos : self.current_pos + T,
196
+ : self.current_pos + T
197
+ ]
198
+
199
+ out = F.scaled_dot_product_attention(
200
+ Q, K, V,
201
+ attn_mask=attn_mask,
202
+ is_causal=False
203
+ )
204
+
205
+ self.current_pos += T
206
+
207
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
208
+ out = self.output_projection(out)
209
+ return out
210
+
211
+ # --------------------------------------------------
212
+ # (3) decode : cached decoding (1 token)
213
+ # --------------------------------------------------
214
+ def forward_cached_decoding(self, x):
215
+ B, T, C = x.shape
216
+ assert T == 1, "cached decoding expects T==1"
217
+
218
+ Q = self.query_fc(x)
219
+ K = self.key_fc(x)
220
+ V = self.value_fc(x)
221
+
222
+ Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
223
+ K = K.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
224
+ V = V.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
225
+
226
+ # This is not usually needed since prefill should have initialized the cache.
227
+ # Just in case for "Hi" scenario, which starts with single token input.
228
+ if self.cache_k is None:
229
+ self.cache_k = torch.zeros(
230
+ B, self.num_heads, self.config.max_sequence_length, self.head_dim,
231
+ device=x.device, dtype=K.dtype
232
+ )
233
+ self.cache_v = torch.zeros(
234
+ B, self.num_heads, self.config.max_sequence_length, self.head_dim,
235
+ device=x.device, dtype=V.dtype
236
+ )
237
+ self.current_pos = 0
238
+
239
+ if self.current_pos + 1 >= self.config.max_sequence_length:
240
+ raise RuntimeError("KV cache exceeded max_sequence_length")
241
+
242
+ # RoPE : offset = current_pos
243
+ Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=self.current_pos)
244
+ K = self.rotary_emb.apply_rotary_emb(K, position_offset=self.current_pos)
245
+
246
+ self.cache_k[:, :, self.current_pos:self.current_pos + 1, :] = K
247
+ self.cache_v[:, :, self.current_pos:self.current_pos + 1, :] = V
248
+
249
+ K = self.cache_k[:, :, :self.current_pos + 1, :]
250
+ V = self.cache_v[:, :, :self.current_pos + 1, :]
251
+
252
+ out = F.scaled_dot_product_attention(
253
+ Q, K, V,
254
+ attn_mask=None,
255
+ is_causal=False
256
+ )
257
+
258
+ self.current_pos += 1
259
+
260
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
261
+ out = self.output_projection(out)
262
+ return out
263
+
264
+ def reset_cache(self):
265
+ self.cache_k = None
266
+ self.cache_v = None
267
+ self.current_pos = 0
268
+
269
+
270
+
271
+ class FeedForward(nn.Module):
272
+ def __init__(self, config):
273
+ super().__init__()
274
+ self.net = nn.Sequential(
275
+ nn.Linear(config.embedding_dim, config.hidden_dim, bias=False),
276
+ nn.ReLU(),
277
+ nn.Linear(config.hidden_dim, config.embedding_dim, bias=False),
278
+ )
279
+
280
+ def forward(self, input_tensor):
281
+ return self.net(input_tensor)
282
+
283
+
284
+ class TransformerBlock(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.layer_norm1 = nn.LayerNorm(config.embedding_dim)
288
+ self.layer_norm2 = nn.LayerNorm(config.embedding_dim)
289
+ self.multihead_attention = MultiHeadAttention(config=config)
290
+ self.feed_forward = FeedForward(config=config)
291
+
292
+
293
+ def forward(self, input_tensor, use_cache=False):
294
+ normed_input = self.layer_norm1(input_tensor)
295
+ attention_output = self.multihead_attention(normed_input, use_cache=use_cache)
296
+ residual_attention = attention_output + input_tensor
297
+ normed_attention = self.layer_norm2(residual_attention)
298
+ feedforward_output = self.feed_forward(normed_attention)
299
+ final_output = feedforward_output + residual_attention
300
+ return final_output
301
+
302
+
303
+ class VocabularyLogits(nn.Module):
304
+ def __init__(self, config):
305
+ super().__init__()
306
+ self.output_norm = nn.LayerNorm(config.embedding_dim)
307
+ self.vocab_projection = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
308
+
309
+ def forward(self, transformer_block_output):
310
+ x = transformer_block_output
311
+ normalized_output = self.output_norm(x)
312
+ vocab_logits = self.vocab_projection(normalized_output)
313
+ return vocab_logits
314
+
315
+
316
+ class GPT(nn.Module):
317
+ def __init__(self, config):
318
+ super().__init__()
319
+ self.config = config
320
+ self.token_embedding_layer = TokenEmbedding(config=config)
321
+ self.blocks = nn.ModuleList([TransformerBlock(config=config) for _ in range(config.layer_count)])
322
+ self.vocab_projection = VocabularyLogits(config=config)
323
+ self.criterion = nn.CrossEntropyLoss()
324
+
325
+
326
+ def forward(self, input_indices, target_indices, use_cache=False):
327
+ token_embeddings = self.token_embedding_layer.forward(input_indices)
328
+
329
+ x = token_embeddings
330
+ for block in self.blocks:
331
+ x = block(x, use_cache=use_cache)
332
+ logits = self.vocab_projection(x)
333
+
334
+ if target_indices is None:
335
+ return logits, None
336
+
337
+ batch_size, token_len, vocab_size = logits.shape
338
+ logits_flat = logits.view(batch_size * token_len, vocab_size)
339
+ targets_flat = target_indices.view(batch_size * token_len)
340
+ loss = self.criterion(logits_flat, targets_flat)
341
+ return logits, loss
342
+
343
+
344
+ def generate(self,
345
+ input_indices,
346
+ max_new_tokens,
347
+ temperature=1.0,
348
+ use_cache=True,
349
+ reset_cache=True,
350
+ top_k=None, # ### NEW ###
351
+ top_p=None, # ### NEW ###
352
+ ):
353
+ self.eval()
354
+
355
+ if reset_cache:
356
+ for block in self.blocks:
357
+ block.multihead_attention.reset_cache()
358
+
359
+ next_token = None
360
+
361
+ for i in range(max_new_tokens):
362
+ if use_cache:
363
+ if i == 0:
364
+ logits, _ = self.forward(input_indices, None, use_cache=True)
365
+ else:
366
+ logits, _ = self.forward(next_token, None, use_cache=True)
367
+ else:
368
+ logits, _ = self.forward(input_indices, None, use_cache=False)
369
+
370
+ """ DELETE
371
+ last_logits = logits[:, -1, :] / temperature
372
+ probs = F.softmax(last_logits, dim=-1)
373
+ next_token = torch.multinomial(probs, num_samples=1)
374
+ """
375
+
376
+ ### NEW ###
377
+ last_logits = logits[:, -1, :] / temperature
378
+
379
+ if top_k is not None:
380
+ top_k = min(top_k, last_logits.size(-1))
381
+ values, _ = torch.topk(last_logits, top_k)
382
+ min_value = values[:, -1].unsqueeze(-1)
383
+ last_logits = torch.where(
384
+ last_logits < min_value,
385
+ torch.full_like(last_logits, float("-inf")),
386
+ last_logits,
387
+ )
388
+
389
+ if top_p is not None:
390
+ sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
391
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
392
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
393
+
394
+ sorted_mask = cumulative_probs > top_p
395
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
396
+ sorted_mask[..., 0] = False
397
+
398
+ sorted_logits = torch.where(
399
+ sorted_mask,
400
+ torch.full_like(sorted_logits, float("-inf")),
401
+ sorted_logits,
402
+ )
403
+
404
+ last_logits = torch.zeros_like(last_logits).scatter(
405
+ -1, sorted_indices, sorted_logits
406
+ )
407
+
408
+ probs = F.softmax(last_logits, dim=-1)
409
+ next_token = torch.multinomial(probs, num_samples=1)
410
+ ### NEW ###
411
+
412
+ yield int(next_token.item())
413
+ input_indices = torch.cat((input_indices, next_token), dim=1)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ huggingface_hub
3
+ tiktoken
4
+ gradio