DigitalDaimyo commited on
Commit
8f41d1a
·
verified ·
1 Parent(s): b14dbfc

Upload generation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generation.py +205 -0
generation.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text Generation Utilities for ASA Models
3
+
4
+ Simple, dependency-free text generation with common decoding strategies.
5
+
6
+ Repository: https://github.com/DigitalDaimyo/AddressedStateAttention
7
+ """
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Set, Tuple, List
12
+
13
+
14
+ __all__ = ['generate']
15
+
16
+
17
+ def _forward_logits(model, input_ids, attention_mask=None):
18
+ """Extract logits from various model output formats."""
19
+ out = model(input_ids, attention_mask=attention_mask) if attention_mask is not None else model(input_ids)
20
+
21
+ if isinstance(out, torch.Tensor):
22
+ return out
23
+ if isinstance(out, (tuple, list)):
24
+ return out[0]
25
+ if isinstance(out, dict):
26
+ for key in ["logits", "out", "y", "pred"]:
27
+ if key in out:
28
+ return out[key]
29
+ raise TypeError(f"Unrecognized model output type: {type(out)}")
30
+
31
+
32
+ def _apply_repetition_penalty(logits: torch.Tensor, input_ids: torch.Tensor, penalty: float):
33
+ """Apply repetition penalty to logits (GPT-2 style)."""
34
+ if penalty is None or penalty == 1.0:
35
+ return logits
36
+
37
+ B = logits.size(0)
38
+ for b in range(B):
39
+ prev_tokens = torch.unique(input_ids[b])
40
+ l = logits[b, prev_tokens]
41
+ logits[b, prev_tokens] = torch.where(l < 0, l * penalty, l / penalty)
42
+ return logits
43
+
44
+
45
+ def _top_k_top_p_filtering(
46
+ logits: torch.Tensor,
47
+ top_k: int = 0,
48
+ top_p: float = 1.0,
49
+ min_tokens_to_keep: int = 1
50
+ ):
51
+ """Filter logits using top-k and nucleus (top-p) filtering."""
52
+ B, V = logits.shape
53
+ top_k = int(top_k) if top_k is not None else 0
54
+ top_p = float(top_p) if top_p is not None else 1.0
55
+
56
+ if top_k > 0 and top_k < V:
57
+ kth = torch.topk(logits, top_k, dim=-1).values[:, -1].unsqueeze(-1)
58
+ logits = logits.masked_fill(logits < kth, float("-inf"))
59
+
60
+ if top_p < 1.0:
61
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
62
+ probs = F.softmax(sorted_logits, dim=-1)
63
+ cum = probs.cumsum(dim=-1)
64
+
65
+ remove = cum > top_p
66
+ if min_tokens_to_keep > 1:
67
+ remove[:, :min_tokens_to_keep] = False
68
+ remove = torch.cat([
69
+ torch.zeros((B, 1), device=logits.device, dtype=torch.bool),
70
+ remove[:, :-1]
71
+ ], dim=-1)
72
+
73
+ sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
74
+ logits = torch.full_like(logits, float("-inf"))
75
+ logits.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
76
+
77
+ return logits
78
+
79
+
80
+ def _update_seen_ngrams(seen: Set, tokens: List[int], n: int):
81
+ """Add n-gram to seen set."""
82
+ if n > 0 and len(tokens) >= n:
83
+ seen.add(tuple(tokens[-n:]))
84
+
85
+
86
+ def _seed_seen_ngrams(input_ids: torch.Tensor, n: int) -> Set:
87
+ """Initialize seen n-grams from input."""
88
+ seen = set()
89
+ if n <= 0:
90
+ return seen
91
+ tokens = input_ids[0].tolist()
92
+ if len(tokens) >= n:
93
+ for i in range(len(tokens) - n + 1):
94
+ seen.add(tuple(tokens[i:i+n]))
95
+ return seen
96
+
97
+
98
+ def _banned_from_seen(seen: Set, input_ids: torch.Tensor, n: int) -> Set:
99
+ """Get tokens banned by n-gram constraint."""
100
+ if n <= 0 or input_ids.shape[1] < n - 1:
101
+ return set()
102
+
103
+ prefix = tuple(input_ids[0, -(n - 1):].tolist())
104
+ banned = set()
105
+ for ng in seen:
106
+ if ng[:-1] == prefix:
107
+ banned.add(ng[-1])
108
+ return banned
109
+
110
+
111
+ @torch.no_grad()
112
+ def generate(
113
+ model,
114
+ tokenizer,
115
+ prompt: str,
116
+ max_new_tokens: int = 120,
117
+ max_seq_len: int = 1024,
118
+ strategy: str = "sample",
119
+ temperature: float = 1.0,
120
+ top_k: int = 0,
121
+ top_p: float = 0.9,
122
+ repetition_penalty: float = 1.0,
123
+ no_repeat_ngram_size: int = 0,
124
+ eos_token_id: Optional[int] = None,
125
+ device: str = "cuda",
126
+ ) -> str:
127
+ """
128
+ Generate text from a prompt using various decoding strategies.
129
+
130
+ Args:
131
+ model: ASA language model
132
+ tokenizer: HuggingFace tokenizer
133
+ prompt: Input text prompt
134
+ max_new_tokens: Maximum tokens to generate
135
+ max_seq_len: Maximum sequence length (truncates context if exceeded)
136
+ strategy: "greedy" or "sample"
137
+ temperature: Sampling temperature (higher = more random)
138
+ top_k: Keep only top k tokens (0 = disabled)
139
+ top_p: Nucleus sampling threshold (1.0 = disabled)
140
+ repetition_penalty: Penalty for repeating tokens (1.0 = disabled)
141
+ no_repeat_ngram_size: Block repeating n-grams (0 = disabled)
142
+ eos_token_id: Stop generation at this token
143
+ device: Device to run on
144
+
145
+ Returns:
146
+ Generated text (including prompt)
147
+
148
+ Example:
149
+ >>> text = generate(
150
+ ... model, tokenizer,
151
+ ... prompt="The capital of France is",
152
+ ... max_new_tokens=20,
153
+ ... strategy="greedy"
154
+ ... )
155
+ """
156
+ model.eval()
157
+
158
+ enc = tokenizer(prompt, return_tensors="pt")
159
+ input_ids = enc.input_ids.to(device)
160
+
161
+ if eos_token_id is None:
162
+ eos_token_id = tokenizer.eos_token_id
163
+
164
+ seen = _seed_seen_ngrams(input_ids, no_repeat_ngram_size)
165
+
166
+ for _ in range(max_new_tokens):
167
+ # Truncate if exceeding context length
168
+ if input_ids.shape[1] > max_seq_len:
169
+ input_ids = input_ids[:, -max_seq_len:]
170
+ seen = _seed_seen_ngrams(input_ids, no_repeat_ngram_size)
171
+
172
+ logits = _forward_logits(model, input_ids)
173
+ next_logits = logits[:, -1, :].to(torch.float32).clone()
174
+
175
+ # Apply repetition penalty
176
+ next_logits = _apply_repetition_penalty(next_logits, input_ids, repetition_penalty)
177
+
178
+ # Block repeated n-grams
179
+ banned = _banned_from_seen(seen, input_ids, no_repeat_ngram_size)
180
+ if banned:
181
+ next_logits[0, list(banned)] = float("-inf")
182
+
183
+ # Decode strategy
184
+ if strategy == "greedy":
185
+ next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
186
+ elif strategy == "sample":
187
+ temp = max(1e-6, float(temperature))
188
+ next_logits = next_logits / temp
189
+ next_logits = _top_k_top_p_filtering(next_logits, top_k=top_k, top_p=top_p)
190
+ probs = F.softmax(next_logits, dim=-1)
191
+ next_token = torch.multinomial(probs, num_samples=1)
192
+ else:
193
+ raise ValueError(f"Unknown strategy '{strategy}'. Use 'greedy' or 'sample'.")
194
+
195
+ input_ids = torch.cat([input_ids, next_token], dim=1)
196
+
197
+ # Update n-gram tracking
198
+ tokens = input_ids[0].tolist()
199
+ _update_seen_ngrams(seen, tokens, no_repeat_ngram_size)
200
+
201
+ # Check for EOS
202
+ if eos_token_id is not None and next_token.item() == eos_token_id:
203
+ break
204
+
205
+ return tokenizer.decode(input_ids[0], skip_special_tokens=False)