Upload 13 files
Browse files- .gitattributes +1 -32
- .gitignore +11 -0
- README.md +64 -1
- app.py +523 -0
- config.json +10 -0
- generation_config.json +9 -0
- luluv2_inference_runtime.py +842 -0
- luluv2_live_inference.py +698 -0
- luluv2_optimized_engine.py +1133 -0
- requirements.txt +7 -0
- run_chat.ps1 +2 -0
- run_chat.sh +3 -0
- run_inference.py +46 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,4 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Local/private artifacts
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.env
|
| 5 |
+
*.log
|
| 6 |
+
lulu_chats/
|
| 7 |
+
luluv2_chats/
|
| 8 |
+
*lulu_memory*.json
|
| 9 |
+
private_artifacts/
|
| 10 |
+
checkpoints/
|
| 11 |
+
runs/
|
README.md
CHANGED
|
@@ -1,3 +1,66 @@
|
|
| 1 |
---
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
library_name: pytorch
|
| 5 |
+
pipeline_tag: text-generation
|
| 6 |
+
tags:
|
| 7 |
+
- text-generation
|
| 8 |
+
- bfloat16
|
| 9 |
+
- inference-only
|
| 10 |
+
- local-inference
|
| 11 |
---
|
| 12 |
+
|
| 13 |
+
# LULUV2 native-bf16 local inference package
|
| 14 |
+
|
| 15 |
+
This repository is prepared as an inference-only package for a native-bf16 LULUV2 checkpoint.
|
| 16 |
+
It is designed so users can run the model directly in native bfloat16 without an extra conversion step.
|
| 17 |
+
|
| 18 |
+
## What is included
|
| 19 |
+
|
| 20 |
+
- `luluv2_inference_runtime.py` — stripped runtime loader and model architecture needed for inference only.
|
| 21 |
+
- `luluv2_live_inference.py` — streaming inference engine.
|
| 22 |
+
- `luluv2_optimized_engine.py` — optimized local inference engine with cache paths.
|
| 23 |
+
- `app.py` — local Gradio chat UI.
|
| 24 |
+
- `run_inference.py` — minimal command-line runner.
|
| 25 |
+
- `tokenizer/` — local tokenizer files and chat template.
|
| 26 |
+
|
| 27 |
+
## What is not included
|
| 28 |
+
|
| 29 |
+
Private development tooling, data-preparation scripts, connector code, local chat logs, memory files, workspace artifacts, API keys, and secret tokens are not included.
|
| 30 |
+
|
| 31 |
+
## Weights
|
| 32 |
+
|
| 33 |
+
Place the native-bf16 checkpoint in the repository root as:
|
| 34 |
+
|
| 35 |
+
```text
|
| 36 |
+
LULUV2-bf16.pt
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
The uploaded cleanup source did not include weights, so this package does not contain a `.pt` or `.safetensors` model file yet.
|
| 40 |
+
If you publish weights on Hugging Face, keep them in native bfloat16. This package includes `.gitattributes` patterns for large weight files.
|
| 41 |
+
|
| 42 |
+
## Install
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install -r requirements.txt
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Run the local UI
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
python app.py --ckpt ./LULUV2-bf16.pt --model-py ./luluv2_inference_runtime.py --tokenizer-dir ./tokenizer --inbrowser
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Run from CLI
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
python run_inference.py --ckpt ./LULUV2-bf16.pt --prompt "Write a short introduction to LuluV2."
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Native bf16 note
|
| 61 |
+
|
| 62 |
+
This package is intended for native bfloat16 inference. Users should be able to run the native-bf16 package directly. Hardware without bfloat16 support may require `--dtype fp16` or `--dtype fp32`, depending on their PyTorch/device setup.
|
| 63 |
+
|
| 64 |
+
## Safety and disclosure checklist before upload
|
| 65 |
+
|
| 66 |
+
Before making the Hugging Face repository public, confirm that your base-model license permits redistribution of the final weights and that any legally required notices are present in the model card or repository files.
|
app.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
LULUV2 Pro local chat UI.
|
| 5 |
+
|
| 6 |
+
A clean ChatGPT-style desktop UI for the fine-tuned LULUV2 checkpoint.
|
| 7 |
+
It keeps the important local features only:
|
| 8 |
+
- chat inference
|
| 9 |
+
- live token streaming
|
| 10 |
+
- new chat / save / load chats
|
| 11 |
+
- persistent memory notes
|
| 12 |
+
- live edge monitor: tok/s, RAM, VRAM, GPU, pass2 metrics
|
| 13 |
+
- 32K context controls and test prompt helper
|
| 14 |
+
|
| 15 |
+
Run:
|
| 16 |
+
python ./app.py --ckpt ./LULUV2-bf16.pt --model-py ./luluv2_inference_runtime.py --inbrowser
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, List, Tuple
|
| 28 |
+
|
| 29 |
+
import gradio as gr
|
| 30 |
+
|
| 31 |
+
from luluv2_live_inference import (
|
| 32 |
+
GenerationConfig,
|
| 33 |
+
LULUV2LiveEngine,
|
| 34 |
+
clean_text,
|
| 35 |
+
normalize_history,
|
| 36 |
+
system_usage,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
APP_NAME = "LuluV2"
|
| 40 |
+
CHAT_DIR = Path(os.getenv("LULU_CHAT_DIR", "lulu_chats"))
|
| 41 |
+
MEMORY_FILE = Path(os.getenv("LULU_MEMORY_FILE", "lulu_memory.json"))
|
| 42 |
+
|
| 43 |
+
DEFAULT_SYSTEM_PROMPT = """Your name is LuluV2.
|
| 44 |
+
You are a local AI assistant made by Open Machine.
|
| 45 |
+
You run offline from the LULUV2 VWM checkpoint.
|
| 46 |
+
Answer directly and naturally.
|
| 47 |
+
Use Markdown for structure.
|
| 48 |
+
When writing code, use fenced code blocks with the correct language tag.
|
| 49 |
+
Do not output role tags, hidden scratchpad text, JSON UI fragments, or {'type':'text'} blocks.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
PRESETS = {
|
| 53 |
+
"Balanced": dict(temperature=0.65, top_k=40, top_p=0.90, min_p=0.03, repetition_penalty=1.10, frequency_penalty=0.02, max_new_tokens=768),
|
| 54 |
+
"Precise": dict(temperature=0.35, top_k=30, top_p=0.84, min_p=0.04, repetition_penalty=1.14, frequency_penalty=0.03, max_new_tokens=512),
|
| 55 |
+
"Code": dict(temperature=0.42, top_k=40, top_p=0.88, min_p=0.03, repetition_penalty=1.10, frequency_penalty=0.02, max_new_tokens=1200),
|
| 56 |
+
"Long 32K": dict(temperature=0.55, top_k=50, top_p=0.90, min_p=0.025, repetition_penalty=1.08, frequency_penalty=0.02, max_new_tokens=1200),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def safe_int(value: Any, default: int, low: int | None = None, high: int | None = None) -> int:
|
| 61 |
+
try:
|
| 62 |
+
value = int(value)
|
| 63 |
+
except Exception:
|
| 64 |
+
value = default
|
| 65 |
+
if low is not None:
|
| 66 |
+
value = max(low, value)
|
| 67 |
+
if high is not None:
|
| 68 |
+
value = min(high, value)
|
| 69 |
+
return value
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def clamp(value: Any, low: float, high: float, default: float) -> float:
|
| 73 |
+
try:
|
| 74 |
+
value = float(value)
|
| 75 |
+
except Exception:
|
| 76 |
+
return default
|
| 77 |
+
return max(low, min(high, value))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def esc(text: Any) -> str:
|
| 81 |
+
return str(text).replace("&", "&").replace("<", "<").replace(">", ">").replace('"', """)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def status_html(title: str, detail: str = "", tone: str = "neutral") -> str:
|
| 85 |
+
tone = tone if tone in {"neutral", "good", "warn", "bad", "live"} else "neutral"
|
| 86 |
+
return f"""
|
| 87 |
+
<div class="status-pill status-{tone}">
|
| 88 |
+
<span class="pulse-dot"></span>
|
| 89 |
+
<div><b>{esc(title)}</b><small>{esc(detail)}</small></div>
|
| 90 |
+
</div>
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def read_memory() -> str:
|
| 95 |
+
if not MEMORY_FILE.exists():
|
| 96 |
+
return ""
|
| 97 |
+
try:
|
| 98 |
+
return str(json.loads(MEMORY_FILE.read_text(encoding="utf-8")).get("memory_notes", ""))
|
| 99 |
+
except Exception:
|
| 100 |
+
return ""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def write_memory(memory_notes: str) -> Tuple[str, str]:
|
| 104 |
+
MEMORY_FILE.write_text(
|
| 105 |
+
json.dumps(
|
| 106 |
+
{"memory_notes": memory_notes or "", "saved_at": datetime.now().isoformat(timespec="seconds"), "app": APP_NAME},
|
| 107 |
+
indent=2,
|
| 108 |
+
ensure_ascii=False,
|
| 109 |
+
),
|
| 110 |
+
encoding="utf-8",
|
| 111 |
+
)
|
| 112 |
+
return str(MEMORY_FILE), status_html("Memory saved", str(MEMORY_FILE), "good")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def safe_chat_filename(chat_name: str, suffix: str) -> Path:
|
| 116 |
+
CHAT_DIR.mkdir(parents=True, exist_ok=True)
|
| 117 |
+
base = re.sub(r"[^a-zA-Z0-9_-]+", "_", chat_name or "chat").strip("_") or "chat"
|
| 118 |
+
return CHAT_DIR / f"{base}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{suffix}"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def list_saved_chats() -> List[str]:
|
| 122 |
+
CHAT_DIR.mkdir(parents=True, exist_ok=True)
|
| 123 |
+
return [str(p) for p in sorted(CHAT_DIR.glob("*.json"), key=lambda x: x.stat().st_mtime, reverse=True)]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def save_chat(history: Any, chat_name: str, memory_notes: str) -> Tuple[str, str, List[str]]:
|
| 127 |
+
path = safe_chat_filename(chat_name or "Lulu chat", "json")
|
| 128 |
+
data = {
|
| 129 |
+
"chat_name": chat_name or "Lulu chat",
|
| 130 |
+
"history": normalize_history(history),
|
| 131 |
+
"memory_notes": memory_notes or "",
|
| 132 |
+
"saved_at": datetime.now().isoformat(timespec="seconds"),
|
| 133 |
+
"app": APP_NAME,
|
| 134 |
+
}
|
| 135 |
+
path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 136 |
+
return str(path), status_html("Chat saved", path.name, "good"), list_saved_chats()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def load_chat(path: str) -> Tuple[List[Dict[str, str]], str, str, str]:
|
| 140 |
+
if not path:
|
| 141 |
+
return [], "New chat", read_memory(), status_html("No saved chat selected", "Pick a JSON file from the sidebar.", "warn")
|
| 142 |
+
try:
|
| 143 |
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 144 |
+
except Exception as exc:
|
| 145 |
+
return [], "New chat", read_memory(), status_html("Load failed", f"{type(exc).__name__}: {exc}", "bad")
|
| 146 |
+
return (
|
| 147 |
+
normalize_history(data.get("history", [])),
|
| 148 |
+
str(data.get("chat_name") or Path(path).stem),
|
| 149 |
+
str(data.get("memory_notes", read_memory())),
|
| 150 |
+
status_html("Chat loaded", Path(path).name, "good"),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def chat_to_markdown(history: Any, chat_name: str) -> str:
|
| 155 |
+
lines = [f"# {clean_text(chat_name) or 'LuluV2 chat'}", ""]
|
| 156 |
+
for item in normalize_history(history):
|
| 157 |
+
lines.append("## You" if item["role"] == "user" else "## LuluV2")
|
| 158 |
+
lines.append(item["content"])
|
| 159 |
+
lines.append("")
|
| 160 |
+
return "\n".join(lines).strip() + "\n"
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def export_markdown(history: Any, chat_name: str) -> Tuple[str, str]:
|
| 164 |
+
path = safe_chat_filename(chat_name or "Lulu chat", "md")
|
| 165 |
+
path.write_text(chat_to_markdown(history, chat_name), encoding="utf-8")
|
| 166 |
+
return str(path), status_html("Markdown exported", path.name, "good")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def postprocess_answer(text: Any, final: bool = False) -> str:
|
| 170 |
+
text = clean_text(text)
|
| 171 |
+
# Remove common generated UI artifacts from older chat data.
|
| 172 |
+
text = re.sub(r"\n?\s*\[\s*\{\s*['\"]text['\"].*?['\"]type['\"]\s*:\s*['\"]text['\"]\s*\}\s*\]\s*$", "", text, flags=re.S)
|
| 173 |
+
text = re.sub(r"\n?\s*type\s*:\s*['\"]text['\"]\s*$", "", text, flags=re.I)
|
| 174 |
+
text = re.sub(r"\n{4,}", "\n\n\n", text)
|
| 175 |
+
if final and text.count("```") % 2 == 1:
|
| 176 |
+
text += "\n```"
|
| 177 |
+
return text.strip()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def metric_cards(engine: LULUV2LiveEngine, max_context: int) -> str:
|
| 181 |
+
stats = engine.stats_dict()
|
| 182 |
+
sys = stats.get("system", {})
|
| 183 |
+
model = stats.get("model", {})
|
| 184 |
+
pass_kl = stats.get("pass1_pass2_kl")
|
| 185 |
+
pass_cos = stats.get("pass1_pass2_logit_cosine")
|
| 186 |
+
pass_text = "base"
|
| 187 |
+
if pass_kl is not None and pass_cos is not None:
|
| 188 |
+
pass_text = f"KL {pass_kl:.3f} / cos {pass_cos:.3f}"
|
| 189 |
+
gpu_util = sys.get("gpu_util_percent")
|
| 190 |
+
gpu_temp = sys.get("gpu_temp_c")
|
| 191 |
+
gpu_text = "n/a" if gpu_util is None else f"{gpu_util}%"
|
| 192 |
+
temp_text = "n/a" if gpu_temp is None else f"{gpu_temp}°C"
|
| 193 |
+
return f"""
|
| 194 |
+
<div class="monitor-bar">
|
| 195 |
+
<div class="mon-card hot"><b>{float(stats.get('tokens_per_sec', 0.0)):.1f}</b><span>tok/s</span></div>
|
| 196 |
+
<div class="mon-card"><b>{int(stats.get('generated_tokens', 0))}</b><span>tokens</span></div>
|
| 197 |
+
<div class="mon-card"><b>{sys.get('python_ram', 'n/a')}</b><span>Python RAM</span></div>
|
| 198 |
+
<div class="mon-card"><b>{sys.get('vram_used', 'n/a')}</b><span>VRAM / {sys.get('vram_total', 'n/a')}</span></div>
|
| 199 |
+
<div class="mon-card"><b>{gpu_text}</b><span>GPU · {temp_text}</span></div>
|
| 200 |
+
<div class="mon-card"><b>{max_context//1024}K</b><span>context</span></div>
|
| 201 |
+
<div class="mon-card"><b>{model.get('has_pass2')}</b><span>pass2</span></div>
|
| 202 |
+
<div class="mon-card wide"><b>{pass_text}</b><span>pass1 → pass2</span></div>
|
| 203 |
+
</div>
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def make_32k_prompt() -> str:
|
| 208 |
+
seed = (
|
| 209 |
+
"We are testing a 32K context window for LuluV2. "
|
| 210 |
+
"Remember these constraints: answer directly, keep code formatted, and summarize the relevant details. "
|
| 211 |
+
"The repeated context below is synthetic filler for a long-context stress test.\n\n"
|
| 212 |
+
)
|
| 213 |
+
block = (
|
| 214 |
+
"Section: VWM reconstruction. A model can use A/B atoms and c-code recipes to reconstruct behavior online. "
|
| 215 |
+
"Pass 1 builds a scaffold, pass 2 refines it, and the UI should keep live tokens/sec, RAM, VRAM, and pass metrics visible. "
|
| 216 |
+
"When asked at the end, explain the three key ideas and provide a tiny Python example.\n"
|
| 217 |
+
)
|
| 218 |
+
# Character length is approximate; token count depends on tokenizer. This usually lands around a long 20K-32K style prompt.
|
| 219 |
+
return seed + (block * 520) + "\nFinal question: What are the three key ideas above, and can you show a tiny Python class for tracking tokens per second?"
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def create_chatbot():
|
| 223 |
+
kwargs = dict(
|
| 224 |
+
value=[],
|
| 225 |
+
elem_id="chatbot",
|
| 226 |
+
height=760,
|
| 227 |
+
show_label=False,
|
| 228 |
+
avatar_images=(None, None),
|
| 229 |
+
bubble_full_width=False,
|
| 230 |
+
)
|
| 231 |
+
try:
|
| 232 |
+
return gr.Chatbot(type="messages", render_markdown=True, sanitize_html=True, **kwargs)
|
| 233 |
+
except TypeError:
|
| 234 |
+
try:
|
| 235 |
+
return gr.Chatbot(render_markdown=True, sanitize_html=True, **kwargs)
|
| 236 |
+
except TypeError:
|
| 237 |
+
return gr.Chatbot(**kwargs)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def build_app(engine: LULUV2LiveEngine, default_context: int):
|
| 241 |
+
def respond(
|
| 242 |
+
message,
|
| 243 |
+
history,
|
| 244 |
+
chat_name,
|
| 245 |
+
system_prompt,
|
| 246 |
+
memory_notes,
|
| 247 |
+
preset,
|
| 248 |
+
history_turns,
|
| 249 |
+
max_context_tokens,
|
| 250 |
+
max_new_tokens,
|
| 251 |
+
temperature,
|
| 252 |
+
top_k,
|
| 253 |
+
top_p,
|
| 254 |
+
min_p,
|
| 255 |
+
repetition_penalty,
|
| 256 |
+
frequency_penalty,
|
| 257 |
+
greedy,
|
| 258 |
+
no_repeat_ngram,
|
| 259 |
+
stream_every,
|
| 260 |
+
show_pass_metrics,
|
| 261 |
+
):
|
| 262 |
+
hist = normalize_history(history)
|
| 263 |
+
msg = clean_text(message)
|
| 264 |
+
max_context_tokens = safe_int(max_context_tokens, default_context, 128, 32768)
|
| 265 |
+
if not msg:
|
| 266 |
+
yield "", hist, status_html("Empty message", "Type something first.", "warn"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
# Preset only affects initial slider defaults; live slider values are honored.
|
| 270 |
+
prompt = engine.build_chat_prompt(
|
| 271 |
+
message=msg,
|
| 272 |
+
history=hist,
|
| 273 |
+
system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT,
|
| 274 |
+
memory_notes=memory_notes or "",
|
| 275 |
+
history_turns=safe_int(history_turns, 4, 0, 32),
|
| 276 |
+
)
|
| 277 |
+
cfg = GenerationConfig(
|
| 278 |
+
max_new_tokens=safe_int(max_new_tokens, 768, 1, 8192),
|
| 279 |
+
temperature=clamp(temperature, 0.0, 2.0, 0.65),
|
| 280 |
+
top_k=safe_int(top_k, 40, 0, 500),
|
| 281 |
+
top_p=clamp(top_p, 0.01, 1.0, 0.90),
|
| 282 |
+
min_p=clamp(min_p, 0.0, 0.5, 0.03),
|
| 283 |
+
repetition_penalty=clamp(repetition_penalty, 1.0, 3.0, 1.10),
|
| 284 |
+
frequency_penalty=clamp(frequency_penalty, 0.0, 3.0, 0.02),
|
| 285 |
+
greedy=bool(greedy),
|
| 286 |
+
no_repeat_ngram=safe_int(no_repeat_ngram, 4, 0, 16),
|
| 287 |
+
stream_every=safe_int(stream_every, 1, 1, 64),
|
| 288 |
+
max_context_tokens=max_context_tokens,
|
| 289 |
+
return_pass_metrics=bool(show_pass_metrics),
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
hist.append({"role": "user", "content": msg})
|
| 293 |
+
hist.append({"role": "assistant", "content": "Thinking..."})
|
| 294 |
+
yield "", hist, status_html("Generating", "LuluV2 is reconstructing tokens live.", "live"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
|
| 295 |
+
|
| 296 |
+
final = ""
|
| 297 |
+
try:
|
| 298 |
+
for partial in engine.generate(prompt, cfg):
|
| 299 |
+
final = postprocess_answer(partial, final=False)
|
| 300 |
+
hist[-1] = {"role": "assistant", "content": final or "..."}
|
| 301 |
+
yield "", hist, status_html("Generating", f"{engine.last_stats.generated_tokens} tokens · {engine.last_stats.tokens_per_sec:.1f} tok/s", "live"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
|
| 302 |
+
except Exception as exc:
|
| 303 |
+
hist[-1] = {"role": "assistant", "content": f"Generation failed:\n\n```text\n{type(exc).__name__}: {exc}\n```"}
|
| 304 |
+
yield msg, hist, status_html("Generation failed", f"{type(exc).__name__}: {exc}", "bad"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
final = postprocess_answer(final, final=True) or "I’m not sure how to answer that yet."
|
| 308 |
+
hist[-1] = {"role": "assistant", "content": final}
|
| 309 |
+
yield "", hist, status_html("Done", f"{engine.last_stats.generated_tokens} tokens · {engine.last_stats.tokens_per_sec:.1f} tok/s", "good"), metric_cards(engine, max_context_tokens), engine.token_trace_text(), engine.stats_dict()
|
| 310 |
+
|
| 311 |
+
def regenerate(
|
| 312 |
+
history,
|
| 313 |
+
chat_name,
|
| 314 |
+
system_prompt,
|
| 315 |
+
memory_notes,
|
| 316 |
+
preset,
|
| 317 |
+
history_turns,
|
| 318 |
+
max_context_tokens,
|
| 319 |
+
max_new_tokens,
|
| 320 |
+
temperature,
|
| 321 |
+
top_k,
|
| 322 |
+
top_p,
|
| 323 |
+
min_p,
|
| 324 |
+
repetition_penalty,
|
| 325 |
+
frequency_penalty,
|
| 326 |
+
greedy,
|
| 327 |
+
no_repeat_ngram,
|
| 328 |
+
stream_every,
|
| 329 |
+
show_pass_metrics,
|
| 330 |
+
):
|
| 331 |
+
hist = normalize_history(history)
|
| 332 |
+
if not hist:
|
| 333 |
+
yield "", hist, status_html("Nothing to regenerate", "Send a message first.", "warn"), metric_cards(engine, safe_int(max_context_tokens, default_context)), engine.token_trace_text(), engine.stats_dict()
|
| 334 |
+
return
|
| 335 |
+
work = hist[:]
|
| 336 |
+
if work and work[-1]["role"] == "assistant":
|
| 337 |
+
work = work[:-1]
|
| 338 |
+
if not work or work[-1]["role"] != "user":
|
| 339 |
+
yield "", hist, status_html("Cannot regenerate", "Last turn is not a user message.", "warn"), metric_cards(engine, safe_int(max_context_tokens, default_context)), engine.token_trace_text(), engine.stats_dict()
|
| 340 |
+
return
|
| 341 |
+
last_msg = work[-1]["content"]
|
| 342 |
+
prev = work[:-1]
|
| 343 |
+
yield from respond(last_msg, prev, chat_name, system_prompt, memory_notes, preset, history_turns, max_context_tokens, max_new_tokens, temperature, top_k, top_p, min_p, repetition_penalty, frequency_penalty, greedy, no_repeat_ngram, stream_every, show_pass_metrics)
|
| 344 |
+
|
| 345 |
+
def new_chat():
|
| 346 |
+
return [], "New chat", status_html("New chat", "Fresh conversation. Memory notes are kept.", "good")
|
| 347 |
+
|
| 348 |
+
def forget_last(history):
|
| 349 |
+
hist = normalize_history(history)
|
| 350 |
+
if len(hist) >= 2:
|
| 351 |
+
return hist[:-2], status_html("Forgot last turn", "Removed the latest exchange.", "good")
|
| 352 |
+
return [], status_html("Nothing to forget", "No full turn to remove.", "warn")
|
| 353 |
+
|
| 354 |
+
def apply_preset(name):
|
| 355 |
+
p = PRESETS.get(name, PRESETS["Balanced"])
|
| 356 |
+
context = 32768 if name == "Long 32K" else default_context
|
| 357 |
+
return p["temperature"], p["top_k"], p["top_p"], p["min_p"], p["repetition_penalty"], p["frequency_penalty"], p["max_new_tokens"], context
|
| 358 |
+
|
| 359 |
+
css = """
|
| 360 |
+
:root{
|
| 361 |
+
--bg:#05060d;--panel:#0b1020;--panel2:#101827;--line:rgba(148,163,184,.16);
|
| 362 |
+
--text:#edf2ff;--muted:#94a3b8;--accent:#8b5cf6;--accent2:#22d3ee;--good:#22c55e;--bad:#ef4444;
|
| 363 |
+
}
|
| 364 |
+
html, body, .gradio-container{
|
| 365 |
+
background: radial-gradient(circle at top left, rgba(139,92,246,.23), transparent 34%),
|
| 366 |
+
radial-gradient(circle at top right, rgba(34,211,238,.14), transparent 30%),
|
| 367 |
+
linear-gradient(180deg,#05060d,#070a12 62%,#02030a)!important;
|
| 368 |
+
color:var(--text)!important;
|
| 369 |
+
}
|
| 370 |
+
.gradio-container{max-width:1680px!important;margin:auto!important;font-family:Inter,ui-sans-serif,system-ui,-apple-system,BlinkMacSystemFont,'Segoe UI',sans-serif!important;}
|
| 371 |
+
footer{display:none!important}.main-wrap{gap:18px!important}.sidebar{padding:16px;border:1px solid var(--line);border-radius:28px;background:rgba(9,14,28,.76);box-shadow:0 20px 70px rgba(0,0,0,.32)}
|
| 372 |
+
.brand{padding:10px 4px 18px}.brand h1{margin:0;font-size:32px;letter-spacing:-.06em;color:#fff}.brand p{margin:5px 0 0;color:var(--muted);font-size:13px}.brand .badge{display:inline-flex;margin-top:12px;padding:7px 10px;border-radius:999px;border:1px solid rgba(34,211,238,.28);background:rgba(8,145,178,.12);color:#cffafe;font-weight:800;font-size:12px}
|
| 373 |
+
.chat-shell{padding:16px;border:1px solid var(--line);border-radius:32px;background:rgba(5,8,18,.62);box-shadow:0 30px 110px rgba(0,0,0,.38)}
|
| 374 |
+
#chatbot{height:760px!important;border:0!important;background:transparent!important;overflow:hidden!important}.message{font-size:15.5px!important;line-height:1.62!important}.message-wrap{max-width:900px!important}.bot .message, .assistant .message{background:rgba(15,23,42,.72)!important;border:1px solid rgba(148,163,184,.13)!important;border-radius:22px!important}.user .message{background:linear-gradient(135deg,rgba(124,58,237,.70),rgba(59,130,246,.42))!important;border:1px solid rgba(167,139,250,.35)!important;border-radius:22px!important;color:white!important}
|
| 375 |
+
#chatbot pre{background:#101827!important;border:1px solid rgba(148,163,184,.22)!important;border-radius:18px!important;padding:16px!important;box-shadow:inset 0 1px 0 rgba(255,255,255,.04)!important}#chatbot code{font-family:'JetBrains Mono','Cascadia Code','SFMono-Regular',Consolas,monospace!important;font-size:14px!important}#chatbot p{margin:0 0 .7em!important}#chatbot ul,#chatbot ol{margin-top:.3em!important}
|
| 376 |
+
.composer-card{display:flex;gap:12px;align-items:end;padding:10px;border-radius:26px;border:1px solid rgba(139,92,246,.28);background:rgba(2,6,23,.80);box-shadow:0 20px 70px rgba(139,92,246,.12)}#composer textarea{min-height:72px!important;max-height:190px!important;background:transparent!important;border:0!important;color:#fff!important;font-size:16px!important;line-height:1.5!important;box-shadow:none!important}.input-container{border:0!important;background:transparent!important}.form{border:0!important;background:transparent!important}label{color:#cbd5e1!important;font-weight:700!important}
|
| 377 |
+
button{border-radius:16px!important;font-weight:850!important;border:1px solid rgba(148,163,184,.16)!important;box-shadow:0 10px 28px rgba(0,0,0,.22)!important}.send-btn{min-height:56px!important;background:linear-gradient(135deg,#8b5cf6,#06b6d4)!important;color:white!important}.side-btn button,.side-btn{width:100%!important}
|
| 378 |
+
.monitor-bar{display:grid;grid-template-columns:repeat(8,minmax(110px,1fr));gap:10px;margin:0 0 12px}.mon-card{padding:12px 13px;border:1px solid var(--line);border-radius:18px;background:rgba(15,23,42,.78);min-height:64px}.mon-card b{display:block;font-size:20px;color:#fff;white-space:nowrap}.mon-card span{display:block;color:var(--muted);font-size:11px;margin-top:3px}.mon-card.hot{background:linear-gradient(135deg,rgba(139,92,246,.30),rgba(34,211,238,.16));border-color:rgba(34,211,238,.30)}.mon-card.wide b{font-size:15px}.status-pill{display:flex;align-items:center;gap:10px;margin:0 0 12px;padding:10px 13px;border-radius:18px;border:1px solid var(--line);background:rgba(2,6,23,.72)}.status-pill b{display:block}.status-pill small{display:block;color:var(--muted);font-size:12px}.pulse-dot{width:10px;height:10px;border-radius:99px;background:var(--accent2);box-shadow:0 0 0 7px rgba(34,211,238,.10),0 0 25px rgba(34,211,238,.55)}.status-good .pulse-dot{background:var(--good);box-shadow:0 0 0 7px rgba(34,197,94,.12),0 0 25px rgba(34,197,94,.5)}.status-bad .pulse-dot{background:var(--bad)}.status-live .pulse-dot{animation:pulse 1.1s infinite}@keyframes pulse{0%{transform:scale(1)}50%{transform:scale(1.45)}100%{transform:scale(1)}}
|
| 379 |
+
.gr-box,.gr-panel,.block{background:transparent!important;border-color:var(--line)!important}.sidebar textarea,.sidebar input,.sidebar select,.sidebar .wrap{background:rgba(2,6,23,.62)!important;color:#e5e7eb!important;border-color:rgba(148,163,184,.16)!important;border-radius:14px!important}.small-note{color:#94a3b8;font-size:12px}.tokenbox textarea,.jsonbox textarea{font-family:'JetBrains Mono','Cascadia Code',Consolas,monospace!important;font-size:12px!important;background:#060914!important}
|
| 380 |
+
@media(max-width:1100px){.monitor-bar{grid-template-columns:repeat(2,1fr)}.sidebar{display:none}.chat-shell{padding:8px}}
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
theme = gr.themes.Base(primary_hue="violet", secondary_hue="cyan", neutral_hue="slate")
|
| 384 |
+
|
| 385 |
+
with gr.Blocks(title=APP_NAME, css=css, theme=theme) as demo:
|
| 386 |
+
with gr.Row(elem_classes=["main-wrap"]):
|
| 387 |
+
with gr.Column(scale=1, min_width=270, elem_classes=["sidebar"]):
|
| 388 |
+
gr.HTML("""
|
| 389 |
+
<div class="brand">
|
| 390 |
+
<h1>LuluV2</h1>
|
| 391 |
+
<p>Offline VWM local assistant.</p>
|
| 392 |
+
<span class="badge">LOCAL EDGE MODE</span>
|
| 393 |
+
</div>
|
| 394 |
+
""")
|
| 395 |
+
new_btn = gr.Button("+ New chat", variant="primary", elem_classes=["side-btn"])
|
| 396 |
+
save_btn = gr.Button("Save chat", elem_classes=["side-btn"])
|
| 397 |
+
saved_path = gr.Textbox(label="Last saved path", interactive=False, visible=False)
|
| 398 |
+
saved_chats = gr.Dropdown(choices=list_saved_chats(), label="Saved chats", value=None, interactive=True)
|
| 399 |
+
with gr.Row():
|
| 400 |
+
refresh_chats = gr.Button("Refresh")
|
| 401 |
+
load_btn = gr.Button("Load")
|
| 402 |
+
export_btn = gr.Button("Export .md", elem_classes=["side-btn"])
|
| 403 |
+
export_path = gr.Textbox(label="Export path", interactive=False, visible=False)
|
| 404 |
+
|
| 405 |
+
with gr.Accordion("Memory", open=True):
|
| 406 |
+
memory_notes = gr.Textbox(label="Persistent memory notes", value=read_memory(), lines=8, placeholder="Things Lulu should remember locally...")
|
| 407 |
+
memory_path = gr.Textbox(label="Memory path", interactive=False, visible=False)
|
| 408 |
+
save_mem_btn = gr.Button("Save memory")
|
| 409 |
+
|
| 410 |
+
with gr.Accordion("Live tokens", open=False):
|
| 411 |
+
token_trace = gr.Textbox(label="Recent generated tokens", value="No tokens generated yet.", lines=14, elem_classes=["tokenbox"])
|
| 412 |
+
|
| 413 |
+
with gr.Accordion("Advanced", open=False):
|
| 414 |
+
chat_name = gr.Textbox(label="Chat name", value="New chat")
|
| 415 |
+
preset = gr.Dropdown(label="Preset", choices=list(PRESETS.keys()), value="Balanced")
|
| 416 |
+
system_prompt = gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=9)
|
| 417 |
+
history_turns = gr.Slider(0, 24, value=4, step=1, label="History turns sent")
|
| 418 |
+
max_context_tokens = gr.Slider(128, 32768, value=default_context, step=128, label="Max context tokens")
|
| 419 |
+
max_new_tokens = gr.Slider(16, 8192, value=768, step=16, label="Max new tokens")
|
| 420 |
+
temperature = gr.Slider(0.0, 2.0, value=0.65, step=0.01, label="Temperature")
|
| 421 |
+
top_k = gr.Slider(0, 500, value=40, step=1, label="Top-k")
|
| 422 |
+
top_p = gr.Slider(0.01, 1.0, value=0.90, step=0.01, label="Top-p")
|
| 423 |
+
min_p = gr.Slider(0.0, 0.5, value=0.03, step=0.005, label="Min-p")
|
| 424 |
+
repetition_penalty = gr.Slider(1.0, 3.0, value=1.10, step=0.01, label="Repetition penalty")
|
| 425 |
+
frequency_penalty = gr.Slider(0.0, 3.0, value=0.02, step=0.01, label="Frequency penalty")
|
| 426 |
+
greedy = gr.Checkbox(value=False, label="Greedy")
|
| 427 |
+
no_repeat_ngram = gr.Slider(0, 16, value=4, step=1, label="No-repeat ngram")
|
| 428 |
+
stream_every = gr.Slider(1, 64, value=1, step=1, label="Stream every N tokens")
|
| 429 |
+
show_pass_metrics = gr.Checkbox(value=True, label="Measure pass1/pass2 before generation")
|
| 430 |
+
insert_32k = gr.Button("Insert 32K stress prompt")
|
| 431 |
+
|
| 432 |
+
with gr.Column(scale=4, elem_classes=["chat-shell"]):
|
| 433 |
+
monitor = gr.HTML(metric_cards(engine, default_context))
|
| 434 |
+
status = gr.HTML(status_html("Ready", f"{engine.model_info.get('checkpoint_size')} checkpoint · {engine.model_info.get('device')}", "good"))
|
| 435 |
+
chatbot = create_chatbot()
|
| 436 |
+
with gr.Row(elem_classes=["composer-card"]):
|
| 437 |
+
msg = gr.Textbox(show_label=False, placeholder="Message LuluV2...", lines=3, elem_id="composer", scale=12)
|
| 438 |
+
send_btn = gr.Button("Send", variant="primary", elem_classes=["send-btn"], scale=2)
|
| 439 |
+
with gr.Row():
|
| 440 |
+
stop_btn = gr.Button("Stop")
|
| 441 |
+
regen_btn = gr.Button("Regenerate")
|
| 442 |
+
forget_btn = gr.Button("Forget last turn")
|
| 443 |
+
prompt_32k_btn = gr.Button("Try 32K prompt")
|
| 444 |
+
with gr.Accordion("Raw metrics", open=False):
|
| 445 |
+
raw_metrics = gr.JSON(label="Raw metrics")
|
| 446 |
+
usage_text = gr.Textbox(label="RAM / VRAM / model stats", value=system_usage(engine), lines=18, elem_classes=["jsonbox"])
|
| 447 |
+
|
| 448 |
+
inputs = [
|
| 449 |
+
msg, chatbot, chat_name, system_prompt, memory_notes, preset,
|
| 450 |
+
history_turns, max_context_tokens, max_new_tokens, temperature, top_k, top_p,
|
| 451 |
+
min_p, repetition_penalty, frequency_penalty, greedy, no_repeat_ngram,
|
| 452 |
+
stream_every, show_pass_metrics,
|
| 453 |
+
]
|
| 454 |
+
outputs = [msg, chatbot, status, monitor, token_trace, raw_metrics]
|
| 455 |
+
send_event = send_btn.click(respond, inputs=inputs, outputs=outputs)
|
| 456 |
+
enter_event = msg.submit(respond, inputs=inputs, outputs=outputs)
|
| 457 |
+
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[send_event, enter_event])
|
| 458 |
+
|
| 459 |
+
regen_inputs = [
|
| 460 |
+
chatbot, chat_name, system_prompt, memory_notes, preset,
|
| 461 |
+
history_turns, max_context_tokens, max_new_tokens, temperature, top_k, top_p,
|
| 462 |
+
min_p, repetition_penalty, frequency_penalty, greedy, no_repeat_ngram,
|
| 463 |
+
stream_every, show_pass_metrics,
|
| 464 |
+
]
|
| 465 |
+
regen_event = regen_btn.click(regenerate, inputs=regen_inputs, outputs=outputs)
|
| 466 |
+
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[regen_event])
|
| 467 |
+
|
| 468 |
+
new_btn.click(new_chat, outputs=[chatbot, chat_name, status])
|
| 469 |
+
forget_btn.click(forget_last, inputs=[chatbot], outputs=[chatbot, status])
|
| 470 |
+
save_btn.click(save_chat, inputs=[chatbot, chat_name, memory_notes], outputs=[saved_path, status, saved_chats])
|
| 471 |
+
refresh_chats.click(lambda: gr.update(choices=list_saved_chats()), outputs=[saved_chats])
|
| 472 |
+
load_btn.click(load_chat, inputs=[saved_chats], outputs=[chatbot, chat_name, memory_notes, status])
|
| 473 |
+
export_btn.click(export_markdown, inputs=[chatbot, chat_name], outputs=[export_path, status])
|
| 474 |
+
save_mem_btn.click(write_memory, inputs=[memory_notes], outputs=[memory_path, status])
|
| 475 |
+
preset.change(apply_preset, inputs=[preset], outputs=[temperature, top_k, top_p, min_p, repetition_penalty, frequency_penalty, max_new_tokens, max_context_tokens])
|
| 476 |
+
insert_32k.click(lambda: make_32k_prompt(), outputs=[msg])
|
| 477 |
+
prompt_32k_btn.click(lambda: make_32k_prompt(), outputs=[msg])
|
| 478 |
+
|
| 479 |
+
return demo
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def parse_args():
|
| 483 |
+
ap = argparse.ArgumentParser()
|
| 484 |
+
ap.add_argument("--ckpt", default="LULU2_instruct_ddp.pt")
|
| 485 |
+
ap.add_argument("--model-py", default="luluv2_inference_runtime.py")
|
| 486 |
+
ap.add_argument("--tokenizer-dir", default="tokenizer")
|
| 487 |
+
ap.add_argument("--host", default="127.0.0.1")
|
| 488 |
+
ap.add_argument("--port", type=int, default=7862)
|
| 489 |
+
ap.add_argument("--device", default="cuda")
|
| 490 |
+
ap.add_argument("--dtype", default="bf16")
|
| 491 |
+
ap.add_argument("--max-context", type=int, default=32768)
|
| 492 |
+
ap.add_argument("--share", action="store_true")
|
| 493 |
+
ap.add_argument("--inbrowser", action="store_true")
|
| 494 |
+
ap.add_argument("--base-only", action="store_true")
|
| 495 |
+
return ap.parse_args()
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def main():
|
| 499 |
+
args = parse_args()
|
| 500 |
+
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
| 501 |
+
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
|
| 502 |
+
engine = LULUV2LiveEngine(
|
| 503 |
+
ckpt_path=args.ckpt,
|
| 504 |
+
model_py=args.model_py,
|
| 505 |
+
tokenizer_dir=args.tokenizer_dir,
|
| 506 |
+
device=args.device,
|
| 507 |
+
dtype=args.dtype,
|
| 508 |
+
local_files_only=True,
|
| 509 |
+
no_config_download=True,
|
| 510 |
+
force_base_only=bool(args.base_only),
|
| 511 |
+
)
|
| 512 |
+
demo = build_app(engine, default_context=safe_int(args.max_context, 32768, 128, 32768))
|
| 513 |
+
demo.queue(default_concurrency_limit=1).launch(
|
| 514 |
+
server_name=args.host,
|
| 515 |
+
server_port=int(args.port),
|
| 516 |
+
share=bool(args.share),
|
| 517 |
+
inbrowser=bool(args.inbrowser),
|
| 518 |
+
show_error=True,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
main()
|
config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "luluv2",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Lulu2ForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"torch_dtype": "bfloat16",
|
| 7 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 8 |
+
"auto_map": {},
|
| 9 |
+
"inference_only_package": true
|
| 10 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_new_tokens": 512,
|
| 3 |
+
"temperature": 0.65,
|
| 4 |
+
"top_k": 40,
|
| 5 |
+
"top_p": 0.9,
|
| 6 |
+
"do_sample": true,
|
| 7 |
+
"eos_token_id": 151645,
|
| 8 |
+
"pad_token_id": 151643
|
| 9 |
+
}
|
luluv2_inference_runtime.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
LULUV2 inference-only runtime.
|
| 5 |
+
|
| 6 |
+
This file intentionally contains only the code needed to load and run a
|
| 7 |
+
standalone native-bf16 LULUV2 checkpoint. It contains only the
|
| 8 |
+
runtime loader, tokenizer bridge, decoder modules, and two-pass inference path
|
| 9 |
+
needed for local generation.
|
| 10 |
+
|
| 11 |
+
Runtime behavior:
|
| 12 |
+
- loads a local checkpoint supplied by the user/repo;
|
| 13 |
+
- uses local tokenizer files;
|
| 14 |
+
- does not download or load any external model weights;
|
| 15 |
+
- preserves the VWM/two-pass inference path when present in the checkpoint.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import time
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from types import SimpleNamespace
|
| 26 |
+
from typing import Dict, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
|
| 32 |
+
_TRANSFORMERS_IMPORT_ERROR = None
|
| 33 |
+
try:
|
| 34 |
+
from transformers import AutoTokenizer as _HFAutoTokenizer
|
| 35 |
+
try:
|
| 36 |
+
from transformers import AutoConfig as _HFAutoConfig
|
| 37 |
+
except Exception:
|
| 38 |
+
_HFAutoConfig = None
|
| 39 |
+
except Exception as _e:
|
| 40 |
+
_TRANSFORMERS_IMPORT_ERROR = _e
|
| 41 |
+
_HFAutoTokenizer = None
|
| 42 |
+
_HFAutoConfig = None
|
| 43 |
+
|
| 44 |
+
class _TokenOutput(dict):
|
| 45 |
+
def __getattr__(self, name):
|
| 46 |
+
try:
|
| 47 |
+
return self[name]
|
| 48 |
+
except KeyError as exc:
|
| 49 |
+
raise AttributeError(name) from exc
|
| 50 |
+
def to(self, device):
|
| 51 |
+
out = _TokenOutput()
|
| 52 |
+
for k, v in self.items():
|
| 53 |
+
out[k] = v.to(device) if torch.is_tensor(v) else v
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
class _LocalTokenizer:
|
| 57 |
+
def __init__(self, path: str, tokenizer_file: Optional[str] = None, **kwargs):
|
| 58 |
+
import json as _json
|
| 59 |
+
try:
|
| 60 |
+
from tokenizers import Tokenizer as _TokenizerCore
|
| 61 |
+
except Exception as exc:
|
| 62 |
+
raise RuntimeError(
|
| 63 |
+
"transformers import failed and tokenizers is unavailable. "
|
| 64 |
+
"Install tokenizers or use a matching torch/transformers pair."
|
| 65 |
+
) from exc
|
| 66 |
+
self.name_or_path = path or tokenizer_file or "<local-tokenizer>"
|
| 67 |
+
if tokenizer_file:
|
| 68 |
+
tok_file = tokenizer_file
|
| 69 |
+
base_dir = os.path.dirname(os.path.abspath(tok_file))
|
| 70 |
+
else:
|
| 71 |
+
base_dir = os.path.abspath(path)
|
| 72 |
+
tok_file = os.path.join(base_dir, "tokenizer.json")
|
| 73 |
+
if not os.path.exists(tok_file):
|
| 74 |
+
raise FileNotFoundError(f"Local tokenizer.json not found: {tok_file}")
|
| 75 |
+
self._tok = _TokenizerCore.from_file(tok_file)
|
| 76 |
+
self.vocab_size = int(self._tok.get_vocab_size())
|
| 77 |
+
self.model_max_length = 10**9
|
| 78 |
+
self.truncation_side = "left"
|
| 79 |
+
self.chat_template = None
|
| 80 |
+
self.eos_token = None
|
| 81 |
+
self.pad_token = None
|
| 82 |
+
cfg_path = os.path.join(base_dir, "tokenizer_config.json")
|
| 83 |
+
sp_path = os.path.join(base_dir, "special_tokens_map.json")
|
| 84 |
+
for p in (cfg_path, sp_path):
|
| 85 |
+
if os.path.exists(p):
|
| 86 |
+
try:
|
| 87 |
+
data = _json.load(open(p, "r", encoding="utf-8"))
|
| 88 |
+
except Exception:
|
| 89 |
+
data = {}
|
| 90 |
+
if self.chat_template is None and isinstance(data.get("chat_template"), str):
|
| 91 |
+
self.chat_template = data.get("chat_template")
|
| 92 |
+
for key, attr in (("eos_token", "eos_token"), ("pad_token", "pad_token")):
|
| 93 |
+
val = data.get(key)
|
| 94 |
+
if isinstance(val, dict):
|
| 95 |
+
val = val.get("content")
|
| 96 |
+
if isinstance(val, str):
|
| 97 |
+
setattr(self, attr, val)
|
| 98 |
+
if self.eos_token is None:
|
| 99 |
+
for cand in ("<|im_end|>", "<|endoftext|>", "</s>"):
|
| 100 |
+
if self._tok.token_to_id(cand) is not None:
|
| 101 |
+
self.eos_token = cand
|
| 102 |
+
break
|
| 103 |
+
if self.pad_token is None:
|
| 104 |
+
self.pad_token = self.eos_token
|
| 105 |
+
self.eos_token_id = self._tok.token_to_id(self.eos_token) if self.eos_token else None
|
| 106 |
+
self.pad_token_id = self._tok.token_to_id(self.pad_token) if self.pad_token else self.eos_token_id
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return self.vocab_size
|
| 110 |
+
|
| 111 |
+
def __call__(self, text, return_tensors=None, truncation=False, max_length=None, add_special_tokens=True, **kwargs):
|
| 112 |
+
if isinstance(text, (list, tuple)):
|
| 113 |
+
encoded = [self._encode_one(t, add_special_tokens, truncation, max_length) for t in text]
|
| 114 |
+
maxlen = max(len(x) for x in encoded) if encoded else 0
|
| 115 |
+
pad = self.pad_token_id if self.pad_token_id is not None else 0
|
| 116 |
+
arr = [x + [pad] * (maxlen - len(x)) for x in encoded]
|
| 117 |
+
if return_tensors == "pt":
|
| 118 |
+
return _TokenOutput(input_ids=torch.tensor(arr, dtype=torch.long))
|
| 119 |
+
return _TokenOutput(input_ids=arr)
|
| 120 |
+
ids = self._encode_one(str(text), add_special_tokens, truncation, max_length)
|
| 121 |
+
if return_tensors == "pt":
|
| 122 |
+
return _TokenOutput(input_ids=torch.tensor([ids], dtype=torch.long))
|
| 123 |
+
return _TokenOutput(input_ids=ids)
|
| 124 |
+
|
| 125 |
+
def _encode_one(self, text, add_special_tokens=True, truncation=False, max_length=None):
|
| 126 |
+
enc = self._tok.encode(text, add_special_tokens=bool(add_special_tokens))
|
| 127 |
+
ids = list(enc.ids)
|
| 128 |
+
if truncation and max_length is not None and len(ids) > int(max_length):
|
| 129 |
+
if self.truncation_side == "left":
|
| 130 |
+
ids = ids[-int(max_length):]
|
| 131 |
+
else:
|
| 132 |
+
ids = ids[:int(max_length)]
|
| 133 |
+
return ids
|
| 134 |
+
|
| 135 |
+
def decode(self, ids, skip_special_tokens=True, **kwargs):
|
| 136 |
+
if torch.is_tensor(ids):
|
| 137 |
+
ids = ids.detach().cpu().tolist()
|
| 138 |
+
if ids and isinstance(ids[0], list):
|
| 139 |
+
ids = ids[0]
|
| 140 |
+
return self._tok.decode([int(x) for x in ids], skip_special_tokens=bool(skip_special_tokens))
|
| 141 |
+
|
| 142 |
+
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False, **kwargs):
|
| 143 |
+
chunks = []
|
| 144 |
+
for m in messages:
|
| 145 |
+
role = str(m.get("role", "user"))
|
| 146 |
+
content = str(m.get("content", ""))
|
| 147 |
+
chunks.append(f"<|im_start|>{role}\n{content}<|im_end|>")
|
| 148 |
+
if add_generation_prompt:
|
| 149 |
+
chunks.append("<|im_start|>assistant\n")
|
| 150 |
+
text = "\n".join(chunks)
|
| 151 |
+
if tokenize:
|
| 152 |
+
return self(text, add_special_tokens=False).input_ids
|
| 153 |
+
return text
|
| 154 |
+
|
| 155 |
+
class _AutoTokenizerShim:
|
| 156 |
+
@staticmethod
|
| 157 |
+
def from_pretrained(path, *args, **kwargs):
|
| 158 |
+
if _HFAutoTokenizer is not None:
|
| 159 |
+
return _HFAutoTokenizer.from_pretrained(path, *args, **kwargs)
|
| 160 |
+
return _LocalTokenizer(path)
|
| 161 |
+
|
| 162 |
+
class _AutoConfigShim:
|
| 163 |
+
@staticmethod
|
| 164 |
+
def from_pretrained(path, *args, **kwargs):
|
| 165 |
+
if _HFAutoConfig is not None:
|
| 166 |
+
return _HFAutoConfig.from_pretrained(path, *args, **kwargs)
|
| 167 |
+
raise RuntimeError(
|
| 168 |
+
"AutoConfig requested, but transformers failed to import. "
|
| 169 |
+
"Use --no-config-download / embedded model_config for LULUV2."
|
| 170 |
+
)
|
| 171 |
+
AutoTokenizer = _AutoTokenizerShim
|
| 172 |
+
AutoConfig = _AutoConfigShim
|
| 173 |
+
|
| 174 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 175 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 176 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 177 |
+
torch.set_float32_matmul_precision("high")
|
| 178 |
+
try:
|
| 179 |
+
if torch.cuda.is_available():
|
| 180 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 181 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 182 |
+
torch.backends.cuda.enable_math_sdp(False)
|
| 183 |
+
except Exception:
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def parse_dtype(name: str):
|
| 189 |
+
name = str(name).strip().lower()
|
| 190 |
+
if name in {"bf16", "bfloat16"}:
|
| 191 |
+
return torch.bfloat16
|
| 192 |
+
if name in {"fp16", "float16", "half"}:
|
| 193 |
+
return torch.float16
|
| 194 |
+
if name in {"fp32", "float32"}:
|
| 195 |
+
return torch.float32
|
| 196 |
+
raise ValueError(f"Unknown dtype: {name}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def human_bytes(n: float) -> str:
|
| 200 |
+
units = ["B", "KB", "MB", "GB", "TB"]
|
| 201 |
+
x = float(n)
|
| 202 |
+
i = 0
|
| 203 |
+
while x >= 1024.0 and i < len(units) - 1:
|
| 204 |
+
x /= 1024.0
|
| 205 |
+
i += 1
|
| 206 |
+
return f"{x:.2f} {units[i]}"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def safe_torch_load(path: str, map_location="cpu"):
|
| 210 |
+
# PyTorch 2.6+ defaults may warn around weights_only. This checkpoint stores
|
| 211 |
+
# Python metadata plus tensors, so weights_only=False is intentional.
|
| 212 |
+
try:
|
| 213 |
+
return torch.load(path, map_location=map_location, weights_only=False)
|
| 214 |
+
except TypeError:
|
| 215 |
+
return torch.load(path, map_location=map_location)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def module_has_vwm(sd: Dict[str, torch.Tensor], prefix: str) -> bool:
|
| 219 |
+
return f"{prefix}.A" in sd and f"{prefix}.B" in sd and f"{prefix}.c" in sd
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def linear_shape_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> Tuple[int, int, bool]:
|
| 223 |
+
if module_has_vwm(sd, prefix):
|
| 224 |
+
out_features = int(sd[f"{prefix}.A"].shape[0])
|
| 225 |
+
in_features = int(sd[f"{prefix}.B"].shape[0])
|
| 226 |
+
has_bias = f"{prefix}.bias" in sd
|
| 227 |
+
return in_features, out_features, has_bias
|
| 228 |
+
wkey = f"{prefix}.weight"
|
| 229 |
+
if wkey not in sd:
|
| 230 |
+
raise KeyError(f"Cannot infer Linear shape for {prefix}; missing {wkey} and VWM A/B/c")
|
| 231 |
+
out_features, in_features = sd[wkey].shape
|
| 232 |
+
has_bias = f"{prefix}.bias" in sd
|
| 233 |
+
return int(in_features), int(out_features), has_bias
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def make_linear_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> nn.Module:
|
| 237 |
+
in_features, out_features, has_bias = linear_shape_from_state(sd, prefix)
|
| 238 |
+
if module_has_vwm(sd, prefix):
|
| 239 |
+
rank = int(sd[f"{prefix}.c"].shape[0])
|
| 240 |
+
return VWMFactorizedLinear(in_features, out_features, rank, bias=has_bias, name=prefix)
|
| 241 |
+
return nn.Linear(in_features, out_features, bias=has_bias)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def module_has_vwm_embedding(sd: Dict[str, torch.Tensor], prefix: str) -> bool:
|
| 245 |
+
return f"{prefix}.A" in sd and f"{prefix}.B" in sd and f"{prefix}.c" in sd
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def embedding_shape_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> Tuple[int, int]:
|
| 249 |
+
if module_has_vwm_embedding(sd, prefix):
|
| 250 |
+
return int(sd[f"{prefix}.A"].shape[0]), int(sd[f"{prefix}.B"].shape[0])
|
| 251 |
+
wkey = f"{prefix}.weight"
|
| 252 |
+
if wkey not in sd:
|
| 253 |
+
raise KeyError(f"Cannot infer embedding shape for {prefix}; missing dense or VWM embedding tensors")
|
| 254 |
+
return int(sd[wkey].shape[0]), int(sd[wkey].shape[1])
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def make_embedding_from_state(sd: Dict[str, torch.Tensor], prefix: str) -> nn.Module:
|
| 258 |
+
vocab_size, hidden_size = embedding_shape_from_state(sd, prefix)
|
| 259 |
+
if module_has_vwm_embedding(sd, prefix):
|
| 260 |
+
rank = int(sd[f"{prefix}.c"].shape[0])
|
| 261 |
+
return VWMFactorizedEmbedding(vocab_size, hidden_size, rank, name=prefix)
|
| 262 |
+
return nn.Embedding(vocab_size, hidden_size)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def expand_shared_banks_into_state(ckpt: Dict, sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 266 |
+
"""Expand experimental shared-bank storage into normal per-module A/B/c tensors."""
|
| 267 |
+
banks = ckpt.get("shared_banks")
|
| 268 |
+
if not banks:
|
| 269 |
+
return sd
|
| 270 |
+
out = dict(sd)
|
| 271 |
+
n = 0
|
| 272 |
+
for bank_id, bank in banks.items():
|
| 273 |
+
A = bank["A"]
|
| 274 |
+
B = bank["B"]
|
| 275 |
+
modules = bank.get("modules", {})
|
| 276 |
+
for prefix, m in modules.items():
|
| 277 |
+
out[f"{prefix}.A"] = A
|
| 278 |
+
out[f"{prefix}.B"] = B
|
| 279 |
+
out[f"{prefix}.c"] = m["c"]
|
| 280 |
+
if "bias" in m and m["bias"] is not None:
|
| 281 |
+
out[f"{prefix}.bias"] = m["bias"]
|
| 282 |
+
n += 1
|
| 283 |
+
print(f"[shared-bank] expanded {len(banks)} banks into {n} VWM modules")
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# -----------------------------
|
| 288 |
+
# VWM linear used by the exported checkpoint
|
| 289 |
+
# -----------------------------
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class VWMFactorizedLinear(nn.Module):
|
| 293 |
+
"""
|
| 294 |
+
W ~= A diag(c) B^T
|
| 295 |
+
y = ((x @ B) * c) @ A^T + bias
|
| 296 |
+
|
| 297 |
+
This matches LULU2 exporter's exported VWMFactorizedLinear
|
| 298 |
+
state names: A, B, c, optional bias.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(self, in_features: int, out_features: int, rank: int, bias: bool = True, name: str = ""):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.in_features = int(in_features)
|
| 304 |
+
self.out_features = int(out_features)
|
| 305 |
+
self.rank = int(rank)
|
| 306 |
+
self.name = name
|
| 307 |
+
self.A = nn.Parameter(torch.empty(out_features, rank), requires_grad=False)
|
| 308 |
+
self.B = nn.Parameter(torch.empty(in_features, rank), requires_grad=False)
|
| 309 |
+
self.c = nn.Parameter(torch.empty(rank), requires_grad=False)
|
| 310 |
+
self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False) if bias else None
|
| 311 |
+
|
| 312 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 313 |
+
# Compute in the activation dtype/device. Parameters are already moved by model.to(...).
|
| 314 |
+
t = torch.matmul(x, self.B.to(dtype=x.dtype))
|
| 315 |
+
t = t * self.c.to(dtype=x.dtype)
|
| 316 |
+
y = torch.matmul(t, self.A.to(dtype=x.dtype).transpose(0, 1))
|
| 317 |
+
if self.bias is not None:
|
| 318 |
+
y = y + self.bias.to(dtype=x.dtype)
|
| 319 |
+
return y
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class VWMFactorizedEmbedding(nn.Module):
|
| 325 |
+
"""Runtime for exported VWM embedding: E ~= A diag(c) B^T."""
|
| 326 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, rank: int, name: str = "model.embed_tokens"):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.num_embeddings = int(num_embeddings)
|
| 329 |
+
self.embedding_dim = int(embedding_dim)
|
| 330 |
+
self.rank = int(rank)
|
| 331 |
+
self.name = name
|
| 332 |
+
self.A = nn.Parameter(torch.empty(num_embeddings, rank), requires_grad=False)
|
| 333 |
+
self.B = nn.Parameter(torch.empty(embedding_dim, rank), requires_grad=False)
|
| 334 |
+
self.c = nn.Parameter(torch.empty(rank), requires_grad=False)
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def weight(self):
|
| 338 |
+
# Dense materialization only for compatibility/debug. Normal forward avoids this.
|
| 339 |
+
return (self.A * self.c.view(1, -1)) @ self.B.T
|
| 340 |
+
|
| 341 |
+
def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 342 |
+
a = F.embedding(input_ids, self.A)
|
| 343 |
+
t = a * self.c.to(dtype=a.dtype)
|
| 344 |
+
return torch.matmul(t, self.B.to(dtype=a.dtype).transpose(0, 1))
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class TiedEmbeddingLMHead(nn.Module):
|
| 348 |
+
"""LM head tied to the model embedding matrix, dense or VWM."""
|
| 349 |
+
def __init__(self, embedding_module: nn.Module):
|
| 350 |
+
super().__init__()
|
| 351 |
+
self.embedding_module = embedding_module
|
| 352 |
+
|
| 353 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 354 |
+
emb = self.embedding_module
|
| 355 |
+
if isinstance(emb, VWMFactorizedEmbedding):
|
| 356 |
+
# logits = h @ E.T = (h @ B) * c @ A.T
|
| 357 |
+
t = torch.matmul(hidden_states, emb.B.to(dtype=hidden_states.dtype))
|
| 358 |
+
t = t * emb.c.to(dtype=hidden_states.dtype)
|
| 359 |
+
return torch.matmul(t, emb.A.to(dtype=hidden_states.dtype).transpose(0, 1))
|
| 360 |
+
return F.linear(hidden_states, emb.weight.to(dtype=hidden_states.dtype))
|
| 361 |
+
|
| 362 |
+
# -----------------------------
|
| 363 |
+
# LULU2 decoder architecture
|
| 364 |
+
# -----------------------------
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class LuluRMSNorm(nn.Module):
|
| 368 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.weight = nn.Parameter(torch.ones(hidden_size), requires_grad=False)
|
| 371 |
+
self.variance_epsilon = float(eps)
|
| 372 |
+
|
| 373 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 374 |
+
input_dtype = hidden_states.dtype
|
| 375 |
+
hidden_states = hidden_states.float()
|
| 376 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 377 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 378 |
+
return self.weight.to(dtype=input_dtype) * hidden_states.to(input_dtype)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class LuluRotaryEmbedding(nn.Module):
|
| 382 |
+
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 1000000.0):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.dim = int(dim)
|
| 385 |
+
self.max_position_embeddings = int(max_position_embeddings)
|
| 386 |
+
self.base = float(base)
|
| 387 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
| 388 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 389 |
+
|
| 390 |
+
@torch.no_grad()
|
| 391 |
+
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 392 |
+
# position_ids: [B, T]
|
| 393 |
+
inv_freq = self.inv_freq.to(device=x.device)
|
| 394 |
+
freqs = torch.einsum("bt,d->btd", position_ids.float(), inv_freq.float())
|
| 395 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 396 |
+
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 400 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 401 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 402 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 406 |
+
# q/k: [B, H, T, D], cos/sin: [B, T, D]
|
| 407 |
+
cos = cos.unsqueeze(1)
|
| 408 |
+
sin = sin.unsqueeze(1)
|
| 409 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 410 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 411 |
+
return q_embed, k_embed
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 415 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 416 |
+
if n_rep == 1:
|
| 417 |
+
return hidden_states
|
| 418 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 419 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class LuluVWMMLP(nn.Module):
|
| 423 |
+
def __init__(self, cfg, sd: Dict[str, torch.Tensor], layer_idx: int):
|
| 424 |
+
super().__init__()
|
| 425 |
+
p = f"model.layers.{layer_idx}.mlp"
|
| 426 |
+
self.gate_proj = make_linear_from_state(sd, f"{p}.gate_proj")
|
| 427 |
+
self.up_proj = make_linear_from_state(sd, f"{p}.up_proj")
|
| 428 |
+
self.down_proj = make_linear_from_state(sd, f"{p}.down_proj")
|
| 429 |
+
|
| 430 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 431 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class LuluVWMAttention(nn.Module):
|
| 435 |
+
def __init__(self, cfg, sd: Dict[str, torch.Tensor], layer_idx: int):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.layer_idx = int(layer_idx)
|
| 438 |
+
self.hidden_size = int(cfg.hidden_size)
|
| 439 |
+
self.num_heads = int(cfg.num_attention_heads)
|
| 440 |
+
self.num_key_value_heads = int(getattr(cfg, "num_key_value_heads", self.num_heads))
|
| 441 |
+
self.head_dim = int(getattr(cfg, "head_dim", self.hidden_size // self.num_heads))
|
| 442 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 443 |
+
self.scaling = self.head_dim ** -0.5
|
| 444 |
+
self.attention_dropout = float(getattr(cfg, "attention_dropout", 0.0))
|
| 445 |
+
|
| 446 |
+
p = f"model.layers.{layer_idx}.self_attn"
|
| 447 |
+
self.q_proj = make_linear_from_state(sd, f"{p}.q_proj")
|
| 448 |
+
self.k_proj = make_linear_from_state(sd, f"{p}.k_proj")
|
| 449 |
+
self.v_proj = make_linear_from_state(sd, f"{p}.v_proj")
|
| 450 |
+
self.o_proj = make_linear_from_state(sd, f"{p}.o_proj")
|
| 451 |
+
|
| 452 |
+
rope_theta = float(getattr(cfg, "rope_theta", 1000000.0))
|
| 453 |
+
max_pos = int(getattr(cfg, "max_position_embeddings", 32768))
|
| 454 |
+
self.rotary_emb = LuluRotaryEmbedding(self.head_dim, max_position_embeddings=max_pos, base=rope_theta)
|
| 455 |
+
|
| 456 |
+
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
|
| 457 |
+
bsz, q_len, _ = hidden_states.size()
|
| 458 |
+
|
| 459 |
+
query_states = self.q_proj(hidden_states)
|
| 460 |
+
key_states = self.k_proj(hidden_states)
|
| 461 |
+
value_states = self.v_proj(hidden_states)
|
| 462 |
+
|
| 463 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 464 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 465 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 466 |
+
|
| 467 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 468 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 469 |
+
|
| 470 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 471 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 472 |
+
|
| 473 |
+
# Full forward is causal. This generation script recomputes the full prefix each token.
|
| 474 |
+
attn_output = F.scaled_dot_product_attention(
|
| 475 |
+
query_states,
|
| 476 |
+
key_states,
|
| 477 |
+
value_states,
|
| 478 |
+
attn_mask=None,
|
| 479 |
+
dropout_p=0.0,
|
| 480 |
+
is_causal=True,
|
| 481 |
+
scale=self.scaling,
|
| 482 |
+
)
|
| 483 |
+
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
|
| 484 |
+
return self.o_proj(attn_output)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class LuluVWMDecoderLayer(nn.Module):
|
| 488 |
+
def __init__(self, cfg, sd: Dict[str, torch.Tensor], layer_idx: int):
|
| 489 |
+
super().__init__()
|
| 490 |
+
self.self_attn = LuluVWMAttention(cfg, sd, layer_idx)
|
| 491 |
+
self.mlp = LuluVWMMLP(cfg, sd, layer_idx)
|
| 492 |
+
self.input_layernorm = LuluRMSNorm(cfg.hidden_size, eps=getattr(cfg, "rms_norm_eps", 1e-6))
|
| 493 |
+
self.post_attention_layernorm = LuluRMSNorm(cfg.hidden_size, eps=getattr(cfg, "rms_norm_eps", 1e-6))
|
| 494 |
+
|
| 495 |
+
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
|
| 496 |
+
residual = hidden_states
|
| 497 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 498 |
+
hidden_states = self.self_attn(hidden_states, position_ids=position_ids)
|
| 499 |
+
hidden_states = residual + hidden_states
|
| 500 |
+
|
| 501 |
+
residual = hidden_states
|
| 502 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 503 |
+
hidden_states = self.mlp(hidden_states)
|
| 504 |
+
hidden_states = residual + hidden_states
|
| 505 |
+
return hidden_states
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class LuluVWMModel(nn.Module):
|
| 509 |
+
def __init__(self, cfg, sd: Dict[str, torch.Tensor]):
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.config = cfg
|
| 512 |
+
vocab_size, hidden_size = embedding_shape_from_state(sd, "model.embed_tokens")
|
| 513 |
+
self.embed_tokens = make_embedding_from_state(sd, "model.embed_tokens")
|
| 514 |
+
self.layers = nn.ModuleList([LuluVWMDecoderLayer(cfg, sd, i) for i in range(int(cfg.num_hidden_layers))])
|
| 515 |
+
self.norm = LuluRMSNorm(hidden_size, eps=getattr(cfg, "rms_norm_eps", 1e-6))
|
| 516 |
+
|
| 517 |
+
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
| 518 |
+
bsz, seq_len = input_ids.shape
|
| 519 |
+
if position_ids is None:
|
| 520 |
+
position_ids = torch.arange(seq_len, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
|
| 521 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 522 |
+
for layer in self.layers:
|
| 523 |
+
hidden_states = layer(hidden_states, position_ids=position_ids)
|
| 524 |
+
return self.norm(hidden_states)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class LuluVWMForCausalLM(nn.Module):
|
| 528 |
+
def __init__(self, cfg, sd: Dict[str, torch.Tensor]):
|
| 529 |
+
super().__init__()
|
| 530 |
+
self.config = cfg
|
| 531 |
+
self.model = LuluVWMModel(cfg, sd)
|
| 532 |
+
_, hidden_size = embedding_shape_from_state(sd, "model.embed_tokens")
|
| 533 |
+
self.tie_word_embeddings = bool(getattr(cfg, "tie_word_embeddings", False))
|
| 534 |
+
if module_has_vwm(sd, "lm_head") or "lm_head.weight" in sd:
|
| 535 |
+
self.lm_head = make_linear_from_state(sd, "lm_head")
|
| 536 |
+
else:
|
| 537 |
+
self.tie_word_embeddings = True
|
| 538 |
+
self.lm_head = TiedEmbeddingLMHead(self.model.embed_tokens)
|
| 539 |
+
|
| 540 |
+
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None):
|
| 541 |
+
hidden_states = self.model(input_ids=input_ids, position_ids=position_ids)
|
| 542 |
+
logits = self.lm_head(hidden_states)
|
| 543 |
+
return SimpleNamespace(logits=logits)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# -----------------------------
|
| 547 |
+
# config loading / inference
|
| 548 |
+
# -----------------------------
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def infer_minimal_config_from_state(sd: Dict[str, torch.Tensor], model_id: str = "") -> SimpleNamespace:
|
| 552 |
+
if "model.embed_tokens.weight" in sd:
|
| 553 |
+
hidden_size = int(sd["model.embed_tokens.weight"].shape[1])
|
| 554 |
+
vocab_size = int(sd["model.embed_tokens.weight"].shape[0])
|
| 555 |
+
elif module_has_vwm_embedding(sd, "model.embed_tokens"):
|
| 556 |
+
vocab_size = int(sd["model.embed_tokens.A"].shape[0])
|
| 557 |
+
hidden_size = int(sd["model.embed_tokens.B"].shape[0])
|
| 558 |
+
else:
|
| 559 |
+
raise ValueError("Checkpoint is missing model.embed_tokens dense or VWM tensors. Use a full standalone checkpoint, not a delta checkpoint.")
|
| 560 |
+
layer_ids = []
|
| 561 |
+
for k in sd.keys():
|
| 562 |
+
if k.startswith("model.layers."):
|
| 563 |
+
try:
|
| 564 |
+
layer_ids.append(int(k.split(".")[2]))
|
| 565 |
+
except Exception:
|
| 566 |
+
pass
|
| 567 |
+
num_hidden_layers = max(layer_ids) + 1 if layer_ids else 0
|
| 568 |
+
inter_key = "model.layers.0.mlp.gate_proj.weight"
|
| 569 |
+
if inter_key in sd:
|
| 570 |
+
intermediate_size = int(sd[inter_key].shape[0])
|
| 571 |
+
else:
|
| 572 |
+
intermediate_size = 4864
|
| 573 |
+
|
| 574 |
+
# Best known defaults for LULU2. If you export
|
| 575 |
+
# model_config into the checkpoint, these assumptions are not used.
|
| 576 |
+
num_attention_heads = 14
|
| 577 |
+
num_key_value_heads = 2
|
| 578 |
+
head_dim = hidden_size // num_attention_heads
|
| 579 |
+
if head_dim * num_attention_heads != hidden_size:
|
| 580 |
+
# Fallback if a different decoder variant is used and no config is present.
|
| 581 |
+
# This requires explicit command-line override in practice.
|
| 582 |
+
num_attention_heads = 1
|
| 583 |
+
num_key_value_heads = 1
|
| 584 |
+
head_dim = hidden_size
|
| 585 |
+
|
| 586 |
+
return SimpleNamespace(
|
| 587 |
+
model_type="luluv2",
|
| 588 |
+
model_id=model_id,
|
| 589 |
+
vocab_size=vocab_size,
|
| 590 |
+
hidden_size=hidden_size,
|
| 591 |
+
intermediate_size=intermediate_size,
|
| 592 |
+
num_hidden_layers=num_hidden_layers,
|
| 593 |
+
num_attention_heads=num_attention_heads,
|
| 594 |
+
num_key_value_heads=num_key_value_heads,
|
| 595 |
+
head_dim=head_dim,
|
| 596 |
+
rms_norm_eps=1e-6,
|
| 597 |
+
rope_theta=1000000.0,
|
| 598 |
+
max_position_embeddings=32768,
|
| 599 |
+
attention_dropout=0.0,
|
| 600 |
+
tie_word_embeddings=False,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def namespace_from_dict(d: Dict) -> SimpleNamespace:
|
| 605 |
+
return SimpleNamespace(**d)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def load_runtime_config(ckpt: Dict, sd: Dict[str, torch.Tensor], args) -> SimpleNamespace:
|
| 609 |
+
if "model_config" in ckpt and isinstance(ckpt["model_config"], dict):
|
| 610 |
+
print("[config] using model_config embedded in checkpoint")
|
| 611 |
+
d = dict(ckpt["model_config"])
|
| 612 |
+
if ckpt.get("tie_word_embeddings") is True:
|
| 613 |
+
d["tie_word_embeddings"] = True
|
| 614 |
+
return namespace_from_dict(d)
|
| 615 |
+
|
| 616 |
+
model_id = args.model_id or ckpt.get("model_id") or ckpt.get("args", {}).get("model_id") or "LULU2"
|
| 617 |
+
|
| 618 |
+
if args.no_config_download:
|
| 619 |
+
print("[config] no embedded config and --no-config-download set; using LULU2 defaults")
|
| 620 |
+
cfg = infer_minimal_config_from_state(sd, model_id=model_id)
|
| 621 |
+
if ckpt.get("tie_word_embeddings") is True:
|
| 622 |
+
cfg.tie_word_embeddings = True
|
| 623 |
+
return cfg
|
| 624 |
+
|
| 625 |
+
print(f"[config] loading config metadata only from {model_id}; no model weights are loaded")
|
| 626 |
+
cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
| 627 |
+
return cfg
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
# -----------------------------
|
| 631 |
+
# generation
|
| 632 |
+
# -----------------------------
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def build_chat_prompt(tokenizer, user_prompt: str, system_prompt: str = "You are a helpful assistant. Answer directly and naturally.") -> str:
|
| 636 |
+
messages = []
|
| 637 |
+
if system_prompt:
|
| 638 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 639 |
+
messages.append({"role": "user", "content": user_prompt})
|
| 640 |
+
try:
|
| 641 |
+
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 642 |
+
except Exception:
|
| 643 |
+
return f"system\n{system_prompt}\nuser\n{user_prompt}\nassistant\n"
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@torch.no_grad()
|
| 647 |
+
def sample_next(logits: torch.Tensor, temperature: float = 0.0, top_k: int = 0, top_p: float = 1.0) -> torch.Tensor:
|
| 648 |
+
if temperature <= 0.0:
|
| 649 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 650 |
+
|
| 651 |
+
logits = logits / max(temperature, 1e-6)
|
| 652 |
+
if top_k and top_k > 0:
|
| 653 |
+
k = min(int(top_k), logits.size(-1))
|
| 654 |
+
thresh = torch.topk(logits, k, dim=-1).values[..., -1, None]
|
| 655 |
+
logits = torch.where(logits >= thresh, logits, torch.full_like(logits, -float("inf")))
|
| 656 |
+
if top_p < 1.0:
|
| 657 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
| 658 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 659 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 660 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 661 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 662 |
+
sorted_indices_to_remove[..., 0] = False
|
| 663 |
+
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, -float("inf"))
|
| 664 |
+
logits = torch.full_like(logits, -float("inf")).scatter(1, sorted_indices, sorted_logits)
|
| 665 |
+
probs = torch.softmax(logits, dim=-1)
|
| 666 |
+
return torch.multinomial(probs, num_samples=1)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
@torch.no_grad()
|
| 670 |
+
def generate_text(model, tokenizer, prompt: str, device, max_new_tokens: int = 120, temperature: float = 0.0, top_k: int = 0, top_p: float = 1.0, max_context: int = 2048) -> Tuple[str, float]:
|
| 671 |
+
model.eval()
|
| 672 |
+
enc = tokenizer(prompt, return_tensors="pt")
|
| 673 |
+
input_ids = enc.input_ids.to(device)
|
| 674 |
+
eos_id = tokenizer.eos_token_id
|
| 675 |
+
t0 = time.time()
|
| 676 |
+
start_len = int(input_ids.shape[1])
|
| 677 |
+
|
| 678 |
+
for _ in range(max_new_tokens):
|
| 679 |
+
ctx = input_ids[:, -max_context:]
|
| 680 |
+
out = model(ctx)
|
| 681 |
+
next_logits = out.logits[:, -1, :].float()
|
| 682 |
+
next_id = sample_next(next_logits, temperature=temperature, top_k=top_k, top_p=top_p)
|
| 683 |
+
input_ids = torch.cat([input_ids, next_id.to(input_ids.device)], dim=-1)
|
| 684 |
+
if eos_id is not None and int(next_id.item()) == int(eos_id):
|
| 685 |
+
break
|
| 686 |
+
|
| 687 |
+
dt = time.time() - t0
|
| 688 |
+
new_tokens = max(1, int(input_ids.shape[1]) - start_len)
|
| 689 |
+
return tokenizer.decode(input_ids[0], skip_special_tokens=True), new_tokens / max(dt, 1e-9)
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def load_tokenizer(args, ckpt):
|
| 693 |
+
tok_path = args.tokenizer or ckpt.get("tokenizer_dir") or ckpt.get("model_id") or ckpt.get("args", {}).get("model_id") or args.model_id
|
| 694 |
+
if not tok_path:
|
| 695 |
+
raise ValueError("Tokenizer path/name is required. Pass --tokenizer <local-dir-or-model-id>.")
|
| 696 |
+
# If checkpoint stores a relative tokenizer_dir like "tokenizer", resolve it
|
| 697 |
+
# relative to the checkpoint location so no HF lookup is needed.
|
| 698 |
+
ckpt_dir = os.path.dirname(os.path.abspath(args.checkpoint))
|
| 699 |
+
if tok_path and not os.path.isabs(tok_path):
|
| 700 |
+
maybe_local = os.path.join(ckpt_dir, tok_path)
|
| 701 |
+
if os.path.isdir(maybe_local):
|
| 702 |
+
tok_path = maybe_local
|
| 703 |
+
print(f"[tokenizer] {tok_path}")
|
| 704 |
+
tok = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True, local_files_only=bool(args.local_files_only))
|
| 705 |
+
if tok.pad_token_id is None and tok.eos_token_id is not None:
|
| 706 |
+
tok.pad_token = tok.eos_token
|
| 707 |
+
return tok
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
# -----------------------------
|
| 711 |
+
# main
|
| 712 |
+
|
| 713 |
+
# Public model aliases used by the UI/runtime.
|
| 714 |
+
Lulu2RMSNorm = LuluRMSNorm
|
| 715 |
+
Lulu2RotaryEmbedding = LuluRotaryEmbedding
|
| 716 |
+
Lulu2VWMMLP = LuluVWMMLP
|
| 717 |
+
Lulu2VWMAttention = LuluVWMAttention
|
| 718 |
+
Lulu2VWMDecoderLayer = LuluVWMDecoderLayer
|
| 719 |
+
Lulu2VWMModel = LuluVWMModel
|
| 720 |
+
Lulu2ForCausalLM = LuluVWMForCausalLM
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class Pass2RefinementAdapter(nn.Module):
|
| 724 |
+
"""Small gated residual adapter conditioned on pass-1 layer state."""
|
| 725 |
+
|
| 726 |
+
def __init__(self, hidden_size: int, rank: int, gate_init: float = -5.0):
|
| 727 |
+
super().__init__()
|
| 728 |
+
self.hidden_size = int(hidden_size)
|
| 729 |
+
self.rank = int(rank)
|
| 730 |
+
self.x_norm = LuluRMSNorm(hidden_size)
|
| 731 |
+
self.cond_norm = LuluRMSNorm(hidden_size)
|
| 732 |
+
self.down = nn.Linear(2 * hidden_size, rank, bias=False)
|
| 733 |
+
self.up = nn.Linear(rank, hidden_size, bias=False)
|
| 734 |
+
self.gate = nn.Parameter(torch.tensor(float(gate_init)))
|
| 735 |
+
|
| 736 |
+
nn.init.normal_(self.down.weight, mean=0.0, std=0.02 / math.sqrt(max(1, hidden_size)))
|
| 737 |
+
# Zero init means the two-pass model starts exactly as pass 1.
|
| 738 |
+
nn.init.zeros_(self.up.weight)
|
| 739 |
+
|
| 740 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 741 |
+
z = torch.cat([self.x_norm(x), self.cond_norm(cond)], dim=-1)
|
| 742 |
+
delta = self.up(F.silu(self.down(z)))
|
| 743 |
+
return torch.sigmoid(self.gate).to(dtype=x.dtype) * delta
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
@dataclass
|
| 747 |
+
class Pass2Config:
|
| 748 |
+
adapter_rank: int = 64
|
| 749 |
+
adapter_gate_init: float = -5.0
|
| 750 |
+
layer_gate_init: float = -5.0
|
| 751 |
+
pass_embed_scale: float = 0.0
|
| 752 |
+
mode: str = "refine_pass1_residual"
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
class Lulu2TwoPassForCausalLM(nn.Module):
|
| 756 |
+
"""
|
| 757 |
+
Wraps a loaded LULU2 base model.
|
| 758 |
+
|
| 759 |
+
Pass 1: normal LULU2 decoder forward, producing the pass-1 residual stream.
|
| 760 |
+
Pass 2: starts from pass-1 residual stream and adds small gated refinements.
|
| 761 |
+
|
| 762 |
+
h2_i = h2_i + sigmoid(layer_gate_i) * (BaseLayer_i(h2_i) - h2_i)
|
| 763 |
+
+ Adapter_i(h2_i, pass1_state_i)
|
| 764 |
+
|
| 765 |
+
With zero-initialized adapter up-projections and negative gates, the model
|
| 766 |
+
starts extremely close to the loaded LULU2 checkpoint and learns refinements.
|
| 767 |
+
"""
|
| 768 |
+
|
| 769 |
+
def __init__(self, base: Lulu2ForCausalLM, cfg: Pass2Config):
|
| 770 |
+
super().__init__()
|
| 771 |
+
self.base = base
|
| 772 |
+
self.pass2_config = cfg
|
| 773 |
+
hidden = int(base.config.hidden_size)
|
| 774 |
+
n_layers = int(base.config.num_hidden_layers)
|
| 775 |
+
self.pass_embed = nn.Parameter(torch.randn(2, hidden) * float(cfg.pass_embed_scale))
|
| 776 |
+
self.layer_gates = nn.Parameter(torch.full((n_layers,), float(cfg.layer_gate_init)))
|
| 777 |
+
self.adapters = nn.ModuleList([
|
| 778 |
+
Pass2RefinementAdapter(hidden, int(cfg.adapter_rank), gate_init=float(cfg.adapter_gate_init))
|
| 779 |
+
for _ in range(n_layers)
|
| 780 |
+
])
|
| 781 |
+
|
| 782 |
+
def _position_ids(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None):
|
| 783 |
+
if position_ids is not None:
|
| 784 |
+
return position_ids
|
| 785 |
+
bsz, seq_len = input_ids.shape
|
| 786 |
+
return torch.arange(seq_len, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
|
| 787 |
+
|
| 788 |
+
def forward_pass1_features(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None):
|
| 789 |
+
position_ids = self._position_ids(input_ids, position_ids)
|
| 790 |
+
h = self.base.model.embed_tokens(input_ids)
|
| 791 |
+
h = h + self.pass_embed[0].to(dtype=h.dtype).view(1, 1, -1)
|
| 792 |
+
layer_states = []
|
| 793 |
+
for layer in self.base.model.layers:
|
| 794 |
+
h = layer(h, position_ids=position_ids)
|
| 795 |
+
layer_states.append(h)
|
| 796 |
+
return h, layer_states, position_ids
|
| 797 |
+
|
| 798 |
+
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, return_pass1_logits: bool = False):
|
| 799 |
+
h1_resid, pass1_states, position_ids = self.forward_pass1_features(input_ids, position_ids=position_ids)
|
| 800 |
+
h1 = self.base.model.norm(h1_resid)
|
| 801 |
+
|
| 802 |
+
# Pass 2 refines pass 1; it does not discard pass 1.
|
| 803 |
+
h2 = h1_resid + self.pass_embed[1].to(dtype=h1_resid.dtype).view(1, 1, -1)
|
| 804 |
+
for i, layer in enumerate(self.base.model.layers):
|
| 805 |
+
before = h2
|
| 806 |
+
layer_out = layer(h2, position_ids=position_ids)
|
| 807 |
+
layer_delta = layer_out - before
|
| 808 |
+
layer_gate = torch.sigmoid(self.layer_gates[i]).to(dtype=h2.dtype)
|
| 809 |
+
adapter_delta = self.adapters[i](h2, pass1_states[i])
|
| 810 |
+
h2 = before + layer_gate * layer_delta + adapter_delta
|
| 811 |
+
|
| 812 |
+
h2 = self.base.model.norm(h2)
|
| 813 |
+
logits2 = self.base.lm_head(h2)
|
| 814 |
+
|
| 815 |
+
if return_pass1_logits:
|
| 816 |
+
with torch.no_grad():
|
| 817 |
+
logits1 = self.base.lm_head(h1)
|
| 818 |
+
else:
|
| 819 |
+
logits1 = None
|
| 820 |
+
return SimpleNamespace(logits=logits2, pass1_logits=logits1)
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
@torch.no_grad()
|
| 824 |
+
|
| 825 |
+
def load_lulu2_base(args, device, dtype):
|
| 826 |
+
print("[guard] LULUV2 VWM runtime: no AutoModelForCausalLM.from_pretrained call and no external-model weights loaded.")
|
| 827 |
+
print(f"[load] {args.checkpoint} ({human_bytes(os.path.getsize(args.checkpoint))})")
|
| 828 |
+
ckpt = safe_torch_load(args.checkpoint, map_location="cpu")
|
| 829 |
+
if "model" not in ckpt:
|
| 830 |
+
raise ValueError("Checkpoint missing model state dict")
|
| 831 |
+
sd = expand_shared_banks_into_state(ckpt, ckpt["model"])
|
| 832 |
+
cfg = load_runtime_config(ckpt, sd, args)
|
| 833 |
+
print(f"[config] hidden={cfg.hidden_size} layers={cfg.num_hidden_layers}")
|
| 834 |
+
base = Lulu2ForCausalLM(cfg, sd)
|
| 835 |
+
missing, unexpected = base.load_state_dict(sd, strict=False)
|
| 836 |
+
print(f"[state:base] missing={len(missing)} unexpected={len(unexpected)}")
|
| 837 |
+
if missing:
|
| 838 |
+
print("[state:base] first missing:", missing[:10])
|
| 839 |
+
if unexpected:
|
| 840 |
+
print("[state:base] first unexpected:", unexpected[:10])
|
| 841 |
+
base.to(device=device, dtype=dtype)
|
| 842 |
+
return ckpt, base
|
luluv2_live_inference.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
LULUV2 live local inference engine.
|
| 5 |
+
|
| 6 |
+
This is the runtime bridge for the LULUV2 fine-tuned checkpoint:
|
| 7 |
+
LULU2_instruct_ddp.pt / LULU2_base_ddp.pt / LULU2.pt
|
| 8 |
+
|
| 9 |
+
It imports the actual LULUV2 architecture file, loads the checkpoint,
|
| 10 |
+
restores pass2_state when present, uses the local tokenizer folder, and streams
|
| 11 |
+
tokens with live metrics.
|
| 12 |
+
|
| 13 |
+
No AutoModelForCausalLM.from_pretrained call is used here.
|
| 14 |
+
No external model weights are loaded.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import importlib.util
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import platform
|
| 24 |
+
import time
|
| 25 |
+
from contextlib import nullcontext
|
| 26 |
+
from dataclasses import dataclass, asdict
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from types import SimpleNamespace
|
| 29 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import psutil
|
| 36 |
+
except Exception:
|
| 37 |
+
psutil = None
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import pynvml
|
| 41 |
+
except Exception:
|
| 42 |
+
pynvml = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
STOP_STRINGS = [
|
| 46 |
+
"<|im_start|>",
|
| 47 |
+
"<|im_end|>",
|
| 48 |
+
"<|user|>",
|
| 49 |
+
"<|system|>",
|
| 50 |
+
"<|assistant|>",
|
| 51 |
+
"User:",
|
| 52 |
+
"Assistant:",
|
| 53 |
+
"\nuser:",
|
| 54 |
+
"\nassistant:",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class GenerationConfig:
|
| 60 |
+
max_new_tokens: int = 512
|
| 61 |
+
temperature: float = 0.65
|
| 62 |
+
top_k: int = 40
|
| 63 |
+
top_p: float = 0.90
|
| 64 |
+
min_p: float = 0.03
|
| 65 |
+
repetition_penalty: float = 1.10
|
| 66 |
+
frequency_penalty: float = 0.02
|
| 67 |
+
greedy: bool = False
|
| 68 |
+
no_repeat_ngram: int = 4
|
| 69 |
+
stream_every: int = 1
|
| 70 |
+
max_context_tokens: int = 4096
|
| 71 |
+
return_pass_metrics: bool = True
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class GenerationStats:
|
| 76 |
+
prompt_tokens: int = 0
|
| 77 |
+
generated_tokens: int = 0
|
| 78 |
+
elapsed_sec: float = 0.0
|
| 79 |
+
tokens_per_sec: float = 0.0
|
| 80 |
+
last_token: str = ""
|
| 81 |
+
last_token_id: int = -1
|
| 82 |
+
last_token_prob: float = 0.0
|
| 83 |
+
last_entropy: float = 0.0
|
| 84 |
+
finish_reason: str = "none"
|
| 85 |
+
pass1_pass2_kl: Optional[float] = None
|
| 86 |
+
pass1_pass2_logit_cosine: Optional[float] = None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def setup_torch():
|
| 90 |
+
if torch.cuda.is_available():
|
| 91 |
+
try:
|
| 92 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 93 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 94 |
+
except Exception:
|
| 95 |
+
pass
|
| 96 |
+
try:
|
| 97 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 98 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 99 |
+
torch.backends.cuda.enable_math_sdp(False)
|
| 100 |
+
except Exception:
|
| 101 |
+
pass
|
| 102 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 103 |
+
try:
|
| 104 |
+
torch.set_float32_matmul_precision("high")
|
| 105 |
+
except Exception:
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def human_bytes(num: float) -> str:
|
| 110 |
+
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
| 111 |
+
if abs(num) < 1024.0:
|
| 112 |
+
return f"{num:.2f} {unit}"
|
| 113 |
+
num /= 1024.0
|
| 114 |
+
return f"{num:.2f} PB"
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _value_to_text(value: Any) -> str:
|
| 118 |
+
"""Coerce Gradio/Textbox/Multimodal values into plain text.
|
| 119 |
+
|
| 120 |
+
Some Gradio versions send messages as {"text": ..., "files": ...} or
|
| 121 |
+
content blocks like [{"type": "text", "text": ...}]. The local UI is
|
| 122 |
+
text-only, so we aggressively unwrap these before tokenization.
|
| 123 |
+
"""
|
| 124 |
+
if value is None:
|
| 125 |
+
return ""
|
| 126 |
+
if isinstance(value, str):
|
| 127 |
+
return value
|
| 128 |
+
if isinstance(value, dict):
|
| 129 |
+
if "text" in value:
|
| 130 |
+
return _value_to_text(value.get("text"))
|
| 131 |
+
if "content" in value:
|
| 132 |
+
return _value_to_text(value.get("content"))
|
| 133 |
+
if "value" in value:
|
| 134 |
+
return _value_to_text(value.get("value"))
|
| 135 |
+
return "\n".join(_value_to_text(v) for v in value.values() if _value_to_text(v))
|
| 136 |
+
if isinstance(value, (list, tuple)):
|
| 137 |
+
return "\n".join(_value_to_text(v) for v in value if _value_to_text(v))
|
| 138 |
+
return str(value)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def clean_text(text: Any) -> str:
|
| 142 |
+
text = _value_to_text(text).replace("\\n", "\n")
|
| 143 |
+
# Cut only after obvious turn-control strings that appear in generated text.
|
| 144 |
+
cut_points = [text.find(s) for s in STOP_STRINGS if s in text and text.find(s) > 0]
|
| 145 |
+
if cut_points:
|
| 146 |
+
text = text[: min(cut_points)]
|
| 147 |
+
for s in STOP_STRINGS:
|
| 148 |
+
text = text.replace(s, "")
|
| 149 |
+
# Remove common role remnants and JSON-ish UI artifacts.
|
| 150 |
+
for prefix in ("assistant\n", "Assistant:", "Lulu:", "assistant:"):
|
| 151 |
+
if text.lstrip().startswith(prefix):
|
| 152 |
+
text = text.lstrip()[len(prefix):]
|
| 153 |
+
text = text.replace("{'type': 'text'}", "").replace('{"type": "text"}', "")
|
| 154 |
+
text = "\n".join(line.rstrip() for line in text.strip().splitlines())
|
| 155 |
+
text = "\n".join(line for line in text.splitlines() if not line.strip().startswith("type: 'text'"))
|
| 156 |
+
return text.strip()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def normalize_history(history: Any) -> List[Dict[str, str]]:
|
| 160 |
+
out: List[Dict[str, str]] = []
|
| 161 |
+
if not history:
|
| 162 |
+
return out
|
| 163 |
+
for item in history:
|
| 164 |
+
if isinstance(item, dict):
|
| 165 |
+
role = item.get("role")
|
| 166 |
+
content = clean_text(item.get("content", ""))
|
| 167 |
+
if role in {"user", "assistant"} and content:
|
| 168 |
+
out.append({"role": role, "content": content})
|
| 169 |
+
elif isinstance(item, (tuple, list)) and len(item) >= 2:
|
| 170 |
+
u = clean_text(item[0])
|
| 171 |
+
a = clean_text(item[1])
|
| 172 |
+
if u:
|
| 173 |
+
out.append({"role": "user", "content": u})
|
| 174 |
+
if a:
|
| 175 |
+
out.append({"role": "assistant", "content": a})
|
| 176 |
+
return out
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def resolve_model_py(model_py: Optional[str] = None) -> str:
|
| 180 |
+
candidates = []
|
| 181 |
+
if model_py:
|
| 182 |
+
candidates.append(model_py)
|
| 183 |
+
candidates.extend(["luluv2_inference_runtime.py"])
|
| 184 |
+
for c in candidates:
|
| 185 |
+
p = Path(c)
|
| 186 |
+
if p.exists():
|
| 187 |
+
return str(p.resolve())
|
| 188 |
+
raise FileNotFoundError(
|
| 189 |
+
"Could not find the LULUV2 model file. Pass --model-py or put "
|
| 190 |
+
"luluv2_inference_runtime.py next to this UI."
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def import_model_py(model_py: Optional[str] = None):
|
| 195 |
+
path = resolve_model_py(model_py)
|
| 196 |
+
spec = importlib.util.spec_from_file_location("luluv2_runtime_module", path)
|
| 197 |
+
if spec is None or spec.loader is None:
|
| 198 |
+
raise RuntimeError(f"Could not import model file: {path}")
|
| 199 |
+
mod = importlib.util.module_from_spec(spec)
|
| 200 |
+
spec.loader.exec_module(mod)
|
| 201 |
+
return mod, path
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class LULUV2LiveEngine:
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
ckpt_path: str,
|
| 208 |
+
model_py: Optional[str] = None,
|
| 209 |
+
tokenizer_dir: Optional[str] = None,
|
| 210 |
+
device: Optional[str] = None,
|
| 211 |
+
dtype: str = "bf16",
|
| 212 |
+
local_files_only: bool = True,
|
| 213 |
+
no_config_download: bool = True,
|
| 214 |
+
force_base_only: bool = False,
|
| 215 |
+
):
|
| 216 |
+
setup_torch()
|
| 217 |
+
self.ckpt_path = str(ckpt_path)
|
| 218 |
+
self.ckpt_dir = Path(self.ckpt_path).resolve().parent
|
| 219 |
+
self.device = self._select_device(device)
|
| 220 |
+
self.dtype = self._dtype_from_name(dtype)
|
| 221 |
+
self.local_files_only = bool(local_files_only)
|
| 222 |
+
self.no_config_download = bool(no_config_download)
|
| 223 |
+
self.force_base_only = bool(force_base_only)
|
| 224 |
+
self.last_stats = GenerationStats()
|
| 225 |
+
self.recent_tokens: List[Dict[str, Any]] = []
|
| 226 |
+
|
| 227 |
+
self.goku, self.model_py_path = import_model_py(model_py)
|
| 228 |
+
|
| 229 |
+
# args object expected by the embedded LULUV2 runtime helpers.
|
| 230 |
+
self.args = SimpleNamespace(
|
| 231 |
+
checkpoint=self.ckpt_path,
|
| 232 |
+
tokenizer=tokenizer_dir or "",
|
| 233 |
+
model_id="",
|
| 234 |
+
no_config_download=self.no_config_download,
|
| 235 |
+
local_files_only=self.local_files_only,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
print("[guard] LULUV2 local UI: no AutoModelForCausalLM.from_pretrained call and no external model weights loaded.")
|
| 239 |
+
print(f"[load] checkpoint={self.ckpt_path}")
|
| 240 |
+
self.base_ckpt, base = self.goku.load_lulu2_base(self.args, self.device, self.dtype)
|
| 241 |
+
|
| 242 |
+
self.tokenizer = self._load_tokenizer(tokenizer_dir)
|
| 243 |
+
self.model, self.has_pass2 = self._maybe_wrap_pass2(base)
|
| 244 |
+
self.model.eval()
|
| 245 |
+
|
| 246 |
+
self.model_info = self._build_model_info()
|
| 247 |
+
|
| 248 |
+
def _select_device(self, device: Optional[str]):
|
| 249 |
+
if device:
|
| 250 |
+
return torch.device(device)
|
| 251 |
+
if torch.cuda.is_available():
|
| 252 |
+
return torch.device("cuda")
|
| 253 |
+
return torch.device("cpu")
|
| 254 |
+
|
| 255 |
+
def _dtype_from_name(self, name: str):
|
| 256 |
+
name = (name or "bf16").lower()
|
| 257 |
+
if name in {"bf16", "bfloat16"}:
|
| 258 |
+
return torch.bfloat16
|
| 259 |
+
if name in {"fp16", "float16", "half"}:
|
| 260 |
+
return torch.float16
|
| 261 |
+
return torch.float32
|
| 262 |
+
|
| 263 |
+
def _load_tokenizer(self, tokenizer_dir: Optional[str]):
|
| 264 |
+
# Prefer explicit path, then sibling tokenizer folder, then checkpoint metadata.
|
| 265 |
+
if tokenizer_dir:
|
| 266 |
+
self.args.tokenizer = tokenizer_dir
|
| 267 |
+
else:
|
| 268 |
+
sibling = self.ckpt_dir / "tokenizer"
|
| 269 |
+
if sibling.is_dir():
|
| 270 |
+
self.args.tokenizer = str(sibling)
|
| 271 |
+
tok = self.goku.load_tokenizer(self.args, self.base_ckpt)
|
| 272 |
+
if getattr(tok, "pad_token_id", None) is None and getattr(tok, "eos_token_id", None) is not None:
|
| 273 |
+
try:
|
| 274 |
+
tok.pad_token = tok.eos_token
|
| 275 |
+
except Exception:
|
| 276 |
+
pass
|
| 277 |
+
return tok
|
| 278 |
+
|
| 279 |
+
def _maybe_wrap_pass2(self, base):
|
| 280 |
+
ckpt = self.base_ckpt
|
| 281 |
+
if self.force_base_only or "pass2_state" not in ckpt:
|
| 282 |
+
print("[pass2] no pass2_state loaded; running base LULUV2 forward")
|
| 283 |
+
return base.to(self.device).eval(), False
|
| 284 |
+
|
| 285 |
+
cfg_dict = dict(ckpt.get("pass2_config") or {})
|
| 286 |
+
Pass2Config = self.goku.Pass2Config
|
| 287 |
+
pass2_cfg = Pass2Config(**{k: v for k, v in cfg_dict.items() if k in Pass2Config.__dataclass_fields__})
|
| 288 |
+
model = self.goku.Lulu2TwoPassForCausalLM(base, pass2_cfg)
|
| 289 |
+
missing, unexpected = model.load_state_dict(ckpt["pass2_state"], strict=False)
|
| 290 |
+
print(f"[pass2] loaded pass2_state missing={len(missing)} unexpected={len(unexpected)}")
|
| 291 |
+
model.to(device=self.device, dtype=self.dtype)
|
| 292 |
+
model.eval()
|
| 293 |
+
return model, True
|
| 294 |
+
|
| 295 |
+
def _build_model_info(self) -> Dict[str, Any]:
|
| 296 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 297 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 298 |
+
c_codes = [(n, p.numel()) for n, p in self.model.named_parameters() if n.endswith(".c")]
|
| 299 |
+
gate_mean = None
|
| 300 |
+
adapter_gate_mean = None
|
| 301 |
+
if self.has_pass2:
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
gate_mean = float(torch.sigmoid(self.model.layer_gates.float()).mean().item())
|
| 304 |
+
vals = []
|
| 305 |
+
for ad in self.model.adapters:
|
| 306 |
+
vals.append(float(torch.sigmoid(ad.gate.float()).item()))
|
| 307 |
+
adapter_gate_mean = float(sum(vals) / max(1, len(vals)))
|
| 308 |
+
ckpt_size = Path(self.ckpt_path).stat().st_size if Path(self.ckpt_path).exists() else 0
|
| 309 |
+
cfg = getattr(self.model.base if self.has_pass2 else self.model, "config", None)
|
| 310 |
+
return {
|
| 311 |
+
"checkpoint": self.ckpt_path,
|
| 312 |
+
"checkpoint_size": human_bytes(ckpt_size),
|
| 313 |
+
"model_py": self.model_py_path,
|
| 314 |
+
"device": str(self.device),
|
| 315 |
+
"dtype": str(self.dtype).replace("torch.", ""),
|
| 316 |
+
"has_pass2": self.has_pass2,
|
| 317 |
+
"total_params": total_params,
|
| 318 |
+
"trainable_params": trainable_params,
|
| 319 |
+
"vwm_c_modules": len(c_codes),
|
| 320 |
+
"vwm_c_params": sum(n for _, n in c_codes),
|
| 321 |
+
"pass2_layer_gate_mean": gate_mean,
|
| 322 |
+
"pass2_adapter_gate_mean": adapter_gate_mean,
|
| 323 |
+
"hidden_size": getattr(cfg, "hidden_size", None),
|
| 324 |
+
"layers": getattr(cfg, "num_hidden_layers", None),
|
| 325 |
+
"heads": getattr(cfg, "num_attention_heads", None),
|
| 326 |
+
"kv_heads": getattr(cfg, "num_key_value_heads", None),
|
| 327 |
+
"max_position_embeddings": getattr(cfg, "max_position_embeddings", None),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def amp_context(self):
|
| 331 |
+
if self.device.type == "cuda" and self.dtype in (torch.bfloat16, torch.float16):
|
| 332 |
+
return torch.autocast("cuda", dtype=self.dtype)
|
| 333 |
+
return nullcontext()
|
| 334 |
+
|
| 335 |
+
def build_chat_prompt(
|
| 336 |
+
self,
|
| 337 |
+
message: str,
|
| 338 |
+
history: Any,
|
| 339 |
+
system_prompt: str,
|
| 340 |
+
memory_notes: str = "",
|
| 341 |
+
history_turns: int = 4,
|
| 342 |
+
extra_context: str = "",
|
| 343 |
+
) -> str:
|
| 344 |
+
history = normalize_history(history)
|
| 345 |
+
recent = history[-max(0, int(history_turns)) * 2:] if history_turns else []
|
| 346 |
+
system_chunks = []
|
| 347 |
+
if system_prompt.strip():
|
| 348 |
+
system_chunks.append(system_prompt.strip())
|
| 349 |
+
if memory_notes.strip():
|
| 350 |
+
system_chunks.append("Useful memory notes:\n" + memory_notes.strip())
|
| 351 |
+
if extra_context.strip():
|
| 352 |
+
system_chunks.append("Relevant local context:\n" + extra_context.strip())
|
| 353 |
+
system = "\n\n".join(system_chunks)
|
| 354 |
+
|
| 355 |
+
messages = []
|
| 356 |
+
if system:
|
| 357 |
+
messages.append({"role": "system", "content": system})
|
| 358 |
+
messages.extend(recent)
|
| 359 |
+
messages.append({"role": "user", "content": clean_text(message)})
|
| 360 |
+
try:
|
| 361 |
+
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 362 |
+
except Exception:
|
| 363 |
+
parts = []
|
| 364 |
+
if system:
|
| 365 |
+
parts.append(f"<|im_start|>system\n{system}<|im_end|>")
|
| 366 |
+
for item in recent:
|
| 367 |
+
parts.append(f"<|im_start|>{item['role']}\n{item['content']}<|im_end|>")
|
| 368 |
+
parts.append(f"<|im_start|>user\n{clean_text(message)}<|im_end|>")
|
| 369 |
+
parts.append("<|im_start|>assistant\n")
|
| 370 |
+
return "\n".join(parts)
|
| 371 |
+
|
| 372 |
+
def encode(self, text: str, max_context_tokens: int = 4096) -> torch.Tensor:
|
| 373 |
+
enc = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=int(max_context_tokens))
|
| 374 |
+
ids = enc.input_ids.to(self.device)
|
| 375 |
+
return ids
|
| 376 |
+
|
| 377 |
+
@torch.no_grad()
|
| 378 |
+
def pass_metrics_for_ids(self, ids: torch.Tensor) -> Tuple[Optional[float], Optional[float]]:
|
| 379 |
+
if not self.has_pass2:
|
| 380 |
+
return None, None
|
| 381 |
+
try:
|
| 382 |
+
with self.amp_context():
|
| 383 |
+
out = self.model(ids, return_pass1_logits=True)
|
| 384 |
+
if out.pass1_logits is None:
|
| 385 |
+
return None, None
|
| 386 |
+
l1 = out.pass1_logits[:, -1, :].float()
|
| 387 |
+
l2 = out.logits[:, -1, :].float()
|
| 388 |
+
kl = F.kl_div(F.log_softmax(l2, dim=-1), F.softmax(l1, dim=-1), reduction="batchmean")
|
| 389 |
+
cos = F.cosine_similarity(l1, l2, dim=-1).mean()
|
| 390 |
+
return float(kl.item()), float(cos.item())
|
| 391 |
+
except Exception as exc:
|
| 392 |
+
print(f"[metrics] pass metrics failed: {type(exc).__name__}: {exc}")
|
| 393 |
+
return None, None
|
| 394 |
+
|
| 395 |
+
def _apply_penalties(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> torch.Tensor:
|
| 396 |
+
if generated.numel() == 0:
|
| 397 |
+
return logits
|
| 398 |
+
out = logits.clone()
|
| 399 |
+
uniq, counts = torch.unique(generated.view(-1), return_counts=True)
|
| 400 |
+
if cfg.repetition_penalty != 1.0:
|
| 401 |
+
selected = out[:, uniq]
|
| 402 |
+
selected = torch.where(selected > 0, selected / float(cfg.repetition_penalty), selected * float(cfg.repetition_penalty))
|
| 403 |
+
out[:, uniq] = selected
|
| 404 |
+
if cfg.frequency_penalty:
|
| 405 |
+
out[:, uniq] -= float(cfg.frequency_penalty) * counts.to(out.dtype).unsqueeze(0)
|
| 406 |
+
n = int(cfg.no_repeat_ngram)
|
| 407 |
+
if n > 1 and generated.size(1) >= n - 1:
|
| 408 |
+
seq = generated[0].tolist()
|
| 409 |
+
prefix = tuple(seq[-(n - 1):])
|
| 410 |
+
banned = []
|
| 411 |
+
for i in range(len(seq) - n + 1):
|
| 412 |
+
if tuple(seq[i:i + n - 1]) == prefix:
|
| 413 |
+
banned.append(seq[i + n - 1])
|
| 414 |
+
if banned:
|
| 415 |
+
out[:, list(set(banned))] = -float("inf")
|
| 416 |
+
return out
|
| 417 |
+
|
| 418 |
+
@torch.no_grad()
|
| 419 |
+
def _sample_next(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
| 420 |
+
work = self._apply_penalties(logits.float(), generated, cfg)
|
| 421 |
+
if cfg.greedy or cfg.temperature <= 0:
|
| 422 |
+
probs = torch.softmax(work, dim=-1)
|
| 423 |
+
next_id = torch.argmax(work, dim=-1, keepdim=True)
|
| 424 |
+
else:
|
| 425 |
+
work = work / max(float(cfg.temperature), 1e-6)
|
| 426 |
+
if cfg.top_k > 0:
|
| 427 |
+
k = min(int(cfg.top_k), work.size(-1))
|
| 428 |
+
thresh = torch.topk(work, k, dim=-1).values[..., -1, None]
|
| 429 |
+
work = torch.where(work >= thresh, work, torch.full_like(work, -float("inf")))
|
| 430 |
+
if 0.0 < cfg.top_p < 1.0:
|
| 431 |
+
sorted_logits, sorted_idx = torch.sort(work, descending=True, dim=-1)
|
| 432 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
| 433 |
+
cumprobs = torch.cumsum(sorted_probs, dim=-1)
|
| 434 |
+
remove = cumprobs > float(cfg.top_p)
|
| 435 |
+
shifted = remove.clone()
|
| 436 |
+
shifted[..., 1:] = remove[..., :-1]
|
| 437 |
+
shifted[..., 0] = False
|
| 438 |
+
sorted_logits = sorted_logits.masked_fill(shifted, -float("inf"))
|
| 439 |
+
work = torch.full_like(work, -float("inf")).scatter(1, sorted_idx, sorted_logits)
|
| 440 |
+
if 0.0 < cfg.min_p < 1.0:
|
| 441 |
+
probs_for_minp = torch.softmax(work, dim=-1)
|
| 442 |
+
max_prob = probs_for_minp.max(dim=-1, keepdim=True).values
|
| 443 |
+
keep = probs_for_minp >= float(cfg.min_p) * max_prob
|
| 444 |
+
work = work.masked_fill(~keep, -float("inf"))
|
| 445 |
+
probs = torch.softmax(work, dim=-1)
|
| 446 |
+
if torch.isnan(probs).any() or not torch.isfinite(probs.sum()) or float(probs.sum()) <= 0:
|
| 447 |
+
next_id = torch.argmax(logits, dim=-1, keepdim=True)
|
| 448 |
+
probs = torch.softmax(logits.float(), dim=-1)
|
| 449 |
+
else:
|
| 450 |
+
next_id = torch.multinomial(probs, 1)
|
| 451 |
+
|
| 452 |
+
prob = float(probs.gather(1, next_id).item()) if probs.numel() else 0.0
|
| 453 |
+
entropy = float((-(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)).mean().item()) if probs.numel() else 0.0
|
| 454 |
+
return next_id, {"prob": prob, "entropy": entropy}
|
| 455 |
+
|
| 456 |
+
@torch.no_grad()
|
| 457 |
+
def generate(self, prompt: str, cfg: GenerationConfig) -> Generator[str, None, None]:
|
| 458 |
+
self.model.eval()
|
| 459 |
+
self.recent_tokens = []
|
| 460 |
+
ids = self.encode(prompt, max_context_tokens=cfg.max_context_tokens)
|
| 461 |
+
prompt_len = int(ids.shape[1])
|
| 462 |
+
t0 = time.time()
|
| 463 |
+
pass_kl, pass_cos = (None, None)
|
| 464 |
+
if cfg.return_pass_metrics:
|
| 465 |
+
pass_kl, pass_cos = self.pass_metrics_for_ids(ids)
|
| 466 |
+
|
| 467 |
+
eos_id = getattr(self.tokenizer, "eos_token_id", None)
|
| 468 |
+
last_text = ""
|
| 469 |
+
finish_reason = "length"
|
| 470 |
+
|
| 471 |
+
for step in range(int(cfg.max_new_tokens)):
|
| 472 |
+
ctx = ids[:, -int(cfg.max_context_tokens):]
|
| 473 |
+
with self.amp_context():
|
| 474 |
+
out = self.model(ctx)
|
| 475 |
+
logits = out.logits[:, -1, :].float()
|
| 476 |
+
generated = ids[:, prompt_len:]
|
| 477 |
+
next_id, tok_stats = self._sample_next(logits, generated, cfg)
|
| 478 |
+
ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
|
| 479 |
+
|
| 480 |
+
token_id = int(next_id.item())
|
| 481 |
+
token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
|
| 482 |
+
self.recent_tokens.append({
|
| 483 |
+
"i": step + 1,
|
| 484 |
+
"id": token_id,
|
| 485 |
+
"text": token_text,
|
| 486 |
+
"prob": tok_stats["prob"],
|
| 487 |
+
"entropy": tok_stats["entropy"],
|
| 488 |
+
})
|
| 489 |
+
self.recent_tokens = self.recent_tokens[-32:]
|
| 490 |
+
|
| 491 |
+
if eos_id is not None and token_id == int(eos_id):
|
| 492 |
+
finish_reason = "eos"
|
| 493 |
+
break
|
| 494 |
+
|
| 495 |
+
if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
|
| 496 |
+
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
|
| 497 |
+
if any(s in raw for s in STOP_STRINGS):
|
| 498 |
+
finish_reason = "stop_string"
|
| 499 |
+
break
|
| 500 |
+
text = clean_text(raw)
|
| 501 |
+
if text and text != last_text:
|
| 502 |
+
elapsed = time.time() - t0
|
| 503 |
+
gen = int(ids.shape[1]) - prompt_len
|
| 504 |
+
self.last_stats = GenerationStats(
|
| 505 |
+
prompt_tokens=prompt_len,
|
| 506 |
+
generated_tokens=gen,
|
| 507 |
+
elapsed_sec=elapsed,
|
| 508 |
+
tokens_per_sec=gen / max(elapsed, 1e-9),
|
| 509 |
+
last_token=token_text,
|
| 510 |
+
last_token_id=token_id,
|
| 511 |
+
last_token_prob=tok_stats["prob"],
|
| 512 |
+
last_entropy=tok_stats["entropy"],
|
| 513 |
+
finish_reason="streaming",
|
| 514 |
+
pass1_pass2_kl=pass_kl,
|
| 515 |
+
pass1_pass2_logit_cosine=pass_cos,
|
| 516 |
+
)
|
| 517 |
+
last_text = text
|
| 518 |
+
yield text
|
| 519 |
+
|
| 520 |
+
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
|
| 521 |
+
final = clean_text(raw)
|
| 522 |
+
elapsed = time.time() - t0
|
| 523 |
+
gen = int(ids.shape[1]) - prompt_len
|
| 524 |
+
self.last_stats = GenerationStats(
|
| 525 |
+
prompt_tokens=prompt_len,
|
| 526 |
+
generated_tokens=gen,
|
| 527 |
+
elapsed_sec=elapsed,
|
| 528 |
+
tokens_per_sec=gen / max(elapsed, 1e-9),
|
| 529 |
+
last_token=self.recent_tokens[-1]["text"] if self.recent_tokens else "",
|
| 530 |
+
last_token_id=self.recent_tokens[-1]["id"] if self.recent_tokens else -1,
|
| 531 |
+
last_token_prob=self.recent_tokens[-1]["prob"] if self.recent_tokens else 0.0,
|
| 532 |
+
last_entropy=self.recent_tokens[-1]["entropy"] if self.recent_tokens else 0.0,
|
| 533 |
+
finish_reason=finish_reason,
|
| 534 |
+
pass1_pass2_kl=pass_kl,
|
| 535 |
+
pass1_pass2_logit_cosine=pass_cos,
|
| 536 |
+
)
|
| 537 |
+
if final:
|
| 538 |
+
yield final
|
| 539 |
+
|
| 540 |
+
def stats_dict(self) -> Dict[str, Any]:
|
| 541 |
+
d = asdict(self.last_stats)
|
| 542 |
+
d["model"] = self.model_info
|
| 543 |
+
d["system"] = system_snapshot(self)
|
| 544 |
+
return d
|
| 545 |
+
|
| 546 |
+
def stats_text(self) -> str:
|
| 547 |
+
s = self.last_stats
|
| 548 |
+
lines = [
|
| 549 |
+
f"Prompt tokens: {s.prompt_tokens}",
|
| 550 |
+
f"Generated tokens: {s.generated_tokens}",
|
| 551 |
+
f"Elapsed: {s.elapsed_sec:.2f}s",
|
| 552 |
+
f"Decode speed: {s.tokens_per_sec:.2f} tok/s",
|
| 553 |
+
f"Finish reason: {s.finish_reason}",
|
| 554 |
+
f"Last token: {s.last_token!r} id={s.last_token_id} p={s.last_token_prob:.4f}",
|
| 555 |
+
f"Last entropy: {s.last_entropy:.3f}",
|
| 556 |
+
]
|
| 557 |
+
if s.pass1_pass2_kl is not None:
|
| 558 |
+
lines.append(f"Pass1→Pass2 KL: {s.pass1_pass2_kl:.6f}")
|
| 559 |
+
if s.pass1_pass2_logit_cosine is not None:
|
| 560 |
+
lines.append(f"Pass1/Pass2 logit cosine: {s.pass1_pass2_logit_cosine:.6f}")
|
| 561 |
+
lines.extend([
|
| 562 |
+
"",
|
| 563 |
+
f"Checkpoint: {self.model_info['checkpoint']}",
|
| 564 |
+
f"Checkpoint size: {self.model_info['checkpoint_size']}",
|
| 565 |
+
f"Device: {self.model_info['device']} dtype={self.model_info['dtype']}",
|
| 566 |
+
f"Pass2 active: {self.model_info['has_pass2']}",
|
| 567 |
+
f"Params: {self.model_info['total_params']:,}",
|
| 568 |
+
f"VWM c modules: {self.model_info['vwm_c_modules']} ({self.model_info['vwm_c_params']:,} c params)",
|
| 569 |
+
f"Layer gate mean: {self.model_info['pass2_layer_gate_mean']}",
|
| 570 |
+
f"Adapter gate mean: {self.model_info['pass2_adapter_gate_mean']}",
|
| 571 |
+
])
|
| 572 |
+
return "\n".join(lines)
|
| 573 |
+
|
| 574 |
+
def token_trace_text(self) -> str:
|
| 575 |
+
if not self.recent_tokens:
|
| 576 |
+
return "No tokens generated yet."
|
| 577 |
+
rows = []
|
| 578 |
+
for t in self.recent_tokens[-24:]:
|
| 579 |
+
safe = repr(t["text"])[1:-1]
|
| 580 |
+
rows.append(f"{t['i']:04d} id={t['id']:<7} p={t['prob']:.4f} H={t['entropy']:.2f} {safe}")
|
| 581 |
+
return "\n".join(rows)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def system_snapshot(engine: Optional[LULUV2LiveEngine] = None) -> Dict[str, Any]:
|
| 585 |
+
"""Return compact live edge-device metrics for the UI cards.
|
| 586 |
+
|
| 587 |
+
Values are safe for JSON/HTML display. NVML is used when available for
|
| 588 |
+
whole-device VRAM/utilization; PyTorch counters are always included.
|
| 589 |
+
"""
|
| 590 |
+
snap: Dict[str, Any] = {
|
| 591 |
+
"python_ram": "n/a",
|
| 592 |
+
"system_ram": "n/a",
|
| 593 |
+
"system_ram_percent": 0.0,
|
| 594 |
+
"cpu_percent": 0.0,
|
| 595 |
+
"gpu_name": "CUDA unavailable",
|
| 596 |
+
"vram_allocated": "n/a",
|
| 597 |
+
"vram_reserved": "n/a",
|
| 598 |
+
"vram_used": "n/a",
|
| 599 |
+
"vram_total": "n/a",
|
| 600 |
+
"vram_percent": 0.0,
|
| 601 |
+
"gpu_util_percent": None,
|
| 602 |
+
"gpu_temp_c": None,
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
if psutil is not None:
|
| 606 |
+
try:
|
| 607 |
+
proc = psutil.Process(os.getpid())
|
| 608 |
+
vm = psutil.virtual_memory()
|
| 609 |
+
snap.update({
|
| 610 |
+
"python_ram": human_bytes(proc.memory_info().rss),
|
| 611 |
+
"system_ram": f"{human_bytes(vm.used)} / {human_bytes(vm.total)}",
|
| 612 |
+
"system_ram_percent": float(vm.percent),
|
| 613 |
+
"cpu_percent": float(psutil.cpu_percent(interval=0.0)),
|
| 614 |
+
})
|
| 615 |
+
except Exception:
|
| 616 |
+
pass
|
| 617 |
+
|
| 618 |
+
if torch.cuda.is_available():
|
| 619 |
+
try:
|
| 620 |
+
idx = torch.cuda.current_device()
|
| 621 |
+
props = torch.cuda.get_device_properties(idx)
|
| 622 |
+
allocated = int(torch.cuda.memory_allocated(idx))
|
| 623 |
+
reserved = int(torch.cuda.memory_reserved(idx))
|
| 624 |
+
total = int(props.total_memory)
|
| 625 |
+
snap.update({
|
| 626 |
+
"gpu_name": props.name,
|
| 627 |
+
"vram_allocated": human_bytes(allocated),
|
| 628 |
+
"vram_reserved": human_bytes(reserved),
|
| 629 |
+
"vram_used": human_bytes(allocated),
|
| 630 |
+
"vram_total": human_bytes(total),
|
| 631 |
+
"vram_percent": (100.0 * allocated / max(total, 1)),
|
| 632 |
+
})
|
| 633 |
+
|
| 634 |
+
if pynvml is not None:
|
| 635 |
+
try:
|
| 636 |
+
pynvml.nvmlInit()
|
| 637 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
| 638 |
+
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| 639 |
+
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 640 |
+
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
|
| 641 |
+
snap.update({
|
| 642 |
+
"gpu_util_percent": int(util.gpu),
|
| 643 |
+
"vram_used": human_bytes(int(mem.used)),
|
| 644 |
+
"vram_total": human_bytes(int(mem.total)),
|
| 645 |
+
"vram_percent": (100.0 * float(mem.used) / max(float(mem.total), 1.0)),
|
| 646 |
+
"gpu_temp_c": int(temp),
|
| 647 |
+
})
|
| 648 |
+
except Exception:
|
| 649 |
+
pass
|
| 650 |
+
except Exception:
|
| 651 |
+
pass
|
| 652 |
+
|
| 653 |
+
return snap
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def system_usage(engine: Optional[LULUV2LiveEngine] = None) -> str:
|
| 657 |
+
lines = [f"OS: {platform.system()} {platform.release()}"]
|
| 658 |
+
if psutil is not None:
|
| 659 |
+
proc = psutil.Process(os.getpid())
|
| 660 |
+
vm = psutil.virtual_memory()
|
| 661 |
+
lines += [
|
| 662 |
+
f"Python RAM: {human_bytes(proc.memory_info().rss)}",
|
| 663 |
+
f"System RAM: {human_bytes(vm.used)} / {human_bytes(vm.total)} ({vm.percent:.1f}%)",
|
| 664 |
+
f"CPU: {psutil.cpu_percent(interval=0.0):.1f}%",
|
| 665 |
+
]
|
| 666 |
+
else:
|
| 667 |
+
lines.append("psutil unavailable")
|
| 668 |
+
|
| 669 |
+
if torch.cuda.is_available():
|
| 670 |
+
idx = torch.cuda.current_device()
|
| 671 |
+
props = torch.cuda.get_device_properties(idx)
|
| 672 |
+
lines += [
|
| 673 |
+
"",
|
| 674 |
+
f"GPU: {props.name}",
|
| 675 |
+
f"VRAM allocated: {human_bytes(torch.cuda.memory_allocated(idx))}",
|
| 676 |
+
f"VRAM reserved: {human_bytes(torch.cuda.memory_reserved(idx))}",
|
| 677 |
+
f"VRAM total: {human_bytes(props.total_memory)}",
|
| 678 |
+
]
|
| 679 |
+
if pynvml is not None:
|
| 680 |
+
try:
|
| 681 |
+
pynvml.nvmlInit()
|
| 682 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
| 683 |
+
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| 684 |
+
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 685 |
+
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
|
| 686 |
+
lines += [
|
| 687 |
+
f"GPU util: {util.gpu}%",
|
| 688 |
+
f"GPU memory: {human_bytes(mem.used)} / {human_bytes(mem.total)}",
|
| 689 |
+
f"GPU temperature: {temp} C",
|
| 690 |
+
]
|
| 691 |
+
except Exception as exc:
|
| 692 |
+
lines.append(f"NVML unavailable: {type(exc).__name__}: {exc}")
|
| 693 |
+
else:
|
| 694 |
+
lines += ["", "GPU: CUDA unavailable"]
|
| 695 |
+
|
| 696 |
+
if engine is not None:
|
| 697 |
+
lines += ["", engine.stats_text()]
|
| 698 |
+
return "\n".join(lines)
|
luluv2_optimized_engine.py
ADDED
|
@@ -0,0 +1,1133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
LULUV2 optimized local inference engine.
|
| 5 |
+
|
| 6 |
+
Goals:
|
| 7 |
+
- load LULU2/LULUV2 checkpoints through the existing LULUV2 model file
|
| 8 |
+
- no AutoModelForCausalLM.from_pretrained and no external model weights
|
| 9 |
+
- vectorized prompt prefill into explicit KV caches
|
| 10 |
+
- persistent session KV cache across turns when prompt tokens extend prior prompt
|
| 11 |
+
- modes: fast(pass1/base), vwm(pass1+pass2), deep(pass1+pass2 long context)
|
| 12 |
+
- safe fallback to slow full-prefix forward if cached path fails
|
| 13 |
+
|
| 14 |
+
This is intentionally Python-first and debuggable. It is a bridge toward
|
| 15 |
+
kernel/CUDA-graph optimization, not the final kernel path.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import importlib.util
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import platform
|
| 24 |
+
import time
|
| 25 |
+
import traceback
|
| 26 |
+
from contextlib import nullcontext
|
| 27 |
+
from dataclasses import dataclass, asdict
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from types import SimpleNamespace
|
| 30 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import psutil
|
| 37 |
+
except Exception:
|
| 38 |
+
psutil = None
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
import pynvml
|
| 42 |
+
except Exception:
|
| 43 |
+
pynvml = None
|
| 44 |
+
|
| 45 |
+
STOP_STRINGS = [
|
| 46 |
+
"<|im_start|>", "<|im_end|>", "<|user|>", "<|system|>", "<|assistant|>",
|
| 47 |
+
"User:", "Assistant:", "\nuser:", "\nassistant:",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def setup_torch() -> None:
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
try:
|
| 54 |
+
# Old API still works on current wheels; warnings are harmless.
|
| 55 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 56 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 57 |
+
except Exception:
|
| 58 |
+
pass
|
| 59 |
+
try:
|
| 60 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 61 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 62 |
+
torch.backends.cuda.enable_math_sdp(False)
|
| 63 |
+
except Exception:
|
| 64 |
+
pass
|
| 65 |
+
if hasattr(torch, "set_float32_matmul_precision"):
|
| 66 |
+
try:
|
| 67 |
+
torch.set_float32_matmul_precision("high")
|
| 68 |
+
except Exception:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def human_bytes(num: float) -> str:
|
| 73 |
+
num = float(num)
|
| 74 |
+
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
| 75 |
+
if abs(num) < 1024.0:
|
| 76 |
+
return f"{num:.2f} {unit}"
|
| 77 |
+
num /= 1024.0
|
| 78 |
+
return f"{num:.2f} PB"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def value_to_text(value: Any) -> str:
|
| 82 |
+
if value is None:
|
| 83 |
+
return ""
|
| 84 |
+
if isinstance(value, str):
|
| 85 |
+
return value
|
| 86 |
+
if isinstance(value, dict):
|
| 87 |
+
for key in ("text", "content", "value"):
|
| 88 |
+
if key in value:
|
| 89 |
+
return value_to_text(value.get(key))
|
| 90 |
+
return "\n".join(value_to_text(v) for v in value.values() if value_to_text(v))
|
| 91 |
+
if isinstance(value, (list, tuple)):
|
| 92 |
+
return "\n".join(value_to_text(v) for v in value if value_to_text(v))
|
| 93 |
+
return str(value)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def clean_text(text: Any) -> str:
|
| 97 |
+
text = value_to_text(text).replace("\\n", "\n")
|
| 98 |
+
cut_points = [text.find(s) for s in STOP_STRINGS if s in text and text.find(s) > 0]
|
| 99 |
+
if cut_points:
|
| 100 |
+
text = text[: min(cut_points)]
|
| 101 |
+
for s in STOP_STRINGS:
|
| 102 |
+
text = text.replace(s, "")
|
| 103 |
+
text = text.strip()
|
| 104 |
+
for prefix in ("Assistant:", "assistant:", "Lulu:", "lulu:"):
|
| 105 |
+
if text.startswith(prefix):
|
| 106 |
+
text = text[len(prefix):].strip()
|
| 107 |
+
lines = [ln.rstrip() for ln in text.splitlines()]
|
| 108 |
+
# collapse excessive vertical whitespace without destroying code blocks too much
|
| 109 |
+
out: List[str] = []
|
| 110 |
+
blank = 0
|
| 111 |
+
for ln in lines:
|
| 112 |
+
if not ln.strip():
|
| 113 |
+
blank += 1
|
| 114 |
+
if blank <= 2:
|
| 115 |
+
out.append("")
|
| 116 |
+
else:
|
| 117 |
+
blank = 0
|
| 118 |
+
out.append(ln)
|
| 119 |
+
return "\n".join(out).strip()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def normalize_history(history: Any) -> List[Dict[str, str]]:
|
| 123 |
+
out: List[Dict[str, str]] = []
|
| 124 |
+
if not history:
|
| 125 |
+
return out
|
| 126 |
+
for item in history:
|
| 127 |
+
if isinstance(item, dict):
|
| 128 |
+
role = item.get("role", "")
|
| 129 |
+
content = clean_text(item.get("content", ""))
|
| 130 |
+
if role in {"user", "assistant"} and content:
|
| 131 |
+
out.append({"role": role, "content": content})
|
| 132 |
+
elif isinstance(item, (tuple, list)) and len(item) >= 2:
|
| 133 |
+
u = clean_text(item[0])
|
| 134 |
+
a = clean_text(item[1])
|
| 135 |
+
if u:
|
| 136 |
+
out.append({"role": "user", "content": u})
|
| 137 |
+
if a:
|
| 138 |
+
out.append({"role": "assistant", "content": a})
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def resolve_model_py(model_py: Optional[str]) -> str:
|
| 143 |
+
candidates: List[str] = []
|
| 144 |
+
if model_py:
|
| 145 |
+
candidates.append(model_py)
|
| 146 |
+
candidates.extend(["luluv2_inference_runtime.py"])
|
| 147 |
+
for c in candidates:
|
| 148 |
+
p = Path(c)
|
| 149 |
+
if p.exists():
|
| 150 |
+
return str(p.resolve())
|
| 151 |
+
raise FileNotFoundError("Could not find LULUV2 model file. Pass --model-py.")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def import_model_py(model_py: Optional[str]):
|
| 155 |
+
path = resolve_model_py(model_py)
|
| 156 |
+
spec = importlib.util.spec_from_file_location("luluv2_runtime_module", path)
|
| 157 |
+
if spec is None or spec.loader is None:
|
| 158 |
+
raise RuntimeError(f"Could not import model file: {path}")
|
| 159 |
+
mod = importlib.util.module_from_spec(spec)
|
| 160 |
+
spec.loader.exec_module(mod)
|
| 161 |
+
return mod, path
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@dataclass
|
| 165 |
+
class GenerationConfig:
|
| 166 |
+
max_new_tokens: int = 512
|
| 167 |
+
temperature: float = 0.65
|
| 168 |
+
top_k: int = 40
|
| 169 |
+
top_p: float = 0.90
|
| 170 |
+
min_p: float = 0.03
|
| 171 |
+
repetition_penalty: float = 1.10
|
| 172 |
+
frequency_penalty: float = 0.02
|
| 173 |
+
greedy: bool = False
|
| 174 |
+
no_repeat_ngram: int = 4
|
| 175 |
+
stream_every: int = 1
|
| 176 |
+
max_context_tokens: int = 4096
|
| 177 |
+
mode: str = "vwm" # fast, vwm, deep, slow
|
| 178 |
+
return_pass_metrics: bool = True
|
| 179 |
+
use_cache: bool = True
|
| 180 |
+
vectorized_prefill: bool = True
|
| 181 |
+
persistent_cache: bool = True
|
| 182 |
+
compile_step: bool = False
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@dataclass
|
| 186 |
+
class GenerationStats:
|
| 187 |
+
prompt_tokens: int = 0
|
| 188 |
+
prompt_total_tokens: int = 0
|
| 189 |
+
prompt_kept_tokens: int = 0
|
| 190 |
+
prompt_dropped_tokens: int = 0
|
| 191 |
+
generated_tokens: int = 0
|
| 192 |
+
elapsed_sec: float = 0.0
|
| 193 |
+
tokens_per_sec: float = 0.0
|
| 194 |
+
prefill_sec: float = 0.0
|
| 195 |
+
prefill_tps: float = 0.0
|
| 196 |
+
cache_hit: bool = False
|
| 197 |
+
cache_reused_tokens: int = 0
|
| 198 |
+
cache_new_prefill_tokens: int = 0
|
| 199 |
+
mode: str = "vwm"
|
| 200 |
+
backend: str = "none"
|
| 201 |
+
last_token: str = ""
|
| 202 |
+
last_token_id: int = -1
|
| 203 |
+
last_token_prob: float = 0.0
|
| 204 |
+
last_entropy: float = 0.0
|
| 205 |
+
finish_reason: str = "none"
|
| 206 |
+
pass1_pass2_kl: Optional[float] = None
|
| 207 |
+
pass1_pass2_logit_cosine: Optional[float] = None
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class KVLayerCache:
|
| 211 |
+
def __init__(self):
|
| 212 |
+
self.k: Optional[torch.Tensor] = None # [B, H, T, Dh]
|
| 213 |
+
self.v: Optional[torch.Tensor] = None
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def length(self) -> int:
|
| 217 |
+
if self.k is None:
|
| 218 |
+
return 0
|
| 219 |
+
return int(self.k.shape[2])
|
| 220 |
+
|
| 221 |
+
def set(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None:
|
| 222 |
+
if k.shape[2] > max_len:
|
| 223 |
+
k = k[:, :, -max_len:, :]
|
| 224 |
+
v = v[:, :, -max_len:, :]
|
| 225 |
+
self.k = k.detach().contiguous()
|
| 226 |
+
self.v = v.detach().contiguous()
|
| 227 |
+
|
| 228 |
+
def append(self, k: torch.Tensor, v: torch.Tensor, max_len: int) -> None:
|
| 229 |
+
if self.k is None:
|
| 230 |
+
self.set(k, v, max_len)
|
| 231 |
+
return
|
| 232 |
+
self.k = torch.cat([self.k, k.detach()], dim=2)
|
| 233 |
+
self.v = torch.cat([self.v, v.detach()], dim=2)
|
| 234 |
+
if self.k.shape[2] > max_len:
|
| 235 |
+
self.k = self.k[:, :, -max_len:, :].contiguous()
|
| 236 |
+
self.v = self.v[:, :, -max_len:, :].contiguous()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class DecoderKVCache:
|
| 240 |
+
def __init__(self, n_layers: int):
|
| 241 |
+
self.layers = [KVLayerCache() for _ in range(int(n_layers))]
|
| 242 |
+
|
| 243 |
+
def clear(self):
|
| 244 |
+
for layer in self.layers:
|
| 245 |
+
layer.k = None
|
| 246 |
+
layer.v = None
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def length(self) -> int:
|
| 250 |
+
if not self.layers:
|
| 251 |
+
return 0
|
| 252 |
+
return self.layers[0].length
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class LULUV2OptimizedEngine:
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
ckpt_path: str,
|
| 259 |
+
model_py: Optional[str] = None,
|
| 260 |
+
tokenizer_dir: Optional[str] = None,
|
| 261 |
+
device: Optional[str] = None,
|
| 262 |
+
dtype: str = "bf16",
|
| 263 |
+
local_files_only: bool = True,
|
| 264 |
+
no_config_download: bool = True,
|
| 265 |
+
force_base_only: bool = False,
|
| 266 |
+
):
|
| 267 |
+
setup_torch()
|
| 268 |
+
self.ckpt_path = str(ckpt_path)
|
| 269 |
+
self.ckpt_dir = Path(self.ckpt_path).resolve().parent
|
| 270 |
+
self.device = self._select_device(device)
|
| 271 |
+
self.dtype = self._dtype_from_name(dtype)
|
| 272 |
+
self.local_files_only = bool(local_files_only)
|
| 273 |
+
self.no_config_download = bool(no_config_download)
|
| 274 |
+
self.force_base_only = bool(force_base_only)
|
| 275 |
+
self.last_stats = GenerationStats()
|
| 276 |
+
self.recent_tokens: List[Dict[str, Any]] = []
|
| 277 |
+
self.last_prompt_total_tokens: int = 0
|
| 278 |
+
self.last_prompt_kept_tokens: int = 0
|
| 279 |
+
self.last_prompt_dropped_tokens: int = 0
|
| 280 |
+
self.cache_ids: Optional[torch.Tensor] = None
|
| 281 |
+
self.cache_mode: str = ""
|
| 282 |
+
self.cache_max_context: int = 0
|
| 283 |
+
self.pass1_cache: Optional[DecoderKVCache] = None
|
| 284 |
+
self.pass2_cache: Optional[DecoderKVCache] = None
|
| 285 |
+
self.cached_logits: Optional[torch.Tensor] = None
|
| 286 |
+
self.cached_pass1_logits: Optional[torch.Tensor] = None
|
| 287 |
+
self.cached_pass2_logits: Optional[torch.Tensor] = None
|
| 288 |
+
self.cache_backend: str = "cold"
|
| 289 |
+
|
| 290 |
+
self.goku, self.model_py_path = import_model_py(model_py)
|
| 291 |
+
self.args = SimpleNamespace(
|
| 292 |
+
checkpoint=self.ckpt_path,
|
| 293 |
+
tokenizer=tokenizer_dir or "",
|
| 294 |
+
model_id="",
|
| 295 |
+
no_config_download=self.no_config_download,
|
| 296 |
+
local_files_only=self.local_files_only,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
print("[guard] LULUV2 cockpit: no AutoModelForCausalLM.from_pretrained call and no external model weights loaded.")
|
| 300 |
+
print(f"[load] checkpoint={self.ckpt_path}")
|
| 301 |
+
self.base_ckpt, base = self.goku.load_lulu2_base(self.args, self.device, self.dtype)
|
| 302 |
+
self.tokenizer = self._load_tokenizer(tokenizer_dir)
|
| 303 |
+
self.model, self.has_pass2 = self._maybe_wrap_pass2(base)
|
| 304 |
+
self.base = self.model.base if self.has_pass2 else self.model
|
| 305 |
+
self.n_layers = int(self.base.config.num_hidden_layers)
|
| 306 |
+
self.model.eval()
|
| 307 |
+
self.base.eval()
|
| 308 |
+
self.model_info = self._build_model_info()
|
| 309 |
+
self._compiled = False
|
| 310 |
+
|
| 311 |
+
def _select_device(self, device: Optional[str]):
|
| 312 |
+
if device:
|
| 313 |
+
return torch.device(device)
|
| 314 |
+
if torch.cuda.is_available():
|
| 315 |
+
return torch.device("cuda")
|
| 316 |
+
return torch.device("cpu")
|
| 317 |
+
|
| 318 |
+
def _dtype_from_name(self, name: str):
|
| 319 |
+
name = (name or "bf16").lower()
|
| 320 |
+
if name in {"bf16", "bfloat16"}:
|
| 321 |
+
return torch.bfloat16
|
| 322 |
+
if name in {"fp16", "float16", "half"}:
|
| 323 |
+
return torch.float16
|
| 324 |
+
return torch.float32
|
| 325 |
+
|
| 326 |
+
def _load_tokenizer(self, tokenizer_dir: Optional[str]):
|
| 327 |
+
if tokenizer_dir:
|
| 328 |
+
self.args.tokenizer = tokenizer_dir
|
| 329 |
+
else:
|
| 330 |
+
sibling = self.ckpt_dir / "tokenizer"
|
| 331 |
+
if sibling.is_dir():
|
| 332 |
+
self.args.tokenizer = str(sibling)
|
| 333 |
+
tok = self.goku.load_tokenizer(self.args, self.base_ckpt)
|
| 334 |
+
if getattr(tok, "pad_token_id", None) is None and getattr(tok, "eos_token_id", None) is not None:
|
| 335 |
+
try:
|
| 336 |
+
tok.pad_token = tok.eos_token
|
| 337 |
+
except Exception:
|
| 338 |
+
pass
|
| 339 |
+
# Long-prompt safety: for chat/RAG prompts, the latest user turn and final
|
| 340 |
+
# instruction are normally at the end. Right-side truncation silently drops
|
| 341 |
+
# exactly the part the model must answer, so force left truncation where the
|
| 342 |
+
# tokenizer supports it. encode() below also performs manual left truncation
|
| 343 |
+
# and records how many tokens were dropped.
|
| 344 |
+
try:
|
| 345 |
+
tok.truncation_side = "left"
|
| 346 |
+
except Exception:
|
| 347 |
+
pass
|
| 348 |
+
try:
|
| 349 |
+
tok.model_max_length = 10**9
|
| 350 |
+
except Exception:
|
| 351 |
+
pass
|
| 352 |
+
return tok
|
| 353 |
+
|
| 354 |
+
def _maybe_wrap_pass2(self, base):
|
| 355 |
+
ckpt = self.base_ckpt
|
| 356 |
+
if self.force_base_only or "pass2_state" not in ckpt:
|
| 357 |
+
print("[pass2] no pass2_state loaded; running base LULUV2 forward")
|
| 358 |
+
return base.to(self.device).eval(), False
|
| 359 |
+
cfg_dict = dict(ckpt.get("pass2_config") or {})
|
| 360 |
+
Pass2Config = self.goku.Pass2Config
|
| 361 |
+
fields = getattr(Pass2Config, "__dataclass_fields__", {})
|
| 362 |
+
pass2_cfg = Pass2Config(**{k: v for k, v in cfg_dict.items() if k in fields})
|
| 363 |
+
model = self.goku.Lulu2TwoPassForCausalLM(base, pass2_cfg)
|
| 364 |
+
missing, unexpected = model.load_state_dict(ckpt["pass2_state"], strict=False)
|
| 365 |
+
print(f"[pass2] loaded pass2_state missing={len(missing)} unexpected={len(unexpected)}")
|
| 366 |
+
model.to(device=self.device, dtype=self.dtype).eval()
|
| 367 |
+
return model, True
|
| 368 |
+
|
| 369 |
+
def _build_model_info(self) -> Dict[str, Any]:
|
| 370 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 371 |
+
c_codes = [(n, p.numel()) for n, p in self.model.named_parameters() if n.endswith(".c")]
|
| 372 |
+
gate_mean = None
|
| 373 |
+
adapter_gate_mean = None
|
| 374 |
+
if self.has_pass2:
|
| 375 |
+
with torch.no_grad():
|
| 376 |
+
gate_mean = float(torch.sigmoid(self.model.layer_gates.float()).mean().item())
|
| 377 |
+
vals = [float(torch.sigmoid(ad.gate.float()).item()) for ad in self.model.adapters]
|
| 378 |
+
adapter_gate_mean = sum(vals) / max(1, len(vals))
|
| 379 |
+
ckpt_size = Path(self.ckpt_path).stat().st_size if Path(self.ckpt_path).exists() else 0
|
| 380 |
+
cfg = getattr(self.base, "config", None)
|
| 381 |
+
return {
|
| 382 |
+
"checkpoint": self.ckpt_path,
|
| 383 |
+
"checkpoint_size": human_bytes(ckpt_size),
|
| 384 |
+
"model_py": self.model_py_path,
|
| 385 |
+
"device": str(self.device),
|
| 386 |
+
"dtype": str(self.dtype).replace("torch.", ""),
|
| 387 |
+
"has_pass2": self.has_pass2,
|
| 388 |
+
"total_params": total_params,
|
| 389 |
+
"vwm_c_modules": len(c_codes),
|
| 390 |
+
"vwm_c_params": sum(n for _, n in c_codes),
|
| 391 |
+
"pass2_layer_gate_mean": gate_mean,
|
| 392 |
+
"pass2_adapter_gate_mean": adapter_gate_mean,
|
| 393 |
+
"hidden_size": getattr(cfg, "hidden_size", None),
|
| 394 |
+
"layers": getattr(cfg, "num_hidden_layers", None),
|
| 395 |
+
"heads": getattr(cfg, "num_attention_heads", None),
|
| 396 |
+
"kv_heads": getattr(cfg, "num_key_value_heads", None),
|
| 397 |
+
"max_position_embeddings": getattr(cfg, "max_position_embeddings", None),
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
def amp_context(self):
|
| 401 |
+
if self.device.type == "cuda" and self.dtype in (torch.bfloat16, torch.float16):
|
| 402 |
+
return torch.autocast("cuda", dtype=self.dtype)
|
| 403 |
+
return nullcontext()
|
| 404 |
+
|
| 405 |
+
def build_chat_prompt(
|
| 406 |
+
self,
|
| 407 |
+
message: str,
|
| 408 |
+
history: Any,
|
| 409 |
+
system_prompt: str,
|
| 410 |
+
memory_notes: str = "",
|
| 411 |
+
history_turns: int = 4,
|
| 412 |
+
extra_context: str = "",
|
| 413 |
+
) -> str:
|
| 414 |
+
history = normalize_history(history)
|
| 415 |
+
recent = history[-max(0, int(history_turns)) * 2:] if history_turns else []
|
| 416 |
+
system_chunks: List[str] = []
|
| 417 |
+
if system_prompt.strip():
|
| 418 |
+
system_chunks.append(system_prompt.strip())
|
| 419 |
+
if memory_notes.strip():
|
| 420 |
+
system_chunks.append("Useful memory notes:\n" + memory_notes.strip())
|
| 421 |
+
if extra_context.strip():
|
| 422 |
+
system_chunks.append("Relevant local context:\n" + extra_context.strip())
|
| 423 |
+
system = "\n\n".join(system_chunks)
|
| 424 |
+
messages: List[Dict[str, str]] = []
|
| 425 |
+
if system:
|
| 426 |
+
messages.append({"role": "system", "content": system})
|
| 427 |
+
messages.extend(recent)
|
| 428 |
+
messages.append({"role": "user", "content": clean_text(message)})
|
| 429 |
+
try:
|
| 430 |
+
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 431 |
+
except Exception:
|
| 432 |
+
parts: List[str] = []
|
| 433 |
+
if system:
|
| 434 |
+
parts.append(f"<|im_start|>system\n{system}<|im_end|>")
|
| 435 |
+
for item in recent:
|
| 436 |
+
parts.append(f"<|im_start|>{item['role']}\n{item['content']}<|im_end|>")
|
| 437 |
+
parts.append(f"<|im_start|>user\n{clean_text(message)}<|im_end|>")
|
| 438 |
+
parts.append("<|im_start|>assistant\n")
|
| 439 |
+
return "\n".join(parts)
|
| 440 |
+
|
| 441 |
+
def encode(self, text: str, max_context_tokens: int) -> torch.Tensor:
|
| 442 |
+
"""Encode prompt with explicit left-truncation and accounting.
|
| 443 |
+
|
| 444 |
+
This avoids a common long-context failure mode: many tokenizers default to
|
| 445 |
+
right-side truncation, which keeps the beginning of a huge prompt and drops
|
| 446 |
+
the final user instruction. For chat, we almost always want the opposite.
|
| 447 |
+
"""
|
| 448 |
+
max_context = max(1, int(max_context_tokens))
|
| 449 |
+
try:
|
| 450 |
+
self.tokenizer.truncation_side = "left"
|
| 451 |
+
except Exception:
|
| 452 |
+
pass
|
| 453 |
+
|
| 454 |
+
# Tokenize without tokenizer-side truncation so we know exactly whether the
|
| 455 |
+
# prompt was clipped. The prompt already contains chat special tokens.
|
| 456 |
+
try:
|
| 457 |
+
enc = self.tokenizer(
|
| 458 |
+
text,
|
| 459 |
+
return_tensors="pt",
|
| 460 |
+
truncation=False,
|
| 461 |
+
add_special_tokens=False,
|
| 462 |
+
)
|
| 463 |
+
except TypeError:
|
| 464 |
+
enc = self.tokenizer(text, return_tensors="pt", truncation=False)
|
| 465 |
+
|
| 466 |
+
ids = enc.input_ids
|
| 467 |
+
total = int(ids.shape[1])
|
| 468 |
+
dropped = max(0, total - max_context)
|
| 469 |
+
if dropped > 0:
|
| 470 |
+
ids = ids[:, -max_context:].contiguous()
|
| 471 |
+
# Do not reuse an older conversation cache after a hard context trim;
|
| 472 |
+
# the logical prefix changed and reuse can make long prompts feel like
|
| 473 |
+
# they are "forgetting" pieces.
|
| 474 |
+
self.pass1_cache = None
|
| 475 |
+
self.pass2_cache = None
|
| 476 |
+
self.cache_ids = None
|
| 477 |
+
self.cached_logits = None
|
| 478 |
+
self.cached_pass1_logits = None
|
| 479 |
+
self.cached_pass2_logits = None
|
| 480 |
+
self.cache_backend = "truncated-rebuild"
|
| 481 |
+
|
| 482 |
+
self.last_prompt_total_tokens = total
|
| 483 |
+
self.last_prompt_kept_tokens = int(ids.shape[1])
|
| 484 |
+
self.last_prompt_dropped_tokens = dropped
|
| 485 |
+
return ids.to(self.device)
|
| 486 |
+
|
| 487 |
+
def _position_ids(self, T: int, offset: int = 0) -> torch.Tensor:
|
| 488 |
+
return torch.arange(offset, offset + T, device=self.device, dtype=torch.long).unsqueeze(0)
|
| 489 |
+
|
| 490 |
+
def _attn_prefill(self, attn, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor:
|
| 491 |
+
bsz, q_len, _ = hidden_states.size()
|
| 492 |
+
query_states = attn.q_proj(hidden_states)
|
| 493 |
+
key_states = attn.k_proj(hidden_states)
|
| 494 |
+
value_states = attn.v_proj(hidden_states)
|
| 495 |
+
query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2)
|
| 496 |
+
key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
|
| 497 |
+
value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
|
| 498 |
+
cos, sin = attn.rotary_emb(value_states, position_ids)
|
| 499 |
+
query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 500 |
+
key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups)
|
| 501 |
+
value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups)
|
| 502 |
+
cache.set(key_states, value_states, max_context)
|
| 503 |
+
attn_output = F.scaled_dot_product_attention(
|
| 504 |
+
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True, scale=attn.scaling
|
| 505 |
+
)
|
| 506 |
+
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size)
|
| 507 |
+
return attn.o_proj(attn_output)
|
| 508 |
+
|
| 509 |
+
def _attn_step(self, attn, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor:
|
| 510 |
+
bsz, q_len, _ = hidden_states.size()
|
| 511 |
+
assert q_len == 1
|
| 512 |
+
query_states = attn.q_proj(hidden_states)
|
| 513 |
+
key_states = attn.k_proj(hidden_states)
|
| 514 |
+
value_states = attn.v_proj(hidden_states)
|
| 515 |
+
query_states = query_states.view(bsz, q_len, attn.num_heads, attn.head_dim).transpose(1, 2)
|
| 516 |
+
key_states = key_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
|
| 517 |
+
value_states = value_states.view(bsz, q_len, attn.num_key_value_heads, attn.head_dim).transpose(1, 2)
|
| 518 |
+
position_ids = self._position_ids(1, pos)
|
| 519 |
+
cos, sin = attn.rotary_emb(value_states, position_ids)
|
| 520 |
+
query_states, key_states = self.goku.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 521 |
+
key_states = self.goku.repeat_kv(key_states, attn.num_key_value_groups)
|
| 522 |
+
value_states = self.goku.repeat_kv(value_states, attn.num_key_value_groups)
|
| 523 |
+
cache.append(key_states, value_states, max_context)
|
| 524 |
+
if cache.k is None or cache.v is None:
|
| 525 |
+
raise RuntimeError("KV cache append failed")
|
| 526 |
+
attn_output = F.scaled_dot_product_attention(
|
| 527 |
+
query_states, cache.k, cache.v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scaling
|
| 528 |
+
)
|
| 529 |
+
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, attn.hidden_size)
|
| 530 |
+
return attn.o_proj(attn_output)
|
| 531 |
+
|
| 532 |
+
def _layer_prefill(self, layer, hidden_states: torch.Tensor, position_ids: torch.Tensor, cache: KVLayerCache, max_context: int) -> torch.Tensor:
|
| 533 |
+
residual = hidden_states
|
| 534 |
+
x = layer.input_layernorm(hidden_states)
|
| 535 |
+
x = self._attn_prefill(layer.self_attn, x, position_ids, cache, max_context)
|
| 536 |
+
hidden_states = residual + x
|
| 537 |
+
residual = hidden_states
|
| 538 |
+
x = layer.post_attention_layernorm(hidden_states)
|
| 539 |
+
x = layer.mlp(x)
|
| 540 |
+
return residual + x
|
| 541 |
+
|
| 542 |
+
def _layer_step(self, layer, hidden_states: torch.Tensor, pos: int, cache: KVLayerCache, max_context: int) -> torch.Tensor:
|
| 543 |
+
residual = hidden_states
|
| 544 |
+
x = layer.input_layernorm(hidden_states)
|
| 545 |
+
x = self._attn_step(layer.self_attn, x, pos, cache, max_context)
|
| 546 |
+
hidden_states = residual + x
|
| 547 |
+
residual = hidden_states
|
| 548 |
+
x = layer.post_attention_layernorm(hidden_states)
|
| 549 |
+
x = layer.mlp(x)
|
| 550 |
+
return residual + x
|
| 551 |
+
|
| 552 |
+
@torch.no_grad()
|
| 553 |
+
def _prefill_pass1(self, input_ids: torch.Tensor, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 554 |
+
T = int(input_ids.shape[1])
|
| 555 |
+
position_ids = self._position_ids(T, 0)
|
| 556 |
+
cache = DecoderKVCache(self.n_layers)
|
| 557 |
+
h = self.base.model.embed_tokens(input_ids)
|
| 558 |
+
if use_pass_embed and self.has_pass2:
|
| 559 |
+
h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1)
|
| 560 |
+
layer_states: List[torch.Tensor] = []
|
| 561 |
+
for i, layer in enumerate(self.base.model.layers):
|
| 562 |
+
h = self._layer_prefill(layer, h, position_ids, cache.layers[i], max_context)
|
| 563 |
+
layer_states.append(h)
|
| 564 |
+
normed = self.base.model.norm(h)
|
| 565 |
+
logits = self.base.lm_head(normed)
|
| 566 |
+
self.pass1_cache = cache
|
| 567 |
+
return h, layer_states, position_ids, logits
|
| 568 |
+
|
| 569 |
+
@torch.no_grad()
|
| 570 |
+
def _prefill_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], position_ids: torch.Tensor, max_context: int) -> torch.Tensor:
|
| 571 |
+
if not self.has_pass2:
|
| 572 |
+
raise RuntimeError("pass2 requested but checkpoint has no pass2_state")
|
| 573 |
+
cache = DecoderKVCache(self.n_layers)
|
| 574 |
+
h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1)
|
| 575 |
+
for i, layer in enumerate(self.base.model.layers):
|
| 576 |
+
before = h2
|
| 577 |
+
layer_out = self._layer_prefill(layer, h2, position_ids, cache.layers[i], max_context)
|
| 578 |
+
layer_delta = layer_out - before
|
| 579 |
+
gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device)
|
| 580 |
+
adapter_delta = self.model.adapters[i](h2, pass1_states[i])
|
| 581 |
+
h2 = before + gate * layer_delta + adapter_delta
|
| 582 |
+
normed = self.base.model.norm(h2)
|
| 583 |
+
logits = self.base.lm_head(normed)
|
| 584 |
+
self.pass2_cache = cache
|
| 585 |
+
return logits
|
| 586 |
+
|
| 587 |
+
@torch.no_grad()
|
| 588 |
+
def _step_pass1(self, token_id: torch.Tensor, pos: int, max_context: int, use_pass_embed: bool) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
| 589 |
+
if self.pass1_cache is None:
|
| 590 |
+
self.pass1_cache = DecoderKVCache(self.n_layers)
|
| 591 |
+
h = self.base.model.embed_tokens(token_id)
|
| 592 |
+
if use_pass_embed and self.has_pass2:
|
| 593 |
+
h = h + self.model.pass_embed[0].to(dtype=h.dtype, device=h.device).view(1, 1, -1)
|
| 594 |
+
states: List[torch.Tensor] = []
|
| 595 |
+
for i, layer in enumerate(self.base.model.layers):
|
| 596 |
+
h = self._layer_step(layer, h, pos, self.pass1_cache.layers[i], max_context)
|
| 597 |
+
states.append(h)
|
| 598 |
+
logits = self.base.lm_head(self.base.model.norm(h))
|
| 599 |
+
return h, states, logits
|
| 600 |
+
|
| 601 |
+
@torch.no_grad()
|
| 602 |
+
def _step_pass2(self, h1_resid: torch.Tensor, pass1_states: List[torch.Tensor], pos: int, max_context: int) -> torch.Tensor:
|
| 603 |
+
if not self.has_pass2:
|
| 604 |
+
raise RuntimeError("pass2 step requested but unavailable")
|
| 605 |
+
if self.pass2_cache is None:
|
| 606 |
+
self.pass2_cache = DecoderKVCache(self.n_layers)
|
| 607 |
+
h2 = h1_resid + self.model.pass_embed[1].to(dtype=h1_resid.dtype, device=h1_resid.device).view(1, 1, -1)
|
| 608 |
+
for i, layer in enumerate(self.base.model.layers):
|
| 609 |
+
before = h2
|
| 610 |
+
layer_out = self._layer_step(layer, h2, pos, self.pass2_cache.layers[i], max_context)
|
| 611 |
+
layer_delta = layer_out - before
|
| 612 |
+
gate = torch.sigmoid(self.model.layer_gates[i]).to(dtype=h2.dtype, device=h2.device)
|
| 613 |
+
adapter_delta = self.model.adapters[i](h2, pass1_states[i])
|
| 614 |
+
h2 = before + gate * layer_delta + adapter_delta
|
| 615 |
+
return self.base.lm_head(self.base.model.norm(h2))
|
| 616 |
+
|
| 617 |
+
def _ids_prefix_len(self, old: torch.Tensor, new: torch.Tensor) -> int:
|
| 618 |
+
if old is None or old.numel() == 0 or new.numel() == 0:
|
| 619 |
+
return 0
|
| 620 |
+
old1 = old[0]
|
| 621 |
+
new1 = new[0]
|
| 622 |
+
max_n = min(int(old1.numel()), int(new1.numel()))
|
| 623 |
+
if max_n == 0:
|
| 624 |
+
return 0
|
| 625 |
+
# Fast path: old is exact prefix of new.
|
| 626 |
+
if int(old1.numel()) <= int(new1.numel()) and torch.equal(old1, new1[: old1.numel()]):
|
| 627 |
+
return int(old1.numel())
|
| 628 |
+
# Conservative fallback, scan from max down; prompts are usually exact-prefix or reset.
|
| 629 |
+
for n in range(max_n, 0, -1):
|
| 630 |
+
if torch.equal(old1[:n], new1[:n]):
|
| 631 |
+
return n
|
| 632 |
+
return 0
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
@torch.no_grad()
|
| 636 |
+
def _token_prefill_context(self, input_ids: torch.Tensor, cfg: GenerationConfig, use_pass2: bool, use_pass_embed: bool, max_context: int) -> None:
|
| 637 |
+
"""
|
| 638 |
+
Conservative cache builder.
|
| 639 |
+
|
| 640 |
+
It fills the same pass1/pass2 KV caches by walking the prompt one token at a time.
|
| 641 |
+
This is slower than vectorized prefill but much safer across checkpoint/runtime variants,
|
| 642 |
+
and it still gives a valid decode cache + persistent cache for the generated tokens.
|
| 643 |
+
"""
|
| 644 |
+
self.pass1_cache = DecoderKVCache(self.n_layers)
|
| 645 |
+
self.pass2_cache = DecoderKVCache(self.n_layers) if use_pass2 else None
|
| 646 |
+
self.cached_logits = None
|
| 647 |
+
self.cached_pass1_logits = None
|
| 648 |
+
self.cached_pass2_logits = None
|
| 649 |
+
|
| 650 |
+
T = int(input_ids.shape[1])
|
| 651 |
+
for pos in range(T):
|
| 652 |
+
tok = input_ids[:, pos:pos + 1]
|
| 653 |
+
h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed)
|
| 654 |
+
if use_pass2:
|
| 655 |
+
logits2 = self._step_pass2(h1, states, pos, max_context)
|
| 656 |
+
self.cached_logits = logits2
|
| 657 |
+
self.cached_pass1_logits = logits1
|
| 658 |
+
self.cached_pass2_logits = logits2
|
| 659 |
+
else:
|
| 660 |
+
self.cached_logits = logits1
|
| 661 |
+
self.cached_pass1_logits = logits1
|
| 662 |
+
self.cached_pass2_logits = None
|
| 663 |
+
|
| 664 |
+
@torch.no_grad()
|
| 665 |
+
def _prepare_cached_context(self, input_ids: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, bool, int, int, str]:
|
| 666 |
+
mode = self._effective_mode(cfg.mode)
|
| 667 |
+
max_context = int(cfg.max_context_tokens)
|
| 668 |
+
use_pass2 = mode in {"vwm", "deep"} and self.has_pass2
|
| 669 |
+
use_pass_embed = bool(use_pass2)
|
| 670 |
+
T = int(input_ids.shape[1])
|
| 671 |
+
if T > max_context:
|
| 672 |
+
input_ids = input_ids[:, -max_context:]
|
| 673 |
+
T = max_context
|
| 674 |
+
|
| 675 |
+
# If mode/context changed, persistent cache is invalid.
|
| 676 |
+
cache_ok = (
|
| 677 |
+
cfg.persistent_cache
|
| 678 |
+
and self.cache_ids is not None
|
| 679 |
+
and self.cache_mode == mode
|
| 680 |
+
and self.cache_max_context == max_context
|
| 681 |
+
and self.pass1_cache is not None
|
| 682 |
+
)
|
| 683 |
+
prefix = self._ids_prefix_len(self.cache_ids, input_ids) if cache_ok else 0
|
| 684 |
+
cache_hit = bool(cache_ok and prefix == int(self.cache_ids.shape[1]) and prefix <= T and prefix > 0)
|
| 685 |
+
|
| 686 |
+
t0 = time.time()
|
| 687 |
+
if cache_hit:
|
| 688 |
+
# Process only suffix between prior cached prompt and new prompt.
|
| 689 |
+
suffix = input_ids[:, prefix:]
|
| 690 |
+
for j in range(int(suffix.shape[1])):
|
| 691 |
+
tok = suffix[:, j : j + 1]
|
| 692 |
+
pos = prefix + j
|
| 693 |
+
h1, states, logits1 = self._step_pass1(tok, pos, max_context, use_pass_embed=use_pass_embed)
|
| 694 |
+
if use_pass2:
|
| 695 |
+
logits2 = self._step_pass2(h1, states, pos, max_context)
|
| 696 |
+
self.cached_logits = logits2
|
| 697 |
+
self.cached_pass1_logits = logits1
|
| 698 |
+
self.cached_pass2_logits = logits2
|
| 699 |
+
else:
|
| 700 |
+
self.cached_logits = logits1
|
| 701 |
+
self.cached_pass1_logits = logits1
|
| 702 |
+
self.cached_pass2_logits = None
|
| 703 |
+
self.cache_ids = input_ids.detach().clone()
|
| 704 |
+
self.cache_backend = "persistent-kv-suffix" if suffix.numel() else "persistent-kv-hit"
|
| 705 |
+
return input_ids, True, prefix, int(suffix.shape[1]), self.cache_backend
|
| 706 |
+
|
| 707 |
+
# Reset and prefill. Prefer vectorized prefill, but fall back to conservative
|
| 708 |
+
# token prefill if the runtime variant does not support our vectorized cache path.
|
| 709 |
+
self.pass1_cache = None
|
| 710 |
+
self.pass2_cache = None
|
| 711 |
+
backend = "vectorized-prefill"
|
| 712 |
+
if bool(cfg.vectorized_prefill):
|
| 713 |
+
try:
|
| 714 |
+
h1, states, pos_ids, logits1 = self._prefill_pass1(input_ids, max_context, use_pass_embed=use_pass_embed)
|
| 715 |
+
if use_pass2:
|
| 716 |
+
logits2 = self._prefill_pass2(h1, states, pos_ids, max_context)
|
| 717 |
+
self.cached_logits = logits2
|
| 718 |
+
self.cached_pass1_logits = logits1
|
| 719 |
+
self.cached_pass2_logits = logits2
|
| 720 |
+
else:
|
| 721 |
+
self.cached_logits = logits1
|
| 722 |
+
self.cached_pass1_logits = logits1
|
| 723 |
+
self.cached_pass2_logits = None
|
| 724 |
+
except Exception as exc:
|
| 725 |
+
if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}:
|
| 726 |
+
print("[cache] vectorized prefill failed; using token-prefill cache.")
|
| 727 |
+
traceback.print_exc()
|
| 728 |
+
self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context)
|
| 729 |
+
backend = "token-prefill-cache"
|
| 730 |
+
else:
|
| 731 |
+
self._token_prefill_context(input_ids, cfg, use_pass2=use_pass2, use_pass_embed=use_pass_embed, max_context=max_context)
|
| 732 |
+
backend = "token-prefill-cache"
|
| 733 |
+
|
| 734 |
+
self.cache_ids = input_ids.detach().clone()
|
| 735 |
+
self.cache_mode = mode
|
| 736 |
+
self.cache_max_context = max_context
|
| 737 |
+
self.cache_backend = backend
|
| 738 |
+
return input_ids, False, 0, T, self.cache_backend
|
| 739 |
+
|
| 740 |
+
def _effective_mode(self, mode: str) -> str:
|
| 741 |
+
mode = (mode or "vwm").lower()
|
| 742 |
+
if mode in {"fast", "base", "pass1"}:
|
| 743 |
+
return "fast"
|
| 744 |
+
if mode in {"deep", "32k", "long"}:
|
| 745 |
+
return "deep"
|
| 746 |
+
if mode in {"slow", "full"}:
|
| 747 |
+
return "slow"
|
| 748 |
+
return "vwm"
|
| 749 |
+
|
| 750 |
+
@torch.no_grad()
|
| 751 |
+
def pass_metrics_from_logits(self, logits1: Optional[torch.Tensor], logits2: Optional[torch.Tensor]) -> Tuple[Optional[float], Optional[float]]:
|
| 752 |
+
if logits1 is None or logits2 is None:
|
| 753 |
+
return None, None
|
| 754 |
+
try:
|
| 755 |
+
l1 = logits1[:, -1, :].float()
|
| 756 |
+
l2 = logits2[:, -1, :].float()
|
| 757 |
+
kl = F.kl_div(F.log_softmax(l2, dim=-1), F.softmax(l1, dim=-1), reduction="batchmean")
|
| 758 |
+
cos = F.cosine_similarity(l1, l2, dim=-1).mean()
|
| 759 |
+
return float(kl.item()), float(cos.item())
|
| 760 |
+
except Exception:
|
| 761 |
+
return None, None
|
| 762 |
+
|
| 763 |
+
def _apply_penalties(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> torch.Tensor:
|
| 764 |
+
if generated.numel() == 0:
|
| 765 |
+
return logits
|
| 766 |
+
out = logits.clone()
|
| 767 |
+
uniq, counts = torch.unique(generated.view(-1), return_counts=True)
|
| 768 |
+
if cfg.repetition_penalty != 1.0:
|
| 769 |
+
selected = out[:, uniq]
|
| 770 |
+
selected = torch.where(selected > 0, selected / float(cfg.repetition_penalty), selected * float(cfg.repetition_penalty))
|
| 771 |
+
out[:, uniq] = selected
|
| 772 |
+
if cfg.frequency_penalty:
|
| 773 |
+
out[:, uniq] -= float(cfg.frequency_penalty) * counts.to(out.dtype).unsqueeze(0)
|
| 774 |
+
n = int(cfg.no_repeat_ngram)
|
| 775 |
+
if n > 1 and generated.size(1) >= n - 1:
|
| 776 |
+
seq = generated[0].tolist()
|
| 777 |
+
prefix = tuple(seq[-(n - 1):])
|
| 778 |
+
banned = []
|
| 779 |
+
for i in range(len(seq) - n + 1):
|
| 780 |
+
if tuple(seq[i:i + n - 1]) == prefix:
|
| 781 |
+
banned.append(seq[i + n - 1])
|
| 782 |
+
if banned:
|
| 783 |
+
out[:, list(set(banned))] = -float("inf")
|
| 784 |
+
return out
|
| 785 |
+
|
| 786 |
+
@torch.no_grad()
|
| 787 |
+
def _sample_next(self, logits: torch.Tensor, generated: torch.Tensor, cfg: GenerationConfig) -> Tuple[torch.Tensor, Dict[str, float]]:
|
| 788 |
+
work = self._apply_penalties(logits.float(), generated, cfg)
|
| 789 |
+
if cfg.greedy or cfg.temperature <= 0:
|
| 790 |
+
probs = torch.softmax(work, dim=-1)
|
| 791 |
+
next_id = torch.argmax(work, dim=-1, keepdim=True)
|
| 792 |
+
else:
|
| 793 |
+
work = work / max(float(cfg.temperature), 1e-6)
|
| 794 |
+
if cfg.top_k > 0:
|
| 795 |
+
k = min(int(cfg.top_k), work.size(-1))
|
| 796 |
+
thresh = torch.topk(work, k, dim=-1).values[..., -1, None]
|
| 797 |
+
work = torch.where(work >= thresh, work, torch.full_like(work, -float("inf")))
|
| 798 |
+
if 0.0 < cfg.top_p < 1.0:
|
| 799 |
+
sorted_logits, sorted_idx = torch.sort(work, descending=True, dim=-1)
|
| 800 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
| 801 |
+
cumprobs = torch.cumsum(sorted_probs, dim=-1)
|
| 802 |
+
remove = cumprobs > float(cfg.top_p)
|
| 803 |
+
shifted = remove.clone()
|
| 804 |
+
shifted[..., 1:] = remove[..., :-1]
|
| 805 |
+
shifted[..., 0] = False
|
| 806 |
+
sorted_logits = sorted_logits.masked_fill(shifted, -float("inf"))
|
| 807 |
+
work = torch.full_like(work, -float("inf")).scatter(1, sorted_idx, sorted_logits)
|
| 808 |
+
if 0.0 < cfg.min_p < 1.0:
|
| 809 |
+
probs_for_minp = torch.softmax(work, dim=-1)
|
| 810 |
+
max_prob = probs_for_minp.max(dim=-1, keepdim=True).values
|
| 811 |
+
keep = probs_for_minp >= float(cfg.min_p) * max_prob
|
| 812 |
+
work = work.masked_fill(~keep, -float("inf"))
|
| 813 |
+
probs = torch.softmax(work, dim=-1)
|
| 814 |
+
if torch.isnan(probs).any() or not torch.isfinite(probs.sum()) or float(probs.sum()) <= 0:
|
| 815 |
+
next_id = torch.argmax(logits, dim=-1, keepdim=True)
|
| 816 |
+
probs = torch.softmax(logits.float(), dim=-1)
|
| 817 |
+
else:
|
| 818 |
+
next_id = torch.multinomial(probs, 1)
|
| 819 |
+
prob = float(probs.gather(1, next_id).item()) if probs.numel() else 0.0
|
| 820 |
+
entropy = float((-(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)).mean().item()) if probs.numel() else 0.0
|
| 821 |
+
return next_id, {"prob": prob, "entropy": entropy}
|
| 822 |
+
|
| 823 |
+
@torch.no_grad()
|
| 824 |
+
def _slow_generate(self, ids: torch.Tensor, prompt_len: int, cfg: GenerationConfig) -> Generator[str, None, None]:
|
| 825 |
+
# Compatibility path: full prefix recompute every token.
|
| 826 |
+
eos_id = getattr(self.tokenizer, "eos_token_id", None)
|
| 827 |
+
last_text = ""
|
| 828 |
+
t0 = time.time()
|
| 829 |
+
for step in range(int(cfg.max_new_tokens)):
|
| 830 |
+
ctx = ids[:, -int(cfg.max_context_tokens):]
|
| 831 |
+
with self.amp_context():
|
| 832 |
+
out = self.model(ctx) if self._effective_mode(cfg.mode) != "fast" else self.base(ctx)
|
| 833 |
+
logits = out.logits[:, -1, :].float()
|
| 834 |
+
generated = ids[:, prompt_len:]
|
| 835 |
+
next_id, tok_stats = self._sample_next(logits, generated, cfg)
|
| 836 |
+
ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
|
| 837 |
+
token_id = int(next_id.item())
|
| 838 |
+
token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
|
| 839 |
+
self._record_token(step + 1, token_id, token_text, tok_stats)
|
| 840 |
+
if eos_id is not None and token_id == int(eos_id):
|
| 841 |
+
break
|
| 842 |
+
if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
|
| 843 |
+
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
|
| 844 |
+
if any(s in raw for s in STOP_STRINGS):
|
| 845 |
+
break
|
| 846 |
+
text = clean_text(raw)
|
| 847 |
+
if text and text != last_text:
|
| 848 |
+
elapsed = time.time() - t0
|
| 849 |
+
gen = int(ids.shape[1]) - prompt_len
|
| 850 |
+
self.last_stats = GenerationStats(prompt_tokens=prompt_len, prompt_total_tokens=self.last_prompt_total_tokens, prompt_kept_tokens=self.last_prompt_kept_tokens, prompt_dropped_tokens=self.last_prompt_dropped_tokens, generated_tokens=gen, elapsed_sec=elapsed, tokens_per_sec=gen / max(elapsed, 1e-9), mode=cfg.mode, backend="slow-full-prefix", last_token=token_text, last_token_id=token_id, last_token_prob=tok_stats["prob"], last_entropy=tok_stats["entropy"], finish_reason="streaming")
|
| 851 |
+
last_text = text
|
| 852 |
+
yield text
|
| 853 |
+
final = clean_text(self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True))
|
| 854 |
+
if final:
|
| 855 |
+
yield final
|
| 856 |
+
|
| 857 |
+
def _record_token(self, i: int, token_id: int, token_text: str, tok_stats: Dict[str, float]) -> None:
|
| 858 |
+
self.recent_tokens.append({"i": i, "id": token_id, "text": token_text, "prob": tok_stats.get("prob", 0.0), "entropy": tok_stats.get("entropy", 0.0)})
|
| 859 |
+
self.recent_tokens = self.recent_tokens[-64:]
|
| 860 |
+
|
| 861 |
+
@torch.no_grad()
|
| 862 |
+
def generate(self, prompt: str, cfg: GenerationConfig) -> Generator[str, None, None]:
|
| 863 |
+
self.model.eval()
|
| 864 |
+
self.base.eval()
|
| 865 |
+
self.recent_tokens = []
|
| 866 |
+
mode = self._effective_mode(cfg.mode)
|
| 867 |
+
if mode == "deep":
|
| 868 |
+
cfg.max_context_tokens = max(int(cfg.max_context_tokens), 16384)
|
| 869 |
+
ids = self.encode(prompt, max_context_tokens=int(cfg.max_context_tokens))
|
| 870 |
+
prompt_len = int(ids.shape[1])
|
| 871 |
+
if self.last_prompt_dropped_tokens > 0:
|
| 872 |
+
print(f"[context] prompt clipped: kept={self.last_prompt_kept_tokens} total={self.last_prompt_total_tokens} dropped={self.last_prompt_dropped_tokens}")
|
| 873 |
+
t_start = time.time()
|
| 874 |
+
prefill_sec = 0.0
|
| 875 |
+
cache_hit = False
|
| 876 |
+
reused = 0
|
| 877 |
+
new_prefill = prompt_len
|
| 878 |
+
backend = ""
|
| 879 |
+
pass_kl = None
|
| 880 |
+
pass_cos = None
|
| 881 |
+
|
| 882 |
+
if (not cfg.use_cache) or mode == "slow":
|
| 883 |
+
yield from self._slow_generate(ids, prompt_len, cfg)
|
| 884 |
+
return
|
| 885 |
+
|
| 886 |
+
try:
|
| 887 |
+
with self.amp_context():
|
| 888 |
+
t_pref = time.time()
|
| 889 |
+
ids, cache_hit, reused, new_prefill, backend = self._prepare_cached_context(ids, cfg)
|
| 890 |
+
prefill_sec = time.time() - t_pref
|
| 891 |
+
pass_kl, pass_cos = self.pass_metrics_from_logits(self.cached_pass1_logits, self.cached_pass2_logits) if cfg.return_pass_metrics else (None, None)
|
| 892 |
+
except Exception as exc:
|
| 893 |
+
print(f"[cache] cached path failed; falling back to slow full-prefix: {type(exc).__name__}: {exc}")
|
| 894 |
+
if os.getenv("LULUV2_CACHE_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}:
|
| 895 |
+
traceback.print_exc()
|
| 896 |
+
self.pass1_cache = None
|
| 897 |
+
self.pass2_cache = None
|
| 898 |
+
self.cache_ids = None
|
| 899 |
+
yield from self._slow_generate(ids, prompt_len, cfg)
|
| 900 |
+
return
|
| 901 |
+
|
| 902 |
+
eos_id = getattr(self.tokenizer, "eos_token_id", None)
|
| 903 |
+
last_text = ""
|
| 904 |
+
finish_reason = "length"
|
| 905 |
+
use_pass2 = mode in {"vwm", "deep"} and self.has_pass2
|
| 906 |
+
use_pass_embed = bool(use_pass2)
|
| 907 |
+
|
| 908 |
+
for step in range(int(cfg.max_new_tokens)):
|
| 909 |
+
logits = self.cached_logits[:, -1, :].float() if self.cached_logits is not None and self.cached_logits.dim() == 3 else self.cached_logits.float()
|
| 910 |
+
generated = ids[:, prompt_len:]
|
| 911 |
+
next_id, tok_stats = self._sample_next(logits, generated, cfg)
|
| 912 |
+
token_id = int(next_id.item())
|
| 913 |
+
token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
|
| 914 |
+
self._record_token(step + 1, token_id, token_text, tok_stats)
|
| 915 |
+
ids = torch.cat([ids, next_id.to(ids.device)], dim=-1)
|
| 916 |
+
|
| 917 |
+
if eos_id is not None and token_id == int(eos_id):
|
| 918 |
+
finish_reason = "eos"
|
| 919 |
+
break
|
| 920 |
+
|
| 921 |
+
pos = int(ids.shape[1]) - 1
|
| 922 |
+
try:
|
| 923 |
+
with self.amp_context():
|
| 924 |
+
h1, states, logits1 = self._step_pass1(next_id.to(self.device), pos, int(cfg.max_context_tokens), use_pass_embed=use_pass_embed)
|
| 925 |
+
if use_pass2:
|
| 926 |
+
logits2 = self._step_pass2(h1, states, pos, int(cfg.max_context_tokens))
|
| 927 |
+
self.cached_logits = logits2
|
| 928 |
+
self.cached_pass1_logits = logits1
|
| 929 |
+
self.cached_pass2_logits = logits2
|
| 930 |
+
else:
|
| 931 |
+
self.cached_logits = logits1
|
| 932 |
+
self.cached_pass1_logits = logits1
|
| 933 |
+
self.cached_pass2_logits = None
|
| 934 |
+
if self.cache_ids is not None:
|
| 935 |
+
self.cache_ids = torch.cat([self.cache_ids, next_id.detach().to(self.cache_ids.device)], dim=-1)
|
| 936 |
+
if self.cache_ids.shape[1] > int(cfg.max_context_tokens):
|
| 937 |
+
self.cache_ids = self.cache_ids[:, -int(cfg.max_context_tokens):]
|
| 938 |
+
except Exception as exc:
|
| 939 |
+
print(f"[decode-cache] step failed; falling back for this request: {type(exc).__name__}: {exc}")
|
| 940 |
+
# Finish with slow path from current ids; do not pretend cache is valid.
|
| 941 |
+
self.cache_ids = None
|
| 942 |
+
yield from self._slow_generate(ids, prompt_len, cfg)
|
| 943 |
+
return
|
| 944 |
+
|
| 945 |
+
if (step + 1) % int(cfg.stream_every) == 0 or step == 0:
|
| 946 |
+
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
|
| 947 |
+
if any(s in raw for s in STOP_STRINGS):
|
| 948 |
+
finish_reason = "stop_string"
|
| 949 |
+
break
|
| 950 |
+
text = clean_text(raw)
|
| 951 |
+
if text and text != last_text:
|
| 952 |
+
elapsed = time.time() - t_start
|
| 953 |
+
gen = int(ids.shape[1]) - prompt_len
|
| 954 |
+
self.last_stats = GenerationStats(
|
| 955 |
+
prompt_tokens=prompt_len,
|
| 956 |
+
prompt_total_tokens=self.last_prompt_total_tokens,
|
| 957 |
+
prompt_kept_tokens=self.last_prompt_kept_tokens,
|
| 958 |
+
prompt_dropped_tokens=self.last_prompt_dropped_tokens,
|
| 959 |
+
generated_tokens=gen,
|
| 960 |
+
elapsed_sec=elapsed,
|
| 961 |
+
tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9),
|
| 962 |
+
prefill_sec=prefill_sec,
|
| 963 |
+
prefill_tps=(new_prefill / max(prefill_sec, 1e-9)),
|
| 964 |
+
cache_hit=cache_hit,
|
| 965 |
+
cache_reused_tokens=reused,
|
| 966 |
+
cache_new_prefill_tokens=new_prefill,
|
| 967 |
+
mode=mode,
|
| 968 |
+
backend=backend,
|
| 969 |
+
last_token=token_text,
|
| 970 |
+
last_token_id=token_id,
|
| 971 |
+
last_token_prob=tok_stats["prob"],
|
| 972 |
+
last_entropy=tok_stats["entropy"],
|
| 973 |
+
finish_reason="streaming",
|
| 974 |
+
pass1_pass2_kl=pass_kl,
|
| 975 |
+
pass1_pass2_logit_cosine=pass_cos,
|
| 976 |
+
)
|
| 977 |
+
last_text = text
|
| 978 |
+
yield text
|
| 979 |
+
|
| 980 |
+
raw = self.tokenizer.decode(ids[0, prompt_len:], skip_special_tokens=True)
|
| 981 |
+
final = clean_text(raw)
|
| 982 |
+
elapsed = time.time() - t_start
|
| 983 |
+
gen = int(ids.shape[1]) - prompt_len
|
| 984 |
+
self.last_stats = GenerationStats(
|
| 985 |
+
prompt_tokens=prompt_len,
|
| 986 |
+
prompt_total_tokens=self.last_prompt_total_tokens,
|
| 987 |
+
prompt_kept_tokens=self.last_prompt_kept_tokens,
|
| 988 |
+
prompt_dropped_tokens=self.last_prompt_dropped_tokens,
|
| 989 |
+
generated_tokens=gen,
|
| 990 |
+
elapsed_sec=elapsed,
|
| 991 |
+
tokens_per_sec=gen / max(elapsed - prefill_sec, 1e-9),
|
| 992 |
+
prefill_sec=prefill_sec,
|
| 993 |
+
prefill_tps=(new_prefill / max(prefill_sec, 1e-9)),
|
| 994 |
+
cache_hit=cache_hit,
|
| 995 |
+
cache_reused_tokens=reused,
|
| 996 |
+
cache_new_prefill_tokens=new_prefill,
|
| 997 |
+
mode=mode,
|
| 998 |
+
backend=backend,
|
| 999 |
+
last_token=self.recent_tokens[-1]["text"] if self.recent_tokens else "",
|
| 1000 |
+
last_token_id=self.recent_tokens[-1]["id"] if self.recent_tokens else -1,
|
| 1001 |
+
last_token_prob=self.recent_tokens[-1]["prob"] if self.recent_tokens else 0.0,
|
| 1002 |
+
last_entropy=self.recent_tokens[-1]["entropy"] if self.recent_tokens else 0.0,
|
| 1003 |
+
finish_reason=finish_reason,
|
| 1004 |
+
pass1_pass2_kl=pass_kl,
|
| 1005 |
+
pass1_pass2_logit_cosine=pass_cos,
|
| 1006 |
+
)
|
| 1007 |
+
if final:
|
| 1008 |
+
yield final
|
| 1009 |
+
|
| 1010 |
+
def clear_session_cache(self) -> None:
|
| 1011 |
+
self.pass1_cache = None
|
| 1012 |
+
self.pass2_cache = None
|
| 1013 |
+
self.cache_ids = None
|
| 1014 |
+
self.cached_logits = None
|
| 1015 |
+
self.cached_pass1_logits = None
|
| 1016 |
+
self.cached_pass2_logits = None
|
| 1017 |
+
self.cache_backend = "cleared"
|
| 1018 |
+
|
| 1019 |
+
def stats_dict(self) -> Dict[str, Any]:
|
| 1020 |
+
return {"generation": asdict(self.last_stats), "model": self.model_info, "system": system_snapshot(self)}
|
| 1021 |
+
|
| 1022 |
+
def stats_text(self) -> str:
|
| 1023 |
+
s = self.last_stats
|
| 1024 |
+
lines = [
|
| 1025 |
+
f"Mode: {s.mode} | backend={s.backend}",
|
| 1026 |
+
f"Prompt tokens: {s.prompt_tokens} kept / {getattr(s, 'prompt_total_tokens', s.prompt_tokens)} total / {getattr(s, 'prompt_dropped_tokens', 0)} dropped",
|
| 1027 |
+
f"Generated tokens: {s.generated_tokens}",
|
| 1028 |
+
f"Elapsed: {s.elapsed_sec:.2f}s | prefill={s.prefill_sec:.2f}s ({s.prefill_tps:.1f} tok/s)",
|
| 1029 |
+
f"Decode speed: {s.tokens_per_sec:.2f} tok/s",
|
| 1030 |
+
f"Cache: hit={s.cache_hit} reused={s.cache_reused_tokens} new_prefill={s.cache_new_prefill_tokens}",
|
| 1031 |
+
f"Finish reason: {s.finish_reason}",
|
| 1032 |
+
f"Last token: {s.last_token!r} id={s.last_token_id} p={s.last_token_prob:.4f} H={s.last_entropy:.2f}",
|
| 1033 |
+
]
|
| 1034 |
+
if s.pass1_pass2_kl is not None:
|
| 1035 |
+
lines.append(f"Pass1→Pass2 KL: {s.pass1_pass2_kl:.6f}")
|
| 1036 |
+
if s.pass1_pass2_logit_cosine is not None:
|
| 1037 |
+
lines.append(f"Pass1/Pass2 cosine: {s.pass1_pass2_logit_cosine:.6f}")
|
| 1038 |
+
lines.extend([
|
| 1039 |
+
"",
|
| 1040 |
+
f"Checkpoint: {self.model_info['checkpoint']}",
|
| 1041 |
+
f"Checkpoint size: {self.model_info['checkpoint_size']}",
|
| 1042 |
+
f"Device: {self.model_info['device']} dtype={self.model_info['dtype']}",
|
| 1043 |
+
f"Pass2 active: {self.model_info['has_pass2']}",
|
| 1044 |
+
f"Params: {self.model_info['total_params']:,}",
|
| 1045 |
+
f"VWM c modules: {self.model_info['vwm_c_modules']} ({self.model_info['vwm_c_params']:,} c params)",
|
| 1046 |
+
])
|
| 1047 |
+
return "\n".join(lines)
|
| 1048 |
+
|
| 1049 |
+
def token_trace_text(self) -> str:
|
| 1050 |
+
if not self.recent_tokens:
|
| 1051 |
+
return "No tokens generated yet."
|
| 1052 |
+
rows = []
|
| 1053 |
+
for t in self.recent_tokens[-48:]:
|
| 1054 |
+
safe = repr(t["text"])[1:-1]
|
| 1055 |
+
rows.append(f"{t['i']:04d} id={t['id']:<7} p={t['prob']:.4f} H={t['entropy']:.2f} {safe}")
|
| 1056 |
+
return "\n".join(rows)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def system_snapshot(engine: Optional[LULUV2OptimizedEngine] = None) -> Dict[str, Any]:
|
| 1060 |
+
snap: Dict[str, Any] = {
|
| 1061 |
+
"python_ram": "n/a", "system_ram": "n/a", "system_ram_percent": 0.0,
|
| 1062 |
+
"cpu_percent": 0.0, "gpu_name": "CUDA unavailable", "vram_allocated": "n/a",
|
| 1063 |
+
"vram_reserved": "n/a", "vram_used": "n/a", "vram_total": "n/a",
|
| 1064 |
+
"vram_percent": 0.0, "gpu_util_percent": None, "gpu_temp_c": None,
|
| 1065 |
+
}
|
| 1066 |
+
if psutil is not None:
|
| 1067 |
+
try:
|
| 1068 |
+
proc = psutil.Process(os.getpid())
|
| 1069 |
+
vm = psutil.virtual_memory()
|
| 1070 |
+
snap.update({
|
| 1071 |
+
"python_ram": human_bytes(proc.memory_info().rss),
|
| 1072 |
+
"system_ram": f"{human_bytes(vm.used)} / {human_bytes(vm.total)}",
|
| 1073 |
+
"system_ram_percent": float(vm.percent),
|
| 1074 |
+
"cpu_percent": float(psutil.cpu_percent(interval=0.0)),
|
| 1075 |
+
})
|
| 1076 |
+
except Exception:
|
| 1077 |
+
pass
|
| 1078 |
+
if torch.cuda.is_available():
|
| 1079 |
+
try:
|
| 1080 |
+
idx = torch.cuda.current_device()
|
| 1081 |
+
props = torch.cuda.get_device_properties(idx)
|
| 1082 |
+
allocated = int(torch.cuda.memory_allocated(idx))
|
| 1083 |
+
reserved = int(torch.cuda.memory_reserved(idx))
|
| 1084 |
+
total = int(props.total_memory)
|
| 1085 |
+
snap.update({
|
| 1086 |
+
"gpu_name": props.name,
|
| 1087 |
+
"vram_allocated": human_bytes(allocated),
|
| 1088 |
+
"vram_reserved": human_bytes(reserved),
|
| 1089 |
+
"vram_used": human_bytes(allocated),
|
| 1090 |
+
"vram_total": human_bytes(total),
|
| 1091 |
+
"vram_percent": 100.0 * allocated / max(total, 1),
|
| 1092 |
+
})
|
| 1093 |
+
if pynvml is not None:
|
| 1094 |
+
try:
|
| 1095 |
+
pynvml.nvmlInit()
|
| 1096 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
| 1097 |
+
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| 1098 |
+
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 1099 |
+
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
|
| 1100 |
+
snap.update({
|
| 1101 |
+
"gpu_util_percent": int(util.gpu),
|
| 1102 |
+
"vram_used": human_bytes(int(mem.used)),
|
| 1103 |
+
"vram_total": human_bytes(int(mem.total)),
|
| 1104 |
+
"vram_percent": 100.0 * float(mem.used) / max(float(mem.total), 1.0),
|
| 1105 |
+
"gpu_temp_c": int(temp),
|
| 1106 |
+
})
|
| 1107 |
+
except Exception:
|
| 1108 |
+
pass
|
| 1109 |
+
except Exception:
|
| 1110 |
+
pass
|
| 1111 |
+
return snap
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
def system_usage(engine: Optional[LULUV2OptimizedEngine] = None) -> str:
|
| 1115 |
+
snap = system_snapshot(engine)
|
| 1116 |
+
lines = [
|
| 1117 |
+
f"OS: {platform.system()} {platform.release()}",
|
| 1118 |
+
f"Python RAM: {snap['python_ram']}",
|
| 1119 |
+
f"System RAM: {snap['system_ram']} ({snap['system_ram_percent']:.1f}%)",
|
| 1120 |
+
f"CPU: {snap['cpu_percent']:.1f}%",
|
| 1121 |
+
"",
|
| 1122 |
+
f"GPU: {snap['gpu_name']}",
|
| 1123 |
+
f"VRAM used: {snap['vram_used']} / {snap['vram_total']} ({snap['vram_percent']:.1f}%)",
|
| 1124 |
+
f"VRAM allocated: {snap['vram_allocated']}",
|
| 1125 |
+
f"VRAM reserved: {snap['vram_reserved']}",
|
| 1126 |
+
]
|
| 1127 |
+
if snap.get("gpu_util_percent") is not None:
|
| 1128 |
+
lines.append(f"GPU util: {snap['gpu_util_percent']}%")
|
| 1129 |
+
if snap.get("gpu_temp_c") is not None:
|
| 1130 |
+
lines.append(f"GPU temp: {snap['gpu_temp_c']} C")
|
| 1131 |
+
if engine is not None:
|
| 1132 |
+
lines.extend(["", engine.stats_text()])
|
| 1133 |
+
return "\n".join(lines)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1
|
| 2 |
+
tokenizers>=0.15
|
| 3 |
+
transformers>=4.40
|
| 4 |
+
gradio>=4.0
|
| 5 |
+
|
| 6 |
+
psutil>=5.9
|
| 7 |
+
nvidia-ml-py>=12.0; platform_system != "Darwin"
|
run_chat.ps1
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
$ErrorActionPreference = "Stop"
|
| 2 |
+
python .\app.py --ckpt .\LULUV2-bf16.pt --model-py .\luluv2_inference_runtime.py --tokenizer-dir .\tokenizer --inbrowser
|
run_chat.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
python ./app.py --ckpt ./LULUV2-bf16.pt --model-py ./luluv2_inference_runtime.py --tokenizer-dir ./tokenizer --inbrowser
|
run_inference.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Small CLI launcher for LULUV2 native-bf16 local inference."""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
import argparse
|
| 5 |
+
import torch
|
| 6 |
+
from luluv2_live_inference import LULUV2LiveEngine, GenerationConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
p = argparse.ArgumentParser("LULUV2 local inference")
|
| 11 |
+
p.add_argument("--ckpt", default="LULUV2-bf16.pt", help="Path to the native-bf16 checkpoint file")
|
| 12 |
+
p.add_argument("--tokenizer-dir", default="tokenizer", help="Local tokenizer directory")
|
| 13 |
+
p.add_argument("--prompt", required=True, help="User prompt")
|
| 14 |
+
p.add_argument("--system", default="You are LuluV2, a helpful local AI assistant.")
|
| 15 |
+
p.add_argument("--max-new-tokens", type=int, default=512)
|
| 16 |
+
p.add_argument("--temperature", type=float, default=0.65)
|
| 17 |
+
p.add_argument("--top-p", type=float, default=0.90)
|
| 18 |
+
p.add_argument("--top-k", type=int, default=40)
|
| 19 |
+
p.add_argument("--device", default=None)
|
| 20 |
+
p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
| 21 |
+
args = p.parse_args()
|
| 22 |
+
|
| 23 |
+
engine = LULUV2LiveEngine(
|
| 24 |
+
ckpt_path=args.ckpt,
|
| 25 |
+
model_py="luluv2_inference_runtime.py",
|
| 26 |
+
tokenizer_dir=args.tokenizer_dir,
|
| 27 |
+
device=args.device,
|
| 28 |
+
dtype=args.dtype,
|
| 29 |
+
local_files_only=True,
|
| 30 |
+
no_config_download=True,
|
| 31 |
+
)
|
| 32 |
+
cfg = GenerationConfig(
|
| 33 |
+
max_new_tokens=args.max_new_tokens,
|
| 34 |
+
temperature=args.temperature,
|
| 35 |
+
top_p=args.top_p,
|
| 36 |
+
top_k=args.top_k,
|
| 37 |
+
)
|
| 38 |
+
history = []
|
| 39 |
+
for text in engine.generate_stream(args.prompt, history, args.system, cfg):
|
| 40 |
+
print(text, end="", flush=True)
|
| 41 |
+
print()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
torch.set_grad_enabled(False)
|
| 46 |
+
main()
|