Spaces:
Running on Zero
Running on Zero
| # inference.py | |
| import torch | |
| import torch.nn.functional as F | |
| def generate_stream( | |
| model, | |
| input_ids, | |
| max_new_tokens, | |
| temperature, | |
| top_p=None, | |
| top_k=None, | |
| ): | |
| """ | |
| ストリーミング生成(batch size = 1 固定) | |
| - GPT.generate と同じロジック | |
| - KV cache 使用 | |
| - top-k / top-p 対応 | |
| """ | |
| model.eval() | |
| next_token = None | |
| with torch.no_grad(): | |
| for i in range(max_new_tokens): | |
| # ===== forward ===== | |
| if i == 0: | |
| logits, _ = model(input_ids, None, use_cache=True) | |
| else: | |
| logits, _ = model(next_token, None, use_cache=True) | |
| # last token logits | |
| last_logits = logits[:, -1, :] / temperature # [1, vocab] | |
| # ===== top-k ===== | |
| if top_k is not None: | |
| top_k = min(top_k, last_logits.size(-1)) | |
| values, _ = torch.topk(last_logits, top_k) | |
| min_value = values[:, -1].unsqueeze(-1) | |
| last_logits = torch.where( | |
| last_logits < min_value, | |
| torch.full_like(last_logits, float("-inf")), | |
| last_logits, | |
| ) | |
| # ===== top-p (nucleus) ===== | |
| if top_p is not None: | |
| sorted_logits, sorted_indices = torch.sort( | |
| last_logits, descending=True | |
| ) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| sorted_mask = cumulative_probs > top_p | |
| # ★ ここが重要:clone() を入れる | |
| sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() | |
| sorted_mask[..., 0] = False | |
| sorted_logits = torch.where( | |
| sorted_mask, | |
| torch.full_like(sorted_logits, float("-inf")), | |
| sorted_logits, | |
| ) | |
| last_logits = torch.zeros_like(last_logits).scatter( | |
| -1, sorted_indices, sorted_logits | |
| ) | |
| # ===== sample ===== | |
| probs = F.softmax(last_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) # [1, 1] | |
| yield int(next_token.item()) | |
| # 次ステップ用に連結 | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |