HayatoHongoEveryonesAI commited on
Commit
ea941c2
·
1 Parent(s): ebc5472

fixed inference.py

Browse files
Files changed (1) hide show
  1. inference.py +37 -23
inference.py CHANGED
@@ -11,51 +11,65 @@ def generate_stream(
11
  top_k=None,
12
  ):
13
  """
14
- 最小・安全なストリーミング生成
15
- - batch size = 1 前提
16
- - KV cache
 
17
  """
18
  model.eval()
19
  next_token = None
20
 
21
  with torch.no_grad():
22
  for i in range(max_new_tokens):
 
 
23
  if i == 0:
24
  logits, _ = model(input_ids, None, use_cache=True)
25
  else:
26
  logits, _ = model(next_token, None, use_cache=True)
27
 
28
- logits = logits[:, -1, :] / temperature
 
29
 
30
- # top-k
31
  if top_k is not None:
32
- top_k = min(top_k, logits.size(-1))
33
- values, _ = torch.topk(logits, top_k)
34
- min_val = values[:, -1].unsqueeze(-1)
35
- logits = torch.where(
36
- logits < min_val,
37
- torch.full_like(logits, float("-inf")),
38
- logits,
39
  )
40
 
41
- # top-p (nucleus)
42
  if top_p is not None:
43
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
44
- probs = F.softmax(sorted_logits, dim=-1)
45
- cum_probs = torch.cumsum(probs, dim=-1)
 
 
46
 
47
- mask = cum_probs > top_p
48
- mask[..., 1:] = mask[..., :-1]
49
- mask[..., 0] = False
 
 
 
 
 
 
 
50
 
51
- sorted_logits[mask] = -float("inf")
52
- logits = torch.zeros_like(logits).scatter(
53
  -1, sorted_indices, sorted_logits
54
  )
55
 
56
- probs = F.softmax(logits, dim=-1)
57
- next_token = torch.multinomial(probs, num_samples=1)
 
58
 
59
  yield int(next_token.item())
60
 
 
61
  input_ids = torch.cat([input_ids, next_token], dim=1)
 
11
  top_k=None,
12
  ):
13
  """
14
+ ストリーミング生成(batch size = 1 固定)
15
+ - GPT.generate と同じロジック
16
+ - KV cache 使
17
+ - top-k / top-p 対応
18
  """
19
  model.eval()
20
  next_token = None
21
 
22
  with torch.no_grad():
23
  for i in range(max_new_tokens):
24
+
25
+ # ===== forward =====
26
  if i == 0:
27
  logits, _ = model(input_ids, None, use_cache=True)
28
  else:
29
  logits, _ = model(next_token, None, use_cache=True)
30
 
31
+ # last token logits
32
+ last_logits = logits[:, -1, :] / temperature # [1, vocab]
33
 
34
+ # ===== top-k =====
35
  if top_k is not None:
36
+ top_k = min(top_k, last_logits.size(-1))
37
+ values, _ = torch.topk(last_logits, top_k)
38
+ min_value = values[:, -1].unsqueeze(-1)
39
+ last_logits = torch.where(
40
+ last_logits < min_value,
41
+ torch.full_like(last_logits, float("-inf")),
42
+ last_logits,
43
  )
44
 
45
+ # ===== top-p (nucleus) =====
46
  if top_p is not None:
47
+ sorted_logits, sorted_indices = torch.sort(
48
+ last_logits, descending=True
49
+ )
50
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
51
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
52
 
53
+ sorted_mask = cumulative_probs > top_p
54
+ # ここが重要:clone() を入れる
55
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
56
+ sorted_mask[..., 0] = False
57
+
58
+ sorted_logits = torch.where(
59
+ sorted_mask,
60
+ torch.full_like(sorted_logits, float("-inf")),
61
+ sorted_logits,
62
+ )
63
 
64
+ last_logits = torch.zeros_like(last_logits).scatter(
 
65
  -1, sorted_indices, sorted_logits
66
  )
67
 
68
+ # ===== sample =====
69
+ probs = F.softmax(last_logits, dim=-1)
70
+ next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
71
 
72
  yield int(next_token.item())
73
 
74
+ # 次ステップ用に連結
75
  input_ids = torch.cat([input_ids, next_token], dim=1)