OneFineStarstuff commited on
Commit
4333c13
ยท
verified ยท
1 Parent(s): 60118eb

Create open_ended_question_generator_secure.py

Browse files
open_ended_question_generator_secure.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ open_ended_question_generator_secure.py
6
+
7
+ End-to-end script to generate open-ended questions from context(s) with:
8
+ - Robust list-formatted parsing
9
+ - CLI with single or batch inputs (TXT/CSV)
10
+ - Reproducibility (seed)
11
+ - Device auto-select (CUDA / MPS / CPU)
12
+ - Export to JSON / CSV / TXT
13
+ - Optional AES-256-like authenticated encryption via Fernet (with PBKDF2 key derivation)
14
+ - Optional decryption utility
15
+
16
+ Dependencies:
17
+ pip install torch transformers cryptography
18
+
19
+ Example:
20
+ python open_ended_question_generator_secure.py \
21
+ --context "AGI for cosmology" --n 5 --model gpt2-large \
22
+ --out questions.json --format json --encrypt --password "your-secret"
23
+ """
24
+
25
+ import os
26
+ import re
27
+ import csv
28
+ import json
29
+ import argparse
30
+ import getpass
31
+ import base64
32
+ import sys
33
+ from typing import List, Dict, Tuple, Optional
34
+
35
+ import torch
36
+ from transformers import AutoTokenizer, AutoModelForCausalLM
37
+
38
+ # --- Optional encryption deps ---
39
+ try:
40
+ from cryptography.fernet import Fernet
41
+ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
42
+ from cryptography.hazmat.primitives import hashes
43
+ from cryptography.hazmat.backends import default_backend
44
+ except Exception:
45
+ Fernet = None # Will validate at runtime if encryption/decryption is used.
46
+
47
+
48
+ # ----------------------------
49
+ # Device selection
50
+ # ----------------------------
51
+ def select_device() -> torch.device:
52
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
53
+ return torch.device("mps")
54
+ if torch.cuda.is_available():
55
+ return torch.device("cuda")
56
+ return torch.device("cpu")
57
+
58
+
59
+ # ----------------------------
60
+ # Prompt and parsing
61
+ # ----------------------------
62
+ PROMPT_TEMPLATE = """You are a master at generating deep, open-ended, and thought-provoking questions.
63
+ Each question must be:
64
+ - Self-contained and understandable without extra context.
65
+ - Exploratory (not answerable with yes/no).
66
+ - Written in clear, engaging language.
67
+
68
+ Context:
69
+ {context}
70
+
71
+ Output exactly {n} questions as a numbered list, one per line, formatted like:
72
+ 1. ...
73
+ 2. ...
74
+ 3. ...
75
+ No extra commentary, no headings, no explanations โ€” just the list.
76
+ """
77
+
78
+ def build_prompt(context: str, n: int) -> str:
79
+ return PROMPT_TEMPLATE.format(context=context.strip(), n=n)
80
+
81
+ _Q_LINE_RE = re.compile(r"^\s*(\d+)\.\s+(.*\S)\s*$")
82
+
83
+ def normalize_q(q: str) -> str:
84
+ q = q.strip()
85
+ # Ensure it ends with a question mark for consistency
86
+ if not q.endswith("?"):
87
+ q += "?"
88
+ return q
89
+
90
+ def parse_questions_from_text(text: str, n: int) -> List[str]:
91
+ lines = text.splitlines()
92
+ candidates = []
93
+ for line in lines:
94
+ m = _Q_LINE_RE.match(line)
95
+ if m:
96
+ q_text = normalize_q(m.group(2))
97
+ candidates.append(q_text)
98
+ # Deduplicate while preserving order
99
+ seen = set()
100
+ unique = []
101
+ for q in candidates:
102
+ key = q.lower().strip()
103
+ if key not in seen:
104
+ seen.add(key)
105
+ unique.append(q)
106
+ return unique[:n]
107
+
108
+
109
+ # ----------------------------
110
+ # Model loading and generation
111
+ # ----------------------------
112
+ def load_model_and_tokenizer(model_name: str, device: torch.device):
113
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
114
+ model = AutoModelForCausalLM.from_pretrained(model_name)
115
+ model.to(device)
116
+ # For models like GPT-2 without a pad token
117
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
118
+ tokenizer.pad_token_id = tokenizer.eos_token_id
119
+ return model, tokenizer
120
+
121
+ def generate_questions_once(
122
+ model,
123
+ tokenizer,
124
+ device: torch.device,
125
+ context: str,
126
+ n: int,
127
+ max_new_tokens: int,
128
+ temperature: float,
129
+ top_p: float,
130
+ top_k: int,
131
+ ) -> List[str]:
132
+ prompt = build_prompt(context, n)
133
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
134
+ output = model.generate(
135
+ **inputs,
136
+ max_new_tokens=max_new_tokens,
137
+ temperature=temperature,
138
+ top_p=top_p,
139
+ top_k=top_k,
140
+ do_sample=True,
141
+ pad_token_id=tokenizer.pad_token_id,
142
+ eos_token_id=tokenizer.eos_token_id,
143
+ )
144
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
145
+ # Extract only the continuation after the prompt
146
+ # In many causal LMs, decoded contains prompt + completion; we slice from len(input_ids)
147
+ # Simpler approach: parse all lines and trust the numbered format.
148
+ questions = parse_questions_from_text(decoded, n)
149
+ return questions
150
+
151
+ def generate_questions(
152
+ model,
153
+ tokenizer,
154
+ device: torch.device,
155
+ context: str,
156
+ n: int = 3,
157
+ max_new_tokens: int = 200,
158
+ temperature: float = 0.95,
159
+ top_p: float = 0.95,
160
+ top_k: int = 50,
161
+ seed: Optional[int] = None,
162
+ attempts: int = 3,
163
+ ) -> List[str]:
164
+ if seed is not None:
165
+ torch.manual_seed(seed)
166
+ if device.type == "cuda":
167
+ torch.cuda.manual_seed_all(seed)
168
+ collected: List[str] = []
169
+ tried = 0
170
+ while len(collected) < n and tried < attempts:
171
+ tried += 1
172
+ # Slightly adjust temperature on retries to improve variety
173
+ temp = min(1.2, max(0.7, temperature + 0.1 * (tried - 1)))
174
+ qs = generate_questions_once(
175
+ model, tokenizer, device, context, n, max_new_tokens, temp, top_p, top_k
176
+ )
177
+ # Merge unique
178
+ existing = set([q.lower().strip() for q in collected])
179
+ for q in qs:
180
+ key = q.lower().strip()
181
+ if key not in existing and len(collected) < n:
182
+ collected.append(q)
183
+ existing.add(key)
184
+ # If still short, pad with simple variants (rare)
185
+ while len(collected) < n:
186
+ collected.append(collected[-1] + " (expand)") if collected else collected.append("What deeper questions arise from this context?")
187
+ return collected[:n]
188
+
189
+
190
+ # ----------------------------
191
+ # Batch input handling
192
+ # ----------------------------
193
+ def load_contexts(source_text: Optional[str], source_file: Optional[str]) -> List[Tuple[str, str]]:
194
+ """
195
+ Returns list of (context_id, context_text).
196
+ - If source_text is provided, returns single-item list.
197
+ - If CSV file: expects a 'context' column.
198
+ - If TXT/MD: splits on lines containing only '---' or returns whole file as one context.
199
+ """
200
+ out: List[Tuple[str, str]] = []
201
+ if source_text:
202
+ out.append(("context_1", source_text.strip()))
203
+ return out
204
+ if not source_file:
205
+ raise ValueError("Either --context or --context-file is required.")
206
+ if not os.path.exists(source_file):
207
+ raise FileNotFoundError(f"Context file not found: {source_file}")
208
+
209
+ ext = os.path.splitext(source_file)[1].lower()
210
+ if ext == ".csv":
211
+ with open(source_file, "r", encoding="utf-8", newline="") as f:
212
+ reader = csv.DictReader(f)
213
+ if "context" not in reader.fieldnames:
214
+ raise ValueError("CSV must have a 'context' column.")
215
+ for i, row in enumerate(reader, start=1):
216
+ ctx = (row.get("context") or "").strip()
217
+ if ctx:
218
+ out.append((f"context_{i}", ctx))
219
+ else:
220
+ # Plain text / markdown: split on '---' delimiter lines if present
221
+ with open(source_file, "r", encoding="utf-8") as f:
222
+ content = f.read()
223
+ parts = re.split(r"^\s*---\s*$", content, flags=re.MULTILINE)
224
+ parts = [p.strip() for p in parts if p.strip()]
225
+ if not parts:
226
+ raise ValueError("No context found in file.")
227
+ for i, ctx in enumerate(parts, start=1):
228
+ out.append((f"context_{i}", ctx))
229
+ return out
230
+
231
+
232
+ # ----------------------------
233
+ # Output writers
234
+ # ----------------------------
235
+ def write_json(out_path: str, rows: List[Dict]):
236
+ with open(out_path, "w", encoding="utf-8") as f:
237
+ json.dump(rows, f, ensure_ascii=False, indent=2)
238
+
239
+ def write_csv(out_path: str, rows: List[Dict], n: int):
240
+ fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, n + 1)]
241
+ with open(out_path, "w", encoding="utf-8", newline="") as f:
242
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
243
+ writer.writeheader()
244
+ for r in rows:
245
+ writer.writerow(r)
246
+
247
+ def write_txt(out_path: str, rows: List[Dict], n: int):
248
+ with open(out_path, "w", encoding="utf-8") as f:
249
+ for r in rows:
250
+ f.write(f"[{r['context_id']}]\n")
251
+ f.write(r["context"].strip() + "\n")
252
+ for i in range(1, n + 1):
253
+ f.write(f"{i}. {r[f'q{i}']}\n")
254
+ f.write("\n")
255
+
256
+
257
+ # ----------------------------
258
+ # Encryption / Decryption
259
+ # ----------------------------
260
+ MAGIC = b"QSEC1"
261
+
262
+ def require_crypto():
263
+ if Fernet is None:
264
+ raise RuntimeError("Encryption requested but 'cryptography' is not installed. Run: pip install cryptography")
265
+
266
+ def derive_key_from_password(password: str, salt: bytes) -> bytes:
267
+ kdf = PBKDF2HMAC(
268
+ algorithm=hashes.SHA256(),
269
+ length=32,
270
+ salt=salt,
271
+ iterations=200_000,
272
+ backend=default_backend(),
273
+ )
274
+ key = kdf.derive(password.encode("utf-8"))
275
+ return base64.urlsafe_b64encode(key)
276
+
277
+ def encrypt_file(in_path: str, out_path: str, password: str):
278
+ require_crypto()
279
+ with open(in_path, "rb") as f:
280
+ plaintext = f.read()
281
+ salt = os.urandom(16)
282
+ key = derive_key_from_password(password, salt)
283
+ fernet = Fernet(key)
284
+ ciphertext = fernet.encrypt(plaintext)
285
+ with open(out_path, "wb") as f:
286
+ f.write(MAGIC + salt + ciphertext)
287
+
288
+ def decrypt_file(in_path: str, out_path: str, password: str):
289
+ require_crypto()
290
+ with open(in_path, "rb") as f:
291
+ blob = f.read()
292
+ if not blob.startswith(MAGIC) or len(blob) < len(MAGIC) + 16 + 1:
293
+ raise ValueError("Invalid or unsupported encrypted file.")
294
+ salt = blob[len(MAGIC):len(MAGIC)+16]
295
+ ciphertext = blob[len(MAGIC)+16:]
296
+ key = derive_key_from_password(password, salt)
297
+ fernet = Fernet(key)
298
+ plaintext = fernet.decrypt(ciphertext)
299
+ with open(out_path, "wb") as f:
300
+ f.write(plaintext)
301
+
302
+
303
+ # ----------------------------
304
+ # Main CLI
305
+ # ----------------------------
306
+ def main():
307
+ parser = argparse.ArgumentParser(description="Generate deep open-ended questions with optional encryption/decryption.")
308
+ mode = parser.add_mutuallyExclusiveGroup(required=True)
309
+ mode.add_argument("--generate", action="store_true", help="Generate questions from context(s).")
310
+ mode.add_argument("--decrypt", action="store_true", help="Decrypt an encrypted file (no generation).")
311
+
312
+ # Generation inputs
313
+ parser.add_argument("--context", type=str, help="Inline context text.")
314
+ parser.add_argument("--context-file", type=str, help="Path to TXT/MD (split by ---) or CSV with 'context' column.")
315
+ parser.add_argument("--n", type=int, default=3, help="Number of questions to generate per context.")
316
+ parser.add_argument("--model", type=str, default="gpt2-large", help="HuggingFace model name.")
317
+ parser.add_argument("--max-new-tokens", type=int, default=220, help="Max new tokens for generation.")
318
+ parser.add_argument("--temperature", type=float, default=0.95, help="Sampling temperature.")
319
+ parser.add_argument("--top-p", type=float, default=0.95, help="Top-p nucleus sampling.")
320
+ parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling.")
321
+ parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
322
+ parser.add_argument("--attempts", type=int, default=3, help="Max attempts to reach exactly n questions.")
323
+
324
+ # Output
325
+ parser.add_argument("--out", type=str, default=None, help="Output file path. If omitted, prints to stdout.")
326
+ parser.add_argument("--format", type=str, choices=["json", "csv", "txt"], default="json", help="Output format when generating.")
327
+ parser.add_argument("--encrypt", action="store_true", help="Encrypt the output file after generation.")
328
+ parser.add_argument("--password", type=str, default=None, help="Password for encryption/decryption. If omitted, prompts securely.")
329
+
330
+ # Decryption I/O
331
+ parser.add_argument("--in", dest="in_path", type=str, help="Input file for decryption (encrypted).")
332
+ parser.add_argument("--out-decrypted", type=str, help="Output file for decrypted plaintext.")
333
+
334
+ args = parser.parse_args()
335
+
336
+ device = select_device()
337
+
338
+ if args.decrypt:
339
+ # Decrypt mode
340
+ if not args.in_path or not args.out_decrypted:
341
+ parser.error("--decrypt requires --in and --out-decrypted.")
342
+ password = args.password or getpass.getpass("Enter password: ")
343
+ decrypt_file(args.in_path, args.out_decrypted, password)
344
+ print(f"Decrypted to: {args.out_decrypted}")
345
+ return
346
+
347
+ # Generate mode
348
+ contexts = load_contexts(args.context, args.context_file)
349
+ model, tokenizer = load_model_and_tokenizer(args.model, device)
350
+
351
+ rows: List[Dict] = []
352
+ for ctx_id, ctx in contexts:
353
+ qs = generate_questions(
354
+ model=model,
355
+ tokenizer=tokenizer,
356
+ device=device,
357
+ context=ctx,
358
+ n=args.n,
359
+ max_new_tokens=args.max_new_tokens,
360
+ temperature=args.temperature,
361
+ top_p=args.top_p,
362
+ top_k=args.top_k,
363
+ seed=args.seed,
364
+ attempts=args.attempts,
365
+ )
366
+ row = {"context_id": ctx_id, "context": ctx}
367
+ for i, q in enumerate(qs, start=1):
368
+ row[f"q{i}"] = q
369
+ rows.append(row)
370
+
371
+ # Output
372
+ if args.out:
373
+ out_path = args.out
374
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
375
+ if args.format == "json":
376
+ write_json(out_path, rows)
377
+ elif args.format == "csv":
378
+ write_csv(out_path, rows, args.n)
379
+ else:
380
+ write_txt(out_path, rows, args.n)
381
+
382
+ if args.encrypt:
383
+ password = args.password or getpass.getpass("Enter password: ")
384
+ enc_path = out_path + ".enc"
385
+ encrypt_file(out_path, enc_path, password)
386
+ print(f"Saved: {out_path}")
387
+ print(f"Encrypted copy: {enc_path}")
388
+ else:
389
+ print(f"Saved: {out_path}")
390
+ else:
391
+ # Print to stdout in selected format
392
+ if args.format == "json":
393
+ print(json.dumps(rows, ensure_ascii=False, indent=2))
394
+ elif args.format == "csv":
395
+ # Minimal CSV to stdout
396
+ fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, args.n + 1)]
397
+ writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames)
398
+ writer.writeheader()
399
+ for r in rows:
400
+ writer.writerow(r)
401
+ else:
402
+ for r in rows:
403
+ print(f"[{r['context_id']}]")
404
+ print(r["context"].strip())
405
+ for i in range(1, args.n + 1):
406
+ print(f"{i}. {r[f'q{i}']}")
407
+ print()
408
+
409
+ if __name__ == "__main__":
410
+ main()