Yifei Wang commited on
Commit
9168fe7
·
1 Parent(s): fccac4e

Add RAG toggle app and dependencies

Browse files
README.md CHANGED
@@ -5,7 +5,7 @@ emoji: ✨
5
  colorFrom: indigo
6
  colorTo: purple
7
  sdk: gradio
8
- app_file: app.py
9
  pinned: false
10
  ---
11
  ## Project Structure (Updated)
 
5
  colorFrom: indigo
6
  colorTo: purple
7
  sdk: gradio
8
+ app_file: app_rag.py
9
  pinned: false
10
  ---
11
  ## Project Structure (Updated)
app_rag.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import traceback
4
+ import sys
5
+ from pathlib import Path
6
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))
7
+
8
+ import os
9
+ import queue
10
+ import re
11
+ import threading
12
+ import time
13
+ from functools import lru_cache
14
+ from pathlib import Path
15
+
16
+ import gradio as gr
17
+
18
+ from numen_scriptorium.inference.qwen import get_model_device, load_model, stream_generate
19
+
20
+
21
+ BASE_MODEL = os.getenv("NS_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
22
+ ADAPTER = os.getenv("NS_ADAPTER", "outputs/qwen2_5_7b_boh_qlora/best").strip() or None
23
+ USE_4BIT = os.getenv("NS_USE_4BIT", "1") == "1"
24
+ DEFAULT_INSTRUCTION = os.getenv("NS_DEFAULT_INSTRUCTION", "请将输入翻译为中文,并保持原文风格。")
25
+
26
+ _RUNTIME_LOADED = False
27
+ _ACTIVE_STOP_EVENT: threading.Event | None = None
28
+ _STOP_LOCK = threading.Lock()
29
+
30
+
31
+ @lru_cache(maxsize=1)
32
+ def _get_rag_resource_summary() -> str:
33
+ from infer_hybrid_RAG import rag_resource_summary
34
+
35
+ return rag_resource_summary()
36
+
37
+
38
+ def _format_mode_indicator(use_rag: bool) -> str:
39
+ if use_rag:
40
+ resources = _get_rag_resource_summary()
41
+ return (
42
+ "### Active mode\n"
43
+ "- **Mode:** `RAG (hybrid)`\n"
44
+ f"- **Resources:** `{resources}`"
45
+ )
46
+ return (
47
+ "### Active mode\n"
48
+ "- **Mode:** `Non-RAG (existing stream_generate pipeline)`\n"
49
+ f"- **Resources:** `base={BASE_MODEL}, adapter={ADAPTER or 'None'}, 4bit={USE_4BIT}`"
50
+ )
51
+
52
+
53
+ def _on_mode_toggled(use_rag: bool):
54
+ return _format_mode_indicator(use_rag)
55
+
56
+
57
+ def _is_rag_runtime_loaded() -> bool:
58
+ try:
59
+ from infer_hybrid_RAG import get_rag_runtime
60
+
61
+ return get_rag_runtime.cache_info().currsize > 0
62
+ except Exception:
63
+ return False
64
+
65
+
66
+ def _set_active_stop_event(stop_event: threading.Event | None):
67
+ global _ACTIVE_STOP_EVENT
68
+ lock = _STOP_LOCK
69
+ # During interpreter shutdown, module globals can be partially torn down.
70
+ # Fall back to a best-effort direct assignment instead of raising.
71
+ if lock is None:
72
+ _ACTIVE_STOP_EVENT = stop_event
73
+ return
74
+ try:
75
+ with lock:
76
+ _ACTIVE_STOP_EVENT = stop_event
77
+ except Exception:
78
+ _ACTIVE_STOP_EVENT = stop_event
79
+
80
+
81
+ def _request_stop():
82
+ lock = _STOP_LOCK
83
+ if lock is None:
84
+ event = _ACTIVE_STOP_EVENT
85
+ if event is not None:
86
+ event.set()
87
+ return
88
+ try:
89
+ with lock:
90
+ if _ACTIVE_STOP_EVENT is not None:
91
+ _ACTIVE_STOP_EVENT.set()
92
+ except Exception:
93
+ event = _ACTIVE_STOP_EVENT
94
+ if event is not None:
95
+ event.set()
96
+
97
+
98
+ def _on_stop_clicked():
99
+ _request_stop()
100
+ return _format_status(
101
+ stage="Stop requested",
102
+ loaded=_RUNTIME_LOADED,
103
+ device="unknown",
104
+ loading_percent="--",
105
+ error="Stop requested. Waiting for backend generation to halt.",
106
+ )
107
+
108
+
109
+ def _on_clear_clicked():
110
+ # Clear should also stop any in-flight generation to avoid concurrent
111
+ # updates from the stream generator after UI has been reset.
112
+ _request_stop()
113
+ return (
114
+ DEFAULT_INSTRUCTION,
115
+ "",
116
+ False,
117
+ "",
118
+ _format_status(stage="Idle", loaded=_RUNTIME_LOADED, device="unknown", loading_percent="0%"),
119
+ _format_mode_indicator(False),
120
+ "0.00s",
121
+ )
122
+
123
+
124
+ def _format_loading_percent(value: int) -> str:
125
+ return f"{max(0, min(100, int(value)))}%"
126
+
127
+
128
+ def _infer_example_label(instruction: str, user_input: str, idx: int) -> str:
129
+ lower_instruction = instruction.lower()
130
+ if "sun's design" in user_input.lower():
131
+ return "BoH EN→ZH (Sun's Design)"
132
+ if "velvet lesson" in user_input.lower() or "moth and dream" in lower_instruction:
133
+ return "Moth&Dream EN→ZH (Velvet Lesson)"
134
+ if "deposition" in lower_instruction:
135
+ return "EN Generation (Deposition)"
136
+ if "generate one entry" in lower_instruction or "catalog" in lower_instruction:
137
+ return "EN Generation (Catalog Entry)"
138
+ return f"Example {idx + 1}"
139
+
140
+
141
+ def _load_demo_examples():
142
+ """Load examples from demo_examples.txt / demo_example.txt.
143
+
144
+ Expected per block:
145
+ - python infer_qlora_qwen3_boh.py ...
146
+ - --instruction "..."
147
+ - --input "..."
148
+ - optional --max_new_tokens <int>
149
+ """
150
+ candidate_files = [
151
+ Path(__file__).resolve().parent / "demo_examples.txt",
152
+ Path(__file__).resolve().parent / "demo_example.txt",
153
+ ]
154
+ file_path = next((p for p in candidate_files if p.exists()), None)
155
+ if file_path is None:
156
+ return [], "⚠️ Examples file not found (expected demo_examples.txt)."
157
+
158
+ try:
159
+ raw = file_path.read_text(encoding="utf-8")
160
+ except Exception:
161
+ return [], "⚠️ Could not read examples file."
162
+
163
+ block_pattern = re.compile(
164
+ r"python\s+infer_qlora_qwen3_boh\.py(?P<body>.*?)(?=(?:\n\s*python\s+infer_qlora_qwen3_boh\.py)|\Z)",
165
+ re.DOTALL,
166
+ )
167
+ instruction_pattern = re.compile(r'--instruction\s+"(?P<instruction>.*?)"\s*`', re.DOTALL)
168
+ input_pattern = re.compile(r'--input\s+"(?P<input>.*?)"\s*`', re.DOTALL)
169
+ max_tokens_pattern = re.compile(r"--max_new_tokens\s+(?P<max_new_tokens>\d+)")
170
+
171
+ parsed = []
172
+ for idx, block in enumerate(block_pattern.finditer(raw)):
173
+ body = block.group("body")
174
+ instruction_match = instruction_pattern.search(body)
175
+ input_match = input_pattern.search(body)
176
+ if not instruction_match or not input_match:
177
+ continue
178
+
179
+ instruction = instruction_match.group("instruction").strip()
180
+ user_input = input_match.group("input").strip()
181
+ max_match = max_tokens_pattern.search(body)
182
+ max_new_tokens = int(max_match.group("max_new_tokens")) if max_match else None
183
+
184
+ parsed.append(
185
+ {
186
+ "label": _infer_example_label(instruction, user_input, idx),
187
+ "instruction": instruction,
188
+ "input": user_input,
189
+ "max_new_tokens": max_new_tokens,
190
+ "use_rag": False,
191
+ }
192
+ )
193
+
194
+ if not parsed:
195
+ return [], "⚠️ Failed to parse demo examples. Please check examples file format."
196
+
197
+ has_rag_example = any("rag" in ex["label"].lower() or ex.get("use_rag") for ex in parsed)
198
+ if not has_rag_example:
199
+ parsed.append(
200
+ {
201
+ "label": "RAG Example (hybrid terms)",
202
+ "instruction": "You are a translator. Translate English into Chinese while preserving lore style and preferred lore term mappings.",
203
+ "input": "In Emesa, the Sun-in-Splendour is named in a black corundum tablet beside the Grail and the Forge.",
204
+ "max_new_tokens": 384,
205
+ "use_rag": True,
206
+ }
207
+ )
208
+ return parsed, None
209
+
210
+
211
+ def _apply_example(example: dict):
212
+ max_tokens_update = (
213
+ example["max_new_tokens"] if example.get("max_new_tokens") is not None else gr.update()
214
+ )
215
+ use_rag = bool(example.get("use_rag", False))
216
+ return example["instruction"], example["input"], max_tokens_update, use_rag, _format_mode_indicator(use_rag)
217
+
218
+
219
+ def _format_status(
220
+ *,
221
+ stage: str,
222
+ loaded: bool,
223
+ device: str,
224
+ loading_percent: str | None = None,
225
+ elapsed: float | None = None,
226
+ error: str | None = None,
227
+ stream_chunks: int | None = None,
228
+ output_chars: int | None = None,
229
+ ):
230
+ lines = [
231
+ "### Model / System status",
232
+ f"- **Stage:** {stage}",
233
+ f"- **Model loaded:** {'✅ Yes' if loaded else '❌ No'}",
234
+ f"- **Device:** `{device}`",
235
+ f"- **Base model:** `{BASE_MODEL}`",
236
+ f"- **Adapter:** `{ADAPTER or 'None'}`",
237
+ f"- **4-bit quantization:** `{USE_4BIT}`",
238
+ ]
239
+ if loading_percent is not None:
240
+ lines.append(f"- **Model loading:** `{loading_percent}`")
241
+ if elapsed is not None:
242
+ lines.append(f"- **Time per request:** `{elapsed:.2f}s`")
243
+ if stream_chunks is not None:
244
+ lines.append(f"- **Stream chunks received:** `{stream_chunks}`")
245
+ if output_chars is not None:
246
+ lines.append(f"- **Output characters so far:** `{output_chars}`")
247
+ if error:
248
+ lines.append(f"- **Error:** ⚠️ {error}")
249
+ return "\n".join(lines)
250
+
251
+
252
+ @lru_cache(maxsize=1)
253
+ def get_runtime():
254
+ global _RUNTIME_LOADED
255
+ runtime = load_model(base_model=BASE_MODEL, lora_dir=ADAPTER, use_4bit=USE_4BIT)
256
+ _RUNTIME_LOADED = True
257
+ return runtime
258
+
259
+
260
+ def run_inference_stream(
261
+ instruction: str,
262
+ user_input: str,
263
+ max_new_tokens: int,
264
+ temperature: float,
265
+ top_p: float,
266
+ seed: int,
267
+ ):
268
+ set_active_stop = _set_active_stop_event
269
+ start = time.perf_counter()
270
+ device = "unknown"
271
+ stage = "Preparing request"
272
+ load_progress = 0
273
+ cleaned_instruction = instruction.strip() or DEFAULT_INSTRUCTION
274
+ cleaned_input = user_input.strip()
275
+ normalized_seed = None if seed is None or int(seed) < 0 else int(seed)
276
+ stop_event = threading.Event()
277
+ set_active_stop(stop_event)
278
+
279
+ if not cleaned_input:
280
+ msg = "⚠️ Please provide input text before running generation."
281
+ yield (
282
+ msg,
283
+ _format_status(
284
+ stage="Waiting for input",
285
+ loaded=_RUNTIME_LOADED,
286
+ device=device,
287
+ loading_percent=_format_loading_percent(load_progress),
288
+ ),
289
+ "0.00s",
290
+ )
291
+ set_active_stop(None)
292
+ return
293
+
294
+ try:
295
+ stage = "Loading model"
296
+ if _RUNTIME_LOADED:
297
+ tokenizer, model = get_runtime()
298
+ load_progress = 100
299
+ yield (
300
+ "",
301
+ _format_status(
302
+ stage="Model ready (cached)",
303
+ loaded=True,
304
+ device=device,
305
+ loading_percent=_format_loading_percent(load_progress),
306
+ ),
307
+ f"{time.perf_counter() - start:.2f}s",
308
+ )
309
+ else:
310
+ runtime_box: dict[str, tuple] = {}
311
+ err_box: dict[str, Exception] = {}
312
+
313
+ def _loader():
314
+ try:
315
+ runtime_box["runtime"] = get_runtime()
316
+ except Exception as exc:
317
+ err_box["error"] = exc
318
+
319
+ loader_thread = threading.Thread(target=_loader, daemon=True)
320
+ loader_thread.start()
321
+
322
+ load_progress = 3
323
+ while loader_thread.is_alive():
324
+ if stop_event.is_set():
325
+ elapsed = time.perf_counter() - start
326
+ yield (
327
+ "⚠️ Stop requested. Model loading may continue in background.",
328
+ _format_status(
329
+ stage="Stopped during model loading",
330
+ loaded=False,
331
+ device=device,
332
+ loading_percent=_format_loading_percent(load_progress),
333
+ elapsed=elapsed,
334
+ ),
335
+ f"{elapsed:.2f}s",
336
+ )
337
+ return
338
+
339
+ load_progress = min(95, load_progress + 4)
340
+ elapsed = time.perf_counter() - start
341
+ yield (
342
+ "",
343
+ _format_status(
344
+ stage=f"Loading model ({load_progress}%)",
345
+ loaded=False,
346
+ device=device,
347
+ loading_percent=_format_loading_percent(load_progress),
348
+ elapsed=elapsed,
349
+ ),
350
+ f"{elapsed:.2f}s",
351
+ )
352
+ time.sleep(0.2)
353
+
354
+ loader_thread.join()
355
+ if "error" in err_box:
356
+ raise err_box["error"]
357
+ tokenizer, model = runtime_box["runtime"]
358
+ load_progress = 100
359
+
360
+ device = get_model_device(model)
361
+
362
+ stage = "Tokenizing / preparing generation"
363
+ elapsed = time.perf_counter() - start
364
+ yield (
365
+ "",
366
+ _format_status(
367
+ stage=stage,
368
+ loaded=True,
369
+ device=device,
370
+ loading_percent=_format_loading_percent(load_progress),
371
+ elapsed=elapsed,
372
+ stream_chunks=0,
373
+ output_chars=0,
374
+ ),
375
+ f"{elapsed:.2f}s",
376
+ )
377
+
378
+ stage = "Generating"
379
+ partial = ""
380
+ chunk_count = 0
381
+ token_queue: queue.Queue[str | None] = queue.Queue()
382
+ error_queue: queue.Queue[Exception] = queue.Queue()
383
+
384
+ def _token_producer():
385
+ try:
386
+ for token in stream_generate(
387
+ tokenizer=tokenizer,
388
+ model=model,
389
+ instruction=cleaned_instruction,
390
+ user_input=cleaned_input,
391
+ max_new_tokens=max_new_tokens,
392
+ temperature=temperature,
393
+ top_p=top_p,
394
+ do_sample=True,
395
+ seed=normalized_seed,
396
+ stop_event=stop_event,
397
+ ):
398
+ token_queue.put(token)
399
+ except Exception as exc:
400
+ error_queue.put(exc)
401
+ finally:
402
+ token_queue.put(None)
403
+
404
+ producer = threading.Thread(target=_token_producer, daemon=True)
405
+ producer.start()
406
+
407
+ first_token_seen = False
408
+ while True:
409
+ if stop_event.is_set():
410
+ elapsed = time.perf_counter() - start
411
+ yield (
412
+ partial.strip(),
413
+ _format_status(
414
+ stage="Stopped by user",
415
+ loaded=True,
416
+ device=device,
417
+ loading_percent=_format_loading_percent(load_progress),
418
+ elapsed=elapsed,
419
+ stream_chunks=chunk_count,
420
+ output_chars=len(partial.strip()),
421
+ ),
422
+ f"{elapsed:.2f}s",
423
+ )
424
+ return
425
+
426
+ if not error_queue.empty():
427
+ raise error_queue.get()
428
+
429
+ try:
430
+ delta = token_queue.get(timeout=0.2)
431
+ except queue.Empty:
432
+ elapsed = time.perf_counter() - start
433
+ wait_stage = "Generating (waiting for first token)" if not first_token_seen else "Generating"
434
+ yield (
435
+ partial,
436
+ _format_status(
437
+ stage=wait_stage,
438
+ loaded=True,
439
+ device=device,
440
+ loading_percent=_format_loading_percent(load_progress),
441
+ elapsed=elapsed,
442
+ stream_chunks=chunk_count,
443
+ output_chars=len(partial),
444
+ ),
445
+ f"{elapsed:.2f}s",
446
+ )
447
+ continue
448
+
449
+ if delta is None:
450
+ if not error_queue.empty():
451
+ raise error_queue.get()
452
+ break
453
+
454
+ first_token_seen = True
455
+ chunk_count += 1
456
+ partial += delta
457
+ elapsed = time.perf_counter() - start
458
+ yield (
459
+ partial,
460
+ _format_status(
461
+ stage=stage,
462
+ loaded=True,
463
+ device=device,
464
+ loading_percent=_format_loading_percent(load_progress),
465
+ elapsed=elapsed,
466
+ stream_chunks=chunk_count,
467
+ output_chars=len(partial),
468
+ ),
469
+ f"{elapsed:.2f}s",
470
+ )
471
+
472
+ elapsed = time.perf_counter() - start
473
+ yield (
474
+ partial.strip(),
475
+ _format_status(
476
+ stage="Done",
477
+ loaded=True,
478
+ device=device,
479
+ loading_percent=_format_loading_percent(load_progress),
480
+ elapsed=elapsed,
481
+ stream_chunks=chunk_count,
482
+ output_chars=len(partial.strip()),
483
+ ),
484
+ f"{elapsed:.2f}s",
485
+ )
486
+ except Exception as e:
487
+ elapsed = time.perf_counter() - start
488
+ tb = traceback.format_exc()
489
+
490
+ print("=== Generation failure traceback ===")
491
+ print(tb)
492
+
493
+ err = f"{type(e).__name__}: {e}"
494
+ yield (
495
+ f"⚠️ Generation failed: {err}",
496
+ _format_status(
497
+ stage=stage,
498
+ loaded=_RUNTIME_LOADED,
499
+ device=device,
500
+ loading_percent=_format_loading_percent(load_progress),
501
+ elapsed=elapsed,
502
+ error=err,
503
+ ),
504
+ f"{elapsed:.2f}s",
505
+ )
506
+ finally:
507
+ set_active_stop(None)
508
+
509
+
510
+ def run_rag_inference_stream(
511
+ instruction: str,
512
+ user_input: str,
513
+ max_new_tokens: int,
514
+ temperature: float,
515
+ top_p: float,
516
+ seed: int,
517
+ ):
518
+ set_active_stop = _set_active_stop_event
519
+ start = time.perf_counter()
520
+ cleaned_instruction = instruction.strip() or DEFAULT_INSTRUCTION
521
+ cleaned_input = user_input.strip()
522
+ normalized_seed = None if seed is None or int(seed) < 0 else int(seed)
523
+ stop_event = threading.Event()
524
+ set_active_stop(stop_event)
525
+ resources = "(lazy-loaded)"
526
+
527
+ if not cleaned_input:
528
+ yield (
529
+ "⚠️ Please provide input text before running generation.",
530
+ _format_status(stage="Waiting for input", loaded=False, device="unknown", loading_percent="0%"),
531
+ "0.00s",
532
+ )
533
+ set_active_stop(None)
534
+ return
535
+
536
+ try:
537
+ stage = "Loading RAG pipeline"
538
+ yield (
539
+ "",
540
+ _format_status(stage=stage, loaded=_is_rag_runtime_loaded(), device="unknown", loading_percent="5%"),
541
+ f"{time.perf_counter() - start:.2f}s",
542
+ )
543
+
544
+ from infer_hybrid_RAG import rag_answer_stream, rag_resource_summary
545
+
546
+ resources = rag_resource_summary()
547
+ yield (
548
+ "",
549
+ _format_status(
550
+ stage="Retrieving (hybrid)",
551
+ loaded=_is_rag_runtime_loaded(),
552
+ device="unknown",
553
+ loading_percent="25%",
554
+ ),
555
+ f"{time.perf_counter() - start:.2f}s",
556
+ )
557
+
558
+ token_queue: queue.Queue[str | None] = queue.Queue()
559
+ error_queue: queue.Queue[Exception] = queue.Queue()
560
+
561
+ def _token_producer():
562
+ try:
563
+ for token in rag_answer_stream(
564
+ instruction=cleaned_instruction,
565
+ user_input=cleaned_input,
566
+ max_new_tokens=max_new_tokens,
567
+ temperature=temperature,
568
+ top_p=top_p,
569
+ do_sample=True,
570
+ seed=normalized_seed,
571
+ stop_event=stop_event,
572
+ ):
573
+ token_queue.put(token)
574
+ except Exception as exc:
575
+ error_queue.put(exc)
576
+ finally:
577
+ token_queue.put(None)
578
+
579
+ producer = threading.Thread(target=_token_producer, daemon=True)
580
+ producer.start()
581
+
582
+ partial = ""
583
+ chunk_count = 0
584
+ first_token_seen = False
585
+ while True:
586
+ if stop_event.is_set():
587
+ elapsed = time.perf_counter() - start
588
+ yield (
589
+ partial.strip(),
590
+ _format_status(
591
+ stage="Stopped by user (RAG)",
592
+ loaded=_is_rag_runtime_loaded(),
593
+ device="auto",
594
+ loading_percent="--",
595
+ elapsed=elapsed,
596
+ stream_chunks=chunk_count,
597
+ output_chars=len(partial.strip()),
598
+ ),
599
+ f"{elapsed:.2f}s",
600
+ )
601
+ return
602
+
603
+ if not error_queue.empty():
604
+ raise error_queue.get()
605
+
606
+ try:
607
+ delta = token_queue.get(timeout=0.2)
608
+ except queue.Empty:
609
+ elapsed = time.perf_counter() - start
610
+ wait_stage = (
611
+ "Generating with RAG (loading/retrieving...)"
612
+ if not first_token_seen
613
+ else "Generating with RAG"
614
+ )
615
+ yield (
616
+ partial,
617
+ _format_status(
618
+ stage=wait_stage,
619
+ loaded=_is_rag_runtime_loaded(),
620
+ device="auto",
621
+ loading_percent="90%" if first_token_seen else "60%",
622
+ elapsed=elapsed,
623
+ stream_chunks=chunk_count,
624
+ output_chars=len(partial),
625
+ ),
626
+ f"{elapsed:.2f}s",
627
+ )
628
+ continue
629
+
630
+ if delta is None:
631
+ if not error_queue.empty():
632
+ raise error_queue.get()
633
+ break
634
+
635
+ first_token_seen = True
636
+ chunk_count += 1
637
+ partial += delta
638
+ elapsed = time.perf_counter() - start
639
+ yield (
640
+ partial,
641
+ _format_status(
642
+ stage="Generating with RAG",
643
+ loaded=_is_rag_runtime_loaded(),
644
+ device="auto",
645
+ loading_percent="95%",
646
+ elapsed=elapsed,
647
+ stream_chunks=chunk_count,
648
+ output_chars=len(partial),
649
+ ),
650
+ f"{elapsed:.2f}s",
651
+ )
652
+
653
+ elapsed = time.perf_counter() - start
654
+ yield (
655
+ partial.strip(),
656
+ _format_status(
657
+ stage=f"Done (RAG) · {resources}",
658
+ loaded=_is_rag_runtime_loaded(),
659
+ device="auto",
660
+ loading_percent="100%",
661
+ elapsed=elapsed,
662
+ stream_chunks=chunk_count,
663
+ output_chars=len(partial.strip()),
664
+ ),
665
+ f"{elapsed:.2f}s",
666
+ )
667
+ except Exception as e:
668
+ elapsed = time.perf_counter() - start
669
+ err = f"{type(e).__name__}: {e}"
670
+ tb = traceback.format_exc()
671
+ print("=== RAG generation failure traceback ===")
672
+ print(tb)
673
+ yield (
674
+ f"⚠️ RAG generation failed: {err}",
675
+ _format_status(
676
+ stage="RAG failure",
677
+ loaded=False,
678
+ device="unknown",
679
+ loading_percent="--",
680
+ elapsed=elapsed,
681
+ error=err,
682
+ ),
683
+ f"{elapsed:.2f}s",
684
+ )
685
+ finally:
686
+ set_active_stop(None)
687
+
688
+
689
+ def run_inference_with_mode(
690
+ instruction: str,
691
+ user_input: str,
692
+ max_new_tokens: int,
693
+ temperature: float,
694
+ top_p: float,
695
+ seed: int,
696
+ use_rag: bool,
697
+ ):
698
+ # Routing note: checkbox OFF -> existing non-RAG stream_generate path,
699
+ # checkbox ON -> hybrid RAG retrieval + generation path.
700
+ if use_rag:
701
+ yield from run_rag_inference_stream(
702
+ instruction=instruction,
703
+ user_input=user_input,
704
+ max_new_tokens=max_new_tokens,
705
+ temperature=temperature,
706
+ top_p=top_p,
707
+ seed=seed,
708
+ )
709
+ return
710
+
711
+ yield from run_inference_stream(
712
+ instruction=instruction,
713
+ user_input=user_input,
714
+ max_new_tokens=max_new_tokens,
715
+ temperature=temperature,
716
+ top_p=top_p,
717
+ seed=seed,
718
+ )
719
+
720
+
721
+ with gr.Blocks(title="Numen Scriptorium Demo") as demo:
722
+ gr.Markdown("# ✨ Numen Scriptorium · HF Demo")
723
+ gr.Markdown(
724
+ "This demo can: (1) translate EN↔ZH with Book-of-Hours/Cultist-Simulator-like tone., and (2) rewrite/generate text with instructed tone and nouns.\n\n"
725
+ "For lore-like quality, load a matching LoRA adapter (base model alone is not enough).\n\n"
726
+ "**How to use**\n"
727
+ "1. Keep or edit the instruction.\n"
728
+ "2. Paste your input text.\n"
729
+ "3. Click **Run** to generate output."
730
+ )
731
+
732
+ with gr.Row():
733
+ with gr.Column(scale=3):
734
+ instruction = gr.Textbox(label="Instruction", value=DEFAULT_INSTRUCTION, lines=3)
735
+ user_input = gr.Textbox(label="Input", placeholder="在这里输入待翻译/待改写文本", lines=8)
736
+ use_rag = gr.Checkbox(label="Use RAG (hybrid)", value=False)
737
+ mode_panel = gr.Markdown(_format_mode_indicator(False), label="Inference mode")
738
+
739
+ with gr.Accordion("Advanced settings", open=False):
740
+ max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
741
+ temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature")
742
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
743
+ seed = gr.Number(
744
+ value=-1,
745
+ precision=0,
746
+ label="seed (-1 = random)",
747
+ info="Use a fixed integer seed for more reproducible sampling.",
748
+ )
749
+
750
+ gr.Markdown("### Examples")
751
+ gr.Markdown("Click an example button to auto-fill Instruction and Input.")
752
+ parsed_examples, example_warning = _load_demo_examples()
753
+ if example_warning:
754
+ gr.Markdown(example_warning)
755
+
756
+ with gr.Row():
757
+ for example in parsed_examples:
758
+ example_btn = gr.Button(example["label"], variant="secondary")
759
+ example_btn.click(
760
+ fn=lambda ex=example: _apply_example(ex),
761
+ inputs=None,
762
+ outputs=[instruction, user_input, max_new_tokens, use_rag, mode_panel],
763
+ )
764
+
765
+ with gr.Row():
766
+ run_btn = gr.Button("Run", variant="primary")
767
+ stop_btn = gr.Button("Stop")
768
+ clear_btn = gr.Button("Clear")
769
+
770
+ with gr.Column(scale=2):
771
+ output = gr.Markdown(label="Output", value="")
772
+ elapsed_text = gr.Textbox(label="Elapsed", value="0.00s", interactive=False)
773
+ status_panel = gr.Markdown(
774
+ _format_status(stage="Idle", loaded=False, device="unknown", loading_percent="0%"),
775
+ label="Model / System status",
776
+ )
777
+
778
+ use_rag.change(fn=_on_mode_toggled, inputs=[use_rag], outputs=[mode_panel])
779
+
780
+ run_event = run_btn.click(
781
+ fn=run_inference_with_mode,
782
+ inputs=[instruction, user_input, max_new_tokens, temperature, top_p, seed, use_rag],
783
+ outputs=[output, status_panel, elapsed_text],
784
+ )
785
+
786
+ stop_btn.click(fn=_on_stop_clicked, inputs=None, outputs=[status_panel], cancels=[run_event])
787
+
788
+ clear_btn.click(
789
+ fn=_on_clear_clicked,
790
+ inputs=None,
791
+ outputs=[instruction, user_input, use_rag, output, status_panel, mode_panel, elapsed_text],
792
+ cancels=[run_event],
793
+ )
794
+
795
+
796
+ if __name__ == "__main__":
797
+ demo.queue(default_concurrency_limit=1).launch()
build_vector_db.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import chromadb
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ def load_json(filepath):
7
+ if not os.path.exists(filepath):
8
+ print(f"[错误] 找不到文件: {filepath}")
9
+ return {}
10
+ with open(filepath, "r", encoding="utf-8") as f:
11
+ return json.load(f)
12
+
13
+ def build_vector_db():
14
+ # 推荐使用 m3e-base,对中文文本的检索效果非常好,且体积小
15
+ print("[1] 正在加载嵌入模型...")
16
+ embedder = SentenceTransformer('moka-ai/m3e-base', device='cuda')
17
+
18
+ print("[2] 初始化本地 Chroma 向量数据库...")
19
+ # 这会在当前目录下创建一个名为 "chroma_data" 的文件夹来持久化存储数据
20
+ chroma_client = chromadb.PersistentClient(path="./chroma_data")
21
+
22
+ # 创建或获取一个集合(Collection),相当于关系型数据库里的"表"
23
+ collection = chroma_client.get_or_create_collection(name="mansus_lore")
24
+
25
+ print("[3] 正在读取 JSON 数据...")
26
+ hours_data = load_json("data/hours_merged.json")
27
+ history_data = load_json("data/mansus_history_events_rag.json")
28
+
29
+ documents = [] # 存储纯文本块
30
+ metadatas = [] # 存储元数据(用于过滤和与图谱联动)
31
+ ids = [] # 存储唯一 ID
32
+
33
+ print("[4] 正在处理司辰 (Hours) 文本...")
34
+ for hour in hours_data.get("hours", []):
35
+ hour_id = hour.get("id", "")
36
+ desc = hour.get("desc_cn", "")
37
+ name = hour.get("name_cn", "")
38
+
39
+ if not hour_id or not desc:
40
+ continue
41
+
42
+ documents.append(f"【司辰档案】{name}:{desc}")
43
+ metadatas.append({
44
+ "type": "hour",
45
+ "entity_id": hour_id,
46
+ "entity_name": name
47
+ })
48
+ ids.append(f"doc_{hour_id}")
49
+
50
+ print("[5] 正在处理漫宿历史事件 (History Events) 文本...")
51
+ for era_name, era_obj in history_data.items():
52
+ for event_title, event_obj in era_obj.get("events", {}).items():
53
+ # 优先使用我们之前用大模型生成的精炼摘要
54
+ summary = event_obj.get("summary_cn", "")
55
+ if not summary:
56
+ # 如果没有摘要,就把原段落拼起来
57
+ summary = "\n".join(event_obj.get("paragraphs", []))
58
+
59
+ if summary.strip():
60
+ documents.append(f"【历史事件】{era_name} - {event_title}:\n{summary}")
61
+ metadatas.append({
62
+ "type": "event",
63
+ "era": era_name,
64
+ "event_title": event_title
65
+ })
66
+ ids.append(f"doc_event_{event_title}")
67
+
68
+ # 处理子事件 (h4)
69
+ for sub_title, sub_obj in event_obj.get("subevents", {}).items():
70
+ sub_summary = sub_obj.get("summary_cn", "")
71
+ if not sub_summary:
72
+ sub_summary = "\n".join(sub_obj.get("paragraphs", []))
73
+
74
+ if sub_summary.strip():
75
+ documents.append(f"【历史事件】{era_name} - {event_title} ({sub_title}):\n{sub_summary}")
76
+ metadatas.append({
77
+ "type": "subevent",
78
+ "era": era_name,
79
+ "parent_event": event_title,
80
+ "event_title": sub_title
81
+ })
82
+ ids.append(f"doc_subevent_{sub_title}")
83
+
84
+ print(f"[6] 开始对 {len(documents)} 个文本块进行向量化并存入数据库 ...")
85
+ # 批量进行向量化
86
+ embeddings = embedder.encode(documents, show_progress_bar=True).tolist()
87
+
88
+ # 批量存入 ChromaDB
89
+ # 注意:如果数据量上万,建议分批次(Batch)存入。这里数据量在几百条左右,可以直接一次性插入。
90
+ collection.upsert(
91
+ documents=documents,
92
+ embeddings=embeddings,
93
+ metadatas=metadatas,
94
+ ids=ids
95
+ )
96
+
97
+ print("[7] 向量库构建完成!数据已持久化保存在 ./chroma_data 目录。")
98
+
99
+ if __name__ == "__main__":
100
+ build_vector_db()
infer_hybrid_RAG.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from threading import Event
9
+ from typing import Iterator
10
+
11
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))
12
+
13
+ from numen_scriptorium.inference.qwen import generate, load_model, stream_generate
14
+
15
+
16
+ RAG_BASE_MODEL = os.getenv("NS_RAG_BASE_MODEL", os.getenv("NS_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct"))
17
+ RAG_ADAPTER = os.getenv("NS_RAG_ADAPTER", os.getenv("NS_ADAPTER", "ICGenAIShare06/boh-qlora-adapter/best")).strip() or None
18
+ RAG_USE_4BIT = os.getenv("NS_RAG_USE_4BIT", os.getenv("NS_USE_4BIT", "1")) == "1"
19
+ RAG_CHROMA_DIR = os.getenv("NS_RAG_CHROMA_DIR", "chroma_data")
20
+ RAG_COLLECTION = os.getenv("NS_RAG_COLLECTION", "mansus_lore")
21
+ RAG_ALIAS_FILE = os.getenv("NS_RAG_ALIAS_FILE", "data/hours_merged.json")
22
+ RAG_EMBED_MODEL = os.getenv("NS_RAG_EMBED_MODEL", "moka-ai/m3e-base")
23
+
24
+
25
+ def _resolve_repo_path(path_like: str) -> Path:
26
+ p = Path(path_like)
27
+ if p.exists():
28
+ return p
29
+ return Path(__file__).resolve().parent / p
30
+
31
+
32
+ class HybridRetriever:
33
+ def __init__(self, chroma_dir: str, collection_name: str, alias_file: str, embed_model: str):
34
+ import chromadb
35
+ import torch
36
+ from sentence_transformers import SentenceTransformer
37
+
38
+ chroma_path = _resolve_repo_path(chroma_dir)
39
+ self.chroma_client = chromadb.PersistentClient(path=str(chroma_path))
40
+ self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ self.embedder = SentenceTransformer(embed_model, device=device)
44
+ self.alias_map = self._load_alias_map(alias_file)
45
+
46
+ @staticmethod
47
+ def _load_alias_map(alias_file: str) -> dict[str, str]:
48
+ path = _resolve_repo_path(alias_file)
49
+ if not path.exists():
50
+ return {}
51
+ with path.open("r", encoding="utf-8") as f:
52
+ hours_data = json.load(f)
53
+
54
+ alias_map: dict[str, str] = {}
55
+ for hour in hours_data.get("hours", []):
56
+ standard_name = hour.get("name_cn", "")
57
+ for alias in hour.get("aliases", []):
58
+ alias = alias.strip()
59
+ if alias:
60
+ alias_map[alias] = standard_name
61
+ return alias_map
62
+
63
+ def retrieve_dict(self, query: str, stop_event: Event | None = None) -> dict[str, str]:
64
+ rag_dict: dict[str, str] = {}
65
+ lowered = query.lower()
66
+ for alias, std_name in self.alias_map.items():
67
+ if stop_event is not None and stop_event.is_set():
68
+ break
69
+ if len(alias) <= 2:
70
+ continue
71
+ if alias.lower() in lowered:
72
+ rag_dict[alias] = std_name
73
+ return rag_dict
74
+
75
+ def retrieve_context(self, query: str, top_k: int = 1) -> str:
76
+ query_embedding = self.embedder.encode([query]).tolist()
77
+ results = self.collection.query(query_embeddings=query_embedding, n_results=top_k)
78
+ docs = results.get("documents", [[]])
79
+ vector_context = docs[0] if docs else []
80
+ return "\n".join(vector_context)
81
+
82
+
83
+ @lru_cache(maxsize=1)
84
+ def get_hybrid_retriever() -> HybridRetriever:
85
+ return HybridRetriever(
86
+ chroma_dir=RAG_CHROMA_DIR,
87
+ collection_name=RAG_COLLECTION,
88
+ alias_file=RAG_ALIAS_FILE,
89
+ embed_model=RAG_EMBED_MODEL,
90
+ )
91
+
92
+
93
+ @lru_cache(maxsize=1)
94
+ def get_rag_runtime():
95
+ return load_model(base_model=RAG_BASE_MODEL, lora_dir=RAG_ADAPTER, use_4bit=RAG_USE_4BIT)
96
+
97
+
98
+ def rag_resource_summary() -> str:
99
+ return (
100
+ f"base={RAG_BASE_MODEL}, adapter={RAG_ADAPTER or 'None'}, "
101
+ f"embed={RAG_EMBED_MODEL}, chroma={RAG_CHROMA_DIR}/{RAG_COLLECTION}, alias={RAG_ALIAS_FILE}"
102
+ )
103
+
104
+
105
+ def prepare_rag_input(
106
+ user_input: str,
107
+ stop_event: Event | None = None,
108
+ top_k: int = 1,
109
+ ) -> tuple[str, dict[str, str], str]:
110
+ retriever = get_hybrid_retriever()
111
+ rag_dict = retriever.retrieve_dict(user_input, stop_event=stop_event)
112
+ vector_context = ""
113
+ if stop_event is None or not stop_event.is_set():
114
+ try:
115
+ vector_context = retriever.retrieve_context(user_input, top_k=top_k)
116
+ except Exception:
117
+ vector_context = ""
118
+
119
+ injected_text = user_input
120
+ for eng_term, cn_term in rag_dict.items():
121
+ if stop_event is not None and stop_event.is_set():
122
+ break
123
+ if eng_term in injected_text:
124
+ injected_text = injected_text.replace(eng_term, f"{eng_term}({cn_term})")
125
+ return injected_text, rag_dict, vector_context
126
+
127
+
128
+ def _build_rag_instruction(base_instruction: str, rag_dict: dict[str, str], vector_context: str) -> str:
129
+ glossary = "\n".join(f"- {k} -> {v}" for k, v in rag_dict.items()) or "- (no matched terms)"
130
+ context = vector_context.strip() or "(no retrieved context)"
131
+ return (
132
+ f"{base_instruction.strip()}\n\n"
133
+ "[RAG glossary: use these preferred translations when relevant]\n"
134
+ f"{glossary}\n\n"
135
+ "[RAG retrieved background context: reference only, do not copy verbatim]\n"
136
+ f"{context}"
137
+ )
138
+
139
+
140
+ def rag_answer(
141
+ instruction: str,
142
+ user_input: str,
143
+ *,
144
+ max_new_tokens: int = 512,
145
+ temperature: float = 0.3,
146
+ top_p: float = 0.85,
147
+ do_sample: bool = True,
148
+ seed: int | None = None,
149
+ stop_event: Event | None = None,
150
+ ) -> str:
151
+ if stop_event is not None and stop_event.is_set():
152
+ return ""
153
+ injected_text, rag_dict, vector_context = prepare_rag_input(user_input, stop_event=stop_event)
154
+ if stop_event is not None and stop_event.is_set():
155
+ return ""
156
+
157
+ tokenizer, model = get_rag_runtime()
158
+ rag_instruction = _build_rag_instruction(instruction, rag_dict, vector_context)
159
+ return generate(
160
+ tokenizer=tokenizer,
161
+ model=model,
162
+ instruction=rag_instruction,
163
+ user_input=injected_text,
164
+ max_new_tokens=max_new_tokens,
165
+ temperature=temperature,
166
+ top_p=top_p,
167
+ do_sample=do_sample,
168
+ seed=seed,
169
+ )
170
+
171
+
172
+ def rag_answer_stream(
173
+ instruction: str,
174
+ user_input: str,
175
+ *,
176
+ max_new_tokens: int = 512,
177
+ temperature: float = 0.3,
178
+ top_p: float = 0.85,
179
+ do_sample: bool = True,
180
+ seed: int | None = None,
181
+ stop_event: Event | None = None,
182
+ ) -> Iterator[str]:
183
+ if stop_event is not None and stop_event.is_set():
184
+ return
185
+ injected_text, rag_dict, vector_context = prepare_rag_input(user_input, stop_event=stop_event)
186
+ if stop_event is not None and stop_event.is_set():
187
+ return
188
+
189
+ tokenizer, model = get_rag_runtime()
190
+ rag_instruction = _build_rag_instruction(instruction, rag_dict, vector_context)
191
+ yield from stream_generate(
192
+ tokenizer=tokenizer,
193
+ model=model,
194
+ instruction=rag_instruction,
195
+ user_input=injected_text,
196
+ max_new_tokens=max_new_tokens,
197
+ temperature=temperature,
198
+ top_p=top_p,
199
+ do_sample=do_sample,
200
+ seed=seed,
201
+ stop_event=stop_event,
202
+ )
203
+
204
+
205
+ if __name__ == "__main__":
206
+ sample_instruction = (
207
+ "You are a translator. Translate the English text into Chinese and keep lore-related style and terms coherent."
208
+ )
209
+ sample_input = (
210
+ "In the city of Emesa, Elagabalus lies beneath black corundum, and the Sun-in-Splendour watches in silence."
211
+ )
212
+ print(rag_answer(sample_instruction, sample_input, max_new_tokens=200))
kg_merge.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def load_json(filepath):
5
+ if not os.path.exists(filepath):
6
+ print(f"[错误] 找不到文件: {filepath}")
7
+ return {}
8
+ with open(filepath, "r", encoding="utf-8") as f:
9
+ return json.load(f)
10
+
11
+ def build_knowledge_graph():
12
+ print("[1] 正在加载数据...")
13
+ hours_data = load_json("data/hours_merged.json")
14
+ history_data = load_json("data/mansus_history_events_rag.json")
15
+
16
+ triplets = []
17
+ alias_map = {}
18
+
19
+ print("[2] 正在解析司辰实体,提取内部关系 (起源、派系)...")
20
+
21
+ hours_list = hours_data.get("hours", [])
22
+ for hour in hours_list:
23
+ hour_id = hour.get("id", "")
24
+ hour_name = hour.get("name_cn", "")
25
+ if not hour_id:
26
+ continue
27
+
28
+ # 提取起源 (HAS_ORIGIN)
29
+ for origin in hour.get("origin", []):
30
+ triplets.append({
31
+ "head_id": hour_id, "head_name": hour_name,
32
+ "relation": "HAS_ORIGIN",
33
+ "tail_id": f"origin.{origin}", "tail_name": origin
34
+ })
35
+
36
+ # 提取派系 (BELONGS_TO)
37
+ for faction in hour.get("factions", []):
38
+ triplets.append({
39
+ "head_id": hour_id, "head_name": hour_name,
40
+ "relation": "BELONGS_TO",
41
+ "tail_id": f"faction.{faction}", "tail_name": faction
42
+ })
43
+
44
+ # 构建倒排索引映射字典,用于后续在历史文本中“抓取”司辰
45
+ for alias in hour.get("aliases", []):
46
+ if alias.strip():
47
+ # 记录别名对应的司辰 ID 和标准名称
48
+ alias_map[alias.strip()] = {"id": hour_id, "name": hour_name}
49
+
50
+ print(f" -> 提取了 {len(alias_map)} 个别名用于实体链接匹配。")
51
+
52
+ print("[3] 正在扫描历史事件,建立事件参与关系 (PARTICIPATED_IN)...")
53
+ # 遍历漫宿历史的每一个时代和事件
54
+ for era_name, era_obj in history_data.items():
55
+ events = era_obj.get("events", {})
56
+
57
+ for event_title, event_obj in events.items():
58
+ # 将主事件的段落和摘要拼成一段完整文本用于检索
59
+ texts_to_search = [event_obj.get("summary_cn", "")] #+ event_obj.get("paragraphs", [])
60
+ full_text = "\n".join(texts_to_search)
61
+
62
+ # 使用别名映射表在文本中寻找司辰的踪迹
63
+ matched_hours = set()
64
+ for alias, hour_info in alias_map.items():
65
+ if alias in full_text:
66
+ matched_hours.add((hour_info["id"], hour_info["name"]))
67
+
68
+ # 如果找到,则生成参与事件的三元组
69
+ for h_id, h_name in matched_hours:
70
+ triplets.append({
71
+ "head_id": h_id, "head_name": h_name,
72
+ "relation": "PARTICIPATED_IN",
73
+ "tail_id": f"event.{event_title}", "tail_name": event_title
74
+ })
75
+
76
+ # 同样地,扫描子事件 (h4)
77
+ for sub_title, sub_obj in event_obj.get("subevents", {}).items():
78
+ sub_texts = sub_obj.get("paragraphs", []) + [sub_obj.get("summary_cn", "")]
79
+ sub_full_text = "\n".join(sub_texts)
80
+
81
+ sub_matched = set()
82
+ for alias, hour_info in alias_map.items():
83
+ if alias in sub_full_text:
84
+ sub_matched.add((hour_info["id"], hour_info["name"]))
85
+
86
+ for h_id, h_name in sub_matched:
87
+ triplets.append({
88
+ "head_id": h_id, "head_name": h_name,
89
+ "relation": "PARTICIPATED_IN",
90
+ "tail_id": f"event.{sub_title}", "tail_name": sub_title
91
+ })
92
+
93
+ print(f"[4] 构建完成!共生成 {len(triplets)} 条知识图谱三元组边。")
94
+
95
+ output_file = "kg_triplets.json"
96
+ with open(output_file, "w", encoding="utf-8") as f:
97
+ json.dump(triplets, f, ensure_ascii=False, indent=2)
98
+ print(f"[5] 数据已保存至 {output_file}")
99
+
100
+ if __name__ == "__main__":
101
+ build_knowledge_graph()
requirements.txt CHANGED
@@ -4,4 +4,6 @@ transformers>=4.45.0
4
  peft>=0.12.0
5
  accelerate>=0.33.0
6
  sentencepiece>=0.2.0
7
- bitsandbytes
 
 
 
4
  peft>=0.12.0
5
  accelerate>=0.33.0
6
  sentencepiece>=0.2.0
7
+ bitsandbytes
8
+ chromadb>=0.5.0
9
+ sentence-transformers>=3.0.1
src/numen_scriptorium/inference/qwen.py CHANGED
@@ -61,7 +61,20 @@ def load_model(base_model: str, lora_dir: str | None, use_4bit: bool = True):
61
 
62
  model = base
63
  if lora_dir:
64
- model = PeftModel.from_pretrained(base, _resolve_path(lora_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  model.eval()
67
  return tokenizer, model
 
61
 
62
  model = base
63
  if lora_dir:
64
+ resolved_lora = _resolve_path(lora_dir)
65
+ try:
66
+ model = PeftModel.from_pretrained(base, resolved_lora)
67
+ except ValueError as exc:
68
+ # Common misconfiguration: passing a ".../best" suffix when the
69
+ # adapter files are actually stored at the repo root.
70
+ # Try a graceful fallback before surfacing the original error.
71
+ lora_text = str(lora_dir).rstrip("/\\")
72
+ if lora_text.endswith("/best") or lora_text.endswith("\\best"):
73
+ parent_lora = lora_text.rsplit("/", 1)[0].rsplit("\\", 1)[0]
74
+ resolved_parent = _resolve_path(parent_lora)
75
+ model = PeftModel.from_pretrained(base, resolved_parent)
76
+ else:
77
+ raise exc
78
 
79
  model.eval()
80
  return tokenizer, model
summarise_manus.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import re
5
+ import requests
6
+ from bs4 import BeautifulSoup
7
+ from typing import Dict, List, Any, Optional
8
+
9
+ from google import genai
10
+ from google.genai import types
11
+ import re
12
+ # ========= 配置 =========
13
+
14
+ WIKI_URL = "https://mansus.huijiwiki.com/wiki/%E6%BC%AB%E5%AE%BF%E5%8E%86%E5%8F%B2"
15
+ OUTPUT_JSON = "mansus_history_events_rag.json"
16
+
17
+ os.environ.get("MY_API_KEY")
18
+ client = genai.Client()
19
+ GEMINI_MODEL = "gemini-2.5-flash"
20
+
21
+ # ========= 工具函数 =========
22
+
23
+
24
+ HTML_CACHE_PATH = "data/mansus_history.html"
25
+
26
+ def fetch_html(url: str) -> str:
27
+ if os.path.exists(HTML_CACHE_PATH):
28
+ with open(HTML_CACHE_PATH, "r", encoding="utf-8") as f:
29
+ return f.read()
30
+
31
+ headers = {
32
+ #add your header here
33
+ }
34
+ resp = requests.get(url, headers=headers, timeout=20)
35
+ resp.raise_for_status()
36
+ resp.encoding = resp.apparent_encoding
37
+ html = resp.text
38
+
39
+ os.makedirs(os.path.dirname(HTML_CACHE_PATH) or ".", exist_ok=True)
40
+ with open(HTML_CACHE_PATH, "w", encoding="utf-8") as f:
41
+ f.write(html)
42
+
43
+ return html
44
+
45
+ def parse_article_structure(html: str) -> Dict[str, Any]:
46
+ soup = BeautifulSoup(html, "html.parser")
47
+ article = soup.find("article", class_="wiki-body-section", role="main")
48
+ if not article:
49
+ raise RuntimeError("Cannot find target <article> section.")
50
+
51
+ data: Dict[str, Any] = {}
52
+
53
+ # 1. 预设“引言”状态:这样在遇到第一个 <h2> 之前出现的所有 <p> 标签,
54
+ # 都会被自动接住,并归类到“漫宿历史与时代划分”这个伪事件中。
55
+ current_era = "引言"
56
+ current_h3 = "漫宿历史与时代划分"
57
+ current_h4 = None
58
+
59
+ data[current_era] = {
60
+ "title": current_era,
61
+ "events": {
62
+ current_h3: {
63
+ "level": "h3",
64
+ "paragraphs": [],
65
+ "subevents": {}
66
+ }
67
+ }
68
+ }
69
+
70
+ # 开始遍历 DOM 树
71
+ for el in article.descendants:
72
+ if not getattr(el, "name", None):
73
+ continue
74
+ name = el.name.lower()
75
+
76
+ if name == "h2":
77
+ # 遇到新的 h2,切换时代
78
+ current_era = el.get_text(strip=True)
79
+ data.setdefault(current_era, {"title": current_era, "events": {}})
80
+ current_h3 = None
81
+ current_h4 = None
82
+
83
+ elif name == "h3":
84
+ if not current_era:
85
+ continue
86
+ current_h3 = el.get_text(strip=True)
87
+ current_h4 = None
88
+ data[current_era]["events"].setdefault(
89
+ current_h3,
90
+ {"level": "h3", "paragraphs": [], "subevents": {}}
91
+ )
92
+
93
+ elif name == "h4":
94
+ if not current_era or not current_h3:
95
+ continue
96
+ current_h4 = el.get_text(strip=True)
97
+ data[current_era]["events"][current_h3]["subevents"].setdefault(
98
+ current_h4,
99
+ {"level": "h4", "paragraphs": []}
100
+ )
101
+
102
+ elif name == "p":
103
+ if not current_era or not current_h3:
104
+ continue
105
+
106
+ text = el.get_text(strip=True)
107
+ if not text:
108
+ continue
109
+
110
+ text = re.sub(r'\[\d+\]', '', text)
111
+
112
+ event_obj = data[current_era]["events"][current_h3]
113
+ if current_h4:
114
+ event_obj["subevents"][current_h4]["paragraphs"].append(text)
115
+ else:
116
+ event_obj["paragraphs"].append(text)
117
+
118
+ # 2. 后置清理:遍历提取到的数据,剔除没有任何段落内容的“空壳”节点
119
+ cleaned_data = {}
120
+ for era, era_obj in data.items():
121
+ valid_events = {}
122
+ for h3_title, event_obj in era_obj["events"].items():
123
+ has_h3_paras = len(event_obj["paragraphs"]) > 0
124
+
125
+ # 顺便清理空的 h4 子事件
126
+ valid_subevents = {}
127
+ for h4_title, sub_obj in event_obj["subevents"].items():
128
+ if len(sub_obj["paragraphs"]) > 0:
129
+ valid_subevents[h4_title] = sub_obj
130
+ event_obj["subevents"] = valid_subevents
131
+
132
+ # 只要 h3 自身有段落,或者其子节点 h4 有段落,就视为有效事件并保留
133
+ if has_h3_paras or len(valid_subevents) > 0:
134
+ valid_events[h3_title] = event_obj
135
+
136
+ # 只要这个大时代 (h2) 下存在有效的事件,就保留整个大时代
137
+ if len(valid_events) > 0:
138
+ era_obj["events"] = valid_events
139
+ cleaned_data[era] = era_obj
140
+ return cleaned_data
141
+
142
+
143
+ def is_conflict_or_death_event(title: str, paragraphs: List[str]) -> bool:
144
+ """
145
+ 粗略判断是否是“司辰斗争 / 死亡”相关重大事件,用于决定摘要长度。
146
+ 可以根据需要扩展关键词。
147
+ """
148
+ text = title + "\n" + "\n".join(paragraphs)
149
+ keywords = [
150
+ "覆石之战", "太阳大战", "大战", "战争",
151
+ "被", "杀死", "斩杀", "粉碎", "饮干",
152
+ "除名", "分裂", "死亡", "陨落",'毁灭','击败','猎杀'
153
+ ]
154
+ # 简单规则:出现“战”“大战”等高风险词,或者“被…杀死/斩杀”等
155
+ for kw in keywords:
156
+ if kw in text:
157
+ return True
158
+ return False
159
+
160
+
161
+ def summarise_event_text(
162
+ era: str,
163
+ title: str,
164
+ paragraphs: List[str],
165
+ is_conflict: bool
166
+ ) -> str:
167
+ full_text = "\n\n".join(paragraphs)
168
+
169
+ if is_conflict:
170
+ length_hint = "请写 4~6 句中文摘要,适当具体描述关键冲突、参与者与结果。"
171
+ else:
172
+ length_hint = "请写 2~4 句中文摘要,突出关键参与者、起因与后果。"
173
+
174
+ # ========== 新增修改 ==========
175
+ # 强化 Prompt,严禁直接摘抄,以绕过 Recitation 拦截
176
+ system_prompt = (
177
+ "你是一个世界观设定编辑,现在要为漫宿相关的历史事件生成适合 RAG 的精炼摘要。\n"
178
+ "总体要求:\n"
179
+ "1. 使用中文输出。\n"
180
+ "2. 保持信息密度高,不写旁白、不写对白,不编造新设定。\n"
181
+ "3. 尽量保留关键参与者(司辰/派系/起源)、事件起因与影响。\n"
182
+ "4. 【极其重要】绝对不可使用引号原样摘抄原文的词句!必须完全使用你自己的语言进行转述(Paraphrase),否则会被判定为抄袭。\n"
183
+ )
184
+
185
+ user_prompt = (
186
+ f"时代(h2):{era}\n"
187
+ f"事件标题:{title}\n\n"
188
+ f"原始段落:\n{full_text}\n\n"
189
+ f"{length_hint}"
190
+ )
191
+
192
+ # 放宽安全限制
193
+ safety_settings = [
194
+ types.SafetySetting(
195
+ category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
196
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
197
+ ),
198
+ types.SafetySetting(
199
+ category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
200
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
201
+ ),
202
+ types.SafetySetting(
203
+ category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
204
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
205
+ ),
206
+ types.SafetySetting(
207
+ category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
208
+ threshold=types.HarmBlockThreshold.BLOCK_NONE,
209
+ ),
210
+ ]
211
+
212
+ resp = client.models.generate_content(
213
+ model=GEMINI_MODEL,
214
+ contents=user_prompt,
215
+ config=types.GenerateContentConfig(
216
+ system_instruction=system_prompt,
217
+ temperature=0.4,
218
+ max_output_tokens=2048,
219
+ safety_settings=safety_settings
220
+ )
221
+ )
222
+
223
+ # ========== 新增修改 ==========
224
+ # 强制诊断输出:如果文本被截断,告诉你到底是撞了什么拦截墙
225
+ if resp.candidates:
226
+ finish_reason = resp.candidates[0].finish_reason.name
227
+ if finish_reason != "STOP":
228
+ print(f"\n[拦截警告] 事件 '{title}' 被意外截断!原因代码: {finish_reason}")
229
+ # 如果原因是 RECITATION,说明模型还是照抄了;如果是 SAFETY,说明还有别的敏感词。
230
+
231
+ return resp.text.strip()
232
+
233
+ def build_rag_json(structured: Dict[str, Any]) -> Dict[str, Any]:
234
+ """
235
+ 输出结构:
236
+ {
237
+ era_h2: {
238
+ "title":...,
239
+ "events": {
240
+ h3_title: {
241
+ "level": "h3",
242
+ "paragraphs": [...],
243
+ "summary_cn": "...",
244
+ "subevents": {
245
+ h4_title: {
246
+ "level": "h4",
247
+ "paragraphs": [...],
248
+ "summary_cn": "..."
249
+ }
250
+ }
251
+ }
252
+ }
253
+ }
254
+ }
255
+ """
256
+ rag = {}
257
+
258
+ for era, era_obj in structured.items():
259
+ rag[era] = {"title": era_obj["title"], "events": {}}
260
+ for h3_title, event_obj in era_obj["events"].items():
261
+ paragraphs_h3 = event_obj.get("paragraphs", [])
262
+ subevents = event_obj.get("subevents", {})
263
+
264
+ # 先 summarise h3 主事件本身
265
+ event_entry = {
266
+ "level": "h3",
267
+ "paragraphs": paragraphs_h3,
268
+ "summary_cn": ""
269
+ }
270
+ if paragraphs_h3:
271
+ is_conflict = is_conflict_or_death_event(h3_title, paragraphs_h3)
272
+ try:
273
+ summary = summarise_event_text(era, h3_title, paragraphs_h3, is_conflict)
274
+ time.sleep(1.0)
275
+ except Exception as e:
276
+ print(f"[WARN] summarise failed for {era} / {h3_title}: {e}")
277
+ summary = ""
278
+ event_entry["summary_cn"] = summary
279
+
280
+ # 再 summarise 每个 h4 子事件
281
+ subevents_out = {}
282
+ for h4_title, sub_obj in subevents.items():
283
+ paras_h4 = sub_obj.get("paragraphs", [])
284
+ if not paras_h4:
285
+ continue
286
+ is_conflict_sub = is_conflict_or_death_event(h4_title, paras_h4)
287
+ try:
288
+ summary_h4 = summarise_event_text(era, h4_title, paras_h4, is_conflict_sub)
289
+ time.sleep(1.0)
290
+ except Exception as e:
291
+ print(f"[WARN] summarise failed for {era} / {h3_title} / {h4_title}: {e}")
292
+ summary_h4 = ""
293
+ subevents_out[h4_title] = {
294
+ "level": "h4",
295
+ "paragraphs": paras_h4,
296
+ "summary_cn": summary_h4
297
+ }
298
+
299
+ event_entry["subevents"] = subevents_out
300
+ rag[era]["events"][h3_title] = event_entry
301
+
302
+ return rag
303
+
304
+
305
+ def main():
306
+ print("[1] Fetching page...")
307
+ html = fetch_html(WIKI_URL)
308
+
309
+ print("[2] Parsing article structure (h2/h3/h4/p)...")
310
+ structured = parse_article_structure(html)
311
+
312
+ print("[3] Summarising events via Gemini (with conflict-aware length)...")
313
+ rag_json = build_rag_json(structured)
314
+
315
+ print(f"[4] Saving JSON to {OUTPUT_JSON}...")
316
+ os.makedirs(os.path.dirname(OUTPUT_JSON) or ".", exist_ok=True)
317
+ with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
318
+ json.dump(rag_json, f, ensure_ascii=False, indent=2)
319
+
320
+ print("Done.")
321
+
322
+
323
+ if __name__ == "__main__":
324
+ main()