TheOpenMachine commited on
Commit
edf2c2e
·
verified ·
1 Parent(s): 7594205

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
  *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Local/private artifacts
2
+ __pycache__/
3
+ *.pyc
4
+ .env
5
+ *.log
6
+ lulu_chats/
7
+ luluv2_chats/
8
+ *lulu_memory*.json
9
+ private_artifacts/
10
+ checkpoints/
11
+ runs/
README.md CHANGED
@@ -1,3 +1,66 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ library_name: pytorch
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - text-generation
8
+ - bfloat16
9
+ - inference-only
10
+ - local-inference
11
  ---
12
+
13
+ # LULUV2 native-bf16 local inference package
14
+
15
+ This repository is prepared as an inference-only package for a native-bf16 LULUV2 checkpoint.
16
+ It is designed so users can run the model directly in native bfloat16 without an extra conversion step.
17
+
18
+ ## What is included
19
+
20
+ - `luluv2_inference_runtime.py` — stripped runtime loader and model architecture needed for inference only.
21
+ - `luluv2_live_inference.py` — streaming inference engine.
22
+ - `luluv2_optimized_engine.py` — optimized local inference engine with cache paths.
23
+ - `app.py` — local Gradio chat UI.
24
+ - `run_inference.py` — minimal command-line runner.
25
+ - `tokenizer/` — local tokenizer files and chat template.
26
+
27
+ ## What is not included
28
+
29
+ Private development tooling, data-preparation scripts, connector code, local chat logs, memory files, workspace artifacts, API keys, and secret tokens are not included.
30
+
31
+ ## Weights
32
+
33
+ Place the native-bf16 checkpoint in the repository root as:
34
+
35
+ ```text
36
+ LULUV2-bf16.pt
37
+ ```
38
+
39
+ The uploaded cleanup source did not include weights, so this package does not contain a `.pt` or `.safetensors` model file yet.
40
+ If you publish weights on Hugging Face, keep them in native bfloat16. This package includes `.gitattributes` patterns for large weight files.
41
+
42
+ ## Install
43
+
44
+ ```bash
45
+ pip install -r requirements.txt
46
+ ```
47
+
48
+ ## Run the local UI
49
+
50
+ ```bash
51
+ python app.py --ckpt ./LULUV2-bf16.pt --model-py ./luluv2_inference_runtime.py --tokenizer-dir ./tokenizer --inbrowser
52
+ ```
53
+
54
+ ## Run from CLI
55
+
56
+ ```bash
57
+ python run_inference.py --ckpt ./LULUV2-bf16.pt --prompt "Write a short introduction to LuluV2."
58
+ ```
59
+
60
+ ## Native bf16 note
61
+
62
+ This package is intended for native bfloat16 inference. Users should be able to run the native-bf16 package directly. Hardware without bfloat16 support may require `--dtype fp16` or `--dtype fp32`, depending on their PyTorch/device setup.
63
+
64
+ ## Safety and disclosure checklist before upload
65
+
66
+ Before making the Hugging Face repository public, confirm that your base-model license permits redistribution of the final weights and that any legally required notices are present in the model card or repository files.
app.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ LULUV2 Pro local chat UI.
5
+
6
+ A clean ChatGPT-style desktop UI for the fine-tuned LULUV2 checkpoint.
7
+ It keeps the important local features only:
8
+ - chat inference
9
+ - live token streaming
10
+ - new chat / save / load chats
11
+ - persistent memory notes
12
+ - live edge monitor: tok/s, RAM, VRAM, GPU, pass2 metrics
13
+ - 32K context controls and test prompt helper
14
+
15
+ Run:
16
+ python ./app.py --ckpt ./LULUV2-bf16.pt --model-py ./luluv2_inference_runtime.py --inbrowser
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import os
24
+ import re
25
+ from datetime import datetime
26
+ from pathlib import Path
27
+ from typing import Any, Dict, List, Tuple
28
+
29
+ import gradio as gr
30
+
31
+ from luluv2_live_inference import (
32
+ GenerationConfig,
33
+ LULUV2LiveEngine,
34
+ clean_text,
35
+ normalize_history,
36
+ system_usage,
37
+ )
38
+
39
+ APP_NAME = "LuluV2"
40
+ CHAT_DIR = Path(os.getenv("LULU_CHAT_DIR", "lulu_chats"))
41
+ MEMORY_FILE = Path(os.getenv("LULU_MEMORY_FILE", "lulu_memory.json"))
42
+
43
+ DEFAULT_SYSTEM_PROMPT = """Your name is LuluV2.
44
+ You are a local AI assistant made by Open Machine.
45
+ You run offline from the LULUV2 VWM checkpoint.
46
+ Answer directly and naturally.
47
+ Use Markdown for structure.
48
+ When writing code, use fenced code blocks with the correct language tag.
49
+ Do not output role tags, hidden scratchpad text, JSON UI fragments, or {'type':'text'} blocks.
50
+ """
51
+
52
+ PRESETS = {
53
+ "Balanced": dict(temperature=0.65, top_k=40, top_p=0.90, min_p=0.03, repetition_penalty=1.10, frequency_penalty=0.02, max_new_tokens=768),
54
+ "Precise": dict(temperature=0.35, top_k=30, top_p=0.84, min_p=0.04, repetition_penalty=1.14, frequency_penalty=0.03, max_new_tokens=512),
55
+ "Code": dict(temperature=0.42, top_k=40, top_p=0.88, min_p=0.03, repetition_penalty=1.10, frequency_penalty=0.02, max_new_tokens=1200),
56
+ "Long 32K": dict(temperature=0.55, top_k=50, top_p=0.90, min_p=0.025, repetition_penalty=1.08, frequency_penalty=0.02, max_new_tokens=1200),
57
+ }
58
+
59
+
60
+ def safe_int(value: Any, default: int, low: int | None = None, high: int | None = None) -> int:
61
+ try:
62
+ value = int(value)
63
+ except Exception:
64
+ value = default
65
+ if low is not None:
66
+ value = max(low, value)
67
+ if high is not None:
68
+ value = min(high, value)
69
+ return value
70
+
71
+
72
+ def clamp(value: Any, low: float, high: float, default: float) -> float:
73
+ try:
74
+ value = float(value)
75
+ except Exception:
76
+ return default
77
+ return max(low, min(high, value))
78
+
79
+
80
+ def esc(text: Any) -> str:
81
+ return str(text).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace('"', "&quot;")
82
+
83
+
84
+ def status_html(title: str, detail: str = "", tone: str = "neutral") -> str:
85
+ tone = tone if tone in {"neutral", "good", "warn", "bad", "live"} else "neutral"
86
+ return f"""
87
+ <div class="status-pill status-{tone}">
88
+ <span class="pulse-dot"></span>
89
+ <div><b>{esc(title)}</b><small>{esc(detail)}</small></div>
90
+ </div>
91
+ """
92
+
93
+
94
+ def read_memory() -> str:
95
+ if not MEMORY_FILE.exists():
96
+ return ""
97
+ try:
98
+ return str(json.loads(MEMORY_FILE.read_text(encoding="utf-8")).get("memory_notes", ""))
99
+ except Exception:
100
+ return ""
101
+
102
+
103
+ def write_memory(memory_notes: str) -> Tuple[str, str]:
104
+ MEMORY_FILE.write_text(
105
+ json.dumps(
106
+ {"memory_notes": memory_notes or "", "saved_at": datetime.now().isoformat(timespec="seconds"), "app": APP_NAME},
107
+ indent=2,
108
+ ensure_ascii=False,
109
+ ),
110
+ encoding="utf-8",
111
+ )
112
+ return str(MEMORY_FILE), status_html("Memory saved", str(MEMORY_FILE), "good")
113
+
114
+
115
+ def safe_chat_filename(chat_name: str, suffix: str) -> Path:
116
+ CHAT_DIR.mkdir(parents=True, exist_ok=True)
117
+ base = re.sub(r"[^a-zA-Z0-9_-]+", "_", chat_name or "chat").strip("_") or "chat"
118
+ return CHAT_DIR / f"{base}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{suffix}"
119
+
120
+
121
+ def list_saved_chats() -> List[str]:
122
+ CHAT_DIR.mkdir(parents=True, exist_ok=True)
123
+ return [str(p) for p in sorted(CHAT_DIR.glob("*.json"), key=lambda x: x.stat().st_mtime, reverse=True)]
124
+
125
+
126
+ def save_chat(history: Any, chat_name: str, memory_notes: str) -> Tuple[str, str, List[str]]:
127
+ path = safe_chat_filename(chat_name or "Lulu chat", "json")
128
+ data = {
129
+ "chat_name": chat_name or "Lulu chat",
130
+ "history": normalize_history(history),
131
+ "memory_notes": memory_notes or "",
132
+ "saved_at": datetime.now().isoformat(timespec="seconds"),
133
+ "app": APP_NAME,
134
+ }
135
+ path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
136
+ return str(path), status_html("Chat saved", path.name, "good"), list_saved_chats()
137
+
138
+
139
+ def load_chat(path: str) -> Tuple[List[Dict[str, str]], str, str, str]:
140
+ if not path:
141
+ return [], "New chat", read_memory(), status_html("No saved chat selected", "Pick a JSON file from the sidebar.", "warn")
142
+ try:
143
+ data = json.loads(Path(path).read_text(encoding="utf-8"))
144
+ except Exception as exc:
145
+ return [], "New chat", read_memory(), status_html("Load failed", f"{type(exc).__name__}: {exc}", "bad")
146
+ return (
147
+ normalize_history(data.get("history", [])),
148
+ str(data.get("chat_name") or Path(path).stem),
149
+ str(data.get("memory_notes", read_memory())),
150
+ status_html("Chat loaded", Path(path).name, "good"),
151
+ )
152
+
153
+
154
+ def chat_to_markdown(history: Any, chat_name: str) -> str:
155
+ lines = [f"# {clean_text(chat_name) or 'LuluV2 chat'}", ""]
156
+ for item in normalize_history(history):
157
+ lines.append("## You" if item["role"] == "user" else "## LuluV2")
158
+ lines.append(item["content"])
159
+ lines.append("")
160
+ return "\n".join(lines).strip() + "\n"
161
+
162
+
163
+ def export_markdown(history: Any, chat_name: str) -> Tuple[str, str]:
164
+ path = safe_chat_filename(chat_name or "Lulu chat", "md")
165
+ path.write_text(chat_to_markdown(history, chat_name), encoding="utf-8")
166
+ return str(path), status_html("Markdown exported", path.name, "good")
167
+
168
+
169
+ def postprocess_answer(text: Any, final: bool = False) -> str:
170
+ text = clean_text(text)
171
+ # Remove common generated UI artifacts from older chat data.
172
+ text = re.sub(r"\n?\s*\[\s*\{\s*['\"]text['\"].*?['\"]type['\"]\s*:\s*['\"]text['\"]\s*\}\s*\]\s*$", "", text, flags=re.S)
173
+ text = re.sub(r"\n?\s*type\s*:\s*['\"]text['\"]\s*$", "", text, flags=re.I)
174
+ text = re.sub(r"\n{4,}", "\n\n\n", text)
175
+ if final and text.count("```") % 2 == 1:
176
+ text += "\n```"
177
+ return text.strip()
178
+
179
+
180
+ def metric_cards(engine: LULUV2LiveEngine, max_context: int) -> str:
181
+ stats = engine.stats_dict()
182
+ sys = stats.get("system", {})
183
+ model = stats.get("model", {})
184
+ pass_kl = stats.get("pass1_pass2_kl")
185
+ pass_cos = stats.get("pass1_pass2_logit_cosine")
186
+ pass_text = "base"
187
+ if pass_kl is not None and pass_cos is not None:
188
+ pass_text = f"KL {pass_kl:.3f} / cos {pass_cos:.3f}"
189
+ gpu_util = sys.get("gpu_util_percent")
190
+ gpu_temp = sys.get("gpu_temp_c")
191
+ gpu_text = "n/a" if gpu_util is None else f"{gpu_util}%"
192
+ temp_text = "n/a" if gpu_temp is None else f"{gpu_temp}°C"
193
+ return f"""
194
+ <div class="monitor-bar">
195
+ <div class="mon-card hot"><b>{float(stats.get('tokens_per_sec', 0.0)):.1f}</b><span>tok/s</span></div>
196
+ <div class="mon-card"><b>{int(stats.get('generated_tokens', 0))}</b><span>tokens</span></div>
197
+ <div class="mon-card"><b>{sys.get('python_ram', 'n/a')}</b><span>Python RAM</span></div>
198
+ <div class="mon-card"><b>{sys.get('vram_used', 'n/a')}</b><span>VRAM / {sys.get('vram_total', 'n/a')}</span></div>
199
+ <div class="mon-card"><b>{gpu_text}</b><span>GPU · {temp_text}</span></div>
200
+ <div class="mon-card"><b>{max_context//1024}K</b><span>context</span></div>
201
+ <div class="mon-card"><b>{model.get('has_pass2')}</b><span>pass2</span></div>
202
+ <div class="mon-card wide"><b>{pass_text}</b><span>pass1 → pass2</span></div>
203
+ </div>
204
+ """
205
+
206
+
207
+ def make_32k_prompt() -> str:
208
+ seed = (
209
+ "We are testing a 32K context window for LuluV2. "
210
+ "Remember these constraints: answer directly, keep code formatted, and summarize the relevant details. "
211
+ "The repeated context below is synthetic filler for a long-context stress test.\n\n"
212
+ )
213
+ block = (
214
+ "Section: VWM reconstruction. A model can use A/B atoms and c-code recipes to reconstruct behavior online. "
215
+ "Pass 1 builds a scaffold, pass 2 refines it, and the UI should keep live tokens/sec, RAM, VRAM, and pass metrics visible. "
216
+ "When asked at the end, explain the three key ideas and provide a tiny Python example.\n"
217
+ )
218
+ # Character length is approximate; token count depends on tokenizer. This usually lands around a long 20K-32K style prompt.
219
+ return seed + (block * 520) + "\nFinal question: What are the three key ideas above, and can you show a tiny Python class for tracking tokens per second?"
220
+
221
+
222
+ def create_chatbot():
223
+ kwargs = dict(
224
+ value=[],
225
+ elem_id="chatbot",
226
+ height=760,
227
+ show_label=False,
228
+ avatar_images=(None, None),
229
+ bubble_full_width=False,
230
+ )
231
+ try:
232
+ return gr.Chatbot(type="messages", render_markdown=True, sanitize_html=True, **kwargs)
233
+ except TypeError:
234
+ try:
235
+ return gr.Chatbot(render_markdown=True, sanitize_html=True, **kwargs)
236
+ except TypeError:
237
+ return gr.Chatbot(**kwargs)
238
+
239
+
240
+ def build_app(engine: LULUV2LiveEngine, default_context: int):
241
+ def respond(
242
+ message,
243
+ history,
244
+ chat_name,
245
+ system_prompt,
246
+ memory_notes,
247
+ preset,
248
+ history_turns,
249
+ max_context_tokens,
250
+ max_new_tokens,
251
+ temperature,
252
+ top_k,
253
+ top_p,
254
+ min_p,
255
+ repetition_penalty,
256
+ frequency_penalty,
257
+ greedy,
258
+ no_repeat_ngram,
259
+ stream_every,
260
+ show_pass_metrics,
261
+ ):
262
+ hist = normalize_history(history)
263
+ msg = clean_text(message)
264
+ max_context_tokens = safe_int(max_context_tokens, default_context, 128, 32768)
265
+ if not msg:
266
+ yield "", hist, status_html("Empty message", "Type something first.", "warn"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
267
+ return
268
+
269
+ # Preset only affects initial slider defaults; live slider values are honored.
270
+ prompt = engine.build_chat_prompt(
271
+ message=msg,
272
+ history=hist,
273
+ system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT,
274
+ memory_notes=memory_notes or "",
275
+ history_turns=safe_int(history_turns, 4, 0, 32),
276
+ )
277
+ cfg = GenerationConfig(
278
+ max_new_tokens=safe_int(max_new_tokens, 768, 1, 8192),
279
+ temperature=clamp(temperature, 0.0, 2.0, 0.65),
280
+ top_k=safe_int(top_k, 40, 0, 500),
281
+ top_p=clamp(top_p, 0.01, 1.0, 0.90),
282
+ min_p=clamp(min_p, 0.0, 0.5, 0.03),
283
+ repetition_penalty=clamp(repetition_penalty, 1.0, 3.0, 1.10),
284
+ frequency_penalty=clamp(frequency_penalty, 0.0, 3.0, 0.02),
285
+ greedy=bool(greedy),
286
+ no_repeat_ngram=safe_int(no_repeat_ngram, 4, 0, 16),
287
+ stream_every=safe_int(stream_every, 1, 1, 64),
288
+ max_context_tokens=max_context_tokens,
289
+ return_pass_metrics=bool(show_pass_metrics),
290
+ )
291
+
292
+ hist.append({"role": "user", "content": msg})
293
+ hist.append({"role": "assistant", "content": "Thinking..."})
294
+ yield "", hist, status_html("Generating", "LuluV2 is reconstructing tokens live.", "live"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
295
+
296
+ final = ""
297
+ try:
298
+ for partial in engine.generate(prompt, cfg):
299
+ final = postprocess_answer(partial, final=False)
300
+ hist[-1] = {"role": "assistant", "content": final or "..."}
301
+ yield "", hist, status_html("Generating", f"{engine.last_stats.generated_tokens} tokens · {engine.last_stats.tokens_per_sec:.1f} tok/s", "live"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
302
+ except Exception as exc:
303
+ hist[-1] = {"role": "assistant", "content": f"Generation failed:\n\n```text\n{type(exc).__name__}: {exc}\n```"}
304
+ yield msg, hist, status_html("Generation failed", f"{type(exc).__name__}: {exc}", "bad"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
305
+ return
306
+
307
+ final = postprocess_answer(final, final=True) or "I’m not sure how to answer that yet."
308
+ hist[-1] = {"role": "assistant", "content": final}
309
+ yield "", hist, status_html("Done", f"{engine.last_stats.generated_tokens} tokens · {engine.last_stats.tokens_per_sec:.1f} tok/s", "good"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
310
+
311
+ def regenerate(
312
+ history,
313
+ chat_name,
314
+ system_prompt,
315
+ memory_notes,
316
+ preset,
317
+ history_turns,
318
+ max_context_tokens,
319
+ max_new_tokens,
320
+ temperature,
321
+ top_k,
322
+ top_p,
323
+ min_p,
324
+ repetition_penalty,
325
+ frequency_penalty,
326
+ greedy,
327
+ no_repeat_ngram,
328
+ stream_every,
329
+ show_pass_metrics,
330
+ ):
331
+ hist = normalize_history(history)
332
+ if not hist:
333
+ yield "", hist, status_html("Nothing to regenerate", "Send a message first.", "warn"), metric_cards(engine, safe_int(max_context_tokens, default_context)), engine.token_trace_text(), engine.stats_dict()
334
+ return
335
+ work = hist[:]
336
+ if work and work[-1]["role"] == "assistant":
337
+ work = work[:-1]
338
+ if not work or work[-1]["role"] != "user":
339
+ yield "", hist, status_html("Cannot regenerate", "Last turn is not a user message.", "warn"), metric_cards(engine, safe_int(max_context_tokens, default_context)), engine.token_trace_text(), engine.stats_dict()
340
+ return
341
+ last_msg = work[-1]["content"]
342
+ prev = work[:-1]
343
+ yield from respond(last_msg, prev, chat_name, system_prompt, memory_notes, preset, history_turns, max_context_tokens, max_new_tokens, temperature, top_k, top_p, min_p, repetition_penalty, frequency_penalty, greedy, no_repeat_ngram, stream_every, show_pass_metrics)
344
+
345
+ def new_chat():
346
+ return [], "New chat", status_html("New chat", "Fresh conversation. Memory notes are kept.", "good")
347
+
348
+ def forget_last(history):
349
+ hist = normalize_history(history)
350
+ if len(hist) >= 2:
351
+ return hist[:-2], status_html("Forgot last turn", "Removed the latest exchange.", "good")
352
+ return [], status_html("Nothing to forget", "No full turn to remove.", "warn")
353
+
354
+ def apply_preset(name):
355
+ p = PRESETS.get(name, PRESETS["Balanced"])
356
+ context = 32768 if name == "Long 32K" else default_context
357
+ return p["temperature"], p["top_k"], p["top_p"], p["min_p"], p["repetition_penalty"], p["frequency_penalty"], p["max_new_tokens"], context
358
+
359
+ css = """
360
+ :root{
361
+ --bg:#05060d;--panel:#0b1020;--panel2:#101827;--line:rgba(148,163,184,.16);
362
+ --text:#edf2ff;--muted:#94a3b8;--accent:#8b5cf6;--accent2:#22d3ee;--good:#22c55e;--bad:#ef4444;
363
+ }
364
+ html, body, .gradio-container{
365
+ background: radial-gradient(circle at top left, rgba(139,92,246,.23), transparent 34%),
366
+ radial-gradient(circle at top right, rgba(34,211,238,.14), transparent 30%),
367
+ linear-gradient(180deg,#05060d,#070a12 62%,#02030a)!important;
368
+ color:var(--text)!important;
369
+ }
370
+ .gradio-container{max-width:1680px!important;margin:auto!important;font-family:Inter,ui-sans-serif,system-ui,-apple-system,BlinkMacSystemFont,'Segoe UI',sans-serif!important;}
371
+ footer{display:none!important}.main-wrap{gap:18px!important}.sidebar{padding:16px;border:1px solid var(--line);border-radius:28px;background:rgba(9,14,28,.76);box-shadow:0 20px 70px rgba(0,0,0,.32)}
372
+ .brand{padding:10px 4px 18px}.brand h1{margin:0;font-size:32px;letter-spacing:-.06em;color:#fff}.brand p{margin:5px 0 0;color:var(--muted);font-size:13px}.brand .badge{display:inline-flex;margin-top:12px;padding:7px 10px;border-radius:999px;border:1px solid rgba(34,211,238,.28);background:rgba(8,145,178,.12);color:#cffafe;font-weight:800;font-size:12px}
373
+ .chat-shell{padding:16px;border:1px solid var(--line);border-radius:32px;background:rgba(5,8,18,.62);box-shadow:0 30px 110px rgba(0,0,0,.38)}
374
+ #chatbot{height:760px!important;border:0!important;background:transparent!important;overflow:hidden!important}.message{font-size:15.5px!important;line-height:1.62!important}.message-wrap{max-width:900px!important}.bot .message, .assistant .message{background:rgba(15,23,42,.72)!important;border:1px solid rgba(148,163,184,.13)!important;border-radius:22px!important}.user .message{background:linear-gradient(135deg,rgba(124,58,237,.70),rgba(59,130,246,.42))!important;border:1px solid rgba(167,139,250,.35)!important;border-radius:22px!important;color:white!important}
375
+ #chatbot pre{background:#101827!important;border:1px solid rgba(148,163,184,.22)!important;border-radius:18px!important;padding:16px!important;box-shadow:inset 0 1px 0 rgba(255,255,255,.04)!important}#chatbot code{font-family:'JetBrains Mono','Cascadia Code','SFMono-Regular',Consolas,monospace!important;font-size:14px!important}#chatbot p{margin:0 0 .7em!important}#chatbot ul,#chatbot ol{margin-top:.3em!important}
376
+ .composer-card{display:flex;gap:12px;align-items:end;padding:10px;border-radius:26px;border:1px solid rgba(139,92,246,.28);background:rgba(2,6,23,.80);box-shadow:0 20px 70px rgba(139,92,246,.12)}#composer textarea{min-height:72px!important;max-height:190px!important;background:transparent!important;border:0!important;color:#fff!important;font-size:16px!important;line-height:1.5!important;box-shadow:none!important}.input-container{border:0!important;background:transparent!important}.form{border:0!important;background:transparent!important}label{color:#cbd5e1!important;font-weight:700!important}
377
+ button{border-radius:16px!important;font-weight:850!important;border:1px solid rgba(148,163,184,.16)!important;box-shadow:0 10px 28px rgba(0,0,0,.22)!important}.send-btn{min-height:56px!important;background:linear-gradient(135deg,#8b5cf6,#06b6d4)!important;color:white!important}.side-btn button,.side-btn{width:100%!important}
378
+ .monitor-bar{display:grid;grid-template-columns:repeat(8,minmax(110px,1fr));gap:10px;margin:0 0 12px}.mon-card{padding:12px 13px;border:1px solid var(--line);border-radius:18px;background:rgba(15,23,42,.78);min-height:64px}.mon-card b{display:block;font-size:20px;color:#fff;white-space:nowrap}.mon-card span{display:block;color:var(--muted);font-size:11px;margin-top:3px}.mon-card.hot{background:linear-gradient(135deg,rgba(139,92,246,.30),rgba(34,211,238,.16));border-color:rgba(34,211,238,.30)}.mon-card.wide b{font-size:15px}.status-pill{display:flex;align-items:center;gap:10px;margin:0 0 12px;padding:10px 13px;border-radius:18px;border:1px solid var(--line);background:rgba(2,6,23,.72)}.status-pill b{display:block}.status-pill small{display:block;color:var(--muted);font-size:12px}.pulse-dot{width:10px;height:10px;border-radius:99px;background:var(--accent2);box-shadow:0 0 0 7px rgba(34,211,238,.10),0 0 25px rgba(34,211,238,.55)}.status-good .pulse-dot{background:var(--good);box-shadow:0 0 0 7px rgba(34,197,94,.12),0 0 25px rgba(34,197,94,.5)}.status-bad .pulse-dot{background:var(--bad)}.status-live .pulse-dot{animation:pulse 1.1s infinite}@keyframes pulse{0%{transform:scale(1)}50%{transform:scale(1.45)}100%{transform:scale(1)}}
379
+ .gr-box,.gr-panel,.block{background:transparent!important;border-color:var(--line)!important}.sidebar textarea,.sidebar input,.sidebar select,.sidebar .wrap{background:rgba(2,6,23,.62)!important;color:#e5e7eb!important;border-color:rgba(148,163,184,.16)!important;border-radius:14px!important}.small-note{color:#94a3b8;font-size:12px}.tokenbox textarea,.jsonbox textarea{font-family:'JetBrains Mono','Cascadia Code',Consolas,monospace!important;font-size:12px!important;background:#060914!important}
380
+ @media(max-width:1100px){.monitor-bar{grid-template-columns:repeat(2,1fr)}.sidebar{display:none}.chat-shell{padding:8px}}
381
+ """
382
+
383
+ theme = gr.themes.Base(primary_hue="violet", secondary_hue="cyan", neutral_hue="slate")
384
+
385
+ with gr.Blocks(title=APP_NAME, css=css, theme=theme) as demo:
386
+ with gr.Row(elem_classes=["main-wrap"]):
387
+ with gr.Column(scale=1, min_width=270, elem_classes=["sidebar"]):
388
+ gr.HTML("""
389
+ <div class="brand">
390
+ <h1>LuluV2</h1>
391
+ <p>Offline VWM local assistant.</p>
392
+ <span class="badge">LOCAL EDGE MODE</span>
393
+ </div>
394
+ """)
395
+ new_btn = gr.Button("+ New chat", variant="primary", elem_classes=["side-btn"])
396
+ save_btn = gr.Button("Save chat", elem_classes=["side-btn"])
397
+ saved_path = gr.Textbox(label="Last saved path", interactive=False, visible=False)
398
+ saved_chats = gr.Dropdown(choices=list_saved_chats(), label="Saved chats", value=None, interactive=True)
399
+ with gr.Row():
400
+ refresh_chats = gr.Button("Refresh")
401
+ load_btn = gr.Button("Load")
402
+ export_btn = gr.Button("Export .md", elem_classes=["side-btn"])
403
+ export_path = gr.Textbox(label="Export path", interactive=False, visible=False)
404
+
405
+ with gr.Accordion("Memory", open=True):
406
+ memory_notes = gr.Textbox(label="Persistent memory notes", value=read_memory(), lines=8, placeholder="Things Lulu should remember locally...")
407
+ memory_path = gr.Textbox(label="Memory path", interactive=False, visible=False)
408
+ save_mem_btn = gr.Button("Save memory")
409
+
410
+ with gr.Accordion("Live tokens", open=False):
411
+ token_trace = gr.Textbox(label="Recent generated tokens", value="No tokens generated yet.", lines=14, elem_classes=["tokenbox"])
412
+
413
+ with gr.Accordion("Advanced", open=False):
414
+ chat_name = gr.Textbox(label="Chat name", value="New chat")
415
+ preset = gr.Dropdown(label="Preset", choices=list(PRESETS.keys()), value="Balanced")
416
+ system_prompt = gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=9)
417
+ history_turns = gr.Slider(0, 24, value=4, step=1, label="History turns sent")
418
+ max_context_tokens = gr.Slider(128, 32768, value=default_context, step=128, label="Max context tokens")
419
+ max_new_tokens = gr.Slider(16, 8192, value=768, step=16, label="Max new tokens")
420
+ temperature = gr.Slider(0.0, 2.0, value=0.65, step=0.01, label="Temperature")
421
+ top_k = gr.Slider(0, 500, value=40, step=1, label="Top-k")
422
+ top_p = gr.Slider(0.01, 1.0, value=0.90, step=0.01, label="Top-p")
423
+ min_p = gr.Slider(0.0, 0.5, value=0.03, step=0.005, label="Min-p")
424
+ repetition_penalty = gr.Slider(1.0, 3.0, value=1.10, step=0.01, label="Repetition penalty")
425
+ frequency_penalty = gr.Slider(0.0, 3.0, value=0.02, step=0.01, label="Frequency penalty")
426
+ greedy = gr.Checkbox(value=False, label="Greedy")
427
+ no_repeat_ngram = gr.Slider(0, 16, value=4, step=1, label="No-repeat ngram")
428
+ stream_every = gr.Slider(1, 64, value=1, step=1, label="Stream every N tokens")
429
+ show_pass_metrics = gr.Checkbox(value=True, label="Measure pass1/pass2 before generation")
430
+ insert_32k = gr.Button("Insert 32K stress prompt")
431
+
432
+ with gr.Column(scale=4, elem_classes=["chat-shell"]):
433
+ monitor = gr.HTML(metric_cards(engine, default_context))
434
+ status = gr.HTML(status_html("Ready", f"{engine.model_info.get('checkpoint_size')} checkpoint · {engine.model_info.get('device')}", "good"))
435
+ chatbot = create_chatbot()
436
+ with gr.Row(elem_classes=["composer-card"]):
437
+ msg = gr.Textbox(show_label=False, placeholder="Message LuluV2...", lines=3, elem_id="composer", scale=12)
438
+ send_btn = gr.Button("Send", variant="primary", elem_classes=["send-btn"], scale=2)
439
+ with gr.Row():
440
+ stop_btn = gr.Button("Stop")
441
+ regen_btn = gr.Button("Regenerate")
442
+ forget_btn = gr.Button("Forget last turn")
443
+ prompt_32k_btn = gr.Button("Try 32K prompt")
444
+ with gr.Accordion("Raw metrics", open=False):
445
+ raw_metrics = gr.JSON(label="Raw metrics")
446
+ usage_text = gr.Textbox(label="RAM / VRAM / model stats", value=system_usage(engine), lines=18, elem_classes=["jsonbox"])
447
+
448
+ inputs = [
449
+ msg, chatbot, chat_name, system_prompt, memory_notes, preset,
450
+ history_turns, max_context_tokens, max_new_tokens, temperature, top_k, top_p,
451
+ min_p, repetition_penalty, frequency_penalty, greedy, no_repeat_ngram,
452
+ stream_every, show_pass_metrics,
453
+ ]
454
+ outputs = [msg, chatbot, status, monitor, token_trace, raw_metrics]
455
+ send_event = send_btn.click(respond, inputs=inputs, outputs=outputs)
456
+ enter_event = msg.submit(respond, inputs=inputs, outputs=outputs)
457
+ stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[send_event, enter_event])
458
+
459
+ regen_inputs = [
460
+ chatbot, chat_name, system_prompt, memory_notes, preset,
461
+ history_turns, max_context_tokens, max_new_tokens, temperature, top_k, top_p,
462
+ min_p, repetition_penalty, frequency_penalty, greedy, no_repeat_ngram,
463
+ stream_every, show_pass_metrics,
464
+ ]
465
+ regen_event = regen_btn.click(regenerate, inputs=regen_inputs, outputs=outputs)
466
+ stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[regen_event])
467
+
468
+ new_btn.click(new_chat, outputs=[chatbot, chat_name, status])
469
+ forget_btn.click(forget_last, inputs=[chatbot], outputs=[chatbot, status])
470
+ save_btn.click(save_chat, inputs=[chatbot, chat_name, memory_notes], outputs=[saved_path, status, saved_chats])
471
+ refresh_chats.click(lambda: gr.update(choices=list_saved_chats()), outputs=[saved_chats])
472
+ load_btn.click(load_chat, inputs=[saved_chats], outputs=[chatbot, chat_name, memory_notes, status])
473
+ export_btn.click(export_markdown, inputs=[chatbot, chat_name], outputs=[export_path, status])
474
+ save_mem_btn.click(write_memory, inputs=[memory_notes], outputs=[memory_path, status])
475
+ preset.change(apply_preset, inputs=[preset], outputs=[temperature, top_k, top_p, min_p, repetition_penalty, frequency_penalty, max_new_tokens, max_context_tokens])
476
+ insert_32k.click(lambda: make_32k_prompt(), outputs=[msg])
477
+ prompt_32k_btn.click(lambda: make_32k_prompt(), outputs=[msg])
478
+
479
+ return demo
480
+
481
+
482
+ def parse_args():
483
+ ap = argparse.ArgumentParser()
484
+ ap.add_argument("--ckpt", default="LULU2_instruct_ddp.pt")
485
+ ap.add_argument("--model-py", default="luluv2_inference_runtime.py")
486
+ ap.add_argument("--tokenizer-dir", default="tokenizer")
487
+ ap.add_argument("--host", default="127.0.0.1")
488
+ ap.add_argument("--port", type=int, default=7862)
489
+ ap.add_argument("--device", default="cuda")
490
+ ap.add_argument("--dtype", default="bf16")
491
+ ap.add_argument("--max-context", type=int, default=32768)
492
+ ap.add_argument("--share", action="store_true")
493
+ ap.add_argument("--inbrowser", action="store_true")
494
+ ap.add_argument("--base-only", action="store_true")
495
+ return ap.parse_args()
496
+
497
+
498
+ def main():
499
+ args = parse_args()
500
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
501
+ os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
502
+ engine = LULUV2LiveEngine(
503
+ ckpt_path=args.ckpt,
504
+ model_py=args.model_py,
505
+ tokenizer_dir=args.tokenizer_dir,
506
+ device=args.device,
507
+ dtype=args.dtype,
508
+ local_files_only=True,
509
+ no_config_download=True,
510
+ force_base_only=bool(args.base_only),
511
+ )
512
+ demo = build_app(engine, default_context=safe_int(args.max_context, 32768, 128, 32768))
513
+ demo.queue(default_concurrency_limit=1).launch(
514
+ server_name=args.host,
515
+ server_port=int(args.port),
516
+ share=bool(args.share),
517
+ inbrowser=bool(args.inbrowser),
518
+ show_error=True,
519
+ )
520
+
521
+
522
+ if __name__ == "__main__":
523
+ main()
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "luluv2",
3
+ "architectures": [
4
+ "Lulu2ForCausalLM"
5
+ ],
6
+ "torch_dtype": "bfloat16",
7
+ "tokenizer_class": "PreTrainedTokenizerFast",
8
+ "auto_map": {},
9
+ "inference_only_package": true
10
+ }
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_new_tokens": 512,
3
+ "temperature": 0.65,
4
+ "top_k": 40,
5
+ "top_p": 0.9,
6
+ "do_sample": true,
7
+ "eos_token_id": 151645,
8
+ "pad_token_id": 151643
9
+ }
luluv2_inference_runtime.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ LULUV2 inference-only runtime.
5
+
6
+ This file intentionally contains only the code needed to load and run a
7
+ standalone native-bf16 LULUV2 checkpoint. It contains only the
8
+ runtime loader, tokenizer bridge, decoder modules, and two-pass inference path
9
+ needed for local generation.
10
+
11
+ Runtime behavior:
12
+ - loads a local checkpoint supplied by the user/repo;
13
+ - uses local tokenizer files;
14
+ - does not download or load any external model weights;
15
+ - preserves the VWM/two-pass inference path when present in the checkpoint.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import math
22
+ import os
23
+ import time
24
+ from dataclasses import dataclass
25
+ from types import SimpleNamespace
26
+ from typing import Dict, Optional, Tuple
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ _TRANSFORMERS_IMPORT_ERROR = None
33
+ try:
34
+ from transformers import AutoTokenizer as _HFAutoTokenizer
35
+ try:
36
+ from transformers import AutoConfig as _HFAutoConfig
37
+ except Exception:
38
+ _HFAutoConfig = None
39
+ except Exception as _e:
40
+ _TRANSFORMERS_IMPORT_ERROR = _e
41
+ _HFAutoTokenizer = None
42
+ _HFAutoConfig = None
43
+
44
+ class _TokenOutput(dict):
45
+ def __getattr__(self, name):
46
+ try:
47
+ return self[name]
48
+ except KeyError as exc:
49
+ raise AttributeError(name) from exc
50
+ def to(self, device):
51
+ out = _TokenOutput()
52
+ for k, v in self.items():
53
+ out[k] = v.to(device) if torch.is_tensor(v) else v
54
+ return out
55
+
56
+ class _LocalTokenizer:
57
+ def __init__(self, path: str, tokenizer_file: Optional[str] = None, **kwargs):
58
+ import json as _json
59
+ try:
60
+ from tokenizers import Tokenizer as _TokenizerCore
61
+ except Exception as exc:
62
+ raise RuntimeError(
63
+ "transformers import failed and tokenizers is unavailable. "
64
+ "Install tokenizers or use a matching torch/transformers pair."
65
+ ) from exc
66
+ self.name_or_path = path or tokenizer_file or "<local-tokenizer>"
67
+ if tokenizer_file:
68
+ tok_file = tokenizer_file
69
+ base_dir = os.path.dirname(os.path.abspath(tok_file))
70
+ else:
71
+ base_dir = os.path.abspath(path)
72
+ tok_file = os.path.join(base_dir, "tokenizer.json")
73
+ if not os.path.exists(tok_file):
74
+ raise FileNotFoundError(f"Local tokenizer.json not found: {tok_file}")
75
+ self._tok = _TokenizerCore.from_file(tok_file)
76
+ self.vocab_size = int(self._tok.get_vocab_size())
77
+ self.model_max_length = 10**9
78
+ self.truncation_side = "left"
79
+ self.chat_template = None
80
+ self.eos_token = None
81
+ self.pad_token = None
82
+ cfg_path = os.path.join(base_dir, "tokenizer_config.json")
83
+ sp_path = os.path.join(base_dir, "special_tokens_map.json")
84
+ for p in (cfg_path, sp_path):
85
+ if os.path.exists(p):
86
+ try:
87
+ data = _json.load(open(p, "r", encoding="utf-8"))
88
+ except Exception:
89
+ data = {}
90
+ if self.chat_template is None and isinstance(data.get("chat_template"), str):
91
+ self.chat_template = data.get("chat_template")
92
+ for key, attr in (("eos_token", "eos_token"), ("pad_token", "pad_token")):
93
+ val = data.get(key)
94
+ if isinstance(val, dict):
95
+ val = val.get("content")
96
+ if isinstance(val, str):
97
+ setattr(self, attr, val)
98
+ if self.eos_token is None:
99
+ for cand in ("<|im_end|>", "<|endoftext|>", "</s>"):
100
+ if self._tok.token_to_id(cand) is not None:
101
+ self.eos_token = cand
102
+ break
103
+ if self.pad_token is None:
104
+ self.pad_token = self.eos_token
105
+ self.eos_token_id = self._tok.token_to_id(self.eos_token) if self.eos_token else None
106
+ self.pad_token_id = self._tok.token_to_id(self.pad_token) if self.pad_token else self.eos_token_id
107
+
108
+ def __len__(self):
109
+ return self.vocab_size
110
+
111
+ def __call__(self, text, return_tensors=None, truncation=False, max_length=None, add_special_tokens=True, **kwargs):
112
+ if isinstance(text, (list, tuple)):
113
+ encoded = [self._encode_one(t, add_special_tokens, truncation, max_length) for t in text]
114
+ maxlen = max(len(x) for x in encoded) if encoded else 0
115
+ pad = self.pad_token_id if self.pad_token_id is not None else 0
116
+ arr = [x + [pad] * (maxlen - len(x)) for x in encoded]
117
+ if return_tensors == "pt":
118
+ return _TokenOutput(input_ids=torch.tensor(arr, dtype=torch.long))
119
+ return _TokenOutput(input_ids=arr)
120
+ ids = self._encode_one(str(text), add_special_tokens, truncation, max_length)
121
+ if return_tensors == "pt":
122
+ return _TokenOutput(input_ids=torch.tensor([ids], dtype=torch.long))
123
+ return _TokenOutput(input_ids=ids)
124
+
125
+ def _encode_one(self, text, add_special_tokens=True, truncation=False, max_length=None):
126
+ enc = self._tok.encode(text, add_special_tokens=bool(add_special_tokens))
127
+ ids = list(enc.ids)
128
+ if truncation and max_length is not None and len(ids) > int(max_length):
129
+ if self.truncation_side == "left":
130
+ ids = ids[-int(max_length):]
131
+ else:
132
+ ids = ids[:int(max_length)]
133
+ return ids
134
+
135
+ def decode(self, ids, skip_special_tokens=True, **kwargs):
136
+ if torch.is_tensor(ids):
137
+ ids = ids.detach().cpu().tolist()
138
+ if ids and isinstance(ids[0], list):
139
+ ids = ids[0]
140
+ return self._tok.decode([int(x) for x in ids], skip_special_tokens=bool(skip_special_tokens))
141
+
142
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False, **kwargs):
143
+ chunks = []
144
+ for m in messages:
145
+ role = str(m.get("role", "user"))
146
+ content = str(m.get("content", ""))
147
+ chunks.append(f"<|im_start|>{role}\n{content}<|im_end|>")
148
+ if add_generation_prompt:
149
+ chunks.append("<|im_start|>assistant\n")
150
+ text = "\n".join(chunks)
151
+ if tokenize:
152
+ return self(text, add_special_tokens=False).input_ids
153
+ return text
154
+
155
+ class _AutoTokenizerShim:
156
+ @staticmethod
157
+ def from_pretrained(path, *args, **kwargs):
158
+ if _HFAutoTokenizer is not None:
159
+ return _HFAutoTokenizer.from_pretrained(path, *args, **kwargs)
160
+ return _LocalTokenizer(path)
161
+
162
+ class _AutoConfigShim:
163
+ @staticmethod
164
+ def from_pretrained(path, *args, **kwargs):
165
+ if _HFAutoConfig is not None:
166
+ return _HFAutoConfig.from_pretrained(path, *args, **kwargs)
167
+ raise RuntimeError(
168
+ "AutoConfig requested, but transformers failed to import. "
169
+ "Use --no-config-download / embedded model_config for LULUV2."
170
+ )
171
+ AutoTokenizer = _AutoTokenizerShim
172
+ AutoConfig = _AutoConfigShim
173
+
174
+ torch.backends.cuda.matmul.allow_tf32 = True
175
+ torch.backends.cudnn.allow_tf32 = True
176
+ if hasattr(torch, "set_float32_matmul_precision"):
177
+ torch.set_float32_matmul_precision("high")
178
+ try:
179
+ if torch.cuda.is_available():
180
+ torch.backends.cuda.enable_flash_sdp(True)
181
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
182
+ torch.backends.cuda.enable_math_sdp(False)
183
+ except Exception:
184
+ pass
185
+
186
+
187
+
188
+ def parse_dtype(name: str):
189
+ name = str(name).strip().lower()
190
+ if name in {"bf16", "bfloat16"}:
191
+ return torch.bfloat16
192
+ if name in {"fp16", "float16", "half"}:
193
+ return torch.float16
194
+ if name in {"fp32", "float32"}:
195
+ return torch.float32
196
+ raise ValueError(f"Unknown dtype: {name}")
197
+
198
+
199
+ def human_bytes(n: float) -> str:
200
+ units = ["B", "KB", "MB", "GB", "TB"]
201
+ x = float(n)
202
+ i = 0
203
+ while x >= 1024.0 and i < len(units) - 1:
204
+ x /= 1024.0
205
+ i += 1
206
+ return f"{x:.2f} {units[i]}"
207
+
208
+
209
+ def safe_torch_load(path: str, map_location="cpu"):
210
+ # PyTorch 2.6+ defaults may warn around weights_only. This checkpoint stores
211
+ # Python metadata plus tensors, so weights_only=False is intentional.
212
+ try:
213
+ return torch.load(path, map_location=map_location, weights_only=False)
214
+ except TypeError:
215
+ return torch.load(path, map_location=map_location)
216
+
217
+
218
+ def module_has_vwm(sd: Dict[str, torch.Tensor], prefix: str) -> bool:
219
+ return f"{prefix}.A" in sd and f"{prefix}.B" in sd and f"{prefix}.c" in sd
220
+
221
+
222
+ def linear_shape_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> Tuple[int, int, bool]:
223
+ if module_has_vwm(sd, prefix):
224
+ out_features = int(sd[f"{prefix}.A"].shape[0])
225
+ in_features = int(sd[f"{prefix}.B"].shape[0])
226
+ has_bias = f"{prefix}.bias" in sd
227
+ return in_features, out_features, has_bias
228
+ wkey = f"{prefix}.weight"
229
+ if wkey not in sd:
230
+ raise KeyError(f"Cannot infer Linear shape for {prefix}; missing {wkey} and VWM A/B/c")
231
+ out_features, in_features = sd[wkey].shape
232
+ has_bias = f"{prefix}.bias" in sd
233
+ return int(in_features), int(out_features), has_bias
234
+
235
+
236
+ def make_linear_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> nn.Module:
237
+ in_features, out_features, has_bias = linear_shape_from_state(sd, prefix)
238
+ if module_has_vwm(sd, prefix):
239
+ rank = int(sd[f"{prefix}.c"].shape[0])
240
+ return VWMFactorizedLinear(in_features, out_features, rank, bias=has_bias, name=prefix)
241
+ return nn.Linear(in_features, out_features, bias=has_bias)
242
+
243
+
244
+ def module_has_vwm_embedding(sd: Dict[str, torch.Tensor], prefix: str) -> bool:
245
+ return f"{prefix}.A" in sd and f"{prefix}.B" in sd and f"{prefix}.c" in sd
246
+
247
+
248
+ def embedding_shape_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> Tuple[int, int]:
249
+ if module_has_vwm_embedding(sd, prefix):
250
+ return int(sd[f"{prefix}.A"].shape[0]), int(sd[f"{prefix}.B"].shape[0])
251
+ wkey = f"{prefix}.weight"
252
+ if wkey not in sd:
253
+ raise KeyError(f"Cannot infer embedding shape for {prefix}; missing dense or VWM embedding tensors")
254
+ return int(sd[wkey].shape[0]), int(sd[wkey].shape[1])
255
+
256
+
257
+ def make_embedding_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> nn.Module:
258
+ vocab_size, hidden_size = embedding_shape_from_state(sd, prefix)
259
+ if module_has_vwm_embedding(sd, prefix):
260
+ rank = int(sd[f"{prefix}.c"].shape[0])
261
+ return VWMFactorizedEmbedding(vocab_size, hidden_size, rank, name=prefix)
262
+ return nn.Embedding(vocab_size, hidden_size)
263
+
264
+
265
+ def expand_shared_banks_into_state(ckpt: Dict, sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
266
+ """Expand experimental shared-bank storage into normal per-module A/B/c tensors."""
267
+ banks = ckpt.get("shared_banks")
268
+ if not banks:
269
+ return sd
270
+ out = dict(sd)
271
+ n = 0
272
+ for bank_id, bank in banks.items():
273
+ A = bank["A"]
274
+ B = bank["B"]
275
+ modules = bank.get("modules", {})
276
+ for prefix, m in modules.items():
277
+ out[f"{prefix}.A"] = A
278
+ out[f"{prefix}.B"] = B
279
+ out[f"{prefix}.c"] = m["c"]
280
+ if "bias" in m and m["bias"] is not None:
281
+ out[f"{prefix}.bias"] = m["bias"]
282
+ n += 1
283
+ print(f"[shared-bank] expanded {len(banks)} banks into {n} VWM modules")
284
+ return out
285
+
286
+
287
+ # -----------------------------
288
+ # VWM linear used by the exported checkpoint
289
+ # -----------------------------
290
+
291
+
292
+ class VWMFactorizedLinear(nn.Module):
293
+ """
294
+ W ~= A diag(c) B^T
295
+ y = ((x @ B) * c) @ A^T + bias
296
+
297
+ This matches LULU2 exporter's exported VWMFactorizedLinear
298
+ state names: A, B, c, optional bias.
299
+ """
300
+
301
+ def __init__(self, in_features: int, out_features: int, rank: int, bias: bool = True, name: str = ""):
302
+ super().__init__()
303
+ self.in_features = int(in_features)
304
+ self.out_features = int(out_features)
305
+ self.rank = int(rank)
306
+ self.name = name
307
+ self.A = nn.Parameter(torch.empty(out_features, rank), requires_grad=False)
308
+ self.B = nn.Parameter(torch.empty(in_features, rank), requires_grad=False)
309
+ self.c = nn.Parameter(torch.empty(rank), requires_grad=False)
310
+ self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False) if bias else None
311
+
312
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
313
+ # Compute in the activation dtype/device. Parameters are already moved by model.to(...).
314
+ t = torch.matmul(x, self.B.to(dtype=x.dtype))
315
+ t = t * self.c.to(dtype=x.dtype)
316
+ y = torch.matmul(t, self.A.to(dtype=x.dtype).transpose(0, 1))
317
+ if self.bias is not None:
318
+ y = y + self.bias.to(dtype=x.dtype)
319
+ return y
320
+
321
+
322
+
323
+
324
+ class VWMFactorizedEmbedding(nn.Module):
325
+ """Runtime for exported VWM embedding: E ~= A diag(c) B^T."""
326
+ def __init__(self, num_embeddings: int, embedding_dim: int, rank: int, name: str = "model.embed_tokens"):
327
+ super().__init__()
328
+ self.num_embeddings = int(num_embeddings)
329
+ self.embedding_dim = int(embedding_dim)
330
+ self.rank = int(rank)
331
+ self.name = name
332
+ self.A = nn.Parameter(torch.empty(num_embeddings, rank), requires_grad=False)
333
+ self.B = nn.Parameter(torch.empty(embedding_dim, rank), requires_grad=False)
334
+ self.c = nn.Parameter(torch.empty(rank), requires_grad=False)
335
+
336
+ @property
337
+ def weight(self):
338
+ # Dense materialization only for compatibility/debug. Normal forward avoids this.
339
+ return (self.A * self.c.view(1, -1)) @ self.B.T
340
+
341
+ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
342
+ a = F.embedding(input_ids, self.A)
343
+ t = a * self.c.to(dtype=a.dtype)
344
+ return torch.matmul(t, self.B.to(dtype=a.dtype).transpose(0, 1))
345
+
346
+
347
+ class TiedEmbeddingLMHead(nn.Module):
348
+ """LM head tied to the model embedding matrix, dense or VWM."""
349
+ def __init__(self, embedding_module: nn.Module):
350
+ super().__init__()
351
+ self.embedding_module = embedding_module
352
+
353
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
354
+ emb = self.embedding_module
355
+ if isinstance(emb, VWMFactorizedEmbedding):
356
+ # logits = h @ E.T = (h @ B) * c @ A.T
357
+ t = torch.matmul(hidden_states, emb.B.to(dtype=hidden_states.dtype))
358
+ t = t * emb.c.to(dtype=hidden_states.dtype)
359
+ return torch.matmul(t, emb.A.to(dtype=hidden_states.dtype).transpose(0, 1))
360
+ return F.linear(hidden_states, emb.weight.to(dtype=hidden_states.dtype))
361
+
362
+ # -----------------------------
363
+ # LULU2 decoder architecture
364
+ # -----------------------------
365
+
366
+
367
+ class LuluRMSNorm(nn.Module):
368
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
369
+ super().__init__()
370
+ self.weight = nn.Parameter(torch.ones(hidden_size), requires_grad=False)
371
+ self.variance_epsilon = float(eps)
372
+
373
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
374
+ input_dtype = hidden_states.dtype
375
+ hidden_states = hidden_states.float()
376
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
377
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
378
+ return self.weight.to(dtype=input_dtype) * hidden_states.to(input_dtype)
379
+
380
+
381
+ class LuluRotaryEmbedding(nn.Module):
382
+ def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 1000000.0):
383
+ super().__init__()
384
+ self.dim = int(dim)
385
+ self.max_position_embeddings = int(max_position_embeddings)
386
+ self.base = float(base)
387
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
388
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
389
+
390
+ @torch.no_grad()
391
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
392
+ # position_ids: [B, T]
393
+ inv_freq = self.inv_freq.to(device=x.device)
394
+ freqs = torch.einsum("bt,d->btd", position_ids.float(), inv_freq.float())
395
+ emb = torch.cat((freqs, freqs), dim=-1)
396
+ return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
397
+
398
+
399
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
400
+ x1 = x[..., : x.shape[-1] // 2]
401
+ x2 = x[..., x.shape[-1] // 2 :]
402
+ return torch.cat((-x2, x1), dim=-1)
403
+
404
+
405
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
406
+ # q/k: [B, H, T, D], cos/sin: [B, T, D]
407
+ cos = cos.unsqueeze(1)
408
+ sin = sin.unsqueeze(1)
409
+ q_embed = (q * cos) + (rotate_half(q) * sin)
410
+ k_embed = (k * cos) + (rotate_half(k) * sin)
411
+ return q_embed, k_embed
412
+
413
+
414
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
415
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
416
+ if n_rep == 1:
417
+ return hidden_states
418
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
419
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
420
+
421
+
422
+ class LuluVWMMLP(nn.Module):
423
+ def __init__(self, cfg, sd: Dict[str, torch.Tensor], layer_idx: int):
424
+ super().__init__()
425
+ p = f"model.layers.{layer_idx}.mlp"
426
+ self.gate_proj = make_linear_from_state(sd, f"{p}.gate_proj")
427
+ self.up_proj = make_linear_from_state(sd, f"{p}.up_proj")
428
+ self.down_proj = make_linear_from_state(sd, f"{p}.down_proj")
429
+
430
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
431
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
432
+
433
+
434
+ class LuluVWMAttention(nn.Module):
435
+ def __init__(self, cfg, sd: Dict[str, torch.Tensor], layer_idx: int):
436
+ super().__init__()
437
+ self.layer_idx = int(layer_idx)
438
+ self.hidden_size = int(cfg.hidden_size)
439
+ self.num_heads = int(cfg.num_attention_heads)
440
+ self.num_key_value_heads = int(getattr(cfg, "num_key_value_heads", self.num_heads))
441
+ self.head_dim = int(getattr(cfg, "head_dim", self.hidden_size // self.num_heads))
442
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
443
+ self.scaling = self.head_dim ** -0.5
444
+ self.attention_dropout = float(getattr(cfg, "attention_dropout", 0.0))
445
+
446
+ p = f"model.layers.{layer_idx}.self_attn"
447
+ self.q_proj = make_linear_from_state(sd, f"{p}.q_proj")
448
+ self.k_proj = make_linear_from_state(sd, f"{p}.k_proj")
449
+ self.v_proj = make_linear_from_state(sd, f"{p}.v_proj")
450
+ self.o_proj = make_linear_from_state(sd, f"{p}.o_proj")
451
+
452
+ rope_theta = float(getattr(cfg, "rope_theta", 1000000.0))
453
+ max_pos = int(getattr(cfg, "max_position_embeddings", 32768))
454
+ self.rotary_emb = LuluRotaryEmbedding(self.head_dim, max_position_embeddings=max_pos, base=rope_theta)
455
+
456
+ def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
457
+ bsz, q_len, _ = hidden_states.size()
458
+
459
+ query_states = self.q_proj(hidden_states)
460
+ key_states = self.k_proj(hidden_states)
461
+ value_states = self.v_proj(hidden_states)
462
+
463
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
464
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
465
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
466
+
467
+ cos, sin = self.rotary_emb(value_states, position_ids)
468
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
469
+
470
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
471
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
472
+
473
+ # Full forward is causal. This generation script recomputes the full prefix each token.
474
+ attn_output = F.scaled_dot_product_attention(
475
+ query_states,
476
+ key_states,
477
+ value_states,
478
+ attn_mask=None,
479
+ dropout_p=0.0,
480
+ is_causal=True,
481
+ scale=self.scaling,
482
+ )
483
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
484
+ return self.o_proj(attn_output)
485
+
486
+
487
+ class LuluVWMDecoderLayer(nn.Module):
488
+ def __init__(self, cfg, sd: Dict[str, torch.Tensor], layer_idx: int):
489
+ super().__init__()
490
+ self.self_attn = LuluVWMAttention(cfg, sd, layer_idx)
491
+ self.mlp = LuluVWMMLP(cfg, sd, layer_idx)
492
+ self.input_layernorm = LuluRMSNorm(cfg.hidden_size, eps=getattr(cfg, "rms_norm_eps", 1e-6))
493
+ self.post_attention_layernorm = LuluRMSNorm(cfg.hidden_size, eps=getattr(cfg, "rms_norm_eps", 1e-6))
494
+
495
+ def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
496
+ residual = hidden_states
497
+ hidden_states = self.input_layernorm(hidden_states)
498
+ hidden_states = self.self_attn(hidden_states, position_ids=position_ids)
499
+ hidden_states = residual + hidden_states
500
+
501
+ residual = hidden_states
502
+ hidden_states = self.post_attention_layernorm(hidden_states)
503
+ hidden_states = self.mlp(hidden_states)
504
+ hidden_states = residual + hidden_states
505
+ return hidden_states
506
+
507
+
508
+ class LuluVWMModel(nn.Module):
509
+ def __init__(self, cfg, sd: Dict[str, torch.Tensor]):
510
+ super().__init__()
511
+ self.config = cfg
512
+ vocab_size, hidden_size = embedding_shape_from_state(sd, "model.embed_tokens")
513
+ self.embed_tokens = make_embedding_from_state(sd, "model.embed_tokens")
514
+ self.layers = nn.ModuleList([LuluVWMDecoderLayer(cfg, sd, i) for i in range(int(cfg.num_hidden_layers))])
515
+ self.norm = LuluRMSNorm(hidden_size, eps=getattr(cfg, "rms_norm_eps", 1e-6))
516
+
517
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
518
+ bsz, seq_len = input_ids.shape
519
+ if position_ids is None:
520
+ position_ids = torch.arange(seq_len, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
521
+ hidden_states = self.embed_tokens(input_ids)
522
+ for layer in self.layers:
523
+ hidden_states = layer(hidden_states, position_ids=position_ids)
524
+ return self.norm(hidden_states)
525
+
526
+
527
+ class LuluVWMForCausalLM(nn.Module):
528
+ def __init__(self, cfg, sd: Dict[str, torch.Tensor]):
529
+ super().__init__()
530
+ self.config = cfg
531
+ self.model = LuluVWMModel(cfg, sd)
532
+ _, hidden_size = embedding_shape_from_state(sd, "model.embed_tokens")
533
+ self.tie_word_embeddings = bool(getattr(cfg, "tie_word_embeddings", False))
534
+ if module_has_vwm(sd, "lm_head") or "lm_head.weight" in sd:
535
+ self.lm_head = make_linear_from_state(sd, "lm_head")
536
+ else:
537
+ self.tie_word_embeddings = True
538
+ self.lm_head = TiedEmbeddingLMHead(self.model.embed_tokens)
539
+
540
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None):
541
+ hidden_states = self.model(input_ids=input_ids, position_ids=position_ids)
542
+ logits = self.lm_head(hidden_states)
543
+ return SimpleNamespace(logits=logits)
544
+
545
+
546
+ # -----------------------------
547
+ # config loading / inference
548
+ # -----------------------------
549
+
550
+
551
+ def infer_minimal_config_from_state(sd: Dict[str, torch.Tensor], model_id: str = "") -> SimpleNamespace:
552
+ if "model.embed_tokens.weight" in sd:
553
+ hidden_size = int(sd["model.embed_tokens.weight"].shape[1])
554
+ vocab_size = int(sd["model.embed_tokens.weight"].shape[0])
555
+ elif module_has_vwm_embedding(sd, "model.embed_tokens"):
556
+ vocab_size = int(sd["model.embed_tokens.A"].shape[0])
557
+ hidden_size = int(sd["model.embed_tokens.B"].shape[0])
558
+ else:
559
+ raise ValueError("Checkpoint is missing model.embed_tokens dense or VWM tensors. Use a full standalone checkpoint, not a delta checkpoint.")
560
+ layer_ids = []
561
+ for k in sd.keys():
562
+ if k.startswith("model.layers."):
563
+ try:
564
+ layer_ids.append(int(k.split(".")[2]))
565
+ except Exception:
566
+ pass
567
+ num_hidden_layers = max(layer_ids) + 1 if layer_ids else 0
568
+ inter_key = "model.layers.0.mlp.gate_proj.weight"
569
+ if inter_key in sd:
570
+ intermediate_size = int(sd[inter_key].shape[0])
571
+ else:
572
+ intermediate_size = 4864
573
+
574
+ # Best known defaults for LULU2. If you export
575
+ # model_config into the checkpoint, these assumptions are not used.
576
+ num_attention_heads = 14
577
+ num_key_value_heads = 2
578
+ head_dim = hidden_size // num_attention_heads
579
+ if head_dim * num_attention_heads != hidden_size:
580
+ # Fallback if a different decoder variant is used and no config is present.
581
+ # This requires explicit command-line override in practice.
582
+ num_attention_heads = 1
583
+ num_key_value_heads = 1
584
+ head_dim = hidden_size
585
+
586
+ return SimpleNamespace(
587
+ model_type="luluv2",
588
+ model_id=model_id,
589
+ vocab_size=vocab_size,
590
+ hidden_size=hidden_size,
591
+ intermediate_size=intermediate_size,
592
+ num_hidden_layers=num_hidden_layers,
593
+ num_attention_heads=num_attention_heads,
594
+ num_key_value_heads=num_key_value_heads,
595
+ head_dim=head_dim,
596
+ rms_norm_eps=1e-6,
597
+ rope_theta=1000000.0,
598
+ max_position_embeddings=32768,
599
+ attention_dropout=0.0,
600
+ tie_word_embeddings=False,
601
+ )
602
+
603
+
604
+ def namespace_from_dict(d: Dict) -> SimpleNamespace:
605
+ return SimpleNamespace(**d)
606
+
607
+
608
+ def load_runtime_config(ckpt: Dict, sd: Dict[str, torch.Tensor], args) -> SimpleNamespace:
609
+ if "model_config" in ckpt and isinstance(ckpt["model_config"], dict):
610
+ print("[config] using model_config embedded in checkpoint")
611
+ d = dict(ckpt["model_config"])
612
+ if ckpt.get("tie_word_embeddings") is True:
613
+ d["tie_word_embeddings"] = True
614
+ return namespace_from_dict(d)
615
+
616
+ model_id = args.model_id or ckpt.get("model_id") or ckpt.get("args", {}).get("model_id") or "LULU2"
617
+
618
+ if args.no_config_download:
619
+ print("[config] no embedded config and --no-config-download set; using LULU2 defaults")
620
+ cfg = infer_minimal_config_from_state(sd, model_id=model_id)
621
+ if ckpt.get("tie_word_embeddings") is True:
622
+ cfg.tie_word_embeddings = True
623
+ return cfg
624
+
625
+ print(f"[config] loading config metadata only from {model_id}; no model weights are loaded")
626
+ cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
627
+ return cfg
628
+
629
+
630
+ # -----------------------------
631
+ # generation
632
+ # -----------------------------
633
+
634
+
635
+ def build_chat_prompt(tokenizer, user_prompt: str, system_prompt: str = "You are a helpful assistant. Answer directly and naturally.") -> str:
636
+ messages = []
637
+ if system_prompt:
638
+ messages.append({"role": "system", "content": system_prompt})
639
+ messages.append({"role": "user", "content": user_prompt})
640
+ try:
641
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
642
+ except Exception:
643
+ return f"system\n{system_prompt}\nuser\n{user_prompt}\nassistant\n"
644
+
645
+
646
+ @torch.no_grad()
647
+ def sample_next(logits: torch.Tensor, temperature: float = 0.0, top_k: int = 0, top_p: float = 1.0) -> torch.Tensor:
648
+ if temperature <= 0.0:
649
+ return torch.argmax(logits, dim=-1, keepdim=True)
650
+
651
+ logits = logits / max(temperature, 1e-6)
652
+ if top_k and top_k > 0:
653
+ k = min(int(top_k), logits.size(-1))
654
+ thresh = torch.topk(logits, k, dim=-1).values[..., -1, None]
655
+ logits = torch.where(logits >= thresh, logits, torch.full_like(logits, -float("inf")))
656
+ if top_p < 1.0:
657
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
658
+ probs = torch.softmax(sorted_logits, dim=-1)
659
+ cumulative_probs = torch.cumsum(probs, dim=-1)
660
+ sorted_indices_to_remove = cumulative_probs > top_p
661
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
662
+ sorted_indices_to_remove[..., 0] = False
663
+ sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, -float("inf"))
664
+ logits = torch.full_like(logits, -float("inf")).scatter(1, sorted_indices, sorted_logits)
665
+ probs = torch.softmax(logits, dim=-1)
666
+ return torch.multinomial(probs, num_samples=1)
667
+
668
+
669
+ @torch.no_grad()
670
+ def generate_text(model, tokenizer, prompt: str, device, max_new_tokens: int = 120, temperature: float = 0.0, top_k: int = 0, top_p: float = 1.0, max_context: int = 2048) -> Tuple[str, float]:
671
+ model.eval()
672
+ enc = tokenizer(prompt, return_tensors="pt")
673
+ input_ids = enc.input_ids.to(device)
674
+ eos_id = tokenizer.eos_token_id
675
+ t0 = time.time()
676
+ start_len = int(input_ids.shape[1])
677
+
678
+ for _ in range(max_new_tokens):
679
+ ctx = input_ids[:, -max_context:]
680
+ out = model(ctx)
681
+ next_logits = out.logits[:, -1, :].float()
682
+ next_id = sample_next(next_logits, temperature=temperature, top_k=top_k, top_p=top_p)
683
+ input_ids = torch.cat([input_ids, next_id.to(input_ids.device)], dim=-1)
684
+ if eos_id is not None and int(next_id.item()) == int(eos_id):
685
+ break
686
+
687
+ dt = time.time() - t0
688
+ new_tokens = max(1, int(input_ids.shape[1]) - start_len)
689
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True), new_tokens / max(dt, 1e-9)
690
+
691
+
692
+ def load_tokenizer(args, ckpt):
693
+ tok_path = args.tokenizer or ckpt.get("tokenizer_dir") or ckpt.get("model_id") or ckpt.get("args", {}).get("model_id") or args.model_id
694
+ if not tok_path:
695
+ raise ValueError("Tokenizer path/name is required. Pass --tokenizer <local-dir-or-model-id>.")
696
+ # If checkpoint stores a relative tokenizer_dir like "tokenizer", resolve it
697
+ # relative to the checkpoint location so no HF lookup is needed.
698
+ ckpt_dir = os.path.dirname(os.path.abspath(args.checkpoint))
699
+ if tok_path and not os.path.isabs(tok_path):
700
+ maybe_local = os.path.join(ckpt_dir, tok_path)
701
+ if os.path.isdir(maybe_local):
702
+ tok_path = maybe_local
703
+ print(f"[tokenizer] {tok_path}")
704
+ tok = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True, local_files_only=bool(args.local_files_only))
705
+ if tok.pad_token_id is None and tok.eos_token_id is not None:
706
+ tok.pad_token = tok.eos_token
707
+ return tok
708
+
709
+
710
+ # -----------------------------
711
+ # main
712
+
713
+ # Public model aliases used by the UI/runtime.
714
+ Lulu2RMSNorm = LuluRMSNorm
715
+ Lulu2RotaryEmbedding = LuluRotaryEmbedding
716
+ Lulu2VWMMLP = LuluVWMMLP
717
+ Lulu2VWMAttention = LuluVWMAttention
718
+ Lulu2VWMDecoderLayer = LuluVWMDecoderLayer
719
+ Lulu2VWMModel = LuluVWMModel
720
+ Lulu2ForCausalLM = LuluVWMForCausalLM
721
+
722
+
723
+ class Pass2RefinementAdapter(nn.Module):
724
+ """Small gated residual adapter conditioned on pass-1 layer state."""
725
+
726
+ def __init__(self, hidden_size: int, rank: int, gate_init: float = -5.0):
727
+ super().__init__()
728
+ self.hidden_size = int(hidden_size)
729
+ self.rank = int(rank)
730
+ self.x_norm = LuluRMSNorm(hidden_size)
731
+ self.cond_norm = LuluRMSNorm(hidden_size)
732
+ self.down = nn.Linear(2 * hidden_size, rank, bias=False)
733
+ self.up = nn.Linear(rank, hidden_size, bias=False)
734
+ self.gate = nn.Parameter(torch.tensor(float(gate_init)))
735
+
736
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02 / math.sqrt(max(1, hidden_size)))
737
+ # Zero init means the two-pass model starts exactly as pass 1.
738
+ nn.init.zeros_(self.up.weight)
739
+
740
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
741
+ z = torch.cat([self.x_norm(x), self.cond_norm(cond)], dim=-1)
742
+ delta = self.up(F.silu(self.down(z)))
743
+ return torch.sigmoid(self.gate).to(dtype=x.dtype) * delta
744
+
745
+
746
+ @dataclass
747
+ class Pass2Config:
748
+ adapter_rank: int = 64
749
+ adapter_gate_init: float = -5.0
750
+ layer_gate_init: float = -5.0
751
+ pass_embed_scale: float = 0.0
752
+ mode: str = "refine_pass1_residual"
753
+
754
+
755
+ class Lulu2TwoPassForCausalLM(nn.Module):
756
+ """
757
+ Wraps a loaded LULU2 base model.
758
+
759
+ Pass 1: normal LULU2 decoder forward, producing the pass-1 residual stream.
760
+ Pass 2: starts from pass-1 residual stream and adds small gated refinements.
761
+
762
+ h2_i = h2_i + sigmoid(layer_gate_i) * (BaseLayer_i(h2_i) - h2_i)
763
+ + Adapter_i(h2_i, pass1_state_i)
764
+
765
+ With zero-initialized adapter up-projections and negative gates, the model
766
+ starts extremely close to the loaded LULU2 checkpoint and learns refinements.
767
+ """
768
+
769
+ def __init__(self, base: Lulu2ForCausalLM, cfg: Pass2Config):
770
+ super().__init__()
771
+ self.base = base
772
+ self.pass2_config = cfg
773
+ hidden = int(base.config.hidden_size)
774
+ n_layers = int(base.config.num_hidden_layers)
775
+ self.pass_embed = nn.Parameter(torch.randn(2, hidden) * float(cfg.pass_embed_scale))
776
+ self.layer_gates = nn.Parameter(torch.full((n_layers,), float(cfg.layer_gate_init)))
777
+ self.adapters = nn.ModuleList([
778
+ Pass2RefinementAdapter(hidden, int(cfg.adapter_rank), gate_init=float(cfg.adapter_gate_init))
779
+ for _ in range(n_layers)
780
+ ])
781
+
782
+ def _position_ids(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None):
783
+ if position_ids is not None:
784
+ return position_ids
785
+ bsz, seq_len = input_ids.shape
786
+ return torch.arange(seq_len, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
787
+
788
+ def forward_pass1_features(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None):
789
+ position_ids = self._position_ids(input_ids, position_ids)
790
+ h = self.base.model.embed_tokens(input_ids)
791
+ h = h + self.pass_embed[0].to(dtype=h.dtype).view(1, 1, -1)
792
+ layer_states = []
793
+ for layer in self.base.model.layers:
794
+ h = layer(h, position_ids=position_ids)
795
+ layer_states.append(h)
796
+ return h, layer_states, position_ids
797
+
798
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, return_pass1_logits: bool = False):
799
+ h1_resid, pass1_states, position_ids = self.forward_pass1_features(input_ids, position_ids=position_ids)
800
+ h1 = self.base.model.norm(h1_resid)
801
+
802
+ # Pass 2 refines pass 1; it does not discard pass 1.
803
+ h2 = h1_resid + self.pass_embed[1].to(dtype=h1_resid.dtype).view(1, 1, -1)
804
+ for i, layer in enumerate(self.base.model.layers):
805
+ before = h2
806
+ layer_out = layer(h2, position_ids=position_ids)
807
+ layer_delta = layer_out - before
808
+ layer_gate = torch.sigmoid(self.layer_gates[i]).to(dtype=h2.dtype)
809
+ adapter_delta = self.adapters[i](h2, pass1_states[i])
810
+ h2 = before + layer_gate * layer_delta + adapter_delta
811
+
812
+ h2 = self.base.model.norm(h2)
813
+ logits2 = self.base.lm_head(h2)
814
+
815
+ if return_pass1_logits:
816
+ with torch.no_grad():
817
+ logits1 = self.base.lm_head(h1)
818
+ else:
819
+ logits1 = None
820
+ return SimpleNamespace(logits=logits2, pass1_logits=logits1)
821
+
822
+
823
+ @torch.no_grad()
824
+
825
+ def load_lulu2_base(args, device, dtype):
826
+ print("[guard] LULUV2 VWM runtime: no AutoModelForCausalLM.from_pretrained call and no external-model weights loaded.")
827
+ print(f"[load] {args.checkpoint} ({human_bytes(os.path.getsize(args.checkpoint))})")
828
+ ckpt = safe_torch_load(args.checkpoint, map_location="cpu")
829
+ if "model" not in ckpt:
830
+ raise ValueError("Checkpoint missing model state dict")
831
+ sd = expand_shared_banks_into_state(ckpt, ckpt["model"])
832
+ cfg = load_runtime_config(ckpt, sd, args)
833
+ print(f"[config] hidden={cfg.hidden_size} layers={cfg.num_hidden_layers}")
834
+ base = Lulu2ForCausalLM(cfg, sd)
835
+ missing, unexpected = base.load_state_dict(sd, strict=False)
836
+ print(f"[state:base] missing={len(missing)} unexpected={len(unexpected)}")
837
+ if missing:
838
+ print("[state:base] first missing:", missing[:10])
839
+ if unexpected:
840
+ print("[state:base] first unexpected:", unexpected[:10])
841
+ base.to(device=device, dtype=dtype)
842
+ return ckpt, base
luluv2_live_inference.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ LULUV2 live local inference engine.
5
+
6
+ This is the runtime bridge for the LULUV2 fine-tuned checkpoint:
7
+ LULU2_instruct_ddp.pt / LULU2_base_ddp.pt / LULU2.pt
8
+
9
+ It imports the actual LULUV2 architecture file, loads the checkpoint,
10
+ restores pass2_state when present, uses the local tokenizer folder, and streams
11
+ tokens with live metrics.
12
+
13
+ No AutoModelForCausalLM.from_pretrained call is used here.
14
+ No external model weights are loaded.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import importlib.util
20
+ import json
21
+ import math
22
+ import os
23
+ import platform
24
+ import time
25
+ from contextlib import nullcontext
26
+ from dataclasses import dataclass, asdict
27
+ from pathlib import Path
28
+ from types import SimpleNamespace
29
+ from typing import Any, Dict, Generator, List, Optional, Tuple
30
+
31
+ import torch
32
+ import torch.nn.functional as F
33
+
34
+ try:
35
+ import psutil
36
+ except Exception:
37
+ psutil = None
38
+
39
+ try:
40
+ import pynvml
41
+ except Exception:
42
+ pynvml = None
43
+
44
+
45
+ STOP_STRINGS = [
46
+ "<|im_start|>",
47
+ "<|im_end|>",
48
+ "<|user|>",
49
+ "<|system|>",
50
+ "<|assistant|>",
51
+ "User:",
52
+ "Assistant:",
53
+ "\nuser:",
54
+ "\nassistant:",
55
+ ]
56
+
57
+
58
+ @dataclass
59
+ class GenerationConfig:
60
+ max_new_tokens: int = 512
61
+ temperature: float = 0.65
62
+ top_k: int = 40
63
+ top_p: float = 0.90
64
+ min_p: float = 0.03
65
+ repetition_penalty: float = 1.10
66
+ frequency_penalty: float = 0.02
67
+ greedy: bool = False
68
+ no_repeat_ngram: int = 4
69
+ stream_every: int = 1
70
+ max_context_tokens: int = 4096
71
+ return_pass_metrics: bool = True
72
+
73
+
74
+ @dataclass
75
+ class GenerationStats:
76
+ prompt_tokens: int = 0
77
+ generated_tokens: int = 0
78
+ elapsed_sec: float = 0.0
79
+ tokens_per_sec: float = 0.0
80
+ last_token: str = ""
81
+ last_token_id: int = -1
82
+ last_token_prob: float = 0.0
83
+ last_entropy: float = 0.0
84
+ finish_reason: str = "none"
85
+ pass1_pass2_kl: Optional[float] = None
86
+ pass1_pass2_logit_cosine: Optional[float] = None
87
+
88
+
89
+ def setup_torch():
90
+ if torch.cuda.is_available():
91
+ try:
92
+ torch.backends.cuda.matmul.allow_tf32 = True
93
+ torch.backends.cudnn.allow_tf32 = True
94
+ except Exception:
95
+ pass
96
+ try:
97
+ torch.backends.cuda.enable_flash_sdp(True)
98
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
99
+ torch.backends.cuda.enable_math_sdp(False)
100
+ except Exception:
101
+ pass
102
+ if hasattr(torch, "set_float32_matmul_precision"):
103
+ try:
104
+ torch.set_float32_matmul_precision("high")
105
+ except Exception:
106
+ pass
107
+
108
+
109
+ def human_bytes(num: float) -> str:
110
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
111
+ if abs(num) < 1024.0:
112
+ return f"{num:.2f} {unit}"
113
+ num /= 1024.0
114
+ return f"{num:.2f} PB"
115
+
116
+
117
+ def _value_to_text(value: Any) -> str:
118
+ """Coerce Gradio/Textbox/Multimodal values into plain text.
119
+
120
+ Some Gradio versions send messages as {"text": ..., "files": ...} or
121
+ content blocks like [{"type": "text", "text": ...}]. The local UI is
122
+ text-only, so we aggressively unwrap these before tokenization.
123
+ """
124
+ if value is None:
125
+ return ""
126
+ if isinstance(value, str):
127
+ return value
128
+ if isinstance(value, dict):
129
+ if "text" in value:
130
+ return _value_to_text(value.get("text"))
131
+ if "content" in value:
132
+ return _value_to_text(value.get("content"))
133
+ if "value" in value:
134
+ return _value_to_text(value.get("value"))
135
+ return "\n".join(_value_to_text(v) for v in value.values() if _value_to_text(v))
136
+ if isinstance(value, (list, tuple)):
137
+ return "\n".join(_value_to_text(v) for v in value if _value_to_text(v))
138
+ return str(value)
139
+
140
+
141
+ def clean_text(text: Any) -> str:
142
+ text = _value_to_text(text).replace("\\n", "\n")
143
+ # Cut only after obvious turn-control strings that appear in generated text.
144
+ cut_points = [text.find(s) for s in STOP_STRINGS if s in text and text.find(s) > 0]
145
+ if cut_points:
146
+ text = text[: min(cut_points)]
147
+ for s in STOP_STRINGS:
148
+ text = text.replace(s, "")
149
+ # Remove common role remnants and JSON-ish UI artifacts.
150
+ for prefix in ("assistant\n", "Assistant:", "Lulu:", "assistant:"):
151
+ if text.lstrip().startswith(prefix):
152
+ text = text.lstrip()[len(prefix):]
153
+ text = text.replace("{'type': 'text'}", "").replace('{"type": "text"}', "")
154
+ text = "\n".join(line.rstrip() for line in text.strip().splitlines())
155
+ text = "\n".join(line for line in text.splitlines() if not line.strip().startswith("type: 'text'"))
156
+ return text.strip()
157
+
158
+
159
+ def normalize_history(history: Any) -> List[Dict[str, str]]:
160
+ out: List[Dict[str, str]] = []
161
+ if not history:
162
+ return out
163
+ for item in history:
164
+ if isinstance(item, dict):
165
+ role = item.get("role")
166
+ content = clean_text(item.get("content", ""))
167
+ if role in {"user", "assistant"} and content:
168
+ out.append({"role": role, "content": content})
169
+ elif isinstance(item, (tuple, list)) and len(item) >= 2:
170
+ u = clean_text(item[0])
171
+ a = clean_text(item[1])
172
+ if u:
173
+ out.append({"role": "user", "content": u})
174
+ if a:
175
+ out.append({"role": "assistant", "content": a})
176
+ return out
177
+
178
+
179
+ def resolve_model_py(model_py: Optional[str] = None) -> str:
180
+ candidates = []
181
+ if model_py:
182
+ candidates.append(model_py)
183
+ candidates.extend(["luluv2_inference_runtime.py"])
184
+ for c in candidates:
185
+ p = Path(c)
186
+ if p.exists():
187
+ return str(p.resolve())
188
+ raise FileNotFoundError(
189
+ "Could not find the LULUV2 model file. Pass --model-py or put "
190
+ "luluv2_inference_runtime.py next to this UI."
191
+ )
192
+
193
+
194
+ def import_model_py(model_py: Optional[str] = None):
195
+ path = resolve_model_py(model_py)
196
+ spec = importlib.util.spec_from_file_location("luluv2_runtime_module", path)
197
+ if spec is None or spec.loader is None:
198
+ raise RuntimeError(f"Could not import model file: {path}")
199
+ mod = importlib.util.module_from_spec(spec)
200
+ spec.loader.exec_module(mod)
201
+ return mod, path
202
+
203
+
204
+ class LULUV2LiveEngine:
205
+ def __init__(
206
+ self,
207
+ ckpt_path: str,
208
+ model_py: Optional[str] = None,
209
+ tokenizer_dir: Optional[str] = None,
210
+ device: Optional[str] = None,
211
+ dtype: str = "bf16",
212
+ local_files_only: bool = True,
213
+ no_config_download: bool = True,
214
+ force_base_only: bool = False,
215
+ ):
216
+ setup_torch()
217
+ self.ckpt_path = str(ckpt_path)
218
+ self.ckpt_dir = Path(self.ckpt_path).resolve().parent
219
+ self.device = self._select_device(device)
220
+ self.dtype = self._dtype_from_name(dtype)
221
+ self.local_files_only = bool(local_files_only)
222
+ self.no_config_download = bool(no_config_download)
223
+ self.force_base_only = bool(force_base_only)
224
+ self.last_stats = GenerationStats()
225
+ self.recent_tokens: List[Dict[str, Any]] = []
226
+
227
+ self.goku, self.model_py_path = import_model_py(model_py)
228
+
229
+ # args object expected by the embedded LULUV2 runtime helpers.
230
+ self.args = SimpleNamespace(
231
+ checkpoint=self.ckpt_path,
232
+ tokenizer=tokenizer_dir or "",
233
+ model_id="",
234
+ no_config_download=self.no_config_download,
235
+ local_files_only=self.local_files_only,
236
+ )
237
+
238
+ print("[guard] LULUV2 local UI: no AutoModelForCausalLM.from_pretrained call and no external model weights loaded.")
239
+ print(f"[load] checkpoint={self.ckpt_path}")
240
+ self.base_ckpt, base = self.goku.load_lulu2_base(self.args, self.device, self.dtype)
241
+
242
+ self.tokenizer = self._load_tokenizer(tokenizer_dir)
243
+ self.model, self.has_pass2 = self._maybe_wrap_pass2(base)
244
+ self.model.eval()
245
+
246
+ self.model_info = self._build_model_info()
247
+
248
+ def _select_device(self, device: Optional[str]):
249
+ if device:
250
+ return torch.device(device)
251
+ if torch.cuda.is_available():
252
+ return torch.device("cuda")
253
+ return torch.device("cpu")
254
+
255
+ def _dtype_from_name(self, name: str):
256
+ name = (name or "bf16").lower()
257
+ if name in {"bf16", "bfloat16"}:
258
+ return torch.bfloat16
259
+ if name in {"fp16", "float16", "half"}:
260
+ return torch.float16
261
+ return torch.float32
262
+
263
+ def _load_tokenizer(self, tokenizer_dir: Optional[str]):
264
+ # Prefer explicit path, then sibling tokenizer folder, then checkpoint metadata.
265
+ if tokenizer_dir:
266
+ self.args.tokenizer = tokenizer_dir
267
+ else:
268
+ sibling = self.ckpt_dir / "tokenizer"
269
+ if sibling.is_dir():
270
+ self.args.tokenizer = str(sibling)
271
+ tok = self.goku.load_tokenizer(self.args, self.base_ckpt)
272
+ if getattr(tok, "pad_token_id", None) is None and getattr(tok, "eos_token_id", None) is not None:
273
+ try:
274
+ tok.pad_token = tok.eos_token
275
+ except Exception:
276
+ pass
277
+ return tok
278
+
279
+ def _maybe_wrap_pass2(self, base):
280
+ ckpt = self.base_ckpt
281
+ if self.force_base_only or "pass2_state" not in ckpt:
282
+ print("[pass2] no pass2_state loaded; running base LULUV2 forward")
283
+ return base.to(self.device).eval(), False
284
+
285
+ cfg_dict = dict(ckpt.get("pass2_config") or {})
286
+ Pass2Config = self.goku.Pass2Config
287
+ pass2_cfg = Pass2Config(**{k: v for k, v in cfg_dict.items() if k in Pass2Config.__dataclass_fields__})
288
+ model = self.goku.Lulu2TwoPassForCausalLM(base, pass2_cfg)
289
+ missing, unexpected = model.load_state_dict(ckpt["pass2_state"], strict=False)
290
+ print(f"[pass2] loaded pass2_state missing={len(missing)} unexpected={len(unexpected)}")
291
+ model.to(device=self.device, dtype=self.dtype)
292
+ model.eval()
293
+ return model, True
294
+
295
+ def _build_model_info(self) -> Dict[str, Any]:
296
+ total_params = sum(p.numel() for p in self.model.parameters())
297
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
298
+ c_codes = [(n, p.numel()) for n, p in self.model.named_parameters() if n.endswith(".c")]
299
+ gate_mean = None
300
+ adapter_gate_mean = None
301
+ if self.has_pass2:
302
+ with torch.no_grad():
303
+ gate_mean = float(torch.sigmoid(self.model.layer_gates.float()).mean().item())
304
+ vals = []
305
+ for ad in self.model.adapters:
306
+ vals.append(float(torch.sigmoid(ad.gate.float()).item()))
307
+ adapter_gate_mean = float(sum(vals) / max(1, len(vals)))
308
+ ckpt_size = Path(self.ckpt_path).stat().st_size if Path(self.ckpt_path).exists() else 0
309
+ cfg = getattr(self.model.base if self.has_pass2 else self.model, "config", None)
310
+ return {
311
+ "checkpoint": self.ckpt_path,
312
+ "checkpoint_size": human_bytes(ckpt_size),
313
+ "model_py": self.model_py_path,
314
+ "device": str(self.device),
315
+ "dtype": str(self.dtype).replace("torch.", ""),
316
+ "has_pass2": self.has_pass2,
317
+ "total_params": total_params,
318
+ "trainable_params": trainable_params,
319
+ "vwm_c_modules": len(c_codes),
320
+ "vwm_c_params": sum(n for _, n in c_codes),
321
+ "pass2_layer_gate_mean": gate_mean,
322
+ "pass2_adapter_gate_mean": adapter_gate_mean,
323
+ "hidden_size": getattr(cfg, "hidden_size", None),
324
+ "layers": getattr(cfg, "num_hidden_layers", None),
325
+ "heads": getattr(cfg, "num_attention_heads", None),
326
+ "kv_heads": getattr(cfg, "num_key_value_heads", None),
327
+ "max_position_embeddings": getattr(cfg, "max_position_embeddings", None),
328
+ }
329
+
330
+ def amp_context(self):
331
+ if self.device.type == "cuda" and self.dtype in (torch.bfloat16, torch.float16):
332
+ return torch.autocast("cuda", dtype=self.dtype)
333
+ return nullcontext()
334
+
335
+ def build_chat_prompt(
336
+ self,
337
+ message: str,
338
+ history: Any,
339
+ system_prompt: str,
340
+ memory_notes: str = "",
341
+ history_turns: int = 4,
342
+ extra_context: str = "",
343
+ ) -> str:
344
+ history = normalize_history(history)
345
+ recent = history[-max(0, int(history_turns)) * 2:] if history_turns else []
346
+ system_chunks = []
347
+ if system_prompt.strip():
348
+ system_chunks.append(system_prompt.strip())
349
+ if memory_notes.strip():
350
+ system_chunks.append("Useful memory notes:\n" + memory_notes.strip())
351
+ if extra_context.strip():
352
+ system_chunks.append("Relevant local context:\n" + extra_context.strip())
353
+ system = "\n\n".join(system_chunks)
354
+
355
+ messages = []
356
+ if system:
357
+ messages.append({"role": "system", "content": system})
358
+ messages.extend(recent)
359
+ messages.append({"role": "user", "content": clean_text(message)})
360
+ try:
361
+ return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
362
+ except Exception:
363
+ parts = []
364
+ if system:
365
+ parts.append(f"<|im_start|>system\n{system}<|im_end|>")
366
+ for item in recent:
367
+ parts.append(f"<|im_start|>{item['role']}\n{item['content']}<|im_end|>")
368
+ parts.append(f"<|im_start|>user\n{clean_text(message)}<|im_end|>")
369
+ parts.append("<|im_start|>assistant\n")
370
+ return "\n".join(parts)
371
+
372
+ def encode(self, text: str, max_context_tokens: int = 4096) -> torch.Tensor:
373
+ enc = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=int(max_context_tokens))
374
+ ids = enc.input_ids.to(self.device)
375
+ return ids
376
+
377
+ @torch.no_grad()
378
+ def pass_metrics_for_ids(self, ids: torch.Tensor) -> Tuple[Optional[float], Optional[float]]:
379
+ if not self.has_pass2:
380
+ return None, None
381
+ try:
382
+ with self.amp_context():
383
+ out = self.model(ids, return_pass1_logits=True)
384
+ if out.pass1_logits is None:
385
+ return None, None
386
+ l1 = out.pass1_logits[:, -1, :].float()
387
+ l2 = out.logits[:, -1, :].float()
388
+ kl = F.kl_div(F.log_softmax(l2, dim=-1), F.softmax(l1, dim=-1), reduction="batchmean")
389
+ cos = F.cosine_similarity(l1, l2, dim=-1).mean()
390
+ return float(kl.item()), float(cos.item())
391
+ except Exception as exc:
392
+ print(f"[metrics] pass metrics failed: {type(exc).__name__}: {exc}")
393
+ return None, None
394
+
395
+ def _apply_penalties(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> torch.Tensor:
396
+ if generated.numel() == 0:
397
+ return logits
398
+ out = logits.clone()
399
+ uniq, counts = torch.unique(generated.view(-1), return_counts=True)
400
+ if cfg.repetition_penalty != 1.0:
401
+ selected = out[:, uniq]
402
+ selected = torch.where(selected > 0, selected / float(cfg.repetition_penalty), selected * float(cfg.repetition_penalty))
403
+ out[:, uniq] = selected
404
+ if cfg.frequency_penalty:
405
+ out[:, uniq] -= float(cfg.frequency_penalty) * counts.to(out.dtype).unsqueeze(0)
406
+ n = int(cfg.no_repeat_ngram)
407
+ if n > 1 and generated.size(1) >= n - 1:
408
+ seq = generated[0].tolist()
409
+ prefix = tuple(seq[-(n - 1):])
410
+ banned = []
411
+ for i in range(len(seq) - n + 1):
412
+ if tuple(seq[i:i + n - 1]) == prefix:
413
+ banned.append(seq[i + n - 1])
414
+ if banned:
415
+ out[:, list(set(banned))] = -float("inf")
416
+ return out
417
+
418
+ @torch.no_grad()
419
+ def _sample_next(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, Dict[str, Any]]:
420
+ work = self._apply_penalties(logits.float(), generated, cfg)
421
+ if cfg.greedy or cfg.temperature <= 0:
422
+ probs = torch.softmax(work, dim=-1)
423
+ next_id = torch.argmax(work, dim=-1, keepdim=True)
424
+ else:
425
+ work = work / max(float(cfg.temperature), 1e-6)
426
+ if cfg.top_k > 0:
427
+ k = min(int(cfg.top_k), work.size(-1))
428
+ thresh = torch.topk(work, k, dim=-1).values[..., -1, None]
429
+ work = torch.where(work >= thresh, work, torch.full_like(work, -float("inf")))
430
+ if 0.0 < cfg.top_p < 1.0:
431
+ sorted_logits, sorted_idx = torch.sort(work, descending=True, dim=-1)
432
+ sorted_probs = torch.softmax(sorted_logits, dim=-1)
433
+ cumprobs = torch.cumsum(sorted_probs, dim=-1)
434
+ remove = cumprobs > float(cfg.top_p)
435
+ shifted = remove.clone()
436
+ shifted[..., 1:] = remove[..., :-1]
437
+ shifted[..., 0] = False
438
+ sorted_logits = sorted_logits.masked_fill(shifted, -float("inf"))
439
+ work = torch.full_like(work, -float("inf")).scatter(1, sorted_idx, sorted_logits)
440
+ if 0.0 < cfg.min_p < 1.0:
441
+ probs_for_minp = torch.softmax(work, dim=-1)
442
+ max_prob = probs_for_minp.max(dim=-1, keepdim=True).values
443
+ keep = probs_for_minp >= float(cfg.min_p) * max_prob
444
+ work = work.masked_fill(~keep, -float("inf"))
445
+ probs = torch.softmax(work, dim=-1)
446
+ if torch.isnan(probs).any() or not torch.isfinite(probs.sum()) or float(probs.sum()) <= 0:
447
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
448
+ probs = torch.softmax(logits.float(), dim=-1)
449
+ else:
450
+ next_id = torch.multinomial(probs, 1)
451
+
452
+ prob = float(probs.gather(1, next_id).item()) if probs.numel() else 0.0
453
+ entropy = float((-(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)).mean().item()) if probs.numel() else 0.0
454
+ return next_id, {"prob": prob, "entropy": entropy}
455
+
456
+ @torch.no_grad()
457
+ def generate(self, prompt: str, cfg: GenerationConfig) -> Generator[str, None, None]:
458
+ self.model.eval()
459
+ self.recent_tokens = []
460
+ ids = self.encode(prompt, max_context_tokens=cfg.max_context_tokens)
461
+ prompt_len = int(ids.shape[1])
462
+ t0 = time.time()
463
+ pass_kl, pass_cos = (None, None)
464
+ if cfg.return_pass_metrics:
465
+ pass_kl, pass_cos = self.pass_metrics_for_ids(ids)
466
+
467
+ eos_id = getattr(self.tokenizer, "eos_token_id", None)
468
+ last_text = ""
469
+ finish_reason = "length"
470
+
471
+ for step in range(int(cfg.max_new_tokens)):
472
+ ctx = ids[:, -int(cfg.max_context_tokens):]
473
+ with self.amp_context():
474
+ out = self.model(ctx)
475
+ logits = out.logits[:, -1, :].float()
476
+ generated = ids[:, prompt_len:]
477
+ next_id, tok_stats = self._sample_next(logits, generated, cfg)
478
+ ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
479
+
480
+ token_id = int(next_id.item())
481
+ token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
482
+ self.recent_tokens.append({
483
+ "i": step + 1,
484
+ "id": token_id,
485
+ "text": token_text,
486
+ "prob": tok_stats["prob"],
487
+ "entropy": tok_stats["entropy"],
488
+ })
489
+ self.recent_tokens = self.recent_tokens[-32:]
490
+
491
+ if eos_id is not None and token_id == int(eos_id):
492
+ finish_reason = "eos"
493
+ break
494
+
495
+ if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
496
+ raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
497
+ if any(s in raw for s in STOP_STRINGS):
498
+ finish_reason = "stop_string"
499
+ break
500
+ text = clean_text(raw)
501
+ if text and text != last_text:
502
+ elapsed = time.time() - t0
503
+ gen = int(ids.shape[1]) - prompt_len
504
+ self.last_stats = GenerationStats(
505
+ prompt_tokens=prompt_len,
506
+ generated_tokens=gen,
507
+ elapsed_sec=elapsed,
508
+ tokens_per_sec=gen / max(elapsed, 1e-9),
509
+ last_token=token_text,
510
+ last_token_id=token_id,
511
+ last_token_prob=tok_stats["prob"],
512
+ last_entropy=tok_stats["entropy"],
513
+ finish_reason="streaming",
514
+ pass1_pass2_kl=pass_kl,
515
+ pass1_pass2_logit_cosine=pass_cos,
516
+ )
517
+ last_text = text
518
+ yield text
519
+
520
+ raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
521
+ final = clean_text(raw)
522
+ elapsed = time.time() - t0
523
+ gen = int(ids.shape[1]) - prompt_len
524
+ self.last_stats = GenerationStats(
525
+ prompt_tokens=prompt_len,
526
+ generated_tokens=gen,
527
+ elapsed_sec=elapsed,
528
+ tokens_per_sec=gen / max(elapsed, 1e-9),
529
+ last_token=self.recent_tokens[-1]["text"] if self.recent_tokens else "",
530
+ last_token_id=self.recent_tokens[-1]["id"] if self.recent_tokens else -1,
531
+ last_token_prob=self.recent_tokens[-1]["prob"] if self.recent_tokens else 0.0,
532
+ last_entropy=self.recent_tokens[-1]["entropy"] if self.recent_tokens else 0.0,
533
+ finish_reason=finish_reason,
534
+ pass1_pass2_kl=pass_kl,
535
+ pass1_pass2_logit_cosine=pass_cos,
536
+ )
537
+ if final:
538
+ yield final
539
+
540
+ def stats_dict(self) -> Dict[str, Any]:
541
+ d = asdict(self.last_stats)
542
+ d["model"] = self.model_info
543
+ d["system"] = system_snapshot(self)
544
+ return d
545
+
546
+ def stats_text(self) -> str:
547
+ s = self.last_stats
548
+ lines = [
549
+ f"Prompt tokens: {s.prompt_tokens}",
550
+ f"Generated tokens: {s.generated_tokens}",
551
+ f"Elapsed: {s.elapsed_sec:.2f}s",
552
+ f"Decode speed: {s.tokens_per_sec:.2f} tok/s",
553
+ f"Finish reason: {s.finish_reason}",
554
+ f"Last token: {s.last_token!r} id={s.last_token_id} p={s.last_token_prob:.4f}",
555
+ f"Last entropy: {s.last_entropy:.3f}",
556
+ ]
557
+ if s.pass1_pass2_kl is not None:
558
+ lines.append(f"Pass1→Pass2 KL: {s.pass1_pass2_kl:.6f}")
559
+ if s.pass1_pass2_logit_cosine is not None:
560
+ lines.append(f"Pass1/Pass2 logit cosine: {s.pass1_pass2_logit_cosine:.6f}")
561
+ lines.extend([
562
+ "",
563
+ f"Checkpoint: {self.model_info['checkpoint']}",
564
+ f"Checkpoint size: {self.model_info['checkpoint_size']}",
565
+ f"Device: {self.model_info['device']} dtype={self.model_info['dtype']}",
566
+ f"Pass2 active: {self.model_info['has_pass2']}",
567
+ f"Params: {self.model_info['total_params']:,}",
568
+ f"VWM c modules: {self.model_info['vwm_c_modules']} ({self.model_info['vwm_c_params']:,} c params)",
569
+ f"Layer gate mean: {self.model_info['pass2_layer_gate_mean']}",
570
+ f"Adapter gate mean: {self.model_info['pass2_adapter_gate_mean']}",
571
+ ])
572
+ return "\n".join(lines)
573
+
574
+ def token_trace_text(self) -> str:
575
+ if not self.recent_tokens:
576
+ return "No tokens generated yet."
577
+ rows = []
578
+ for t in self.recent_tokens[-24:]:
579
+ safe = repr(t["text"])[1:-1]
580
+ rows.append(f"{t['i']:04d} id={t['id']:<7} p={t['prob']:.4f} H={t['entropy']:.2f} {safe}")
581
+ return "\n".join(rows)
582
+
583
+
584
+ def system_snapshot(engine: Optional[LULUV2LiveEngine] = None) -> Dict[str, Any]:
585
+ """Return compact live edge-device metrics for the UI cards.
586
+
587
+ Values are safe for JSON/HTML display. NVML is used when available for
588
+ whole-device VRAM/utilization; PyTorch counters are always included.
589
+ """
590
+ snap: Dict[str, Any] = {
591
+ "python_ram": "n/a",
592
+ "system_ram": "n/a",
593
+ "system_ram_percent": 0.0,
594
+ "cpu_percent": 0.0,
595
+ "gpu_name": "CUDA unavailable",
596
+ "vram_allocated": "n/a",
597
+ "vram_reserved": "n/a",
598
+ "vram_used": "n/a",
599
+ "vram_total": "n/a",
600
+ "vram_percent": 0.0,
601
+ "gpu_util_percent": None,
602
+ "gpu_temp_c": None,
603
+ }
604
+
605
+ if psutil is not None:
606
+ try:
607
+ proc = psutil.Process(os.getpid())
608
+ vm = psutil.virtual_memory()
609
+ snap.update({
610
+ "python_ram": human_bytes(proc.memory_info().rss),
611
+ "system_ram": f"{human_bytes(vm.used)} / {human_bytes(vm.total)}",
612
+ "system_ram_percent": float(vm.percent),
613
+ "cpu_percent": float(psutil.cpu_percent(interval=0.0)),
614
+ })
615
+ except Exception:
616
+ pass
617
+
618
+ if torch.cuda.is_available():
619
+ try:
620
+ idx = torch.cuda.current_device()
621
+ props = torch.cuda.get_device_properties(idx)
622
+ allocated = int(torch.cuda.memory_allocated(idx))
623
+ reserved = int(torch.cuda.memory_reserved(idx))
624
+ total = int(props.total_memory)
625
+ snap.update({
626
+ "gpu_name": props.name,
627
+ "vram_allocated": human_bytes(allocated),
628
+ "vram_reserved": human_bytes(reserved),
629
+ "vram_used": human_bytes(allocated),
630
+ "vram_total": human_bytes(total),
631
+ "vram_percent": (100.0 * allocated / max(total, 1)),
632
+ })
633
+
634
+ if pynvml is not None:
635
+ try:
636
+ pynvml.nvmlInit()
637
+ handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
638
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
639
+ mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
640
+ temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
641
+ snap.update({
642
+ "gpu_util_percent": int(util.gpu),
643
+ "vram_used": human_bytes(int(mem.used)),
644
+ "vram_total": human_bytes(int(mem.total)),
645
+ "vram_percent": (100.0 * float(mem.used) / max(float(mem.total), 1.0)),
646
+ "gpu_temp_c": int(temp),
647
+ })
648
+ except Exception:
649
+ pass
650
+ except Exception:
651
+ pass
652
+
653
+ return snap
654
+
655
+
656
+ def system_usage(engine: Optional[LULUV2LiveEngine] = None) -> str:
657
+ lines = [f"OS: {platform.system()} {platform.release()}"]
658
+ if psutil is not None:
659
+ proc = psutil.Process(os.getpid())
660
+ vm = psutil.virtual_memory()
661
+ lines += [
662
+ f"Python RAM: {human_bytes(proc.memory_info().rss)}",
663
+ f"System RAM: {human_bytes(vm.used)} / {human_bytes(vm.total)} ({vm.percent:.1f}%)",
664
+ f"CPU: {psutil.cpu_percent(interval=0.0):.1f}%",
665
+ ]
666
+ else:
667
+ lines.append("psutil unavailable")
668
+
669
+ if torch.cuda.is_available():
670
+ idx = torch.cuda.current_device()
671
+ props = torch.cuda.get_device_properties(idx)
672
+ lines += [
673
+ "",
674
+ f"GPU: {props.name}",
675
+ f"VRAM allocated: {human_bytes(torch.cuda.memory_allocated(idx))}",
676
+ f"VRAM reserved: {human_bytes(torch.cuda.memory_reserved(idx))}",
677
+ f"VRAM total: {human_bytes(props.total_memory)}",
678
+ ]
679
+ if pynvml is not None:
680
+ try:
681
+ pynvml.nvmlInit()
682
+ handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
683
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
684
+ mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
685
+ temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
686
+ lines += [
687
+ f"GPU util: {util.gpu}%",
688
+ f"GPU memory: {human_bytes(mem.used)} / {human_bytes(mem.total)}",
689
+ f"GPU temperature: {temp} C",
690
+ ]
691
+ except Exception as exc:
692
+ lines.append(f"NVML unavailable: {type(exc).__name__}: {exc}")
693
+ else:
694
+ lines += ["", "GPU: CUDA unavailable"]
695
+
696
+ if engine is not None:
697
+ lines += ["", engine.stats_text()]
698
+ return "\n".join(lines)
luluv2_optimized_engine.py ADDED
@@ -0,0 +1,1133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ LULUV2 optimized local inference engine.
5
+
6
+ Goals:
7
+ - load LULU2/LULUV2 checkpoints through the existing LULUV2 model file
8
+ - no AutoModelForCausalLM.from_pretrained and no external model weights
9
+ - vectorized prompt prefill into explicit KV caches
10
+ - persistent session KV cache across turns when prompt tokens extend prior prompt
11
+ - modes: fast(pass1/base), vwm(pass1+pass2), deep(pass1+pass2 long context)
12
+ - safe fallback to slow full-prefix forward if cached path fails
13
+
14
+ This is intentionally Python-first and debuggable. It is a bridge toward
15
+ kernel/CUDA-graph optimization, not the final kernel path.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import importlib.util
20
+ import json
21
+ import math
22
+ import os
23
+ import platform
24
+ import time
25
+ import traceback
26
+ from contextlib import nullcontext
27
+ from dataclasses import dataclass, asdict
28
+ from pathlib import Path
29
+ from types import SimpleNamespace
30
+ from typing import Any, Dict, Generator, List, Optional, Tuple
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+
35
+ try:
36
+ import psutil
37
+ except Exception:
38
+ psutil = None
39
+
40
+ try:
41
+ import pynvml
42
+ except Exception:
43
+ pynvml = None
44
+
45
+ STOP_STRINGS = [
46
+ "<|im_start|>", "<|im_end|>", "<|user|>", "<|system|>", "<|assistant|>",
47
+ "User:", "Assistant:", "\nuser:", "\nassistant:",
48
+ ]
49
+
50
+
51
+ def setup_torch() -> None:
52
+ if torch.cuda.is_available():
53
+ try:
54
+ # Old API still works on current wheels; warnings are harmless.
55
+ torch.backends.cuda.matmul.allow_tf32 = True
56
+ torch.backends.cudnn.allow_tf32 = True
57
+ except Exception:
58
+ pass
59
+ try:
60
+ torch.backends.cuda.enable_flash_sdp(True)
61
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
62
+ torch.backends.cuda.enable_math_sdp(False)
63
+ except Exception:
64
+ pass
65
+ if hasattr(torch, "set_float32_matmul_precision"):
66
+ try:
67
+ torch.set_float32_matmul_precision("high")
68
+ except Exception:
69
+ pass
70
+
71
+
72
+ def human_bytes(num: float) -> str:
73
+ num = float(num)
74
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
75
+ if abs(num) < 1024.0:
76
+ return f"{num:.2f} {unit}"
77
+ num /= 1024.0
78
+ return f"{num:.2f} PB"
79
+
80
+
81
+ def value_to_text(value: Any) -> str:
82
+ if value is None:
83
+ return ""
84
+ if isinstance(value, str):
85
+ return value
86
+ if isinstance(value, dict):
87
+ for key in ("text", "content", "value"):
88
+ if key in value:
89
+ return value_to_text(value.get(key))
90
+ return "\n".join(value_to_text(v) for v in value.values() if value_to_text(v))
91
+ if isinstance(value, (list, tuple)):
92
+ return "\n".join(value_to_text(v) for v in value if value_to_text(v))
93
+ return str(value)
94
+
95
+
96
+ def clean_text(text: Any) -> str:
97
+ text = value_to_text(text).replace("\\n", "\n")
98
+ cut_points = [text.find(s) for s in STOP_STRINGS if s in text and text.find(s) > 0]
99
+ if cut_points:
100
+ text = text[: min(cut_points)]
101
+ for s in STOP_STRINGS:
102
+ text = text.replace(s, "")
103
+ text = text.strip()
104
+ for prefix in ("Assistant:", "assistant:", "Lulu:", "lulu:"):
105
+ if text.startswith(prefix):
106
+ text = text[len(prefix):].strip()
107
+ lines = [ln.rstrip() for ln in text.splitlines()]
108
+ # collapse excessive vertical whitespace without destroying code blocks too much
109
+ out: List[str] = []
110
+ blank = 0
111
+ for ln in lines:
112
+ if not ln.strip():
113
+ blank += 1
114
+ if blank <= 2:
115
+ out.append("")
116
+ else:
117
+ blank = 0
118
+ out.append(ln)
119
+ return "\n".join(out).strip()
120
+
121
+
122
+ def normalize_history(history: Any) -> List[Dict[str, str]]:
123
+ out: List[Dict[str, str]] = []
124
+ if not history:
125
+ return out
126
+ for item in history:
127
+ if isinstance(item, dict):
128
+ role = item.get("role", "")
129
+ content = clean_text(item.get("content", ""))
130
+ if role in {"user", "assistant"} and content:
131
+ out.append({"role": role, "content": content})
132
+ elif isinstance(item, (tuple, list)) and len(item) >= 2:
133
+ u = clean_text(item[0])
134
+ a = clean_text(item[1])
135
+ if u:
136
+ out.append({"role": "user", "content": u})
137
+ if a:
138
+ out.append({"role": "assistant", "content": a})
139
+ return out
140
+
141
+
142
+ def resolve_model_py(model_py: Optional[str]) -> str:
143
+ candidates: List[str] = []
144
+ if model_py:
145
+ candidates.append(model_py)
146
+ candidates.extend(["luluv2_inference_runtime.py"])
147
+ for c in candidates:
148
+ p = Path(c)
149
+ if p.exists():
150
+ return str(p.resolve())
151
+ raise FileNotFoundError("Could not find LULUV2 model file. Pass --model-py.")
152
+
153
+
154
+ def import_model_py(model_py: Optional[str]):
155
+ path = resolve_model_py(model_py)
156
+ spec = importlib.util.spec_from_file_location("luluv2_runtime_module", path)
157
+ if spec is None or spec.loader is None:
158
+ raise RuntimeError(f"Could not import model file: {path}")
159
+ mod = importlib.util.module_from_spec(spec)
160
+ spec.loader.exec_module(mod)
161
+ return mod, path
162
+
163
+
164
+ @dataclass
165
+ class GenerationConfig:
166
+ max_new_tokens: int = 512
167
+ temperature: float = 0.65
168
+ top_k: int = 40
169
+ top_p: float = 0.90
170
+ min_p: float = 0.03
171
+ repetition_penalty: float = 1.10
172
+ frequency_penalty: float = 0.02
173
+ greedy: bool = False
174
+ no_repeat_ngram: int = 4
175
+ stream_every: int = 1
176
+ max_context_tokens: int = 4096
177
+ mode: str = "vwm" # fast, vwm, deep, slow
178
+ return_pass_metrics: bool = True
179
+ use_cache: bool = True
180
+ vectorized_prefill: bool = True
181
+ persistent_cache: bool = True
182
+ compile_step: bool = False
183
+
184
+
185
+ @dataclass
186
+ class GenerationStats:
187
+ prompt_tokens: int = 0
188
+ prompt_total_tokens: int = 0
189
+ prompt_kept_tokens: int = 0
190
+ prompt_dropped_tokens: int = 0
191
+ generated_tokens: int = 0
192
+ elapsed_sec: float = 0.0
193
+ tokens_per_sec: float = 0.0
194
+ prefill_sec: float = 0.0
195
+ prefill_tps: float = 0.0
196
+ cache_hit: bool = False
197
+ cache_reused_tokens: int = 0
198
+ cache_new_prefill_tokens: int = 0
199
+ mode: str = "vwm"
200
+ backend: str = "none"
201
+ last_token: str = ""
202
+ last_token_id: int = -1
203
+ last_token_prob: float = 0.0
204
+ last_entropy: float = 0.0
205
+ finish_reason: str = "none"
206
+ pass1_pass2_kl: Optional[float] = None
207
+ pass1_pass2_logit_cosine: Optional[float] = None
208
+
209
+
210
+ class KVLayerCache:
211
+ def __init__(self):
212
+ self.k: Optional[torch.Tensor] = None # [B, H, T, Dh]
213
+ self.v: Optional[torch.Tensor] = None
214
+
215
+ @property
216
+ def length(self) -> int:
217
+ if self.k is None:
218
+ return 0
219
+ return int(self.k.shape[2])
220
+
221
+ def set(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None:
222
+ if k.shape[2] > max_len:
223
+ k = k[:, :, -max_len:, :]
224
+ v = v[:, :, -max_len:, :]
225
+ self.k = k.detach().contiguous()
226
+ self.v = v.detach().contiguous()
227
+
228
+ def append(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None:
229
+ if self.k is None:
230
+ self.set(k, v, max_len)
231
+ return
232
+ self.k = torch.cat([self.k, k.detach()], dim=2)
233
+ self.v = torch.cat([self.v, v.detach()], dim=2)
234
+ if self.k.shape[2] > max_len:
235
+ self.k = self.k[:, :, -max_len:, :].contiguous()
236
+ self.v = self.v[:, :, -max_len:, :].contiguous()
237
+
238
+
239
+ class DecoderKVCache:
240
+ def __init__(self, n_layers: int):
241
+ self.layers = [KVLayerCache() for _ in range(int(n_layers))]
242
+
243
+ def clear(self):
244
+ for layer in self.layers:
245
+ layer.k = None
246
+ layer.v = None
247
+
248
+ @property
249
+ def length(self) -> int:
250
+ if not self.layers:
251
+ return 0
252
+ return self.layers[0].length
253
+
254
+
255
+ class LULUV2OptimizedEngine:
256
+ def __init__(
257
+ self,
258
+ ckpt_path: str,
259
+ model_py: Optional[str] = None,
260
+ tokenizer_dir: Optional[str] = None,
261
+ device: Optional[str] = None,
262
+ dtype: str = "bf16",
263
+ local_files_only: bool = True,
264
+ no_config_download: bool = True,
265
+ force_base_only: bool = False,
266
+ ):
267
+ setup_torch()
268
+ self.ckpt_path = str(ckpt_path)
269
+ self.ckpt_dir = Path(self.ckpt_path).resolve().parent
270
+ self.device = self._select_device(device)
271
+ self.dtype = self._dtype_from_name(dtype)
272
+ self.local_files_only = bool(local_files_only)
273
+ self.no_config_download = bool(no_config_download)
274
+ self.force_base_only = bool(force_base_only)
275
+ self.last_stats = GenerationStats()
276
+ self.recent_tokens: List[Dict[str, Any]] = []
277
+ self.last_prompt_total_tokens: int = 0
278
+ self.last_prompt_kept_tokens: int = 0
279
+ self.last_prompt_dropped_tokens: int = 0
280
+ self.cache_ids: Optional[torch.Tensor] = None
281
+ self.cache_mode: str = ""
282
+ self.cache_max_context: int = 0
283
+ self.pass1_cache: Optional[DecoderKVCache] = None
284
+ self.pass2_cache: Optional[DecoderKVCache] = None
285
+ self.cached_logits: Optional[torch.Tensor] = None
286
+ self.cached_pass1_logits: Optional[torch.Tensor] = None
287
+ self.cached_pass2_logits: Optional[torch.Tensor] = None
288
+ self.cache_backend: str = "cold"
289
+
290
+ self.goku, self.model_py_path = import_model_py(model_py)
291
+ self.args = SimpleNamespace(
292
+ checkpoint=self.ckpt_path,
293
+ tokenizer=tokenizer_dir or "",
294
+ model_id="",
295
+ no_config_download=self.no_config_download,
296
+ local_files_only=self.local_files_only,
297
+ )
298
+
299
+ print("[guard] LULUV2 cockpit: no AutoModelForCausalLM.from_pretrained call and no external model weights loaded.")
300
+ print(f"[load] checkpoint={self.ckpt_path}")
301
+ self.base_ckpt, base = self.goku.load_lulu2_base(self.args, self.device, self.dtype)
302
+ self.tokenizer = self._load_tokenizer(tokenizer_dir)
303
+ self.model, self.has_pass2 = self._maybe_wrap_pass2(base)
304
+ self.base = self.model.base if self.has_pass2 else self.model
305
+ self.n_layers = int(self.base.config.num_hidden_layers)
306
+ self.model.eval()
307
+ self.base.eval()
308
+ self.model_info = self._build_model_info()
309
+ self._compiled = False
310
+
311
+ def _select_device(self, device: Optional[str]):
312
+ if device:
313
+ return torch.device(device)
314
+ if torch.cuda.is_available():
315
+ return torch.device("cuda")
316
+ return torch.device("cpu")
317
+
318
+ def _dtype_from_name(self, name: str):
319
+ name = (name or "bf16").lower()
320
+ if name in {"bf16", "bfloat16"}:
321
+ return torch.bfloat16
322
+ if name in {"fp16", "float16", "half"}:
323
+ return torch.float16
324
+ return torch.float32
325
+
326
+ def _load_tokenizer(self, tokenizer_dir: Optional[str]):
327
+ if tokenizer_dir:
328
+ self.args.tokenizer = tokenizer_dir
329
+ else:
330
+ sibling = self.ckpt_dir / "tokenizer"
331
+ if sibling.is_dir():
332
+ self.args.tokenizer = str(sibling)
333
+ tok = self.goku.load_tokenizer(self.args, self.base_ckpt)
334
+ if getattr(tok, "pad_token_id", None) is None and getattr(tok, "eos_token_id", None) is not None:
335
+ try:
336
+ tok.pad_token = tok.eos_token
337
+ except Exception:
338
+ pass
339
+ # Long-prompt safety: for chat/RAG prompts, the latest user turn and final
340
+ # instruction are normally at the end. Right-side truncation silently drops
341
+ # exactly the part the model must answer, so force left truncation where the
342
+ # tokenizer supports it. encode() below also performs manual left truncation
343
+ # and records how many tokens were dropped.
344
+ try:
345
+ tok.truncation_side = "left"
346
+ except Exception:
347
+ pass
348
+ try:
349
+ tok.model_max_length = 10**9
350
+ except Exception:
351
+ pass
352
+ return tok
353
+
354
+ def _maybe_wrap_pass2(self, base):
355
+ ckpt = self.base_ckpt
356
+ if self.force_base_only or "pass2_state" not in ckpt:
357
+ print("[pass2] no pass2_state loaded; running base LULUV2 forward")
358
+ return base.to(self.device).eval(), False
359
+ cfg_dict = dict(ckpt.get("pass2_config") or {})
360
+ Pass2Config = self.goku.Pass2Config
361
+ fields = getattr(Pass2Config, "__dataclass_fields__", {})
362
+ pass2_cfg = Pass2Config(**{k: v for k, v in cfg_dict.items() if k in fields})
363
+ model = self.goku.Lulu2TwoPassForCausalLM(base, pass2_cfg)
364
+ missing, unexpected = model.load_state_dict(ckpt["pass2_state"], strict=False)
365
+ print(f"[pass2] loaded pass2_state missing={len(missing)} unexpected={len(unexpected)}")
366
+ model.to(device=self.device, dtype=self.dtype).eval()
367
+ return model, True
368
+
369
+ def _build_model_info(self) -> Dict[str, Any]:
370
+ total_params = sum(p.numel() for p in self.model.parameters())
371
+ c_codes = [(n, p.numel()) for n, p in self.model.named_parameters() if n.endswith(".c")]
372
+ gate_mean = None
373
+ adapter_gate_mean = None
374
+ if self.has_pass2:
375
+ with torch.no_grad():
376
+ gate_mean = float(torch.sigmoid(self.model.layer_gates.float()).mean().item())
377
+ vals = [float(torch.sigmoid(ad.gate.float()).item()) for ad in self.model.adapters]
378
+ adapter_gate_mean = sum(vals) / max(1, len(vals))
379
+ ckpt_size = Path(self.ckpt_path).stat().st_size if Path(self.ckpt_path).exists() else 0
380
+ cfg = getattr(self.base, "config", None)
381
+ return {
382
+ "checkpoint": self.ckpt_path,
383
+ "checkpoint_size": human_bytes(ckpt_size),
384
+ "model_py": self.model_py_path,
385
+ "device": str(self.device),
386
+ "dtype": str(self.dtype).replace("torch.", ""),
387
+ "has_pass2": self.has_pass2,
388
+ "total_params": total_params,
389
+ "vwm_c_modules": len(c_codes),
390
+ "vwm_c_params": sum(n for _, n in c_codes),
391
+ "pass2_layer_gate_mean": gate_mean,
392
+ "pass2_adapter_gate_mean": adapter_gate_mean,
393
+ "hidden_size": getattr(cfg, "hidden_size", None),
394
+ "layers": getattr(cfg, "num_hidden_layers", None),
395
+ "heads": getattr(cfg, "num_attention_heads", None),
396
+ "kv_heads": getattr(cfg, "num_key_value_heads", None),
397
+ "max_position_embeddings": getattr(cfg, "max_position_embeddings", None),
398
+ }
399
+
400
+ def amp_context(self):
401
+ if self.device.type == "cuda" and self.dtype in (torch.bfloat16, torch.float16):
402
+ return torch.autocast("cuda", dtype=self.dtype)
403
+ return nullcontext()
404
+
405
+ def build_chat_prompt(
406
+ self,
407
+ message: str,
408
+ history: Any,
409
+ system_prompt: str,
410
+ memory_notes: str = "",
411
+ history_turns: int = 4,
412
+ extra_context: str = "",
413
+ ) -> str:
414
+ history = normalize_history(history)
415
+ recent = history[-max(0, int(history_turns)) * 2:] if history_turns else []
416
+ system_chunks: List[str] = []
417
+ if system_prompt.strip():
418
+ system_chunks.append(system_prompt.strip())
419
+ if memory_notes.strip():
420
+ system_chunks.append("Useful memory notes:\n" + memory_notes.strip())
421
+ if extra_context.strip():
422
+ system_chunks.append("Relevant local context:\n" + extra_context.strip())
423
+ system = "\n\n".join(system_chunks)
424
+ messages: List[Dict[str, str]] = []
425
+ if system:
426
+ messages.append({"role": "system", "content": system})
427
+ messages.extend(recent)
428
+ messages.append({"role": "user", "content": clean_text(message)})
429
+ try:
430
+ return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
431
+ except Exception:
432
+ parts: List[str] = []
433
+ if system:
434
+ parts.append(f"<|im_start|>system\n{system}<|im_end|>")
435
+ for item in recent:
436
+ parts.append(f"<|im_start|>{item['role']}\n{item['content']}<|im_end|>")
437
+ parts.append(f"<|im_start|>user\n{clean_text(message)}<|im_end|>")
438
+ parts.append("<|im_start|>assistant\n")
439
+ return "\n".join(parts)
440
+
441
+ def encode(self, text: str, max_context_tokens: int) -> torch.Tensor:
442
+ """Encode prompt with explicit left-truncation and accounting.
443
+
444
+ This avoids a common long-context failure mode: many tokenizers default to
445
+ right-side truncation, which keeps the beginning of a huge prompt and drops
446
+ the final user instruction. For chat, we almost always want the opposite.
447
+ """
448
+ max_context = max(1, int(max_context_tokens))
449
+ try:
450
+ self.tokenizer.truncation_side = "left"
451
+ except Exception:
452
+ pass
453
+
454
+ # Tokenize without tokenizer-side truncation so we know exactly whether the
455
+ # prompt was clipped. The prompt already contains chat special tokens.
456
+ try:
457
+ enc = self.tokenizer(
458
+ text,
459
+ return_tensors="pt",
460
+ truncation=False,
461
+ add_special_tokens=False,
462
+ )
463
+ except TypeError:
464
+ enc = self.tokenizer(text, return_tensors="pt", truncation=False)
465
+
466
+ ids = enc.input_ids
467
+ total = int(ids.shape[1])
468
+ dropped = max(0, total - max_context)
469
+ if dropped > 0:
470
+ ids = ids[:, -max_context:].contiguous()
471
+ # Do not reuse an older conversation cache after a hard context trim;
472
+ # the logical prefix changed and reuse can make long prompts feel like
473
+ # they are "forgetting" pieces.
474
+ self.pass1_cache = None
475
+ self.pass2_cache = None
476
+ self.cache_ids = None
477
+ self.cached_logits = None
478
+ self.cached_pass1_logits = None
479
+ self.cached_pass2_logits = None
480
+ self.cache_backend = "truncated-rebuild"
481
+
482
+ self.last_prompt_total_tokens = total
483
+ self.last_prompt_kept_tokens = int(ids.shape[1])
484
+ self.last_prompt_dropped_tokens = dropped
485
+ return ids.to(self.device)
486
+
487
+ def _position_ids(self, T: int, offset: int = 0) -> torch.Tensor:
488
+ return torch.arange(offset, offset + T, device=self.device, dtype=torch.long).unsqueeze(0)
489
+
490
+ def _attn_prefill(self, attn, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor:
491
+ bsz, q_len, _ = hidden_states.size()
492
+ query_states = attn.q_proj(hidden_states)
493
+ key_states = attn.k_proj(hidden_states)
494
+ value_states = attn.v_proj(hidden_states)
495
+ query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2)
496
+ key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
497
+ value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
498
+ cos, sin = attn.rotary_emb(value_states, position_ids)
499
+ query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin)
500
+ key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups)
501
+ value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups)
502
+ cache.set(key_states, value_states, max_context)
503
+ attn_output = F.scaled_dot_product_attention(
504
+ query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True, scale=attn.scaling
505
+ )
506
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size)
507
+ return attn.o_proj(attn_output)
508
+
509
+ def _attn_step(self, attn, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor:
510
+ bsz, q_len, _ = hidden_states.size()
511
+ assert q_len == 1
512
+ query_states = attn.q_proj(hidden_states)
513
+ key_states = attn.k_proj(hidden_states)
514
+ value_states = attn.v_proj(hidden_states)
515
+ query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2)
516
+ key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
517
+ value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
518
+ position_ids = self._position_ids(1, pos)
519
+ cos, sin = attn.rotary_emb(value_states, position_ids)
520
+ query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin)
521
+ key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups)
522
+ value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups)
523
+ cache.append(key_states, value_states, max_context)
524
+ if cache.k is None or cache.v is None:
525
+ raise RuntimeError("KV cache append failed")
526
+ attn_output = F.scaled_dot_product_attention(
527
+ query_states, cache.k, cache.v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scaling
528
+ )
529
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size)
530
+ return attn.o_proj(attn_output)
531
+
532
+ def _layer_prefill(self, layer, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor:
533
+ residual = hidden_states
534
+ x = layer.input_layernorm(hidden_states)
535
+ x = self._attn_prefill(layer.self_attn, x, position_ids, cache, max_context)
536
+ hidden_states = residual + x
537
+ residual = hidden_states
538
+ x = layer.post_attention_layernorm(hidden_states)
539
+ x = layer.mlp(x)
540
+ return residual + x
541
+
542
+ def _layer_step(self, layer, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor:
543
+ residual = hidden_states
544
+ x = layer.input_layernorm(hidden_states)
545
+ x = self._attn_step(layer.self_attn, x, pos, cache, max_context)
546
+ hidden_states = residual + x
547
+ residual = hidden_states
548
+ x = layer.post_attention_layernorm(hidden_states)
549
+ x = layer.mlp(x)
550
+ return residual + x
551
+
552
+ @torch.no_grad()
553
+ def _prefill_pass1(self, input_ids: torch.Tensor, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor, torch.Tensor]:
554
+ T = int(input_ids.shape[1])
555
+ position_ids = self._position_ids(T, 0)
556
+ cache = DecoderKVCache(self.n_layers)
557
+ h = self.base.model.embed_tokens(input_ids)
558
+ if use_pass_embed and self.has_pass2:
559
+ h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1)
560
+ layer_states: List[torch.Tensor] = []
561
+ for i, layer in enumerate(self.base.model.layers):
562
+ h = self._layer_prefill(layer, h, position_ids, cache.layers[i], max_context)
563
+ layer_states.append(h)
564
+ normed = self.base.model.norm(h)
565
+ logits = self.base.lm_head(normed)
566
+ self.pass1_cache = cache
567
+ return h, layer_states, position_ids, logits
568
+
569
+ @torch.no_grad()
570
+ def _prefill_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], position_ids: torch.Tensor, max_context: int) -> torch.Tensor:
571
+ if not self.has_pass2:
572
+ raise RuntimeError("pass2 requested but checkpoint has no pass2_state")
573
+ cache = DecoderKVCache(self.n_layers)
574
+ h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1)
575
+ for i, layer in enumerate(self.base.model.layers):
576
+ before = h2
577
+ layer_out = self._layer_prefill(layer, h2, position_ids, cache.layers[i], max_context)
578
+ layer_delta = layer_out - before
579
+ gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device)
580
+ adapter_delta = self.model.adapters[i](h2, pass1_states[i])
581
+ h2 = before + gate * layer_delta + adapter_delta
582
+ normed = self.base.model.norm(h2)
583
+ logits = self.base.lm_head(normed)
584
+ self.pass2_cache = cache
585
+ return logits
586
+
587
+ @torch.no_grad()
588
+ def _step_pass1(self, token_id: torch.Tensor, pos: int, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
589
+ if self.pass1_cache is None:
590
+ self.pass1_cache = DecoderKVCache(self.n_layers)
591
+ h = self.base.model.embed_tokens(token_id)
592
+ if use_pass_embed and self.has_pass2:
593
+ h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1)
594
+ states: List[torch.Tensor] = []
595
+ for i, layer in enumerate(self.base.model.layers):
596
+ h = self._layer_step(layer, h, pos, self.pass1_cache.layers[i], max_context)
597
+ states.append(h)
598
+ logits = self.base.lm_head(self.base.model.norm(h))
599
+ return h, states, logits
600
+
601
+ @torch.no_grad()
602
+ def _step_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], pos: int, max_context: int) -> torch.Tensor:
603
+ if not self.has_pass2:
604
+ raise RuntimeError("pass2 step requested but unavailable")
605
+ if self.pass2_cache is None:
606
+ self.pass2_cache = DecoderKVCache(self.n_layers)
607
+ h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1)
608
+ for i, layer in enumerate(self.base.model.layers):
609
+ before = h2
610
+ layer_out = self._layer_step(layer, h2, pos, self.pass2_cache.layers[i], max_context)
611
+ layer_delta = layer_out - before
612
+ gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device)
613
+ adapter_delta = self.model.adapters[i](h2, pass1_states[i])
614
+ h2 = before + gate * layer_delta + adapter_delta
615
+ return self.base.lm_head(self.base.model.norm(h2))
616
+
617
+ def _ids_prefix_len(self, old: torch.Tensor, new: torch.Tensor) -> int:
618
+ if old is None or old.numel() == 0 or new.numel() == 0:
619
+ return 0
620
+ old1 = old[0]
621
+ new1 = new[0]
622
+ max_n = min(int(old1.numel()), int(new1.numel()))
623
+ if max_n == 0:
624
+ return 0
625
+ # Fast path: old is exact prefix of new.
626
+ if int(old1.numel()) <= int(new1.numel()) and torch.equal(old1, new1[: old1.numel()]):
627
+ return int(old1.numel())
628
+ # Conservative fallback, scan from max down; prompts are usually exact-prefix or reset.
629
+ for n in range(max_n, 0, -1):
630
+ if torch.equal(old1[:n], new1[:n]):
631
+ return n
632
+ return 0
633
+
634
+
635
+ @torch.no_grad()
636
+ def _token_prefill_context(self, input_ids: torch.Tensor, cfg: GenerationConfig, use_pass2: bool, use_pass_embed: bool, max_context: int) -> None:
637
+ """
638
+ Conservative cache builder.
639
+
640
+ It fills the same pass1/pass2 KV caches by walking the prompt one token at a time.
641
+ This is slower than vectorized prefill but much safer across checkpoint/runtime variants,
642
+ and it still gives a valid decode cache + persistent cache for the generated tokens.
643
+ """
644
+ self.pass1_cache = DecoderKVCache(self.n_layers)
645
+ self.pass2_cache = DecoderKVCache(self.n_layers) if use_pass2 else None
646
+ self.cached_logits = None
647
+ self.cached_pass1_logits = None
648
+ self.cached_pass2_logits = None
649
+
650
+ T = int(input_ids.shape[1])
651
+ for pos in range(T):
652
+ tok = input_ids[:, pos:pos + 1]
653
+ h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed)
654
+ if use_pass2:
655
+ logits2 = self._step_pass2(h1, states, pos, max_context)
656
+ self.cached_logits = logits2
657
+ self.cached_pass1_logits = logits1
658
+ self.cached_pass2_logits = logits2
659
+ else:
660
+ self.cached_logits = logits1
661
+ self.cached_pass1_logits = logits1
662
+ self.cached_pass2_logits = None
663
+
664
+ @torch.no_grad()
665
+ def _prepare_cached_context(self, input_ids: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, bool, int, int, str]:
666
+ mode = self._effective_mode(cfg.mode)
667
+ max_context = int(cfg.max_context_tokens)
668
+ use_pass2 = mode in {"vwm", "deep"} and self.has_pass2
669
+ use_pass_embed = bool(use_pass2)
670
+ T = int(input_ids.shape[1])
671
+ if T > max_context:
672
+ input_ids = input_ids[:, -max_context:]
673
+ T = max_context
674
+
675
+ # If mode/context changed, persistent cache is invalid.
676
+ cache_ok = (
677
+ cfg.persistent_cache
678
+ and self.cache_ids is not None
679
+ and self.cache_mode == mode
680
+ and self.cache_max_context == max_context
681
+ and self.pass1_cache is not None
682
+ )
683
+ prefix = self._ids_prefix_len(self.cache_ids, input_ids) if cache_ok else 0
684
+ cache_hit = bool(cache_ok and prefix == int(self.cache_ids.shape[1]) and prefix <= T and prefix > 0)
685
+
686
+ t0 = time.time()
687
+ if cache_hit:
688
+ # Process only suffix between prior cached prompt and new prompt.
689
+ suffix = input_ids[:, prefix:]
690
+ for j in range(int(suffix.shape[1])):
691
+ tok = suffix[:, j : j + 1]
692
+ pos = prefix + j
693
+ h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed)
694
+ if use_pass2:
695
+ logits2 = self._step_pass2(h1, states, pos, max_context)
696
+ self.cached_logits = logits2
697
+ self.cached_pass1_logits = logits1
698
+ self.cached_pass2_logits = logits2
699
+ else:
700
+ self.cached_logits = logits1
701
+ self.cached_pass1_logits = logits1
702
+ self.cached_pass2_logits = None
703
+ self.cache_ids = input_ids.detach().clone()
704
+ self.cache_backend = "persistent-kv-suffix" if suffix.numel() else "persistent-kv-hit"
705
+ return input_ids, True, prefix, int(suffix.shape[1]), self.cache_backend
706
+
707
+ # Reset and prefill. Prefer vectorized prefill, but fall back to conservative
708
+ # token prefill if the runtime variant does not support our vectorized cache path.
709
+ self.pass1_cache = None
710
+ self.pass2_cache = None
711
+ backend = "vectorized-prefill"
712
+ if bool(cfg.vectorized_prefill):
713
+ try:
714
+ h1, states, pos_ids, logits1 = self._prefill_pass1(input_ids, max_context, use_pass_embed=use_pass_embed)
715
+ if use_pass2:
716
+ logits2 = self._prefill_pass2(h1, states, pos_ids, max_context)
717
+ self.cached_logits = logits2
718
+ self.cached_pass1_logits = logits1
719
+ self.cached_pass2_logits = logits2
720
+ else:
721
+ self.cached_logits = logits1
722
+ self.cached_pass1_logits = logits1
723
+ self.cached_pass2_logits = None
724
+ except Exception as exc:
725
+ if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}:
726
+ print("[cache] vectorized prefill failed; using token-prefill cache.")
727
+ traceback.print_exc()
728
+ self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context)
729
+ backend = "token-prefill-cache"
730
+ else:
731
+ self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context)
732
+ backend = "token-prefill-cache"
733
+
734
+ self.cache_ids = input_ids.detach().clone()
735
+ self.cache_mode = mode
736
+ self.cache_max_context = max_context
737
+ self.cache_backend = backend
738
+ return input_ids, False, 0, T, self.cache_backend
739
+
740
+ def _effective_mode(self, mode: str) -> str:
741
+ mode = (mode or "vwm").lower()
742
+ if mode in {"fast", "base", "pass1"}:
743
+ return "fast"
744
+ if mode in {"deep", "32k", "long"}:
745
+ return "deep"
746
+ if mode in {"slow", "full"}:
747
+ return "slow"
748
+ return "vwm"
749
+
750
+ @torch.no_grad()
751
+ def pass_metrics_from_logits(self, logits1: Optional[torch.Tensor], logits2: Optional[torch.Tensor]) -> Tuple[Optional[float], Optional[float]]:
752
+ if logits1 is None or logits2 is None:
753
+ return None, None
754
+ try:
755
+ l1 = logits1[:, -1, :].float()
756
+ l2 = logits2[:, -1, :].float()
757
+ kl = F.kl_div(F.log_softmax(l2, dim=-1), F.softmax(l1, dim=-1), reduction="batchmean")
758
+ cos = F.cosine_similarity(l1, l2, dim=-1).mean()
759
+ return float(kl.item()), float(cos.item())
760
+ except Exception:
761
+ return None, None
762
+
763
+ def _apply_penalties(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> torch.Tensor:
764
+ if generated.numel() == 0:
765
+ return logits
766
+ out = logits.clone()
767
+ uniq, counts = torch.unique(generated.view(-1), return_counts=True)
768
+ if cfg.repetition_penalty != 1.0:
769
+ selected = out[:, uniq]
770
+ selected = torch.where(selected > 0, selected / float(cfg.repetition_penalty), selected * float(cfg.repetition_penalty))
771
+ out[:, uniq] = selected
772
+ if cfg.frequency_penalty:
773
+ out[:, uniq] -= float(cfg.frequency_penalty) * counts.to(out.dtype).unsqueeze(0)
774
+ n = int(cfg.no_repeat_ngram)
775
+ if n > 1 and generated.size(1) >= n - 1:
776
+ seq = generated[0].tolist()
777
+ prefix = tuple(seq[-(n - 1):])
778
+ banned = []
779
+ for i in range(len(seq) - n + 1):
780
+ if tuple(seq[i:i + n - 1]) == prefix:
781
+ banned.append(seq[i + n - 1])
782
+ if banned:
783
+ out[:, list(set(banned))] = -float("inf")
784
+ return out
785
+
786
+ @torch.no_grad()
787
+ def _sample_next(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, Dict[str, float]]:
788
+ work = self._apply_penalties(logits.float(), generated, cfg)
789
+ if cfg.greedy or cfg.temperature <= 0:
790
+ probs = torch.softmax(work, dim=-1)
791
+ next_id = torch.argmax(work, dim=-1, keepdim=True)
792
+ else:
793
+ work = work / max(float(cfg.temperature), 1e-6)
794
+ if cfg.top_k > 0:
795
+ k = min(int(cfg.top_k), work.size(-1))
796
+ thresh = torch.topk(work, k, dim=-1).values[..., -1, None]
797
+ work = torch.where(work >= thresh, work, torch.full_like(work, -float("inf")))
798
+ if 0.0 < cfg.top_p < 1.0:
799
+ sorted_logits, sorted_idx = torch.sort(work, descending=True, dim=-1)
800
+ sorted_probs = torch.softmax(sorted_logits, dim=-1)
801
+ cumprobs = torch.cumsum(sorted_probs, dim=-1)
802
+ remove = cumprobs > float(cfg.top_p)
803
+ shifted = remove.clone()
804
+ shifted[..., 1:] = remove[..., :-1]
805
+ shifted[..., 0] = False
806
+ sorted_logits = sorted_logits.masked_fill(shifted, -float("inf"))
807
+ work = torch.full_like(work, -float("inf")).scatter(1, sorted_idx, sorted_logits)
808
+ if 0.0 < cfg.min_p < 1.0:
809
+ probs_for_minp = torch.softmax(work, dim=-1)
810
+ max_prob = probs_for_minp.max(dim=-1, keepdim=True).values
811
+ keep = probs_for_minp >= float(cfg.min_p) * max_prob
812
+ work = work.masked_fill(~keep, -float("inf"))
813
+ probs = torch.softmax(work, dim=-1)
814
+ if torch.isnan(probs).any() or not torch.isfinite(probs.sum()) or float(probs.sum()) <= 0:
815
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
816
+ probs = torch.softmax(logits.float(), dim=-1)
817
+ else:
818
+ next_id = torch.multinomial(probs, 1)
819
+ prob = float(probs.gather(1, next_id).item()) if probs.numel() else 0.0
820
+ entropy = float((-(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)).mean().item()) if probs.numel() else 0.0
821
+ return next_id, {"prob": prob, "entropy": entropy}
822
+
823
+ @torch.no_grad()
824
+ def _slow_generate(self, ids: torch.Tensor, prompt_len: int, cfg: GenerationConfig) -> Generator[str, None, None]:
825
+ # Compatibility path: full prefix recompute every token.
826
+ eos_id = getattr(self.tokenizer, "eos_token_id", None)
827
+ last_text = ""
828
+ t0 = time.time()
829
+ for step in range(int(cfg.max_new_tokens)):
830
+ ctx = ids[:, -int(cfg.max_context_tokens):]
831
+ with self.amp_context():
832
+ out = self.model(ctx) if self._effective_mode(cfg.mode) != "fast" else self.base(ctx)
833
+ logits = out.logits[:, -1, :].float()
834
+ generated = ids[:, prompt_len:]
835
+ next_id, tok_stats = self._sample_next(logits, generated, cfg)
836
+ ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
837
+ token_id = int(next_id.item())
838
+ token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
839
+ self._record_token(step + 1, token_id, token_text, tok_stats)
840
+ if eos_id is not None and token_id == int(eos_id):
841
+ break
842
+ if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
843
+ raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
844
+ if any(s in raw for s in STOP_STRINGS):
845
+ break
846
+ text = clean_text(raw)
847
+ if text and text != last_text:
848
+ elapsed = time.time() - t0
849
+ gen = int(ids.shape[1]) - prompt_len
850
+ self.last_stats = GenerationStats(prompt_tokens=prompt_len, prompt_total_tokens=self.last_prompt_total_tokens, prompt_kept_tokens=self.last_prompt_kept_tokens, prompt_dropped_tokens=self.last_prompt_dropped_tokens, generated_tokens=gen, elapsed_sec=elapsed, tokens_per_sec=gen / max(elapsed, 1e-9), mode=cfg.mode, backend="slow-full-prefix", last_token=token_text, last_token_id=token_id, last_token_prob=tok_stats["prob"], last_entropy=tok_stats["entropy"], finish_reason="streaming")
851
+ last_text = text
852
+ yield text
853
+ final = clean_text(self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True))
854
+ if final:
855
+ yield final
856
+
857
+ def _record_token(self, i: int, token_id: int, token_text: str, tok_stats: Dict[str, float]) -> None:
858
+ self.recent_tokens.append({"i": i, "id": token_id, "text": token_text, "prob": tok_stats.get("prob", 0.0), "entropy": tok_stats.get("entropy", 0.0)})
859
+ self.recent_tokens = self.recent_tokens[-64:]
860
+
861
+ @torch.no_grad()
862
+ def generate(self, prompt: str, cfg: GenerationConfig) -> Generator[str, None, None]:
863
+ self.model.eval()
864
+ self.base.eval()
865
+ self.recent_tokens = []
866
+ mode = self._effective_mode(cfg.mode)
867
+ if mode == "deep":
868
+ cfg.max_context_tokens = max(int(cfg.max_context_tokens), 16384)
869
+ ids = self.encode(prompt, max_context_tokens=int(cfg.max_context_tokens))
870
+ prompt_len = int(ids.shape[1])
871
+ if self.last_prompt_dropped_tokens > 0:
872
+ print(f"[context] prompt clipped: kept={self.last_prompt_kept_tokens} total={self.last_prompt_total_tokens} dropped={self.last_prompt_dropped_tokens}")
873
+ t_start = time.time()
874
+ prefill_sec = 0.0
875
+ cache_hit = False
876
+ reused = 0
877
+ new_prefill = prompt_len
878
+ backend = ""
879
+ pass_kl = None
880
+ pass_cos = None
881
+
882
+ if (not cfg.use_cache) or mode == "slow":
883
+ yield from self._slow_generate(ids, prompt_len, cfg)
884
+ return
885
+
886
+ try:
887
+ with self.amp_context():
888
+ t_pref = time.time()
889
+ ids, cache_hit, reused, new_prefill, backend = self._prepare_cached_context(ids, cfg)
890
+ prefill_sec = time.time() - t_pref
891
+ pass_kl, pass_cos = self.pass_metrics_from_logits(self.cached_pass1_logits, self.cached_pass2_logits) if cfg.return_pass_metrics else (None, None)
892
+ except Exception as exc:
893
+ print(f"[cache] cached path failed; falling back to slow full-prefix: {type(exc).__name__}: {exc}")
894
+ if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}:
895
+ traceback.print_exc()
896
+ self.pass1_cache = None
897
+ self.pass2_cache = None
898
+ self.cache_ids = None
899
+ yield from self._slow_generate(ids, prompt_len, cfg)
900
+ return
901
+
902
+ eos_id = getattr(self.tokenizer, "eos_token_id", None)
903
+ last_text = ""
904
+ finish_reason = "length"
905
+ use_pass2 = mode in {"vwm", "deep"} and self.has_pass2
906
+ use_pass_embed = bool(use_pass2)
907
+
908
+ for step in range(int(cfg.max_new_tokens)):
909
+ logits = self.cached_logits[:, -1, :].float() if self.cached_logits is not None and self.cached_logits.dim() == 3 else self.cached_logits.float()
910
+ generated = ids[:, prompt_len:]
911
+ next_id, tok_stats = self._sample_next(logits, generated, cfg)
912
+ token_id = int(next_id.item())
913
+ token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
914
+ self._record_token(step + 1, token_id, token_text, tok_stats)
915
+ ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
916
+
917
+ if eos_id is not None and token_id == int(eos_id):
918
+ finish_reason = "eos"
919
+ break
920
+
921
+ pos = int(ids.shape[1]) - 1
922
+ try:
923
+ with self.amp_context():
924
+ h1, states, logits1 = self._step_pass1(next_id.to(self.device), pos, int(cfg.max_context_tokens), use_pass_embed=use_pass_embed)
925
+ if use_pass2:
926
+ logits2 = self._step_pass2(h1, states, pos, int(cfg.max_context_tokens))
927
+ self.cached_logits = logits2
928
+ self.cached_pass1_logits = logits1
929
+ self.cached_pass2_logits = logits2
930
+ else:
931
+ self.cached_logits = logits1
932
+ self.cached_pass1_logits = logits1
933
+ self.cached_pass2_logits = None
934
+ if self.cache_ids is not None:
935
+ self.cache_ids = torch.cat([self.cache_ids, next_id.detach().to(self.cache_ids.device)], dim=-1)
936
+ if self.cache_ids.shape[1] > int(cfg.max_context_tokens):
937
+ self.cache_ids = self.cache_ids[:, -int(cfg.max_context_tokens):]
938
+ except Exception as exc:
939
+ print(f"[decode-cache] step failed; falling back for this request: {type(exc).__name__}: {exc}")
940
+ # Finish with slow path from current ids; do not pretend cache is valid.
941
+ self.cache_ids = None
942
+ yield from self._slow_generate(ids, prompt_len, cfg)
943
+ return
944
+
945
+ if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
946
+ raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
947
+ if any(s in raw for s in STOP_STRINGS):
948
+ finish_reason = "stop_string"
949
+ break
950
+ text = clean_text(raw)
951
+ if text and text != last_text:
952
+ elapsed = time.time() - t_start
953
+ gen = int(ids.shape[1]) - prompt_len
954
+ self.last_stats = GenerationStats(
955
+ prompt_tokens=prompt_len,
956
+ prompt_total_tokens=self.last_prompt_total_tokens,
957
+ prompt_kept_tokens=self.last_prompt_kept_tokens,
958
+ prompt_dropped_tokens=self.last_prompt_dropped_tokens,
959
+ generated_tokens=gen,
960
+ elapsed_sec=elapsed,
961
+ tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9),
962
+ prefill_sec=prefill_sec,
963
+ prefill_tps=(new_prefill / max(prefill_sec, 1e-9)),
964
+ cache_hit=cache_hit,
965
+ cache_reused_tokens=reused,
966
+ cache_new_prefill_tokens=new_prefill,
967
+ mode=mode,
968
+ backend=backend,
969
+ last_token=token_text,
970
+ last_token_id=token_id,
971
+ last_token_prob=tok_stats["prob"],
972
+ last_entropy=tok_stats["entropy"],
973
+ finish_reason="streaming",
974
+ pass1_pass2_kl=pass_kl,
975
+ pass1_pass2_logit_cosine=pass_cos,
976
+ )
977
+ last_text = text
978
+ yield text
979
+
980
+ raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
981
+ final = clean_text(raw)
982
+ elapsed = time.time() - t_start
983
+ gen = int(ids.shape[1]) - prompt_len
984
+ self.last_stats = GenerationStats(
985
+ prompt_tokens=prompt_len,
986
+ prompt_total_tokens=self.last_prompt_total_tokens,
987
+ prompt_kept_tokens=self.last_prompt_kept_tokens,
988
+ prompt_dropped_tokens=self.last_prompt_dropped_tokens,
989
+ generated_tokens=gen,
990
+ elapsed_sec=elapsed,
991
+ tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9),
992
+ prefill_sec=prefill_sec,
993
+ prefill_tps=(new_prefill / max(prefill_sec, 1e-9)),
994
+ cache_hit=cache_hit,
995
+ cache_reused_tokens=reused,
996
+ cache_new_prefill_tokens=new_prefill,
997
+ mode=mode,
998
+ backend=backend,
999
+ last_token=self.recent_tokens[-1]["text"] if self.recent_tokens else "",
1000
+ last_token_id=self.recent_tokens[-1]["id"] if self.recent_tokens else -1,
1001
+ last_token_prob=self.recent_tokens[-1]["prob"] if self.recent_tokens else 0.0,
1002
+ last_entropy=self.recent_tokens[-1]["entropy"] if self.recent_tokens else 0.0,
1003
+ finish_reason=finish_reason,
1004
+ pass1_pass2_kl=pass_kl,
1005
+ pass1_pass2_logit_cosine=pass_cos,
1006
+ )
1007
+ if final:
1008
+ yield final
1009
+
1010
+ def clear_session_cache(self) -> None:
1011
+ self.pass1_cache = None
1012
+ self.pass2_cache = None
1013
+ self.cache_ids = None
1014
+ self.cached_logits = None
1015
+ self.cached_pass1_logits = None
1016
+ self.cached_pass2_logits = None
1017
+ self.cache_backend = "cleared"
1018
+
1019
+ def stats_dict(self) -> Dict[str, Any]:
1020
+ return {"generation": asdict(self.last_stats), "model": self.model_info, "system": system_snapshot(self)}
1021
+
1022
+ def stats_text(self) -> str:
1023
+ s = self.last_stats
1024
+ lines = [
1025
+ f"Mode: {s.mode} | backend={s.backend}",
1026
+ f"Prompt tokens: {s.prompt_tokens} kept / {getattr(s, 'prompt_total_tokens', s.prompt_tokens)} total / {getattr(s, 'prompt_dropped_tokens', 0)} dropped",
1027
+ f"Generated tokens: {s.generated_tokens}",
1028
+ f"Elapsed: {s.elapsed_sec:.2f}s | prefill={s.prefill_sec:.2f}s ({s.prefill_tps:.1f} tok/s)",
1029
+ f"Decode speed: {s.tokens_per_sec:.2f} tok/s",
1030
+ f"Cache: hit={s.cache_hit} reused={s.cache_reused_tokens} new_prefill={s.cache_new_prefill_tokens}",
1031
+ f"Finish reason: {s.finish_reason}",
1032
+ f"Last token: {s.last_token!r} id={s.last_token_id} p={s.last_token_prob:.4f} H={s.last_entropy:.2f}",
1033
+ ]
1034
+ if s.pass1_pass2_kl is not None:
1035
+ lines.append(f"Pass1→Pass2 KL: {s.pass1_pass2_kl:.6f}")
1036
+ if s.pass1_pass2_logit_cosine is not None:
1037
+ lines.append(f"Pass1/Pass2 cosine: {s.pass1_pass2_logit_cosine:.6f}")
1038
+ lines.extend([
1039
+ "",
1040
+ f"Checkpoint: {self.model_info['checkpoint']}",
1041
+ f"Checkpoint size: {self.model_info['checkpoint_size']}",
1042
+ f"Device: {self.model_info['device']} dtype={self.model_info['dtype']}",
1043
+ f"Pass2 active: {self.model_info['has_pass2']}",
1044
+ f"Params: {self.model_info['total_params']:,}",
1045
+ f"VWM c modules: {self.model_info['vwm_c_modules']} ({self.model_info['vwm_c_params']:,} c params)",
1046
+ ])
1047
+ return "\n".join(lines)
1048
+
1049
+ def token_trace_text(self) -> str:
1050
+ if not self.recent_tokens:
1051
+ return "No tokens generated yet."
1052
+ rows = []
1053
+ for t in self.recent_tokens[-48:]:
1054
+ safe = repr(t["text"])[1:-1]
1055
+ rows.append(f"{t['i']:04d} id={t['id']:<7} p={t['prob']:.4f} H={t['entropy']:.2f} {safe}")
1056
+ return "\n".join(rows)
1057
+
1058
+
1059
+ def system_snapshot(engine: Optional[LULUV2OptimizedEngine] = None) -> Dict[str, Any]:
1060
+ snap: Dict[str, Any] = {
1061
+ "python_ram": "n/a", "system_ram": "n/a", "system_ram_percent": 0.0,
1062
+ "cpu_percent": 0.0, "gpu_name": "CUDA unavailable", "vram_allocated": "n/a",
1063
+ "vram_reserved": "n/a", "vram_used": "n/a", "vram_total": "n/a",
1064
+ "vram_percent": 0.0, "gpu_util_percent": None, "gpu_temp_c": None,
1065
+ }
1066
+ if psutil is not None:
1067
+ try:
1068
+ proc = psutil.Process(os.getpid())
1069
+ vm = psutil.virtual_memory()
1070
+ snap.update({
1071
+ "python_ram": human_bytes(proc.memory_info().rss),
1072
+ "system_ram": f"{human_bytes(vm.used)} / {human_bytes(vm.total)}",
1073
+ "system_ram_percent": float(vm.percent),
1074
+ "cpu_percent": float(psutil.cpu_percent(interval=0.0)),
1075
+ })
1076
+ except Exception:
1077
+ pass
1078
+ if torch.cuda.is_available():
1079
+ try:
1080
+ idx = torch.cuda.current_device()
1081
+ props = torch.cuda.get_device_properties(idx)
1082
+ allocated = int(torch.cuda.memory_allocated(idx))
1083
+ reserved = int(torch.cuda.memory_reserved(idx))
1084
+ total = int(props.total_memory)
1085
+ snap.update({
1086
+ "gpu_name": props.name,
1087
+ "vram_allocated": human_bytes(allocated),
1088
+ "vram_reserved": human_bytes(reserved),
1089
+ "vram_used": human_bytes(allocated),
1090
+ "vram_total": human_bytes(total),
1091
+ "vram_percent": 100.0 * allocated / max(total, 1),
1092
+ })
1093
+ if pynvml is not None:
1094
+ try:
1095
+ pynvml.nvmlInit()
1096
+ handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
1097
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
1098
+ mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
1099
+ temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
1100
+ snap.update({
1101
+ "gpu_util_percent": int(util.gpu),
1102
+ "vram_used": human_bytes(int(mem.used)),
1103
+ "vram_total": human_bytes(int(mem.total)),
1104
+ "vram_percent": 100.0 * float(mem.used) / max(float(mem.total), 1.0),
1105
+ "gpu_temp_c": int(temp),
1106
+ })
1107
+ except Exception:
1108
+ pass
1109
+ except Exception:
1110
+ pass
1111
+ return snap
1112
+
1113
+
1114
+ def system_usage(engine: Optional[LULUV2OptimizedEngine] = None) -> str:
1115
+ snap = system_snapshot(engine)
1116
+ lines = [
1117
+ f"OS: {platform.system()} {platform.release()}",
1118
+ f"Python RAM: {snap['python_ram']}",
1119
+ f"System RAM: {snap['system_ram']} ({snap['system_ram_percent']:.1f}%)",
1120
+ f"CPU: {snap['cpu_percent']:.1f}%",
1121
+ "",
1122
+ f"GPU: {snap['gpu_name']}",
1123
+ f"VRAM used: {snap['vram_used']} / {snap['vram_total']} ({snap['vram_percent']:.1f}%)",
1124
+ f"VRAM allocated: {snap['vram_allocated']}",
1125
+ f"VRAM reserved: {snap['vram_reserved']}",
1126
+ ]
1127
+ if snap.get("gpu_util_percent") is not None:
1128
+ lines.append(f"GPU util: {snap['gpu_util_percent']}%")
1129
+ if snap.get("gpu_temp_c") is not None:
1130
+ lines.append(f"GPU temp: {snap['gpu_temp_c']} C")
1131
+ if engine is not None:
1132
+ lines.extend(["", engine.stats_text()])
1133
+ return "\n".join(lines)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.1
2
+ tokenizers>=0.15
3
+ transformers>=4.40
4
+ gradio>=4.0
5
+
6
+ psutil>=5.9
7
+ nvidia-ml-py>=12.0; platform_system != "Darwin"
run_chat.ps1 ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ $ErrorActionPreference = "Stop"
2
+ python .\app.py --ckpt .\LULUV2-bf16.pt --model-py .\luluv2_inference_runtime.py --tokenizer-dir .\tokenizer --inbrowser
run_chat.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ python ./app.py --ckpt ./LULUV2-bf16.pt --model-py ./luluv2_inference_runtime.py --tokenizer-dir ./tokenizer --inbrowser
run_inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Small CLI launcher for LULUV2 native-bf16 local inference."""
3
+ from __future__ import annotations
4
+ import argparse
5
+ import torch
6
+ from luluv2_live_inference import LULUV2LiveEngine, GenerationConfig
7
+
8
+
9
+ def main():
10
+ p = argparse.ArgumentParser("LULUV2 local inference")
11
+ p.add_argument("--ckpt", default="LULUV2-bf16.pt", help="Path to the native-bf16 checkpoint file")
12
+ p.add_argument("--tokenizer-dir", default="tokenizer", help="Local tokenizer directory")
13
+ p.add_argument("--prompt", required=True, help="User prompt")
14
+ p.add_argument("--system", default="You are LuluV2, a helpful local AI assistant.")
15
+ p.add_argument("--max-new-tokens", type=int, default=512)
16
+ p.add_argument("--temperature", type=float, default=0.65)
17
+ p.add_argument("--top-p", type=float, default=0.90)
18
+ p.add_argument("--top-k", type=int, default=40)
19
+ p.add_argument("--device", default=None)
20
+ p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
21
+ args = p.parse_args()
22
+
23
+ engine = LULUV2LiveEngine(
24
+ ckpt_path=args.ckpt,
25
+ model_py="luluv2_inference_runtime.py",
26
+ tokenizer_dir=args.tokenizer_dir,
27
+ device=args.device,
28
+ dtype=args.dtype,
29
+ local_files_only=True,
30
+ no_config_download=True,
31
+ )
32
+ cfg = GenerationConfig(
33
+ max_new_tokens=args.max_new_tokens,
34
+ temperature=args.temperature,
35
+ top_p=args.top_p,
36
+ top_k=args.top_k,
37
+ )
38
+ history = []
39
+ for text in engine.generate_stream(args.prompt, history, args.system, cfg):
40
+ print(text, end="", flush=True)
41
+ print()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ torch.set_grad_enabled(False)
46
+ main()