barunsaha commited on
Commit
c3cb785
·
verified ·
1 Parent(s): c151b27

Upload s05_generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. s05_generate.py +435 -0
s05_generate.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Step 5: Generate Bengali Text
2
+ ==============================
3
+ Load a trained checkpoint and generate text, optionally conditioning
4
+ on a specific author or completing a given prompt.
5
+
6
+ Usage:
7
+ # Generate with a random author
8
+ python s05_generate.py
9
+
10
+ # Generate conditioned on a specific author
11
+ python s05_generate.py --author "রবীন্দ্রনাথ ঠাকুর" --type poem
12
+
13
+ # Complete a given Bengali prompt
14
+ python s05_generate.py --prompt $'<|bow|><|author:জীবনানন্দ দাশ|><|poem|>\nহাজার বছর ধরে'
15
+
16
+ # Interactive mode
17
+ python s05_generate.py --interactive
18
+
19
+ # Show all available authors
20
+ python s05_generate.py --list-authors
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ from contextlib import nullcontext
26
+ from pathlib import Path
27
+
28
+ import sentencepiece as spm
29
+ import torch
30
+ import torch.nn.functional as F
31
+
32
+ from s00_model import Banalata, ModelConfig
33
+
34
+ # Model and tokenizer configs
35
+ MODULE_PATH = Path(__file__).resolve().parent
36
+ CKPT_PATH = MODULE_PATH / 'checkpoints/ckpt_best.pt'
37
+ TOK_DIR = MODULE_PATH / 'tokenizer'
38
+
39
+ # Defaults used when --prompt is given without --author / --type
40
+ DEFAULT_AUTHOR = 'জীবনানন্দ দাশ'
41
+ DEFAULT_TYPE = 'poem'
42
+
43
+
44
+ def load_model_and_tokenizer(ckpt_path: str, device: torch.device):
45
+ """Load checkpoint, reconstruct model, load tokenizer."""
46
+ tok_config = json.loads((TOK_DIR / 'tokenizer_config.json').read_text(encoding='utf-8'))
47
+ sp = spm.SentencePieceProcessor()
48
+ sp.load(str(MODULE_PATH / tok_config['model_path']))
49
+
50
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
51
+ mcfg_dict = ckpt['mcfg']
52
+ mcfg = ModelConfig(**mcfg_dict)
53
+
54
+ model = Banalata(mcfg).to(device)
55
+ state = ckpt['model']
56
+ state = {k.replace('_orig_mod.', ''): v for k, v in state.items()}
57
+ model.load_state_dict(state)
58
+ model.eval()
59
+
60
+ print(
61
+ f'Loaded checkpoint (iter={ckpt.get("iter", "?")}, '
62
+ f'val_loss={ckpt.get("best_val", "?"):.4f})'
63
+ )
64
+ return model, sp, tok_config
65
+
66
+
67
+ @torch.inference_mode()
68
+ def generate(
69
+ model: Banalata,
70
+ sp,
71
+ tok_config: dict,
72
+ device: torch.device,
73
+ author: str | None = None,
74
+ content_type: str | None = None,
75
+ prompt: str | None = None,
76
+ max_tokens: int = 300,
77
+ temperature: float = 0.85,
78
+ top_p: float = 0.92,
79
+ repetition_penalty: float = 1.0,
80
+ n_samples: int = 1,
81
+ ) -> list[str]:
82
+ """Generate text samples.
83
+
84
+ Prompt construction priority:
85
+ 1. If `prompt` given: encode it directly (author/type ignored — embed them in the prompt string)
86
+ 2. If `author` given: <|bow|><|author:NAME|>[<|poem|> or <|prose|>]
87
+ 3. Otherwise: <|bow|> only
88
+
89
+ repetition_penalty: divides logits of already-seen tokens before sampling.
90
+ 1.0 = no penalty (original behaviour)
91
+ 1.2 = light penalty, reduces mild loops
92
+ 1.3 = recommended default, handles most repetition
93
+ 1.5 = aggressive, may hurt coherence for highly repetitive styles (e.g. Lalan)
94
+ """
95
+ bow_id = tok_config.get('bow_id')
96
+ eow_id = tok_config.get('eow_id')
97
+ special_ids = set(tok_config.get('special_tokens', {}).values())
98
+
99
+ results = []
100
+
101
+ if device.type == 'cuda':
102
+ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
103
+ elif device.type == 'mps':
104
+ ctx = torch.amp.autocast(device_type='mps', dtype=torch.bfloat16)
105
+ else:
106
+ ctx = nullcontext()
107
+
108
+ for _ in range(n_samples):
109
+ if prompt:
110
+ # --prompt accepts plain Bengali text only.
111
+ # Author and type conditioning come from --author / --type args,
112
+ # falling back to DEFAULT_AUTHOR / DEFAULT_TYPE if not set.
113
+ effective_author = content_type and author # use explicit if both given
114
+ eff_author = author or DEFAULT_AUTHOR
115
+ eff_type = content_type or DEFAULT_TYPE
116
+
117
+ aut_tok = f'<|author:{eff_author}|>'
118
+ aut_id = sp.piece_to_id(aut_tok)
119
+ if aut_id == sp.unk_id():
120
+ available = tok_config.get('author_tokens', [])
121
+ matches = [t for t in available if eff_author in t]
122
+ aut_id = sp.piece_to_id(matches[0]) if matches else None
123
+ if aut_id:
124
+ print(f'Using author token: {matches[0]}')
125
+ else:
126
+ print(f"Author '{eff_author}' not found, omitting.")
127
+
128
+ type_tok = f'<|{eff_type}|>'
129
+ type_id = sp.piece_to_id(type_tok)
130
+ if type_id == sp.unk_id():
131
+ print(f"Type token '{type_tok}' not found, omitting.")
132
+ type_id = None
133
+
134
+ # Build: <|bow|><|author:NAME|><|poem|>\nplain text
135
+ text_ids = sp.encode(prompt, out_type=int)
136
+ prefix_ids = [x for x in [bow_id, aut_id, type_id] if x is not None]
137
+ prompt_ids = prefix_ids + text_ids
138
+
139
+ elif author:
140
+ # Author + optional type conditioning
141
+ aut_tok = f'<|author:{author}|>'
142
+ aut_id = sp.piece_to_id(aut_tok)
143
+ if aut_id == sp.unk_id():
144
+ available = tok_config.get('author_tokens', [])
145
+ matches = [t for t in available if author in t]
146
+ if matches:
147
+ aut_tok = matches[0]
148
+ aut_id = sp.piece_to_id(aut_tok)
149
+ print(f'Using author token: {aut_tok}')
150
+ else:
151
+ print(f"Author '{author}' not found. Using <|bow|> only.")
152
+ aut_id = None
153
+
154
+ type_id = None
155
+ if content_type:
156
+ type_tok = f'<|{content_type}|>'
157
+ type_id = sp.piece_to_id(type_tok)
158
+ if type_id == sp.unk_id():
159
+ print(f"Type token '{type_tok}' not found, ignoring.")
160
+ type_id = None
161
+
162
+ prompt_ids = [x for x in [bow_id, aut_id, type_id] if x is not None]
163
+ if not prompt_ids:
164
+ prompt_ids = [bow_id] if bow_id else []
165
+
166
+ else:
167
+ prompt_ids = [bow_id] if bow_id else []
168
+
169
+ if not prompt_ids:
170
+ prompt_ids = [sp.bos_id()]
171
+
172
+ idx = torch.tensor([prompt_ids], dtype=torch.long, device=device)
173
+ with ctx:
174
+ out = _generate_tokens(
175
+ model,
176
+ idx,
177
+ max_new_tokens=max_tokens,
178
+ temperature=temperature,
179
+ top_p=top_p,
180
+ eow_id=eow_id,
181
+ repetition_penalty=repetition_penalty,
182
+ )
183
+
184
+ tokens = out[0].tolist()
185
+ content_ids = [t for t in tokens if t not in special_ids]
186
+ text = sp.decode(content_ids)
187
+ results.append(text.strip())
188
+
189
+ return results
190
+
191
+
192
+ @torch.inference_mode()
193
+ def _generate_tokens(
194
+ model: Banalata,
195
+ idx: torch.Tensor,
196
+ max_new_tokens: int,
197
+ temperature: float,
198
+ top_p: float,
199
+ eow_id: int | None,
200
+ repetition_penalty: float = 1.3,
201
+ ) -> torch.Tensor:
202
+ """Core autoregressive loop with repetition penalty and nucleus sampling.
203
+
204
+ Repetition penalty (from the original "CTRL" paper):
205
+ - For each token already in the sequence, divide its logit by the penalty.
206
+ - Positive logits become smaller (less likely).
207
+ - Negative logits become more negative (even less likely).
208
+ - Applied BEFORE temperature scaling so temperature still controls overall sharpness.
209
+ """
210
+ for _ in range(max_new_tokens):
211
+ idx_cond = idx[:, -model.cfg.context_len :]
212
+ logits, _ = model(idx_cond)
213
+ logits = logits[:, -1, :] # (1, vocab_size)
214
+
215
+ # Repetition penalty
216
+ # Collect unique token ids seen so far in the full sequence
217
+ if repetition_penalty != 1.0:
218
+ seen = idx[0].unique()
219
+ # Penalise: divide positive logits, multiply negative logits
220
+ # This preserves sign while reducing magnitude in both directions
221
+ logits[0, seen] = torch.where(
222
+ logits[0, seen] > 0,
223
+ logits[0, seen] / repetition_penalty,
224
+ logits[0, seen] * repetition_penalty,
225
+ )
226
+
227
+ # Temperature
228
+ logits = logits / temperature
229
+
230
+ # Top-p (nucleus) sampling
231
+ probs = F.softmax(logits, dim=-1)
232
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
233
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
234
+ sorted_probs[cumulative - sorted_probs > top_p] = 0.0
235
+ sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
236
+ next_token = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
237
+
238
+ idx = torch.cat([idx, next_token], dim=1)
239
+
240
+ if eow_id is not None and next_token.item() == eow_id:
241
+ break
242
+
243
+ return idx
244
+
245
+
246
+ def list_authors(tok_config: dict):
247
+ """List all available author names for conditioning."""
248
+ tokens = tok_config.get('author_tokens', [])
249
+ print(f'\nAvailable author tokens ({len(tokens)}):')
250
+ for t in sorted(tokens):
251
+ name = t.replace('<|author:', '').replace('|>', '')
252
+ print(f' --author "{name}"')
253
+
254
+
255
+ def interactive_mode(model, sp, tok_config, device):
256
+ """Start an interactive REPL session for Bengali text generation."""
257
+ print('\n' + '=' * 55)
258
+ print('Banalata — Interactive Mode')
259
+ print('Commands:')
260
+ print(' [Enter] alone — generate with random author')
261
+ print(' author: NAME — set author (Bengali name)')
262
+ print(' type: poem|prose — set content type')
263
+ print(' prompt: TEXT — set raw prompt (overrides author/type)')
264
+ print(' temp: 0.8 — set temperature (default 0.85)')
265
+ print(' top_p: 0.9 — set top-p (default 0.92)')
266
+ print(' penalty: 1.3 — set repetition penalty (default 1.3)')
267
+ print(' tokens: 200 — set max output tokens')
268
+ print(' authors — list available authors')
269
+ print(' quit — exit')
270
+ print('=' * 55 + '\n')
271
+
272
+ import random
273
+
274
+ author = None
275
+ content_type = None
276
+ prompt = None
277
+ temp = 0.85
278
+ top_p = 0.92
279
+ rep_penalty = 1.3
280
+ max_tokens = 250
281
+
282
+ while True:
283
+ try:
284
+ cmd = input('>>> ').strip()
285
+ except (EOFError, KeyboardInterrupt):
286
+ break
287
+
288
+ if cmd.lower() in ('quit', 'exit', 'q'):
289
+ break
290
+ elif cmd.lower() == 'authors':
291
+ list_authors(tok_config)
292
+ elif cmd.lower().startswith('author:'):
293
+ author = cmd.split(':', 1)[1].strip()
294
+ prompt = None
295
+ print(f'Author set to: {author}')
296
+ elif cmd.lower().startswith('type:'):
297
+ content_type = cmd.split(':', 1)[1].strip().lower()
298
+ if content_type not in ('poem', 'prose'):
299
+ print("Type must be 'poem' or 'prose'")
300
+ content_type = None
301
+ else:
302
+ print(f'Type set to: {content_type}')
303
+ elif cmd.lower().startswith('prompt:'):
304
+ prompt = cmd.split(':', 1)[1].strip()
305
+ author = None
306
+ content_type = None
307
+ print(f'Prompt set to: {prompt}')
308
+ elif cmd.lower().startswith('temp:'):
309
+ temp = float(cmd.split(':', 1)[1].strip())
310
+ print(f'Temperature: {temp}')
311
+ elif cmd.lower().startswith('top_p:'):
312
+ top_p = float(cmd.split(':', 1)[1].strip())
313
+ print(f'Top-p: {top_p}')
314
+ elif cmd.lower().startswith('penalty:'):
315
+ rep_penalty = float(cmd.split(':', 1)[1].strip())
316
+ print(f'Repetition penalty: {rep_penalty}')
317
+ elif cmd.lower().startswith('tokens:'):
318
+ max_tokens = int(cmd.split(':', 1)[1].strip())
319
+ print(f'Max tokens: {max_tokens}')
320
+ elif cmd == '':
321
+ if author is None and prompt is None:
322
+ tokens = tok_config.get('author_tokens', [])
323
+ if tokens:
324
+ tok = random.choice(tokens)
325
+ author = tok.replace('<|author:', '').replace('|>', '')
326
+ print(f'(Random author: {author})')
327
+
328
+ results = generate(
329
+ model,
330
+ sp,
331
+ tok_config,
332
+ device,
333
+ author=author,
334
+ content_type=content_type,
335
+ prompt=prompt,
336
+ max_tokens=max_tokens,
337
+ temperature=temp,
338
+ top_p=top_p,
339
+ repetition_penalty=rep_penalty,
340
+ )
341
+ print(f'\n{"-" * 50}')
342
+ print(results[0])
343
+ print(f'{"-" * 50}\n')
344
+
345
+
346
+ # ------------------------------------------------------------------------
347
+ # Main
348
+ # ------------------------------------------------------------------------
349
+
350
+
351
+ def main():
352
+ """Main execution function for text generation via command-line arguments."""
353
+ parser = argparse.ArgumentParser(
354
+ description='Banalata Text Generation',
355
+ formatter_class=argparse.RawDescriptionHelpFormatter,
356
+ epilog="""
357
+ Examples:
358
+ python s05_generate.py --author "রবীন্দ্রনাথ ঠাকুর" --type poem
359
+ python s05_generate.py --author "জীবনানন্দ দাশ" --type poem --penalty 1.2
360
+ python s05_generate.py --prompt $'<|bow|><|author:জীবনানন্দ দাশ|><|poem|>\\nহাজার বছর ধরে'
361
+ python s05_generate.py --interactive
362
+ """,
363
+ )
364
+ parser.add_argument('--ckpt', default=CKPT_PATH)
365
+ parser.add_argument(
366
+ '--author', default=DEFAULT_AUTHOR, help="Bengali author name, e.g. 'রবীন্দ্রনাথ ঠাকুর'"
367
+ )
368
+ parser.add_argument(
369
+ '--type',
370
+ dest='content_type',
371
+ choices=['poem', 'prose'],
372
+ default=DEFAULT_TYPE,
373
+ help='Content type token to prepend (only used with --author)',
374
+ )
375
+ parser.add_argument(
376
+ '--prompt',
377
+ default=None,
378
+ help='Raw prompt string. Embed special tokens directly for full control: '
379
+ "$'<|bow|><|author:NAME|><|poem|>\\nopening line'",
380
+ )
381
+ parser.add_argument('--max-tokens', type=int, default=300)
382
+ parser.add_argument('--temperature', type=float, default=0.85)
383
+ parser.add_argument('--top-p', type=float, default=0.92)
384
+ parser.add_argument(
385
+ '--penalty',
386
+ type=float,
387
+ default=1.3,
388
+ help='Repetition penalty. 1.0=disabled, 1.2=light, 1.3=default, 1.5=aggressive',
389
+ )
390
+ parser.add_argument('--n-samples', type=int, default=1)
391
+ parser.add_argument('--interactive', action='store_true')
392
+ parser.add_argument('--list-authors', action='store_true')
393
+ args = parser.parse_args()
394
+
395
+ device = torch.device(
396
+ 'cuda'
397
+ if torch.cuda.is_available()
398
+ else 'mps'
399
+ if torch.backends.mps.is_available()
400
+ else 'cpu'
401
+ )
402
+
403
+ model, sp, tok_config = load_model_and_tokenizer(args.ckpt, device)
404
+
405
+ if args.list_authors:
406
+ list_authors(tok_config)
407
+ return
408
+
409
+ if args.interactive:
410
+ interactive_mode(model, sp, tok_config, device)
411
+ return
412
+
413
+ results = generate(
414
+ model,
415
+ sp,
416
+ tok_config,
417
+ device,
418
+ author=args.author,
419
+ content_type=args.content_type,
420
+ prompt=args.prompt,
421
+ max_tokens=args.max_tokens,
422
+ temperature=args.temperature,
423
+ top_p=args.top_p,
424
+ repetition_penalty=args.penalty,
425
+ n_samples=args.n_samples,
426
+ )
427
+
428
+ for i, text in enumerate(results):
429
+ if args.n_samples > 1:
430
+ print(f'\n--- Sample {i + 1}')
431
+ print(text)
432
+
433
+
434
+ if __name__ == '__main__':
435
+ main()