EveryonesGPT_Pretrained / inference.py
HayatoHongoEveryonesAI's picture
fixed inference.py
ea941c2
# inference.py
import torch
import torch.nn.functional as F
def generate_stream(
model,
input_ids,
max_new_tokens,
temperature,
top_p=None,
top_k=None,
):
"""
ストリーミング生成(batch size = 1 固定)
- GPT.generate と同じロジック
- KV cache 使用
- top-k / top-p 対応
"""
model.eval()
next_token = None
with torch.no_grad():
for i in range(max_new_tokens):
# ===== forward =====
if i == 0:
logits, _ = model(input_ids, None, use_cache=True)
else:
logits, _ = model(next_token, None, use_cache=True)
# last token logits
last_logits = logits[:, -1, :] / temperature # [1, vocab]
# ===== top-k =====
if top_k is not None:
top_k = min(top_k, last_logits.size(-1))
values, _ = torch.topk(last_logits, top_k)
min_value = values[:, -1].unsqueeze(-1)
last_logits = torch.where(
last_logits < min_value,
torch.full_like(last_logits, float("-inf")),
last_logits,
)
# ===== top-p (nucleus) =====
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(
last_logits, descending=True
)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_mask = cumulative_probs > top_p
# ★ ここが重要:clone() を入れる
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
sorted_logits = torch.where(
sorted_mask,
torch.full_like(sorted_logits, float("-inf")),
sorted_logits,
)
last_logits = torch.zeros_like(last_logits).scatter(
-1, sorted_indices, sorted_logits
)
# ===== sample =====
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
yield int(next_token.item())
# 次ステップ用に連結
input_ids = torch.cat([input_ids, next_token], dim=1)