| |
| import torch |
| import torch.nn.functional as F |
|
|
| def generate_stream( |
| model, |
| input_ids, |
| max_new_tokens, |
| temperature, |
| top_p=None, |
| top_k=None, |
| ): |
| """ |
| ストリーミング生成(batch size = 1 固定) |
| - GPT.generate と同じロジック |
| - KV cache 使用 |
| - top-k / top-p 対応 |
| """ |
| model.eval() |
| next_token = None |
|
|
| with torch.no_grad(): |
| for i in range(max_new_tokens): |
|
|
| |
| if i == 0: |
| logits, _ = model(input_ids, None, use_cache=True) |
| else: |
| logits, _ = model(next_token, None, use_cache=True) |
|
|
| |
| last_logits = logits[:, -1, :] / temperature |
|
|
| |
| if top_k is not None: |
| top_k = min(top_k, last_logits.size(-1)) |
| values, _ = torch.topk(last_logits, top_k) |
| min_value = values[:, -1].unsqueeze(-1) |
| last_logits = torch.where( |
| last_logits < min_value, |
| torch.full_like(last_logits, float("-inf")), |
| last_logits, |
| ) |
|
|
| |
| if top_p is not None: |
| sorted_logits, sorted_indices = torch.sort( |
| last_logits, descending=True |
| ) |
| sorted_probs = F.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
| sorted_mask = cumulative_probs > top_p |
| |
| sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() |
| sorted_mask[..., 0] = False |
|
|
| sorted_logits = torch.where( |
| sorted_mask, |
| torch.full_like(sorted_logits, float("-inf")), |
| sorted_logits, |
| ) |
|
|
| last_logits = torch.zeros_like(last_logits).scatter( |
| -1, sorted_indices, sorted_logits |
| ) |
|
|
| |
| probs = F.softmax(last_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| yield int(next_token.item()) |
|
|
| |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|