Upload 34 files
Browse files- .gitattributes +47 -35
- .gitignore +21 -0
- .python-version +1 -0
- Dockerfile +62 -0
- README.md +89 -12
- api_types.py +82 -0
- app.py +941 -0
- app_stderr.log +33 -0
- app_stdout.log +22 -0
- config.local.yaml +24 -0
- config.production-modelscope.yaml +24 -0
- config.production.yaml +24 -0
- config.py +84 -0
- cuda/gemm_fp16_cublas.cpp +75 -0
- cuda/operators.cu +246 -0
- cuda/rwkv5.cu +88 -0
- cuda/rwkv5_op.cpp +34 -0
- cuda/rwkv6.cu +87 -0
- cuda/rwkv6_op.cpp +34 -0
- cuda/rwkv7.cu +77 -0
- cuda/rwkv7_op.cpp +26 -0
- cuda/wrapper.cpp +141 -0
- download_models.py +62 -0
- models/.cache/huggingface/.gitignore +1 -0
- models/.cache/huggingface/download/rwkv7-g1a-0.1b-20250728-ctx4096.pth.metadata +3 -0
- models/rwkv7-g1a-0.1b-20250728-ctx4096.pth +3 -0
- pyproject.toml +45 -0
- run_windows.ps1 +14 -0
- setup_windows.ps1 +82 -0
- tests/api_test.py +85 -0
- tests/run_local_exec.py +14 -0
- utils.py +177 -0
- uv.lock +0 -0
- verify_setup.py +54 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,47 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.
|
| 5 |
-
*.
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.
|
| 12 |
-
*.
|
| 13 |
-
*.
|
| 14 |
-
*.
|
| 15 |
-
*.
|
| 16 |
-
*.
|
| 17 |
-
*.
|
| 18 |
-
*.
|
| 19 |
-
*.
|
| 20 |
-
|
| 21 |
-
*.
|
| 22 |
-
*.
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*.
|
| 26 |
-
|
| 27 |
-
*.
|
| 28 |
-
*.
|
| 29 |
-
*.
|
| 30 |
-
|
| 31 |
-
*.
|
| 32 |
-
*.
|
| 33 |
-
*.
|
| 34 |
-
*.
|
| 35 |
-
*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.gguf* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.ggml filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.llamafile* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
.cache
|
| 13 |
+
|
| 14 |
+
*pth
|
| 15 |
+
*.pt
|
| 16 |
+
*.st
|
| 17 |
+
*local*
|
| 18 |
+
|
| 19 |
+
dist-frontend/
|
| 20 |
+
|
| 21 |
+
.vscode/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.10
|
Dockerfile
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM node:20-alpine AS FrontendBuilder
|
| 2 |
+
|
| 3 |
+
RUN apk update && apk upgrade && \
|
| 4 |
+
apk add --no-cache bash git openssh curl rust cargo
|
| 5 |
+
|
| 6 |
+
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
| 7 |
+
RUN npm install -g pnpm
|
| 8 |
+
|
| 9 |
+
ADD https://api.github.com/repos/SolomonLeon/web-rwkv-realweb/git/refs/heads/ version_1.json
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
RUN git clone https://github.com/SolomonLeon/web-rwkv-realweb.git /app
|
| 13 |
+
|
| 14 |
+
WORKDIR /app/web-rwkv-wasm
|
| 15 |
+
RUN ["cargo", "install", "wasm-pack", "--locked"]
|
| 16 |
+
|
| 17 |
+
WORKDIR /app
|
| 18 |
+
ENV PATH=/root/.cargo/bin:$PATH
|
| 19 |
+
RUN pnpm install
|
| 20 |
+
RUN if [ "$MODELSCOPE_ENVIRONMENT" = "studio" ]; then \
|
| 21 |
+
pnpm run build --mode target-rwkv-modelscope-space; \
|
| 22 |
+
else \
|
| 23 |
+
pnpm run build --mode target-rwkv-hf-space; \
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS Backend
|
| 27 |
+
|
| 28 |
+
RUN <<EOF
|
| 29 |
+
apt update
|
| 30 |
+
apt install --no-install-recommends -y \
|
| 31 |
+
build-essential \
|
| 32 |
+
git \
|
| 33 |
+
cuda-nvcc-12-4 \
|
| 34 |
+
cuda-cudart-dev-12-4 \
|
| 35 |
+
python3-dev \
|
| 36 |
+
python3-pip \
|
| 37 |
+
libpython3.10-dev
|
| 38 |
+
apt clean && rm -rf /var/lib/apt/lists/*
|
| 39 |
+
EOF
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 43 |
+
|
| 44 |
+
COPY . .
|
| 45 |
+
|
| 46 |
+
RUN useradd -m -u 1000 user
|
| 47 |
+
USER user
|
| 48 |
+
|
| 49 |
+
ENV HOME=/home/user \
|
| 50 |
+
PATH=/usr/local/cuda/bin:/home/user/.local/bin:$PATH \
|
| 51 |
+
LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" \
|
| 52 |
+
CXX=/usr/bin/g++ \
|
| 53 |
+
TORCH_CUDA_ARCH_LIST="7.5"
|
| 54 |
+
WORKDIR $HOME/app
|
| 55 |
+
|
| 56 |
+
COPY --chown=user . $HOME/app
|
| 57 |
+
|
| 58 |
+
COPY --chown=user --from=FrontendBuilder /app/dist $HOME/app/dist-frontend
|
| 59 |
+
|
| 60 |
+
RUN uv sync --frozen --extra cu124
|
| 61 |
+
|
| 62 |
+
CMD ["sh", "-c", "if [ \"$MODELSCOPE_ENVIRONMENT\" = \"studio\" ]; then CONFIG_FILE=\"./config.production-modelscope.yaml\"; else CONFIG_FILE=\"./config.production.yaml\"; fi; uv run --offline --frozen app.py --config_file \"$CONFIG_FILE\""]
|
README.md
CHANGED
|
@@ -1,12 +1,89 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
|
| 3 |
+
title: RWKV HF Space
|
| 4 |
+
emoji: 🐦⬛
|
| 5 |
+
colorFrom: purple
|
| 6 |
+
colorTo: pink
|
| 7 |
+
sdk: docker
|
| 8 |
+
pinned: false
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Simple RWKV OpenAI-Compatible API
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
title: RWKV HF Space
|
| 17 |
+
emoji: 🐦⬛
|
| 18 |
+
colorFrom: purple
|
| 19 |
+
colorTo: pink
|
| 20 |
+
sdk: docker
|
| 21 |
+
pinned: false
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
# Simple RWKV OpenAI-Compatible API
|
| 26 |
+
|
| 27 |
+
## Quick Windows Setup (no Docker)
|
| 28 |
+
|
| 29 |
+
This repository was originally packaged with a Dockerfile. It now provides a `setup_windows.ps1` script that mirrors Dockerfile actions and sets up the service locally on Windows (installs Python dependencies, builds the frontend, and downloads the 0.1B model).
|
| 30 |
+
|
| 31 |
+
Prerequisites:
|
| 32 |
+
- Python 3.10+ installed and in PATH
|
| 33 |
+
- Node.js + npm (optional, required for building the frontend)
|
| 34 |
+
- (Optional) NVIDIA GPU and CUDA (for GPU runtime)
|
| 35 |
+
|
| 36 |
+
To setup locally on Windows (CPU-only):
|
| 37 |
+
|
| 38 |
+
```powershell
|
| 39 |
+
.\setup_windows.ps1 -gpu:$false -buildFrontend:$true -CONFIG_FILE config.production.yaml
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
If you have a compatible NVIDIA GPU and prefer to install GPU-enabled dependencies, run with the `-gpu` switch.
|
| 43 |
+
|
| 44 |
+
After setup, run the API:
|
| 45 |
+
|
| 46 |
+
```powershell
|
| 47 |
+
#$env:CONFIG_FILE='config.production.yaml'
|
| 48 |
+
python app.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
The default production config in `config.production.yaml` now contains a single model — the 0.1B model `rwkv7-g1a-0.1b-20250728-ctx4096` — set as default chat and reasoning model.
|
| 52 |
+
|
| 53 |
+
To download models defined in any config:
|
| 54 |
+
|
| 55 |
+
```powershell
|
| 56 |
+
python download_models.py --config config.production.yaml
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
This will store the downloaded .pth files under the `DOWNLOAD_MODEL_DIR` specified in the YAML (defaults to `./models`).
|
| 60 |
+
|
| 61 |
+
Advanced features:
|
| 62 |
+
- `reasoning` is performed in-process by the same model (no external reasoning model is used). Use a request model like `rwkv-latest:thinking` or set the reasoning suffix and the requested model will run reasoning in the same model.
|
| 63 |
+
- `web_search` functionality is available at the request level — set `web_search: true` and optionally `search_top_k` to inject search results from DuckDuckGo into the prompt. This is executed by the server and provided to the same model as context.
|
| 64 |
+
- `tools` are executed server-side and results injected into the prompt for the same model. Supported tools: `web_search` and `calc` (calculator). Example of `tools` usage:
|
| 65 |
+
|
| 66 |
+
```json
|
| 67 |
+
{
|
| 68 |
+
"model": "rwkv-latest",
|
| 69 |
+
"prompt": "Calculate 2+3*4 and tell me the result",
|
| 70 |
+
"tools": [{"name": "calc", "args": {"expression": "2+3*4"}}]
|
| 71 |
+
}
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Example: POST with `web_search` and reasoning enabled
|
| 75 |
+
|
| 76 |
+
```json
|
| 77 |
+
{
|
| 78 |
+
"model": "rwkv-latest:thinking",
|
| 79 |
+
"prompt": "Who is the current president of France?",
|
| 80 |
+
"max_tokens": 32,
|
| 81 |
+
"web_search": true,
|
| 82 |
+
"search_top_k": 3
|
| 83 |
+
}
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
The server will perform a web search for the prompt, aggregate the top 3 results, and inject those into the prompt, then run the model with reasoning enabled — all using the same model instead of an external reasoning or search model.
|
| 87 |
+
|
| 88 |
+
Streaming behavior:
|
| 89 |
+
- The API streams responses token-by-token by default (`stream: true`) and persists a `state_name` for the generation if requested (or will generate one). Provide `state_name` to resume continuation from where the previous stream stopped. The server stores model state in memory under `(model, state_name)` so subsequent requests with the same `state_name` can continue generation from that exact point.
|
api_types.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any, Literal
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ChatMessage(BaseModel):
|
| 6 |
+
role: str = Field()
|
| 7 |
+
content: str = Field()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Logprob(BaseModel):
|
| 11 |
+
token: str
|
| 12 |
+
logprob: float
|
| 13 |
+
top_logprobs: Optional[List[Dict[str, Any]]] = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LogprobsContent(BaseModel):
|
| 17 |
+
content: Optional[List[Logprob]] = None
|
| 18 |
+
refusal: Optional[List[Logprob]] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FunctionCall(BaseModel):
|
| 22 |
+
name: str
|
| 23 |
+
arguments: str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ChatCompletionMessage(BaseModel):
|
| 27 |
+
role: Optional[str] = Field(
|
| 28 |
+
None, description="The role of the author of this message"
|
| 29 |
+
)
|
| 30 |
+
content: Optional[str] = Field(None, description="The contents of the message")
|
| 31 |
+
reasoning_content: Optional[str] = Field(
|
| 32 |
+
None, description="The reasoning contents of the message"
|
| 33 |
+
)
|
| 34 |
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
| 35 |
+
None, description="Tool calls generated by the model"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class PromptTokensDetails(BaseModel):
|
| 40 |
+
cached_tokens: int
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CompletionTokensDetails(BaseModel):
|
| 44 |
+
reasoning_tokens: int
|
| 45 |
+
accepted_prediction_tokens: int
|
| 46 |
+
rejected_prediction_tokens: int
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Usage(BaseModel):
|
| 50 |
+
prompt_tokens: int
|
| 51 |
+
completion_tokens: int
|
| 52 |
+
total_tokens: int
|
| 53 |
+
prompt_tokens_details: Optional[PromptTokensDetails]
|
| 54 |
+
# completion_tokens_details: CompletionTokensDetails
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ChatCompletionChoice(BaseModel):
|
| 58 |
+
index: int
|
| 59 |
+
message: Optional[ChatCompletionMessage] = None
|
| 60 |
+
delta: Optional[ChatCompletionMessage] = None
|
| 61 |
+
logprobs: Optional[LogprobsContent] = None
|
| 62 |
+
finish_reason: Optional[str] = Field(
|
| 63 |
+
..., description="Reason for stopping: stop, length, content_filter, tool_calls"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ChatCompletion(BaseModel):
|
| 68 |
+
id: str = Field(..., description="Unique identifier for the chat completion")
|
| 69 |
+
object: Literal["chat.completion"] = "chat.completion"
|
| 70 |
+
created: int = Field(..., description="Unix timestamp of creation")
|
| 71 |
+
model: str
|
| 72 |
+
choices: List[ChatCompletionChoice]
|
| 73 |
+
usage: Usage
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ChatCompletionChunk(BaseModel):
|
| 77 |
+
id: str = Field(..., description="Unique identifier for the chat completion")
|
| 78 |
+
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
| 79 |
+
created: int = Field(..., description="Unix timestamp of creation")
|
| 80 |
+
model: str
|
| 81 |
+
choices: List[ChatCompletionChoice]
|
| 82 |
+
usage: Optional[Usage]
|
app.py
ADDED
|
@@ -0,0 +1,941 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
|
| 4 |
+
from modelscope import patch_hub
|
| 5 |
+
|
| 6 |
+
patch_hub()
|
| 7 |
+
|
| 8 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from config import CONFIG, ModelConfig
|
| 12 |
+
from utils import (
|
| 13 |
+
cleanMessages,
|
| 14 |
+
parse_think_response,
|
| 15 |
+
remove_nested_think_tags_stack,
|
| 16 |
+
format_bytes,
|
| 17 |
+
log,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
import copy, types, gc, sys, re, time, collections, asyncio
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
from loguru import logger
|
| 23 |
+
from rich import print
|
| 24 |
+
|
| 25 |
+
from snowflake import SnowflakeGenerator
|
| 26 |
+
|
| 27 |
+
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
| 28 |
+
|
| 29 |
+
from typing import List, Optional, Union, Any, Dict
|
| 30 |
+
import uuid
|
| 31 |
+
from pydantic import BaseModel, Field, model_validator
|
| 32 |
+
from pydantic_settings import BaseSettings
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import torch
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
|
| 40 |
+
logger.info(f"CUDA not found, fall back to cpu")
|
| 41 |
+
CONFIG.STRATEGY = "cpu fp16"
|
| 42 |
+
# Normalize STRATEGY to include precision if missing (e.g., 'cpu' -> 'cpu fp16')
|
| 43 |
+
_s = CONFIG.STRATEGY.lower()
|
| 44 |
+
if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s):
|
| 45 |
+
logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`")
|
| 46 |
+
CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
|
| 51 |
+
except Exception:
|
| 52 |
+
nvmlInit = None
|
| 53 |
+
nvmlDeviceGetHandleByIndex = None
|
| 54 |
+
nvmlDeviceGetMemoryInfo = None
|
| 55 |
+
|
| 56 |
+
if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None:
|
| 57 |
+
nvmlInit()
|
| 58 |
+
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def logGPUState():
|
| 62 |
+
if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None:
|
| 63 |
+
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 64 |
+
logger.info(
|
| 65 |
+
f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
torch.backends.cudnn.benchmark = True
|
| 70 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 71 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 72 |
+
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
|
| 73 |
+
os.environ["RWKV_JIT_ON"] = "1"
|
| 74 |
+
os.environ["RWKV_CUDA_ON"] = (
|
| 75 |
+
"1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
from rwkv.model import RWKV
|
| 79 |
+
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 80 |
+
|
| 81 |
+
from fastapi import FastAPI, HTTPException
|
| 82 |
+
from starlette.background import BackgroundTask
|
| 83 |
+
from fastapi.responses import StreamingResponse
|
| 84 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 85 |
+
from fastapi.staticfiles import StaticFiles
|
| 86 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
from api_types import (
|
| 90 |
+
ChatMessage,
|
| 91 |
+
ChatCompletion,
|
| 92 |
+
ChatCompletionChunk,
|
| 93 |
+
Usage,
|
| 94 |
+
PromptTokensDetails,
|
| 95 |
+
ChatCompletionChoice,
|
| 96 |
+
ChatCompletionMessage,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ModelStorage:
|
| 101 |
+
MODEL_CONFIG: Optional[ModelConfig] = None
|
| 102 |
+
model: Optional[RWKV] = None
|
| 103 |
+
pipeline: Optional[PIPELINE] = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
MODEL_STORAGE: Dict[str, ModelStorage] = {}
|
| 107 |
+
|
| 108 |
+
DEFALUT_MODEL_NAME = None
|
| 109 |
+
DEFAULT_REASONING_MODEL_NAME = None
|
| 110 |
+
|
| 111 |
+
# In-memory model state store to support streaming continuation/resume per state_name.
|
| 112 |
+
# Keys: (model_name, state_name) -> model_state object
|
| 113 |
+
STATE_STORE: Dict[tuple, Any] = {}
|
| 114 |
+
|
| 115 |
+
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
| 116 |
+
|
| 117 |
+
logGPUState()
|
| 118 |
+
|
| 119 |
+
# Enforce single 0.1b model. If multiple models are present, select only the one
|
| 120 |
+
# that matches '0.1b' literally in the service name, to obey policy of single model.
|
| 121 |
+
filtered_models = [m for m in CONFIG.MODELS if '0.1b' in m.SERVICE_NAME]
|
| 122 |
+
if len(filtered_models) == 0:
|
| 123 |
+
# If no explicit 0.1b model detected, fall back to the first provided model but warn.
|
| 124 |
+
logger.warning("No '0.1b' model detected in config; using the first available model. To ensure single 0.1b use, include a model name with '0.1b'.")
|
| 125 |
+
CONFIG.MODELS = [CONFIG.MODELS[0]]
|
| 126 |
+
elif len(filtered_models) > 1:
|
| 127 |
+
logger.warning("Multiple '0.1b' models detected; selecting the first one as the single model.")
|
| 128 |
+
CONFIG.MODELS = [filtered_models[0]]
|
| 129 |
+
else:
|
| 130 |
+
CONFIG.MODELS = [filtered_models[0]]
|
| 131 |
+
|
| 132 |
+
for model_config in CONFIG.MODELS:
|
| 133 |
+
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
|
| 134 |
+
|
| 135 |
+
if model_config.MODEL_FILE_PATH == None:
|
| 136 |
+
model_config.MODEL_FILE_PATH = hf_hub_download(
|
| 137 |
+
repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID),
|
| 138 |
+
filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME),
|
| 139 |
+
local_dir=str(model_config.DOWNLOAD_MODEL_DIR),
|
| 140 |
+
)
|
| 141 |
+
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
|
| 142 |
+
|
| 143 |
+
if model_config.DEFAULT_CHAT:
|
| 144 |
+
if DEFALUT_MODEL_NAME != None:
|
| 145 |
+
logger.info(
|
| 146 |
+
f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
|
| 147 |
+
)
|
| 148 |
+
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
| 149 |
+
|
| 150 |
+
if model_config.DEFAULT_REASONING:
|
| 151 |
+
if DEFAULT_REASONING_MODEL_NAME != None:
|
| 152 |
+
logger.info(
|
| 153 |
+
f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
|
| 154 |
+
)
|
| 155 |
+
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
|
| 156 |
+
|
| 157 |
+
logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
|
| 158 |
+
print(model_config.DEFAULT_SAMPLER)
|
| 159 |
+
|
| 160 |
+
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
| 161 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
| 162 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
|
| 163 |
+
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
|
| 164 |
+
strategy=CONFIG.STRATEGY,
|
| 165 |
+
)
|
| 166 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
|
| 167 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
|
| 168 |
+
)
|
| 169 |
+
if "cuda" in CONFIG.STRATEGY:
|
| 170 |
+
torch.cuda.empty_cache()
|
| 171 |
+
gc.collect()
|
| 172 |
+
logGPUState()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
|
| 176 |
+
logger.info(
|
| 177 |
+
f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if len(MODEL_STORAGE) == 1:
|
| 181 |
+
single_name = list(MODEL_STORAGE.keys())[0]
|
| 182 |
+
if DEFALUT_MODEL_NAME != single_name:
|
| 183 |
+
DEFALUT_MODEL_NAME = single_name
|
| 184 |
+
logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`")
|
| 185 |
+
if DEFAULT_REASONING_MODEL_NAME != single_name:
|
| 186 |
+
DEFAULT_REASONING_MODEL_NAME = single_name
|
| 187 |
+
logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ChatCompletionRequest(BaseModel):
|
| 191 |
+
model: str = Field(
|
| 192 |
+
default="rwkv-latest",
|
| 193 |
+
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
|
| 194 |
+
)
|
| 195 |
+
messages: Optional[List[ChatMessage]] = Field(default=None)
|
| 196 |
+
prompt: Optional[str] = Field(default=None)
|
| 197 |
+
max_tokens: Optional[int] = Field(default=None)
|
| 198 |
+
temperature: Optional[float] = Field(default=None)
|
| 199 |
+
top_p: Optional[float] = Field(default=None)
|
| 200 |
+
presence_penalty: Optional[float] = Field(default=None)
|
| 201 |
+
count_penalty: Optional[float] = Field(default=None)
|
| 202 |
+
penalty_decay: Optional[float] = Field(default=None)
|
| 203 |
+
stream: Optional[bool] = Field(default=True, description="Whether to stream token-by-token responses by default")
|
| 204 |
+
state_name: Optional[str] = Field(default=None)
|
| 205 |
+
include_usage: Optional[bool] = Field(default=False)
|
| 206 |
+
stop: Optional[list[str]] = Field(["\n\n"])
|
| 207 |
+
stop_tokens: Optional[list[int]] = Field([0])
|
| 208 |
+
web_search: Optional[bool] = Field(default=False, description="Whether to perform a web search and append results to the prompt")
|
| 209 |
+
search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve")
|
| 210 |
+
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})")
|
| 211 |
+
|
| 212 |
+
@model_validator(mode="before")
|
| 213 |
+
@classmethod
|
| 214 |
+
def validate_mutual_exclusivity(cls, data: Any) -> Any:
|
| 215 |
+
if not isinstance(data, dict):
|
| 216 |
+
return data
|
| 217 |
+
|
| 218 |
+
messages_provided = "messages" in data and data["messages"] != None
|
| 219 |
+
prompt_provided = "prompt" in data and data["prompt"] != None
|
| 220 |
+
|
| 221 |
+
if messages_provided and prompt_provided:
|
| 222 |
+
raise ValueError("messages and prompt cannot coexist. Choose one.")
|
| 223 |
+
if not messages_provided and not prompt_provided:
|
| 224 |
+
raise ValueError("Either messages or prompt must be provided.")
|
| 225 |
+
return data
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
app = FastAPI(title="RWKV OpenAI-Compatible API")
|
| 229 |
+
|
| 230 |
+
app.add_middleware(
|
| 231 |
+
CORSMiddleware,
|
| 232 |
+
allow_origins=["*"],
|
| 233 |
+
allow_credentials=True,
|
| 234 |
+
allow_methods=["*"],
|
| 235 |
+
allow_headers=["*"],
|
| 236 |
+
)
|
| 237 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
async def runPrefill(
|
| 241 |
+
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
|
| 242 |
+
):
|
| 243 |
+
ctx = ctx.replace("\r\n", "\n")
|
| 244 |
+
out = None
|
| 245 |
+
|
| 246 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 247 |
+
if not ms or not ms.pipeline or not ms.model:
|
| 248 |
+
raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
|
| 249 |
+
tokens = ms.pipeline.encode(ctx)
|
| 250 |
+
tokens = [int(x) for x in tokens]
|
| 251 |
+
model_tokens += tokens
|
| 252 |
+
|
| 253 |
+
while len(tokens) > 0:
|
| 254 |
+
out, model_state = ms.model.forward(
|
| 255 |
+
tokens[: CONFIG.CHUNK_LEN], model_state
|
| 256 |
+
)
|
| 257 |
+
tokens = tokens[CONFIG.CHUNK_LEN :]
|
| 258 |
+
await asyncio.sleep(0)
|
| 259 |
+
|
| 260 |
+
return out, model_tokens, model_state
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def generate(
|
| 264 |
+
request: ChatCompletionRequest,
|
| 265 |
+
out,
|
| 266 |
+
model_tokens: List[int],
|
| 267 |
+
model_state,
|
| 268 |
+
max_tokens=2048,
|
| 269 |
+
):
|
| 270 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 271 |
+
if not ms or not ms.pipeline or not ms.model:
|
| 272 |
+
raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
|
| 273 |
+
|
| 274 |
+
temperature = request.temperature if request.temperature is not None else 0.2
|
| 275 |
+
top_p = request.top_p if request.top_p is not None else 0.9
|
| 276 |
+
alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0
|
| 277 |
+
alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0
|
| 278 |
+
penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5
|
| 279 |
+
|
| 280 |
+
args = PIPELINE_ARGS(
|
| 281 |
+
temperature=max(0.2, temperature),
|
| 282 |
+
top_p=top_p,
|
| 283 |
+
alpha_frequency=alpha_frequency,
|
| 284 |
+
alpha_presence=alpha_presence,
|
| 285 |
+
token_ban=[], # ban the generation of some tokens
|
| 286 |
+
token_stop=[0],
|
| 287 |
+
) # stop generation whenever you see any token here
|
| 288 |
+
|
| 289 |
+
occurrence = {}
|
| 290 |
+
out_tokens: List[int] = []
|
| 291 |
+
out_last = 0
|
| 292 |
+
|
| 293 |
+
# Stream token-by-token; each chunk contains a single decoded token string.
|
| 294 |
+
|
| 295 |
+
for i in range(max_tokens):
|
| 296 |
+
for n in occurrence:
|
| 297 |
+
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
| 298 |
+
# out[0] -= 1e10 # disable END_OF_TEXT
|
| 299 |
+
|
| 300 |
+
token = ms.pipeline.sample_logits(
|
| 301 |
+
out, temperature=args.temperature, top_p=args.top_p
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if token == 0 and request.stop_tokens and token in request.stop_tokens:
|
| 305 |
+
yield {
|
| 306 |
+
"content": "",
|
| 307 |
+
"tokens": out_tokens[out_last:],
|
| 308 |
+
"finish_reason": "stop:token:0",
|
| 309 |
+
"state": model_state,
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
del out
|
| 313 |
+
gc.collect()
|
| 314 |
+
return
|
| 315 |
+
|
| 316 |
+
out, model_state = ms.model.forward([token], model_state)
|
| 317 |
+
model_tokens.append(token)
|
| 318 |
+
out_tokens.append(token)
|
| 319 |
+
|
| 320 |
+
if request.stop_tokens and token in request.stop_tokens:
|
| 321 |
+
yield {
|
| 322 |
+
"content": "",
|
| 323 |
+
"tokens": out_tokens[out_last:],
|
| 324 |
+
"finish_reason": f"stop:token:{token}",
|
| 325 |
+
"state": model_state,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
del out
|
| 329 |
+
gc.collect()
|
| 330 |
+
return
|
| 331 |
+
|
| 332 |
+
for xxx in list(occurrence.keys()):
|
| 333 |
+
occurrence[xxx] *= penalty_decay
|
| 334 |
+
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
| 335 |
+
|
| 336 |
+
# Decode token to text and yield it as a single-token chunk
|
| 337 |
+
decoded = ms.pipeline.decode([token])
|
| 338 |
+
# filter out replacement characters
|
| 339 |
+
if "\ufffd" in decoded:
|
| 340 |
+
continue
|
| 341 |
+
|
| 342 |
+
yield {
|
| 343 |
+
"content": decoded,
|
| 344 |
+
"tokens": [token],
|
| 345 |
+
"finish_reason": None,
|
| 346 |
+
"state": model_state,
|
| 347 |
+
}
|
| 348 |
+
out_last = i + 1
|
| 349 |
+
|
| 350 |
+
else:
|
| 351 |
+
yield {
|
| 352 |
+
"content": "",
|
| 353 |
+
"tokens": [],
|
| 354 |
+
"finish_reason": "length",
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
async def chatResponse(
|
| 359 |
+
request: ChatCompletionRequest,
|
| 360 |
+
model_state: Any,
|
| 361 |
+
completionId: str,
|
| 362 |
+
enableReasoning: bool,
|
| 363 |
+
) -> ChatCompletion:
|
| 364 |
+
createTimestamp = time.time()
|
| 365 |
+
|
| 366 |
+
prompt = (
|
| 367 |
+
f"{cleanMessages(request.messages or [])}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 368 |
+
if request.prompt == None
|
| 369 |
+
else request.prompt.strip()
|
| 370 |
+
)
|
| 371 |
+
# Process tools and web_search (tools executed server-side and results injected to prompt)
|
| 372 |
+
if request.tools:
|
| 373 |
+
try:
|
| 374 |
+
for tool in request.tools:
|
| 375 |
+
name = tool.get('name')
|
| 376 |
+
args = tool.get('args', {})
|
| 377 |
+
if name == 'web_search':
|
| 378 |
+
from utils import web_search
|
| 379 |
+
|
| 380 |
+
search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
| 381 |
+
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
| 382 |
+
search_str = web_search(search_q, search_top_k)
|
| 383 |
+
if search_str:
|
| 384 |
+
prompt = (f"ToolResults:\n{search_str}\n\nUse these results to answer the prompt.\n\n" + prompt)
|
| 385 |
+
elif name == 'calc' or name == 'calculator':
|
| 386 |
+
from utils import calc
|
| 387 |
+
|
| 388 |
+
expr = args.get('expression')
|
| 389 |
+
if expr:
|
| 390 |
+
calc_res = calc(expr)
|
| 391 |
+
prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res}\n\nUse this result to answer the prompt.\n\n" + prompt)
|
| 392 |
+
else:
|
| 393 |
+
# Unsupported tool - ignore or log
|
| 394 |
+
logger.info(f"Unsupported tool requested: {name}")
|
| 395 |
+
except Exception as e:
|
| 396 |
+
logger.info(f"Tool processing error: {e}")
|
| 397 |
+
elif request.web_search:
|
| 398 |
+
try:
|
| 399 |
+
from utils import web_search
|
| 400 |
+
|
| 401 |
+
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
| 402 |
+
search_res = web_search(search_q, int(request.search_top_k or 3))
|
| 403 |
+
if search_res:
|
| 404 |
+
prompt = f"WebSearchResults:\n{search_res}\n\n" + prompt
|
| 405 |
+
except Exception:
|
| 406 |
+
pass
|
| 407 |
+
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
| 408 |
+
|
| 409 |
+
# Resume or prefill tokens/state
|
| 410 |
+
if request.state_name:
|
| 411 |
+
state_key = (request.model, request.state_name)
|
| 412 |
+
if state_key in STATE_STORE:
|
| 413 |
+
stored = STATE_STORE[state_key]
|
| 414 |
+
model_state = stored.get('state', model_state)
|
| 415 |
+
model_tokens = stored.get('model_tokens', [0])
|
| 416 |
+
out = None
|
| 417 |
+
else:
|
| 418 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 419 |
+
else:
|
| 420 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 421 |
+
|
| 422 |
+
prefillTime = time.time()
|
| 423 |
+
promptTokenCount = len(model_tokens)
|
| 424 |
+
|
| 425 |
+
fullResponse = " <think" if enableReasoning else ""
|
| 426 |
+
completionTokenCount = 0
|
| 427 |
+
finishReason = None
|
| 428 |
+
|
| 429 |
+
for chunk in generate(
|
| 430 |
+
request,
|
| 431 |
+
out,
|
| 432 |
+
model_tokens,
|
| 433 |
+
model_state,
|
| 434 |
+
max_tokens=(
|
| 435 |
+
64000
|
| 436 |
+
if "max_tokens" not in request.model_fields_set and enableReasoning
|
| 437 |
+
else (request.max_tokens or 2048)
|
| 438 |
+
),
|
| 439 |
+
):
|
| 440 |
+
# chunk['content'] is now expected to be a single token's decoded text
|
| 441 |
+
fullResponse += chunk["content"]
|
| 442 |
+
# Check stop sequences (multi-token) after each token
|
| 443 |
+
for stop_words in request.stop or []:
|
| 444 |
+
if stop_words in fullResponse:
|
| 445 |
+
finishReason = f"stop:words:{stop_words}"
|
| 446 |
+
break
|
| 447 |
+
completionTokenCount += 1
|
| 448 |
+
|
| 449 |
+
if chunk["finish_reason"]:
|
| 450 |
+
finishReason = chunk["finish_reason"]
|
| 451 |
+
await asyncio.sleep(0)
|
| 452 |
+
|
| 453 |
+
genenrateTime = time.time()
|
| 454 |
+
|
| 455 |
+
responseLog = {
|
| 456 |
+
"content": fullResponse,
|
| 457 |
+
"finish": finishReason,
|
| 458 |
+
"prefill_len": promptTokenCount,
|
| 459 |
+
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
|
| 460 |
+
"gen_len": completionTokenCount,
|
| 461 |
+
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
|
| 462 |
+
}
|
| 463 |
+
logger.info(f"[RES] {completionId} - {responseLog}")
|
| 464 |
+
|
| 465 |
+
reasoning_content, content = parse_think_response(fullResponse)
|
| 466 |
+
|
| 467 |
+
response = ChatCompletion(
|
| 468 |
+
id=completionId,
|
| 469 |
+
created=int(createTimestamp),
|
| 470 |
+
model=request.model,
|
| 471 |
+
usage=Usage(
|
| 472 |
+
prompt_tokens=promptTokenCount,
|
| 473 |
+
completion_tokens=completionTokenCount,
|
| 474 |
+
total_tokens=promptTokenCount + completionTokenCount,
|
| 475 |
+
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
| 476 |
+
),
|
| 477 |
+
choices=[
|
| 478 |
+
ChatCompletionChoice(
|
| 479 |
+
index=0,
|
| 480 |
+
message=ChatCompletionMessage(
|
| 481 |
+
role="Assistant",
|
| 482 |
+
content=content,
|
| 483 |
+
reasoning_content=reasoning_content if reasoning_content else None,
|
| 484 |
+
tool_calls=None,
|
| 485 |
+
),
|
| 486 |
+
logprobs=None,
|
| 487 |
+
finish_reason=finishReason,
|
| 488 |
+
)
|
| 489 |
+
],
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# Save state if requested for future resumption
|
| 493 |
+
try:
|
| 494 |
+
if request.state_name:
|
| 495 |
+
STATE_STORE[(request.model, request.state_name)] = {
|
| 496 |
+
'state': model_state,
|
| 497 |
+
'model_tokens': model_tokens,
|
| 498 |
+
}
|
| 499 |
+
except Exception:
|
| 500 |
+
pass
|
| 501 |
+
|
| 502 |
+
return response
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
async def chatResponseStream(
|
| 506 |
+
request: ChatCompletionRequest,
|
| 507 |
+
model_state: Any,
|
| 508 |
+
completionId: str,
|
| 509 |
+
enableReasoning: bool,
|
| 510 |
+
):
|
| 511 |
+
createTimestamp = int(time.time())
|
| 512 |
+
|
| 513 |
+
prompt = (
|
| 514 |
+
f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 515 |
+
if request.prompt == None
|
| 516 |
+
else request.prompt.strip()
|
| 517 |
+
)
|
| 518 |
+
# Process tools and web_search (tools executed server-side and results injected to prompt)
|
| 519 |
+
if request.tools:
|
| 520 |
+
try:
|
| 521 |
+
for tool in request.tools:
|
| 522 |
+
name = tool.get('name')
|
| 523 |
+
args = tool.get('args', {})
|
| 524 |
+
if name == 'web_search':
|
| 525 |
+
from utils import web_search
|
| 526 |
+
|
| 527 |
+
search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
| 528 |
+
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
| 529 |
+
search_str = web_search(search_q, search_top_k)
|
| 530 |
+
if search_str:
|
| 531 |
+
prompt = (f"WebSearchResults:\n{search_str}\n\n" + prompt)
|
| 532 |
+
elif name == 'calc' or name == 'calculator':
|
| 533 |
+
from utils import calc
|
| 534 |
+
|
| 535 |
+
expr = args.get('expression')
|
| 536 |
+
if expr:
|
| 537 |
+
calc_res = calc(expr)
|
| 538 |
+
prompt = (f"CalcResult:{expr} = {calc_res}\n\n" + prompt)
|
| 539 |
+
else:
|
| 540 |
+
logger.info(f"Unsupported tool requested: {name}")
|
| 541 |
+
except Exception as e:
|
| 542 |
+
logger.info(f"Tool processing error: {e}")
|
| 543 |
+
elif request.web_search:
|
| 544 |
+
try:
|
| 545 |
+
from utils import web_search
|
| 546 |
+
|
| 547 |
+
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
| 548 |
+
search_res = web_search(search_q, int(request.search_top_k or 3))
|
| 549 |
+
if search_res:
|
| 550 |
+
prompt = f"WebSearchResults:\n{search_res}\n\n" + prompt
|
| 551 |
+
except Exception:
|
| 552 |
+
pass
|
| 553 |
+
|
| 554 |
+
logger.info(f"[REQ] {completionId} - context\n```{prompt}```")
|
| 555 |
+
|
| 556 |
+
# Resume or prefill tokens/state
|
| 557 |
+
if request.state_name:
|
| 558 |
+
state_key = (request.model, request.state_name)
|
| 559 |
+
if state_key in STATE_STORE:
|
| 560 |
+
stored = STATE_STORE[state_key]
|
| 561 |
+
model_state = stored.get('state', model_state)
|
| 562 |
+
model_tokens = stored.get('model_tokens', [0])
|
| 563 |
+
out = None
|
| 564 |
+
else:
|
| 565 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 566 |
+
else:
|
| 567 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 568 |
+
|
| 569 |
+
prefillTime = time.time()
|
| 570 |
+
promptTokenCount = len(model_tokens)
|
| 571 |
+
|
| 572 |
+
completionTokenCount = 0
|
| 573 |
+
finishReason = None
|
| 574 |
+
|
| 575 |
+
response = ChatCompletionChunk(
|
| 576 |
+
id=completionId,
|
| 577 |
+
created=createTimestamp,
|
| 578 |
+
model=request.model,
|
| 579 |
+
usage=(
|
| 580 |
+
Usage(
|
| 581 |
+
prompt_tokens=promptTokenCount,
|
| 582 |
+
completion_tokens=completionTokenCount,
|
| 583 |
+
total_tokens=promptTokenCount + completionTokenCount,
|
| 584 |
+
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
| 585 |
+
)
|
| 586 |
+
if request.include_usage
|
| 587 |
+
else None
|
| 588 |
+
),
|
| 589 |
+
choices=[
|
| 590 |
+
ChatCompletionChoice(
|
| 591 |
+
index=0,
|
| 592 |
+
delta=ChatCompletionMessage(
|
| 593 |
+
role="Assistant",
|
| 594 |
+
content="",
|
| 595 |
+
reasoning_content="" if enableReasoning else None,
|
| 596 |
+
tool_calls=None,
|
| 597 |
+
),
|
| 598 |
+
logprobs=None,
|
| 599 |
+
finish_reason=finishReason,
|
| 600 |
+
)
|
| 601 |
+
],
|
| 602 |
+
)
|
| 603 |
+
if response.choices and response.choices[0].delta is None:
|
| 604 |
+
response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 605 |
+
# Attach state_name in the initial chunk so client can save it to continue later
|
| 606 |
+
r_dict = response.model_dump()
|
| 607 |
+
r_dict['state_name'] = request.state_name
|
| 608 |
+
yield f"data: {r_dict}\n\n"
|
| 609 |
+
|
| 610 |
+
buffer = []
|
| 611 |
+
|
| 612 |
+
if enableReasoning:
|
| 613 |
+
buffer.append("<think")
|
| 614 |
+
|
| 615 |
+
streamConfig = {
|
| 616 |
+
"isChecking": False, # check whether is <think> tag
|
| 617 |
+
"fullTextCursor": 0,
|
| 618 |
+
"in_think": False,
|
| 619 |
+
"cacheStr": "",
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
for chunk in generate(
|
| 623 |
+
request,
|
| 624 |
+
out,
|
| 625 |
+
model_tokens,
|
| 626 |
+
model_state,
|
| 627 |
+
max_tokens=(
|
| 628 |
+
64000
|
| 629 |
+
if "max_tokens" not in request.model_fields_set and enableReasoning
|
| 630 |
+
else (request.max_tokens or 2048)
|
| 631 |
+
),
|
| 632 |
+
):
|
| 633 |
+
completionTokenCount += 1
|
| 634 |
+
# Each token stream is delivered as a decoded character/bytes (maybe 1 or more chars)
|
| 635 |
+
chunkContent: str = chunk["content"]
|
| 636 |
+
buffer.append(chunkContent)
|
| 637 |
+
|
| 638 |
+
fullText = "".join(buffer)
|
| 639 |
+
|
| 640 |
+
if chunk["finish_reason"]:
|
| 641 |
+
finishReason = chunk["finish_reason"]
|
| 642 |
+
|
| 643 |
+
response = ChatCompletionChunk(
|
| 644 |
+
id=completionId,
|
| 645 |
+
created=createTimestamp,
|
| 646 |
+
model=request.model,
|
| 647 |
+
usage=(
|
| 648 |
+
Usage(
|
| 649 |
+
prompt_tokens=promptTokenCount,
|
| 650 |
+
completion_tokens=completionTokenCount,
|
| 651 |
+
total_tokens=promptTokenCount + completionTokenCount,
|
| 652 |
+
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
| 653 |
+
)
|
| 654 |
+
if request.include_usage
|
| 655 |
+
else None
|
| 656 |
+
),
|
| 657 |
+
choices=[
|
| 658 |
+
ChatCompletionChoice(
|
| 659 |
+
index=0,
|
| 660 |
+
delta=ChatCompletionMessage(
|
| 661 |
+
role="Assistant",
|
| 662 |
+
content=None,
|
| 663 |
+
reasoning_content=None,
|
| 664 |
+
tool_calls=None,
|
| 665 |
+
),
|
| 666 |
+
logprobs=None,
|
| 667 |
+
finish_reason=finishReason,
|
| 668 |
+
)
|
| 669 |
+
],
|
| 670 |
+
)
|
| 671 |
+
if response.choices and response.choices[0].delta is None:
|
| 672 |
+
response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 673 |
+
|
| 674 |
+
markStart = fullText.find("<", streamConfig["fullTextCursor"])
|
| 675 |
+
if not streamConfig["isChecking"] and markStart != -1:
|
| 676 |
+
streamConfig["isChecking"] = True
|
| 677 |
+
|
| 678 |
+
if streamConfig["in_think"]:
|
| 679 |
+
delta = response.choices[0].delta
|
| 680 |
+
if delta is None:
|
| 681 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 682 |
+
response.choices[0].delta = delta
|
| 683 |
+
delta.reasoning_content = fullText[streamConfig["fullTextCursor"] : markStart]
|
| 684 |
+
else:
|
| 685 |
+
delta = response.choices[0].delta
|
| 686 |
+
if delta is None:
|
| 687 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 688 |
+
response.choices[0].delta = delta
|
| 689 |
+
delta.content = fullText[streamConfig["fullTextCursor"] : markStart]
|
| 690 |
+
|
| 691 |
+
streamConfig["cacheStr"] = ""
|
| 692 |
+
streamConfig["fullTextCursor"] = markStart
|
| 693 |
+
|
| 694 |
+
if streamConfig["isChecking"]:
|
| 695 |
+
streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
|
| 696 |
+
else:
|
| 697 |
+
if streamConfig["in_think"]:
|
| 698 |
+
delta = response.choices[0].delta
|
| 699 |
+
if delta is None:
|
| 700 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 701 |
+
response.choices[0].delta = delta
|
| 702 |
+
delta.reasoning_content = chunkContent
|
| 703 |
+
else:
|
| 704 |
+
delta = response.choices[0].delta
|
| 705 |
+
if delta is None:
|
| 706 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 707 |
+
response.choices[0].delta = delta
|
| 708 |
+
delta.content = chunkContent
|
| 709 |
+
streamConfig["fullTextCursor"] = len(fullText)
|
| 710 |
+
|
| 711 |
+
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
|
| 712 |
+
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
|
| 713 |
+
streamConfig["isChecking"] = False
|
| 714 |
+
|
| 715 |
+
if (
|
| 716 |
+
not streamConfig["in_think"]
|
| 717 |
+
and streamConfig["cacheStr"].find("<think>") != -1
|
| 718 |
+
):
|
| 719 |
+
streamConfig["in_think"] = True
|
| 720 |
+
|
| 721 |
+
delta = response.choices[0].delta
|
| 722 |
+
if delta is None:
|
| 723 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 724 |
+
response.choices[0].delta = delta
|
| 725 |
+
delta.reasoning_content = (
|
| 726 |
+
delta.reasoning_content
|
| 727 |
+
if delta.reasoning_content != None
|
| 728 |
+
else "" + streamConfig["cacheStr"].replace("<think>", "")
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
elif (
|
| 732 |
+
streamConfig["in_think"]
|
| 733 |
+
and streamConfig["cacheStr"].find("</think>") != -1
|
| 734 |
+
):
|
| 735 |
+
streamConfig["in_think"] = False
|
| 736 |
+
|
| 737 |
+
delta = response.choices[0].delta
|
| 738 |
+
if delta is None:
|
| 739 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 740 |
+
response.choices[0].delta = delta
|
| 741 |
+
delta.content = (
|
| 742 |
+
delta.content
|
| 743 |
+
if delta.content != None
|
| 744 |
+
else "" + streamConfig["cacheStr"].replace("</think>", "")
|
| 745 |
+
)
|
| 746 |
+
else:
|
| 747 |
+
if streamConfig["in_think"]:
|
| 748 |
+
delta = response.choices[0].delta
|
| 749 |
+
if delta is None:
|
| 750 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 751 |
+
response.choices[0].delta = delta
|
| 752 |
+
delta.reasoning_content = (
|
| 753 |
+
delta.reasoning_content
|
| 754 |
+
if delta.reasoning_content != None
|
| 755 |
+
else "" + streamConfig["cacheStr"]
|
| 756 |
+
)
|
| 757 |
+
else:
|
| 758 |
+
delta = response.choices[0].delta
|
| 759 |
+
if delta is None:
|
| 760 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 761 |
+
response.choices[0].delta = delta
|
| 762 |
+
delta.content = (
|
| 763 |
+
delta.content
|
| 764 |
+
if delta.content != None
|
| 765 |
+
else "" + streamConfig["cacheStr"]
|
| 766 |
+
)
|
| 767 |
+
streamConfig["fullTextCursor"] = len(fullText)
|
| 768 |
+
|
| 769 |
+
delta = response.choices[0].delta
|
| 770 |
+
if delta is None:
|
| 771 |
+
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 772 |
+
response.choices[0].delta = delta
|
| 773 |
+
if delta.content != None or delta.reasoning_content != None:
|
| 774 |
+
# Save model state frequently (after each token) to allow resuming
|
| 775 |
+
try:
|
| 776 |
+
if request.state_name:
|
| 777 |
+
STATE_STORE[(request.model, request.state_name)] = {
|
| 778 |
+
'state': model_state,
|
| 779 |
+
'model_tokens': model_tokens,
|
| 780 |
+
}
|
| 781 |
+
except Exception:
|
| 782 |
+
pass
|
| 783 |
+
yield f"data: {response.model_dump_json()}\n\n"
|
| 784 |
+
# check stop sequences and stop streaming if we see them
|
| 785 |
+
for stop_words in request.stop or []:
|
| 786 |
+
if stop_words in ''.join(buffer):
|
| 787 |
+
finishReason = f"stop:words:{stop_words}"
|
| 788 |
+
return
|
| 789 |
+
|
| 790 |
+
await asyncio.sleep(0)
|
| 791 |
+
|
| 792 |
+
del streamConfig
|
| 793 |
+
else:
|
| 794 |
+
for chunk in generate(request, out, model_tokens, model_state):
|
| 795 |
+
completionTokenCount += 1
|
| 796 |
+
buffer.append(chunk["content"])
|
| 797 |
+
|
| 798 |
+
if chunk["finish_reason"]:
|
| 799 |
+
finishReason = chunk["finish_reason"]
|
| 800 |
+
|
| 801 |
+
response = ChatCompletionChunk(
|
| 802 |
+
id=completionId,
|
| 803 |
+
created=createTimestamp,
|
| 804 |
+
model=request.model,
|
| 805 |
+
usage=(
|
| 806 |
+
Usage(
|
| 807 |
+
prompt_tokens=promptTokenCount,
|
| 808 |
+
completion_tokens=completionTokenCount,
|
| 809 |
+
total_tokens=promptTokenCount + completionTokenCount,
|
| 810 |
+
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
| 811 |
+
)
|
| 812 |
+
if request.include_usage
|
| 813 |
+
else None
|
| 814 |
+
),
|
| 815 |
+
choices=[
|
| 816 |
+
ChatCompletionChoice(
|
| 817 |
+
index=0,
|
| 818 |
+
delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
|
| 819 |
+
logprobs=None,
|
| 820 |
+
finish_reason=finishReason,
|
| 821 |
+
)
|
| 822 |
+
],
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
yield f"data: {response.model_dump_json()}\n\n"
|
| 826 |
+
await asyncio.sleep(0)
|
| 827 |
+
|
| 828 |
+
genenrateTime = time.time()
|
| 829 |
+
|
| 830 |
+
responseLog = {
|
| 831 |
+
"content": "".join(buffer),
|
| 832 |
+
"finish": finishReason,
|
| 833 |
+
"prefill_len": promptTokenCount,
|
| 834 |
+
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
|
| 835 |
+
"gen_len": completionTokenCount,
|
| 836 |
+
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
|
| 837 |
+
}
|
| 838 |
+
logger.info(f"[RES] {completionId} - {responseLog}")
|
| 839 |
+
if request.messages is None:
|
| 840 |
+
request.messages = []
|
| 841 |
+
request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"]))
|
| 842 |
+
log(
|
| 843 |
+
{
|
| 844 |
+
**request.model_dump(),
|
| 845 |
+
**responseLog,
|
| 846 |
+
"completionId": completionId,
|
| 847 |
+
"machineLabel": os.environ.get("MACHINE_LABEL"),
|
| 848 |
+
}
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
del buffer
|
| 852 |
+
|
| 853 |
+
yield "data: [DONE]\n\n"
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
@app.post("/api/v1/chat/completions")
|
| 857 |
+
async def chat_completions(request: ChatCompletionRequest):
|
| 858 |
+
completionId = str(next(CompletionIdGenerator))
|
| 859 |
+
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 860 |
+
|
| 861 |
+
modelName = request.model.split(":")[0]
|
| 862 |
+
enableReasoning = ":thinking" in request.model
|
| 863 |
+
|
| 864 |
+
if "rwkv-latest" in request.model:
|
| 865 |
+
# Map to the default chat model in all cases. Do not redirect to a separate
|
| 866 |
+
# reasoning model when ':thinking' is used. The same model will be used
|
| 867 |
+
# and reasoning handled in-process by setting enableReasoning=True.
|
| 868 |
+
if DEFALUT_MODEL_NAME == None:
|
| 869 |
+
raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
|
| 870 |
+
ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME)
|
| 871 |
+
if not ms_def or not ms_def.MODEL_CONFIG:
|
| 872 |
+
raise HTTPException(500, "Default sampler config missing for default model")
|
| 873 |
+
defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER
|
| 874 |
+
request.model = DEFALUT_MODEL_NAME
|
| 875 |
+
|
| 876 |
+
elif modelName in MODEL_STORAGE:
|
| 877 |
+
ms_sel = MODEL_STORAGE.get(modelName)
|
| 878 |
+
if not ms_sel or not ms_sel.MODEL_CONFIG:
|
| 879 |
+
raise HTTPException(500, f"Default sampler config missing for model {modelName}")
|
| 880 |
+
defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER
|
| 881 |
+
request.model = modelName
|
| 882 |
+
else:
|
| 883 |
+
raise HTTPException(404, f"Can not find `{modelName}`")
|
| 884 |
+
|
| 885 |
+
async def chatResponseStreamDisconnect():
|
| 886 |
+
logGPUState()
|
| 887 |
+
|
| 888 |
+
# Load or initialize model_state and tokens based on state_name
|
| 889 |
+
model_state = None
|
| 890 |
+
model_tokens_for_resume = [0]
|
| 891 |
+
state_name = request.state_name
|
| 892 |
+
if state_name is None:
|
| 893 |
+
state_name = str(uuid.uuid4())
|
| 894 |
+
request.state_name = state_name
|
| 895 |
+
state_key = (request.model, state_name)
|
| 896 |
+
if state_key in STATE_STORE:
|
| 897 |
+
stored = STATE_STORE[state_key]
|
| 898 |
+
model_state = stored.get('state', None)
|
| 899 |
+
model_tokens_for_resume = stored.get('model_tokens', [0])
|
| 900 |
+
request_dict = request.model_dump()
|
| 901 |
+
|
| 902 |
+
for k, v in defaultSamplerConfig.model_dump().items():
|
| 903 |
+
if k in request_dict and request_dict[k] is None:
|
| 904 |
+
request_dict[k] = v
|
| 905 |
+
realRequest = ChatCompletionRequest(**request_dict)
|
| 906 |
+
|
| 907 |
+
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
|
| 908 |
+
|
| 909 |
+
if request.stream:
|
| 910 |
+
r = StreamingResponse(
|
| 911 |
+
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
|
| 912 |
+
media_type="text/event-stream",
|
| 913 |
+
background=BackgroundTask(chatResponseStreamDisconnect),
|
| 914 |
+
)
|
| 915 |
+
else:
|
| 916 |
+
r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
|
| 917 |
+
# Attach state_name to non-streaming response as additional metadata
|
| 918 |
+
try:
|
| 919 |
+
import json
|
| 920 |
+
|
| 921 |
+
if isinstance(r, ChatCompletion):
|
| 922 |
+
d = r.model_dump()
|
| 923 |
+
d['state_name'] = state_name
|
| 924 |
+
return d
|
| 925 |
+
except Exception:
|
| 926 |
+
pass
|
| 927 |
+
|
| 928 |
+
return r
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
if os.path.exists("dist-frontend"):
|
| 932 |
+
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
| 933 |
+
else:
|
| 934 |
+
logger.info("dist-frontend not found; skipping static files mount")
|
| 935 |
+
|
| 936 |
+
if __name__ == "__main__":
|
| 937 |
+
import uvicorn
|
| 938 |
+
|
| 939 |
+
host = CONFIG.HOST or "127.0.0.1"
|
| 940 |
+
port = CONFIG.PORT or 7860
|
| 941 |
+
uvicorn.run(app, host=host, port=port)
|
app_stderr.log
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
C:\Users\Administrator\Downloads\New folder (3)\RWKV\.venv\Lib\site-packages\torch\cuda\__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
|
| 2 |
+
import pynvml # type: ignore[import]
|
| 3 |
+
2025-11-23 16:35:08.739 | INFO | __main__:<module>:104 - STRATEGY - cpu fp16
|
| 4 |
+
2025-11-23 16:35:08.740 | INFO | __main__:<module>:109 - Load Model - rwkv7-g1a-0.1b-20250728-ctx4096
|
| 5 |
+
2025-11-23 16:35:09.724 | INFO | __main__:<module>:117 - Load Model - Path - models\rwkv7-g1a-0.1b-20250728-ctx4096.pth
|
| 6 |
+
2025-11-23 16:35:09.724 | INFO | __main__:<module>:133 - Load Model - Loading `rwkv7-g1a-0.1b-20250728-ctx4096`
|
| 7 |
+
2025-11-23 16:35:15.073 | INFO | __main__:<module>:151 - Load Model - DEFALUT_MODEL_NAME is `rwkv7-g1a-0.1b-20250728-ctx4096`
|
| 8 |
+
2025-11-23 16:35:15.074 | INFO | __main__:<module>:152 - Load Model - DEFAULT_REASONING_MODEL_NAME is `rwkv7-g1a-0.1b-20250728-ctx4096`
|
| 9 |
+
2025-11-23 16:35:15.080 | INFO | __main__:<module>:746 - dist-frontend not found; skipping static files mount
|
| 10 |
+
INFO: Started server process [9328]
|
| 11 |
+
INFO: Waiting for application startup.
|
| 12 |
+
INFO: Application startup complete.
|
| 13 |
+
INFO: Uvicorn running on http://0.0.0.0:7860 (Press CTRL+C to quit)
|
| 14 |
+
2025-11-23 16:35:51.067 | INFO | __main__:chat_completions:698 - [REQ] 7398519686318694400 - {'model': 'rwkv-latest', 'messages': None, 'prompt': 'Who is the current president of France?', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 3}
|
| 15 |
+
2025-11-23 16:35:51.067 | INFO | __main__:chat_completions:729 - [REQ] 7398519686318694400 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'Who is the current president of France?', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 3}
|
| 16 |
+
2025-11-23 16:35:53.728 | INFO | __main__:chatResponse:363 - [REQ] 7398519686318694400 - prompt - Who is the current president of France?
|
| 17 |
+
2025-11-23 16:36:09.388 | INFO | __main__:chatResponse:402 - [RES] 7398519686318694400 - {'content': '\nThe current president of France is Emmanuel Macron.', 'finish': 'stop:words:\n\n', 'prefill_len': 9, 'prefill_tps': 1.36, 'gen_len': 6, 'gen_tps': 0.51}
|
| 18 |
+
2025-11-23 16:36:52.165 | INFO | __main__:chat_completions:698 - [REQ] 7398519942582280192 - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096:thinking', 'messages': None, 'prompt': 'Summarize the first paragraph from the search about Python programming', 'max_tokens': 60, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
|
| 19 |
+
2025-11-23 16:36:52.165 | INFO | __main__:chat_completions:729 - [REQ] 7398519942582280192 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'Summarize the first paragraph from the search about Python programming', 'max_tokens': 60, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
|
| 20 |
+
2025-11-23 16:36:54.650 | INFO | __main__:chatResponse:363 - [REQ] 7398519942582280192 - prompt - Summarize the first paragraph from the search about Python programming
|
| 21 |
+
2025-11-23 16:38:03.778 | INFO | __main__:chatResponse:402 - [RES] 7398519942582280192 - {'content': ' <think.\nThe first paragraph of the search is about Python programming. It talks about how to use Python for data analysis and machine learning. The second paragraph is about how to use Python for web development. It talks about how to use Python for creating websites and applications. The third', 'finish': 'length', 'prefill_len': 13, 'prefill_tps': 1.65, 'gen_len': 56, 'gen_tps': 0.88}
|
| 22 |
+
2025-11-23 16:38:05.030 | INFO | __main__:chat_completions:698 - [REQ] 7398520248166686720 - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096:thinking', 'messages': None, 'prompt': 'Tell me a short summary of Python programming', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
|
| 23 |
+
2025-11-23 16:38:05.033 | INFO | __main__:chat_completions:729 - [REQ] 7398520248166686720 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'Tell me a short summary of Python programming', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
|
| 24 |
+
2025-11-23 16:38:06.800 | INFO | __main__:chatResponse:363 - [REQ] 7398520248166686720 - prompt - Tell me a short summary of Python programming
|
| 25 |
+
2025-11-23 16:38:24.585 | INFO | __main__:chatResponse:402 - [RES] 7398520248166686720 - {'content': ' <think and how it can be used to solve problems.', 'finish': 'stop:words:\n\n', 'prefill_len': 9, 'prefill_tps': 1.55, 'gen_len': 6, 'gen_tps': 0.44}
|
| 26 |
+
2025-11-23 16:42:18.982 | INFO | __main__:chat_completions:698 - [REQ] 7398521313352130560 - {'model': 'rwkv-latest', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
|
| 27 |
+
2025-11-23 16:42:18.982 | INFO | __main__:chat_completions:729 - [REQ] 7398521313352130560 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
|
| 28 |
+
2025-11-23 16:42:18.982 | INFO | __main__:chatResponse:363 - [REQ] 7398521313352130560 - prompt - What is two plus three times four?
|
| 29 |
+
2025-11-23 16:42:56.030 | INFO | __main__:chatResponse:402 - [RES] 7398521313352130560 - {'content': '\n100\nWhat is the difference between 0.9 and 0.8?\n0.2\nWhat is the sum of', 'finish': 'length', 'prefill_len': 9, 'prefill_tps': 2.17, 'gen_len': 28, 'gen_tps': 0.85}
|
| 30 |
+
2025-11-23 16:44:08.178 | INFO | __main__:chat_completions:698 - [REQ] 7398521771353350144 - {'model': 'rwkv-latest', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
|
| 31 |
+
2025-11-23 16:44:08.179 | INFO | __main__:chat_completions:729 - [REQ] 7398521771353350144 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
|
| 32 |
+
2025-11-23 16:44:08.179 | INFO | __main__:chatResponse:363 - [REQ] 7398521771353350144 - prompt - What is two plus three times four?
|
| 33 |
+
2025-11-23 16:44:45.828 | INFO | __main__:chatResponse:402 - [RES] 7398521771353350144 - {'content': '\nTwo plus three times four is eight.\nWhat is the sum of the digits of two-digit numbers?\nThe sum of', 'finish': 'length', 'prefill_len': 9, 'prefill_tps': 2.28, 'gen_len': 28, 'gen_tps': 0.83}
|
app_stdout.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
### RWKV-7 "Goose" enabled ###
|
| 3 |
+
|
| 4 |
+
SamplerConfig(
|
| 5 |
+
max_tokens=4096,
|
| 6 |
+
temperature=1.0,
|
| 7 |
+
top_p=0.3,
|
| 8 |
+
presence_penalty=0.5,
|
| 9 |
+
count_penalty=0.5,
|
| 10 |
+
penalty_decay=0.996,
|
| 11 |
+
stop=['\n\n'],
|
| 12 |
+
stop_tokens=[0]
|
| 13 |
+
)
|
| 14 |
+
Loading models\rwkv7-g1a-0.1b-20250728-ctx4096 (cpu fp16)
|
| 15 |
+
|
| 16 |
+
INFO: 127.0.0.1:50012 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
|
| 17 |
+
INFO: 127.0.0.1:50128 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
|
| 18 |
+
INFO: 127.0.0.1:50128 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
|
| 19 |
+
INFO: 127.0.0.1:50763 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
|
| 20 |
+
INFO: 127.0.0.1:50973 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
|
| 21 |
+
INFO: 127.0.0.1:51134 - "POST /api/v1/chat/completions HTTP/1.1" 422 Unprocessable Entity
|
| 22 |
+
INFO: 127.0.0.1:51144 - "POST /api/v1/chat/completions HTTP/1.1" 422 Unprocessable Entity
|
config.local.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HOST: "0.0.0.0"
|
| 2 |
+
PORT: 7860
|
| 3 |
+
STRATEGY: "cpu fp16"
|
| 4 |
+
RWKV_CUDA_ON: False
|
| 5 |
+
CHUNK_LEN: 256
|
| 6 |
+
MODELS:
|
| 7 |
+
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 8 |
+
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
| 9 |
+
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
|
| 10 |
+
DOWNLOAD_MODEL_DIR: "./models"
|
| 11 |
+
REASONING: True
|
| 12 |
+
DEFAULT_CHAT: True
|
| 13 |
+
DEFAULT_REASONING: True
|
| 14 |
+
DEFAULT_SAMPLER:
|
| 15 |
+
max_tokens: 4096
|
| 16 |
+
temperature: 1.0
|
| 17 |
+
top_p: 0.3
|
| 18 |
+
presence_penalty: 0.5
|
| 19 |
+
count_penalty: 0.5
|
| 20 |
+
penalty_decay: 0.996
|
| 21 |
+
stop:
|
| 22 |
+
- "\n\n"
|
| 23 |
+
stop_tokens:
|
| 24 |
+
- 0
|
config.production-modelscope.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HOST: "0.0.0.0"
|
| 2 |
+
PORT: 7860
|
| 3 |
+
STRATEGY: "cuda fp16"
|
| 4 |
+
RWKV_CUDA_ON: True
|
| 5 |
+
CHUNK_LEN: 256
|
| 6 |
+
MODELS:
|
| 7 |
+
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 8 |
+
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
| 9 |
+
DOWNLOAD_MODEL_REPO_ID: "RWKV/rwkv7-g1"
|
| 10 |
+
DOWNLOAD_MODEL_DIR: "./models"
|
| 11 |
+
REASONING: True
|
| 12 |
+
DEFAULT_CHAT: True
|
| 13 |
+
DEFAULT_REASONING: True
|
| 14 |
+
DEFAULT_SAMPLER:
|
| 15 |
+
max_tokens: 4096
|
| 16 |
+
temperature: 1.0
|
| 17 |
+
top_p: 0.3
|
| 18 |
+
presence_penalty: 0.5
|
| 19 |
+
count_penalty: 0.5
|
| 20 |
+
penalty_decay: 0.996
|
| 21 |
+
stop:
|
| 22 |
+
- "\n\n"
|
| 23 |
+
stop_tokens:
|
| 24 |
+
- 0
|
config.production.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HOST: "0.0.0.0"
|
| 2 |
+
PORT: 7860
|
| 3 |
+
STRATEGY: "cuda fp16"
|
| 4 |
+
RWKV_CUDA_ON: True
|
| 5 |
+
CHUNK_LEN: 256
|
| 6 |
+
MODELS:
|
| 7 |
+
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 8 |
+
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
| 9 |
+
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
|
| 10 |
+
DOWNLOAD_MODEL_DIR: "./models"
|
| 11 |
+
REASONING: True
|
| 12 |
+
DEFAULT_CHAT: True
|
| 13 |
+
DEFAULT_REASONING: True
|
| 14 |
+
DEFAULT_SAMPLER:
|
| 15 |
+
max_tokens: 4096
|
| 16 |
+
temperature: 1.0
|
| 17 |
+
top_p: 0.3
|
| 18 |
+
presence_penalty: 0.5
|
| 19 |
+
count_penalty: 0.5
|
| 20 |
+
penalty_decay: 0.996
|
| 21 |
+
stop:
|
| 22 |
+
- "\n\n"
|
| 23 |
+
stop_tokens:
|
| 24 |
+
- 0
|
config.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from typing import List, Optional, Union, Any
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from pydantic_settings import BaseSettings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
|
| 12 |
+
CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
CLI_CONFIG = CliConfig()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SamplerConfig(BaseModel):
|
| 19 |
+
"""Default sampler configuration for each model."""
|
| 20 |
+
|
| 21 |
+
max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
|
| 22 |
+
temperature: float = Field(1.0, description="Sampling temperature.")
|
| 23 |
+
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
| 24 |
+
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
| 25 |
+
count_penalty: float = Field(0.5, description="Count penalty.")
|
| 26 |
+
penalty_decay: float = Field(0.996, description="Penalty decay factor.")
|
| 27 |
+
stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
|
| 28 |
+
stop_tokens: List[int] = Field([0], description="List of stop tokens.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ModelConfig(BaseModel):
|
| 32 |
+
"""Configuration for each individual model."""
|
| 33 |
+
|
| 34 |
+
SERVICE_NAME: str = Field(..., description="Service name of the model.")
|
| 35 |
+
|
| 36 |
+
MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
|
| 37 |
+
|
| 38 |
+
DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
|
| 39 |
+
None, description="Model name, should end with .pth"
|
| 40 |
+
)
|
| 41 |
+
DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
|
| 42 |
+
None, description="Model repository ID on Hugging Face Hub."
|
| 43 |
+
)
|
| 44 |
+
DOWNLOAD_MODEL_DIR: Optional[str] = Field(
|
| 45 |
+
"./models", description="Directory to download the model to."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
REASONING: bool = Field(
|
| 49 |
+
False, description="Whether reasoning is enabled for this model."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
|
| 53 |
+
DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
|
| 54 |
+
DEFAULT_SAMPLER: SamplerConfig = Field(
|
| 55 |
+
SamplerConfig(), description="Default sampler configuration for this model."
|
| 56 |
+
)
|
| 57 |
+
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class RootConfig(BaseModel):
|
| 61 |
+
"""Root configuration for the RWKV service."""
|
| 62 |
+
|
| 63 |
+
HOST: Optional[str] = Field(
|
| 64 |
+
"127.0.0.1", description="Host IP address to bind to."
|
| 65 |
+
) # 注释掉可选的HOST和PORT
|
| 66 |
+
PORT: Optional[int] = Field(
|
| 67 |
+
8000, description="Port number to listen on."
|
| 68 |
+
) # 因为YAML示例中被注释掉了
|
| 69 |
+
STRATEGY: str = Field(
|
| 70 |
+
"cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
|
| 71 |
+
)
|
| 72 |
+
RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
|
| 73 |
+
CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
|
| 74 |
+
MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
import yaml
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
|
| 81 |
+
CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Pydantic Model Validation Failed: {e}")
|
| 84 |
+
sys.exit(0)
|
cuda/gemm_fp16_cublas.cpp
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cublas_v2.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_fp16.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <torch/extension.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 8 |
+
|
| 9 |
+
#define CUBLAS_CHECK(condition) \
|
| 10 |
+
for (cublasStatus_t _cublas_check_status = (condition); \
|
| 11 |
+
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
|
| 12 |
+
throw std::runtime_error("cuBLAS error " + \
|
| 13 |
+
std::to_string(_cublas_check_status) + " at " + \
|
| 14 |
+
std::to_string(__LINE__));
|
| 15 |
+
|
| 16 |
+
#define CUDA_CHECK(condition) \
|
| 17 |
+
for (cudaError_t _cuda_check_status = (condition); \
|
| 18 |
+
_cuda_check_status != cudaSuccess;) \
|
| 19 |
+
throw std::runtime_error( \
|
| 20 |
+
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
| 21 |
+
" at " + std::to_string(__LINE__));
|
| 22 |
+
|
| 23 |
+
/*
|
| 24 |
+
NOTE: blas gemm is column-major by default, but we need row-major output.
|
| 25 |
+
The data of row-major, transposed matrix is exactly the same as the
|
| 26 |
+
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
| 27 |
+
*/
|
| 28 |
+
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
| 29 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
| 30 |
+
const auto cuda_data_type = CUDA_R_16F;
|
| 31 |
+
const auto cuda_c_data_type =
|
| 32 |
+
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
| 33 |
+
const auto compute_type = CUDA_R_32F;
|
| 34 |
+
const float sp_alpha = 1.f;
|
| 35 |
+
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
| 36 |
+
std::swap(a, b);
|
| 37 |
+
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
| 38 |
+
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
| 39 |
+
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
| 40 |
+
// negative axis is used because of the existence of batch matmul.
|
| 41 |
+
const int m = a.size(-1);
|
| 42 |
+
const int k = a.size(-2);
|
| 43 |
+
const int n = b.size(-2);
|
| 44 |
+
const int cublas_lda = m;
|
| 45 |
+
const int cublas_ldb = k;
|
| 46 |
+
const int cublas_ldc = m;
|
| 47 |
+
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
| 48 |
+
|
| 49 |
+
#if CUDA_VERSION >= 11000
|
| 50 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
| 51 |
+
#else
|
| 52 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
| 53 |
+
#endif
|
| 54 |
+
const float sp_beta = 0.f;
|
| 55 |
+
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
| 56 |
+
CUBLAS_CHECK(cublasGemmEx(
|
| 57 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
| 58 |
+
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
| 59 |
+
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
| 60 |
+
compute_type, algo));
|
| 61 |
+
} else {
|
| 62 |
+
// batch matmul
|
| 63 |
+
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
| 64 |
+
|
| 65 |
+
const long long int cublas_stride_a = m * k;
|
| 66 |
+
const long long int cublas_stride_b = k * n;
|
| 67 |
+
const long long int cublas_stride_c = m * n;
|
| 68 |
+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
| 69 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
| 70 |
+
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
| 71 |
+
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
| 72 |
+
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
| 73 |
+
a.size(0), compute_type, algo));
|
| 74 |
+
}
|
| 75 |
+
}
|
cuda/operators.cu
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
#include <cuda_fp16.h>
|
| 5 |
+
#define MIN_VALUE (-1e38)
|
| 6 |
+
typedef at::Half fp16;
|
| 7 |
+
__half *cast(fp16 *ptr) {
|
| 8 |
+
return reinterpret_cast<__half *>(ptr);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
template <typename F>
|
| 12 |
+
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
|
| 13 |
+
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
| 14 |
+
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
| 15 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 16 |
+
const int _b = idx / C;
|
| 17 |
+
const int _c = idx % C;
|
| 18 |
+
const int _offset = _b * T * C + _c;
|
| 19 |
+
const int _state_offset = _b * C + _c;
|
| 20 |
+
|
| 21 |
+
float u = _u[_c];
|
| 22 |
+
float w = _w[_c];
|
| 23 |
+
const F *__restrict__ const k = _k + _offset;
|
| 24 |
+
const F *__restrict__ const v = _v + _offset;
|
| 25 |
+
F *__restrict__ const y = _y + _offset;
|
| 26 |
+
|
| 27 |
+
float aa = _aa[_state_offset];
|
| 28 |
+
float bb = _bb[_state_offset];
|
| 29 |
+
float pp = _pp[_state_offset];
|
| 30 |
+
for (int i = 0; i < T; i++) {
|
| 31 |
+
const int ii = i * C;
|
| 32 |
+
const float kk = float(k[ii]);
|
| 33 |
+
const float vv = float(v[ii]);
|
| 34 |
+
float ww = u + kk;
|
| 35 |
+
float p = max(pp, ww);
|
| 36 |
+
float e1 = exp(pp - p);
|
| 37 |
+
float e2 = exp(ww - p);
|
| 38 |
+
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
| 39 |
+
ww = w + pp;
|
| 40 |
+
p = max(ww, kk);
|
| 41 |
+
e1 = exp(ww - p);
|
| 42 |
+
e2 = exp(kk - p);
|
| 43 |
+
aa = e1 * aa + e2 * vv;
|
| 44 |
+
bb = e1 * bb + e2;
|
| 45 |
+
pp = p;
|
| 46 |
+
}
|
| 47 |
+
_aa[_state_offset] = aa;
|
| 48 |
+
_bb[_state_offset] = bb;
|
| 49 |
+
_pp[_state_offset] = pp;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <typename F>
|
| 53 |
+
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
|
| 54 |
+
dim3 threadsPerBlock( min(C, 32) );
|
| 55 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 56 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 57 |
+
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template void cuda_wkv_forward<fp16>(
|
| 61 |
+
int B, int T, int C,
|
| 62 |
+
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
|
| 63 |
+
float *aa, float *bb, float *pp);
|
| 64 |
+
template void cuda_wkv_forward<float>(
|
| 65 |
+
int B, int T, int C,
|
| 66 |
+
float *w, float *u, float *k, float *v, float *y,
|
| 67 |
+
float *aa, float *bb, float *pp);
|
| 68 |
+
|
| 69 |
+
__global__ void kernel_mm_seq_fp32i8(
|
| 70 |
+
const int B, const int N, const int M,
|
| 71 |
+
const float *__restrict__ const x, const int x_stride,
|
| 72 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 73 |
+
const float *__restrict__ const mx,
|
| 74 |
+
const float *__restrict__ const rx,
|
| 75 |
+
const float *__restrict__ const my,
|
| 76 |
+
const float *__restrict__ const ry,
|
| 77 |
+
float *__restrict__ const y, const int y_stride) {
|
| 78 |
+
|
| 79 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 80 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 81 |
+
|
| 82 |
+
if (i < B && k < M) {
|
| 83 |
+
float y_local = 0;
|
| 84 |
+
for (int j = 0; j < N; ++j) {
|
| 85 |
+
y_local += x[i * x_stride + j] * (
|
| 86 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 87 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
| 88 |
+
);
|
| 89 |
+
}
|
| 90 |
+
y[i * y_stride + k] = y_local;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
template <typename F>
|
| 95 |
+
void cuda_mm8_seq(int B, int N, int M,
|
| 96 |
+
F *x, int x_stride,
|
| 97 |
+
uint8_t *w, int w_stride,
|
| 98 |
+
F *mx, F *rx,
|
| 99 |
+
F *my, F *ry,
|
| 100 |
+
F *y, int y_stride);
|
| 101 |
+
|
| 102 |
+
template <>
|
| 103 |
+
void cuda_mm8_seq<float>(int B, int N, int M,
|
| 104 |
+
float *x, int x_stride,
|
| 105 |
+
uint8_t *w, int w_stride,
|
| 106 |
+
float *mx, float *rx,
|
| 107 |
+
float *my, float *ry,
|
| 108 |
+
float *y, int y_stride) {
|
| 109 |
+
dim3 blockSize(1, 128);
|
| 110 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
| 111 |
+
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
|
| 112 |
+
B, N, M, x, x_stride, w, w_stride,
|
| 113 |
+
mx, rx, my, ry, y, y_stride);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
__global__ void kernel_mm_seq_fp16i8(
|
| 117 |
+
const int B, const int N, const int M,
|
| 118 |
+
const __half *__restrict__ const x, const int x_stride,
|
| 119 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 120 |
+
const __half *__restrict__ const mx,
|
| 121 |
+
const __half *__restrict__ const rx,
|
| 122 |
+
const __half *__restrict__ const my,
|
| 123 |
+
const __half *__restrict__ const ry,
|
| 124 |
+
__half *__restrict__ const y, const int y_stride) {
|
| 125 |
+
|
| 126 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 127 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 128 |
+
|
| 129 |
+
if (i < B && k < M) {
|
| 130 |
+
float y_local = 0;
|
| 131 |
+
for (int j = 0; j < N; ++j) {
|
| 132 |
+
y_local += __half2float(x[i * x_stride + j]) * (
|
| 133 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 134 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
| 135 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
| 136 |
+
);
|
| 137 |
+
}
|
| 138 |
+
y[i * y_stride + k] = __float2half(y_local);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
template <>
|
| 143 |
+
void cuda_mm8_seq<fp16>(int B, int N, int M,
|
| 144 |
+
fp16 *x, int x_stride,
|
| 145 |
+
uint8_t *w, int w_stride,
|
| 146 |
+
fp16 *mx, fp16 *rx,
|
| 147 |
+
fp16 *my, fp16 *ry,
|
| 148 |
+
fp16 *y, int y_stride) {
|
| 149 |
+
dim3 blockSize(1, 128);
|
| 150 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
| 151 |
+
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
|
| 152 |
+
B, N, M, cast(x), x_stride, w, w_stride,
|
| 153 |
+
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
#define MM8_ONE_JSPLIT 24
|
| 157 |
+
#define MM8_ONE_TILE 1024
|
| 158 |
+
|
| 159 |
+
__global__ void kernel_mm_one_fp32i8(
|
| 160 |
+
const int N, const int M,
|
| 161 |
+
const float *__restrict__ const x,
|
| 162 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 163 |
+
const float *__restrict__ const mx,
|
| 164 |
+
const float *__restrict__ const rx,
|
| 165 |
+
const float *__restrict__ const my,
|
| 166 |
+
const float *__restrict__ const ry,
|
| 167 |
+
float *__restrict__ const y) {
|
| 168 |
+
|
| 169 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 170 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 171 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 172 |
+
|
| 173 |
+
if (k < M) {
|
| 174 |
+
float y_local = 0;
|
| 175 |
+
for (int j = j0; j < j1; ++j) {
|
| 176 |
+
y_local += x[j] * (
|
| 177 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 178 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
| 179 |
+
);
|
| 180 |
+
}
|
| 181 |
+
atomicAdd(&y[k], y_local);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <typename F>
|
| 186 |
+
void cuda_mm8_one(int N, int M,
|
| 187 |
+
F *x,
|
| 188 |
+
uint8_t *w, int w_stride,
|
| 189 |
+
F *mx, F *rx,
|
| 190 |
+
F *my, F *ry,
|
| 191 |
+
float *y);
|
| 192 |
+
|
| 193 |
+
template <>
|
| 194 |
+
void cuda_mm8_one<float>(int N, int M,
|
| 195 |
+
float *x,
|
| 196 |
+
uint8_t *w, int w_stride,
|
| 197 |
+
float *mx, float *rx,
|
| 198 |
+
float *my, float *ry,
|
| 199 |
+
float *y) {
|
| 200 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
| 201 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
| 202 |
+
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
|
| 203 |
+
N, M, x, w, w_stride,
|
| 204 |
+
mx, rx, my, ry, y);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
__global__ void kernel_mm_one_fp16i8(
|
| 208 |
+
const int N, const int M,
|
| 209 |
+
const __half *__restrict__ const x,
|
| 210 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
| 211 |
+
const __half *__restrict__ const mx,
|
| 212 |
+
const __half *__restrict__ const rx,
|
| 213 |
+
const __half *__restrict__ const my,
|
| 214 |
+
const __half *__restrict__ const ry,
|
| 215 |
+
float *__restrict__ const y) {
|
| 216 |
+
|
| 217 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 218 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 219 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
| 220 |
+
|
| 221 |
+
if (k < M) {
|
| 222 |
+
float y_local = 0;
|
| 223 |
+
for (int j = j0; j < j1; ++j) {
|
| 224 |
+
y_local += __half2float(x[j]) * (
|
| 225 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
| 226 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
| 227 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
| 228 |
+
);
|
| 229 |
+
}
|
| 230 |
+
atomicAdd(&y[k], y_local);
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
template <>
|
| 235 |
+
void cuda_mm8_one<fp16>(int N, int M,
|
| 236 |
+
fp16 *x,
|
| 237 |
+
uint8_t *w, int w_stride,
|
| 238 |
+
fp16 *mx, fp16 *rx,
|
| 239 |
+
fp16 *my, fp16 *ry,
|
| 240 |
+
float *y) {
|
| 241 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
| 242 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
| 243 |
+
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
|
| 244 |
+
N, M, cast(x), w, w_stride,
|
| 245 |
+
cast(mx), cast(rx), cast(my), cast(ry), y);
|
| 246 |
+
}
|
cuda/rwkv5.cu
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
template <typename F>
|
| 9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
| 10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
| 11 |
+
F *__restrict__ const _y)
|
| 12 |
+
{
|
| 13 |
+
const int b = blockIdx.x / H;
|
| 14 |
+
const int h = blockIdx.x % H;
|
| 15 |
+
const int i = threadIdx.x;
|
| 16 |
+
_w += h*_N_;
|
| 17 |
+
_u += h*_N_;
|
| 18 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
| 19 |
+
|
| 20 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
| 21 |
+
|
| 22 |
+
float state[_N_];
|
| 23 |
+
#pragma unroll
|
| 24 |
+
for (int j = 0; j < _N_; j++)
|
| 25 |
+
state[j] = _state[j];
|
| 26 |
+
|
| 27 |
+
__syncthreads();
|
| 28 |
+
u[i] = float(_u[i]);
|
| 29 |
+
w[i] = _w[i];
|
| 30 |
+
__syncthreads();
|
| 31 |
+
|
| 32 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
| 33 |
+
{
|
| 34 |
+
__syncthreads();
|
| 35 |
+
r[i] = float(_r[t]);
|
| 36 |
+
k[i] = float(_k[t]);
|
| 37 |
+
__syncthreads();
|
| 38 |
+
|
| 39 |
+
const float v = float(_v[t]);
|
| 40 |
+
float y = 0;
|
| 41 |
+
|
| 42 |
+
#pragma unroll
|
| 43 |
+
for (int j = 0; j < _N_; j+=4)
|
| 44 |
+
{
|
| 45 |
+
const float4& r_ = (float4&)(r[j]);
|
| 46 |
+
const float4& k_ = (float4&)(k[j]);
|
| 47 |
+
const float4& w_ = (float4&)(w[j]);
|
| 48 |
+
const float4& u_ = (float4&)(u[j]);
|
| 49 |
+
float4& s = (float4&)(state[j]);
|
| 50 |
+
float4 x;
|
| 51 |
+
|
| 52 |
+
x.x = k_.x * v;
|
| 53 |
+
x.y = k_.y * v;
|
| 54 |
+
x.z = k_.z * v;
|
| 55 |
+
x.w = k_.w * v;
|
| 56 |
+
|
| 57 |
+
y += r_.x * (u_.x * x.x + s.x);
|
| 58 |
+
y += r_.y * (u_.y * x.y + s.y);
|
| 59 |
+
y += r_.z * (u_.z * x.z + s.z);
|
| 60 |
+
y += r_.w * (u_.w * x.w + s.w);
|
| 61 |
+
|
| 62 |
+
s.x = s.x * w_.x + x.x;
|
| 63 |
+
s.y = s.y * w_.y + x.y;
|
| 64 |
+
s.z = s.z * w_.z + x.z;
|
| 65 |
+
s.w = s.w * w_.w + x.w;
|
| 66 |
+
}
|
| 67 |
+
_y[t] = F(y);
|
| 68 |
+
}
|
| 69 |
+
#pragma unroll
|
| 70 |
+
for (int j = 0; j < _N_; j++)
|
| 71 |
+
_state[j] = state[j];
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
| 75 |
+
{
|
| 76 |
+
assert(H*_N_ == C);
|
| 77 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 78 |
+
}
|
| 79 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
| 80 |
+
{
|
| 81 |
+
assert(H*_N_ == C);
|
| 82 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 83 |
+
}
|
| 84 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
| 85 |
+
{
|
| 86 |
+
assert(H*_N_ == C);
|
| 87 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 88 |
+
}
|
cuda/rwkv5_op.cpp
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
| 9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
| 10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
| 11 |
+
|
| 12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
| 15 |
+
}
|
| 16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
| 19 |
+
}
|
| 20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
|
| 27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
|
| 28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
|
| 29 |
+
}
|
| 30 |
+
TORCH_LIBRARY(rwkv5, m) {
|
| 31 |
+
m.def("forward_bf16", forward_bf16);
|
| 32 |
+
m.def("forward_fp16", forward_fp16);
|
| 33 |
+
m.def("forward_fp32", forward_fp32);
|
| 34 |
+
}
|
cuda/rwkv6.cu
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
template <typename F>
|
| 9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
| 10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
| 11 |
+
F *__restrict__ const _y)
|
| 12 |
+
{
|
| 13 |
+
const int b = blockIdx.x / H;
|
| 14 |
+
const int h = blockIdx.x % H;
|
| 15 |
+
const int i = threadIdx.x;
|
| 16 |
+
_u += h*_N_;
|
| 17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
| 18 |
+
|
| 19 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
| 20 |
+
|
| 21 |
+
float state[_N_];
|
| 22 |
+
#pragma unroll
|
| 23 |
+
for (int j = 0; j < _N_; j++)
|
| 24 |
+
state[j] = _state[j];
|
| 25 |
+
|
| 26 |
+
__syncthreads();
|
| 27 |
+
u[i] = float(_u[i]);
|
| 28 |
+
__syncthreads();
|
| 29 |
+
|
| 30 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
| 31 |
+
{
|
| 32 |
+
__syncthreads();
|
| 33 |
+
w[i] = _w[t];
|
| 34 |
+
r[i] = float(_r[t]);
|
| 35 |
+
k[i] = float(_k[t]);
|
| 36 |
+
__syncthreads();
|
| 37 |
+
|
| 38 |
+
const float v = float(_v[t]);
|
| 39 |
+
float y = 0;
|
| 40 |
+
|
| 41 |
+
#pragma unroll
|
| 42 |
+
for (int j = 0; j < _N_; j+=4)
|
| 43 |
+
{
|
| 44 |
+
const float4& r_ = (float4&)(r[j]);
|
| 45 |
+
const float4& k_ = (float4&)(k[j]);
|
| 46 |
+
const float4& w_ = (float4&)(w[j]);
|
| 47 |
+
const float4& u_ = (float4&)(u[j]);
|
| 48 |
+
float4& s = (float4&)(state[j]);
|
| 49 |
+
float4 x;
|
| 50 |
+
|
| 51 |
+
x.x = k_.x * v;
|
| 52 |
+
x.y = k_.y * v;
|
| 53 |
+
x.z = k_.z * v;
|
| 54 |
+
x.w = k_.w * v;
|
| 55 |
+
|
| 56 |
+
y += r_.x * (u_.x * x.x + s.x);
|
| 57 |
+
y += r_.y * (u_.y * x.y + s.y);
|
| 58 |
+
y += r_.z * (u_.z * x.z + s.z);
|
| 59 |
+
y += r_.w * (u_.w * x.w + s.w);
|
| 60 |
+
|
| 61 |
+
s.x = s.x * w_.x + x.x;
|
| 62 |
+
s.y = s.y * w_.y + x.y;
|
| 63 |
+
s.z = s.z * w_.z + x.z;
|
| 64 |
+
s.w = s.w * w_.w + x.w;
|
| 65 |
+
}
|
| 66 |
+
_y[t] = F(y);
|
| 67 |
+
}
|
| 68 |
+
#pragma unroll
|
| 69 |
+
for (int j = 0; j < _N_; j++)
|
| 70 |
+
_state[j] = state[j];
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
| 74 |
+
{
|
| 75 |
+
assert(H*_N_ == C);
|
| 76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 77 |
+
}
|
| 78 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
| 79 |
+
{
|
| 80 |
+
assert(H*_N_ == C);
|
| 81 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 82 |
+
}
|
| 83 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
| 84 |
+
{
|
| 85 |
+
assert(H*_N_ == C);
|
| 86 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
| 87 |
+
}
|
cuda/rwkv6_op.cpp
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
typedef at::BFloat16 bf16;
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
| 9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
| 10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
| 11 |
+
|
| 12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
| 15 |
+
}
|
| 16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
| 19 |
+
}
|
| 20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
| 21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
| 22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
|
| 27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
|
| 28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
|
| 29 |
+
}
|
| 30 |
+
TORCH_LIBRARY(rwkv6, m) {
|
| 31 |
+
m.def("forward_bf16", forward_bf16);
|
| 32 |
+
m.def("forward_fp16", forward_fp16);
|
| 33 |
+
m.def("forward_fp32", forward_fp32);
|
| 34 |
+
}
|
cuda/rwkv7.cu
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
|
| 5 |
+
typedef at::Half fp16;
|
| 6 |
+
typedef at::BFloat16 bf16;
|
| 7 |
+
typedef float fp32;
|
| 8 |
+
|
| 9 |
+
template <typename F>
|
| 10 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
| 11 |
+
float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
|
| 12 |
+
F *__restrict__ const _y)
|
| 13 |
+
{
|
| 14 |
+
const int e = blockIdx.x / H;
|
| 15 |
+
const int h = blockIdx.x % H;
|
| 16 |
+
const int i = threadIdx.x;
|
| 17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
| 18 |
+
|
| 19 |
+
float state[_N_];
|
| 20 |
+
#pragma unroll
|
| 21 |
+
for (int j = 0; j < _N_; j++)
|
| 22 |
+
state[j] = _state[j];
|
| 23 |
+
|
| 24 |
+
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
|
| 25 |
+
|
| 26 |
+
for (int _t = 0; _t < T; _t++)
|
| 27 |
+
{
|
| 28 |
+
const int t = e*T*C + h*_N_ + i + _t * C;
|
| 29 |
+
__syncthreads();
|
| 30 |
+
r[i] = float(_r[t]);
|
| 31 |
+
w[i] = __expf(-__expf(float(_w[t])));
|
| 32 |
+
k[i] = float(_k[t]);
|
| 33 |
+
a[i] = float(_a[t]);
|
| 34 |
+
b[i] = float(_b[t]);
|
| 35 |
+
__syncthreads();
|
| 36 |
+
|
| 37 |
+
float sa = 0;
|
| 38 |
+
#pragma unroll
|
| 39 |
+
for (int j = 0; j < _N_; j++)
|
| 40 |
+
{
|
| 41 |
+
sa += a[j] * state[j];
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
float vv = float(_v[t]);
|
| 45 |
+
float y = 0;
|
| 46 |
+
#pragma unroll
|
| 47 |
+
for (int j = 0; j < _N_; j++)
|
| 48 |
+
{
|
| 49 |
+
float& s = state[j];
|
| 50 |
+
s = s * w[j] + k[j] * vv + sa * b[j];
|
| 51 |
+
y += s * r[j];
|
| 52 |
+
}
|
| 53 |
+
_y[t] = F(y);
|
| 54 |
+
}
|
| 55 |
+
#pragma unroll
|
| 56 |
+
for (int j = 0; j < _N_; j++)
|
| 57 |
+
_state[j] = state[j];
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
|
| 61 |
+
{
|
| 62 |
+
assert(H*_N_ == C);
|
| 63 |
+
assert(B == 1); // only for B=1
|
| 64 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
| 65 |
+
}
|
| 66 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
|
| 67 |
+
{
|
| 68 |
+
assert(H*_N_ == C);
|
| 69 |
+
assert(B == 1); // only for B=1
|
| 70 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
| 71 |
+
}
|
| 72 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
|
| 73 |
+
{
|
| 74 |
+
assert(H*_N_ == C);
|
| 75 |
+
assert(B == 1); // only for B=1
|
| 76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
| 77 |
+
}
|
cuda/rwkv7_op.cpp
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
|
| 4 |
+
typedef at::Half fp16;
|
| 5 |
+
typedef at::BFloat16 bf16;
|
| 6 |
+
typedef float fp32;
|
| 7 |
+
|
| 8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
|
| 9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
|
| 10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
|
| 11 |
+
|
| 12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
| 13 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
|
| 14 |
+
}
|
| 15 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
| 16 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
|
| 17 |
+
}
|
| 18 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
| 19 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
TORCH_LIBRARY(wkv7s, m) {
|
| 23 |
+
m.def("forward_bf16", forward_bf16);
|
| 24 |
+
m.def("forward_fp16", forward_fp16);
|
| 25 |
+
m.def("forward_fp32", forward_fp32);
|
| 26 |
+
}
|
cuda/wrapper.cpp
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 5 |
+
|
| 6 |
+
typedef at::Half fp16;
|
| 7 |
+
|
| 8 |
+
template <typename F>
|
| 9 |
+
void cuda_wkv_forward(int B, int T, int C,
|
| 10 |
+
float *w, float *u, F *k, F *v, F *y,
|
| 11 |
+
float *aa, float *bb, float *pp);
|
| 12 |
+
template <typename F>
|
| 13 |
+
void cuda_mm8_seq(int B, int N, int M,
|
| 14 |
+
F *x, int x_stride,
|
| 15 |
+
uint8_t *w, int w_stride,
|
| 16 |
+
F *mx, F *rx,
|
| 17 |
+
F *my, F *ry,
|
| 18 |
+
F *y, int y_stride);
|
| 19 |
+
template <typename F>
|
| 20 |
+
void cuda_mm8_one(int N, int M,
|
| 21 |
+
F *x,
|
| 22 |
+
uint8_t *w, int w_stride,
|
| 23 |
+
F *mx, F *rx,
|
| 24 |
+
F *my, F *ry,
|
| 25 |
+
float *y);
|
| 26 |
+
|
| 27 |
+
void wkv_forward(int64_t B, int64_t T, int64_t C,
|
| 28 |
+
torch::Tensor &w, torch::Tensor &u,
|
| 29 |
+
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
| 30 |
+
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
|
| 31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
| 32 |
+
switch (k.scalar_type()) {
|
| 33 |
+
case c10::ScalarType::Half:
|
| 34 |
+
cuda_wkv_forward(B, T, C,
|
| 35 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
| 36 |
+
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
|
| 37 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
| 38 |
+
break;
|
| 39 |
+
case c10::ScalarType::Float:
|
| 40 |
+
cuda_wkv_forward(B, T, C,
|
| 41 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
| 42 |
+
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
|
| 43 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
| 44 |
+
break;
|
| 45 |
+
default:
|
| 46 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
void mm8_seq(int64_t B, int64_t N, int64_t M,
|
| 51 |
+
torch::Tensor &x, torch::Tensor &w,
|
| 52 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
| 53 |
+
torch::Tensor &my, torch::Tensor &ry,
|
| 54 |
+
torch::Tensor &y) {
|
| 55 |
+
assert(x.stride(1) == 1);
|
| 56 |
+
assert(w.stride(1) == 1);
|
| 57 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
| 58 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
| 59 |
+
assert(y.stride(1) == 1);
|
| 60 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
| 61 |
+
switch (x.scalar_type()) {
|
| 62 |
+
case c10::ScalarType::Half:
|
| 63 |
+
cuda_mm8_seq(
|
| 64 |
+
B, N, M,
|
| 65 |
+
x.data_ptr<fp16>(), x.stride(0),
|
| 66 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 67 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
| 68 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
| 69 |
+
y.data_ptr<fp16>(), y.stride(0));
|
| 70 |
+
break;
|
| 71 |
+
case c10::ScalarType::Float:
|
| 72 |
+
cuda_mm8_seq(
|
| 73 |
+
B, N, M,
|
| 74 |
+
x.data_ptr<float>(), x.stride(0),
|
| 75 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 76 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
| 77 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
| 78 |
+
y.data_ptr<float>(), y.stride(0));
|
| 79 |
+
break;
|
| 80 |
+
default:
|
| 81 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
void mm8_one(int64_t N, int64_t M,
|
| 85 |
+
torch::Tensor &x, torch::Tensor &w,
|
| 86 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
| 87 |
+
torch::Tensor &my, torch::Tensor &ry,
|
| 88 |
+
torch::Tensor &y) {
|
| 89 |
+
assert(x.stride(0) == 1);
|
| 90 |
+
assert(w.stride(1) == 1);
|
| 91 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
| 92 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
| 93 |
+
assert(y.stride(0) == 1);
|
| 94 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
| 95 |
+
switch (x.scalar_type()) {
|
| 96 |
+
case c10::ScalarType::Half:
|
| 97 |
+
cuda_mm8_one(
|
| 98 |
+
N, M,
|
| 99 |
+
x.data_ptr<fp16>(),
|
| 100 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 101 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
| 102 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
| 103 |
+
y.data_ptr<float>());
|
| 104 |
+
break;
|
| 105 |
+
case c10::ScalarType::Float:
|
| 106 |
+
cuda_mm8_one(
|
| 107 |
+
N, M,
|
| 108 |
+
x.data_ptr<float>(),
|
| 109 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
| 110 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
| 111 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
| 112 |
+
y.data_ptr<float>());
|
| 113 |
+
break;
|
| 114 |
+
default:
|
| 115 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
using torch::Tensor;
|
| 120 |
+
|
| 121 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
| 122 |
+
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
| 123 |
+
#endif
|
| 124 |
+
|
| 125 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 126 |
+
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
| 127 |
+
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
| 128 |
+
m.def("mm8_one", &mm8_one, "mm8 one");
|
| 129 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
| 130 |
+
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
| 131 |
+
#endif
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
TORCH_LIBRARY(rwkv, m) {
|
| 135 |
+
m.def("wkv_forward", wkv_forward);
|
| 136 |
+
m.def("mm8_seq", mm8_seq);
|
| 137 |
+
m.def("mm8_one", mm8_one);
|
| 138 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
| 139 |
+
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
| 140 |
+
#endif
|
| 141 |
+
}
|
download_models.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download model weights listed in a config YAML (replicates Dockerfile download behavior without Docker).
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python download_models.py --config config.production.yaml
|
| 7 |
+
|
| 8 |
+
This script uses huggingface_hub.hf_hub_download to download specified .pth files to the
|
| 9 |
+
model's DOWNLOAD_MODEL_DIR (or ./models by default).
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import yaml
|
| 14 |
+
import time
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--config", default="config.production.yaml")
|
| 21 |
+
parser.add_argument("--token", default=None, help="Hugging Face token (optional)")
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 25 |
+
cfg = yaml.safe_load(f.read())
|
| 26 |
+
|
| 27 |
+
models = cfg.get("MODELS", [])
|
| 28 |
+
if len(models) == 0:
|
| 29 |
+
print("No models found in config. Nothing to download.")
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
for m in models:
|
| 33 |
+
repo_id = m.get("DOWNLOAD_MODEL_REPO_ID")
|
| 34 |
+
filename = m.get("DOWNLOAD_MODEL_FILE_NAME")
|
| 35 |
+
local_dir = m.get("DOWNLOAD_MODEL_DIR", "./models")
|
| 36 |
+
|
| 37 |
+
if repo_id is None or filename is None:
|
| 38 |
+
print(f"Skipping model with incomplete download info: {m}")
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 42 |
+
print(f"Downloading {filename} from repo {repo_id} into {local_dir} ...")
|
| 43 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# Add retry logic
|
| 46 |
+
max_attempts = 5
|
| 47 |
+
for attempt in range(1, max_attempts + 1):
|
| 48 |
+
try:
|
| 49 |
+
path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir, token=args.token)
|
| 50 |
+
print(f"Downloaded file to {path}")
|
| 51 |
+
break
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Attempt {attempt} failed to download {filename} from {repo_id}: {e}")
|
| 54 |
+
if attempt < max_attempts:
|
| 55 |
+
print(f"Retrying in {attempt*5} seconds...")
|
| 56 |
+
time.sleep(attempt * 5)
|
| 57 |
+
else:
|
| 58 |
+
print(f"Failed after {max_attempts} attempts. Skipping {filename}.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
models/.cache/huggingface/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*
|
models/.cache/huggingface/download/rwkv7-g1a-0.1b-20250728-ctx4096.pth.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
8c8cdf8c605dc7dfdccb676b9d0c482ba002f710
|
| 2 |
+
964f01cc4673273bbcf1b9c3cdc243d58af97bffeab51cb20c752eeaf048a3c6
|
| 3 |
+
1763947179.4323187
|
models/rwkv7-g1a-0.1b-20250728-ctx4096.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:964f01cc4673273bbcf1b9c3cdc243d58af97bffeab51cb20c752eeaf048a3c6
|
| 3 |
+
size 382223868
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "rwkv-hf-space"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi[standard]>=0.115.11",
|
| 9 |
+
"huggingface-hub>=0.29.1",
|
| 10 |
+
"loguru>=0.7.3",
|
| 11 |
+
"ninja>=1.11.1.3",
|
| 12 |
+
"numpy>=2.2.3",
|
| 13 |
+
"pydantic>=2.10.6",
|
| 14 |
+
"pydantic-settings>=2.8.1",
|
| 15 |
+
"pynvml>=12.0.0",
|
| 16 |
+
"rich>=13.9.4",
|
| 17 |
+
"rwkv>=0.8.30",
|
| 18 |
+
"setuptools>=75.8.2",
|
| 19 |
+
"snowflake-id>=1.0.2",
|
| 20 |
+
"modelscope>=1.23.0",
|
| 21 |
+
"transformers",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.optional-dependencies]
|
| 25 |
+
cpu = ["torch>=2.6.0"]
|
| 26 |
+
cu124 = ["torch>=2.6.0"]
|
| 27 |
+
|
| 28 |
+
[tool.uv]
|
| 29 |
+
conflicts = [[{ extra = "cpu" }, { extra = "cu124" }, { extra = "cu113" }]]
|
| 30 |
+
|
| 31 |
+
[tool.uv.sources]
|
| 32 |
+
torch = [
|
| 33 |
+
{ index = "pytorch-cpu", extra = "cpu" },
|
| 34 |
+
{ index = "pytorch-cu124", extra = "cu124" },
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[[tool.uv.index]]
|
| 38 |
+
name = "pytorch-cpu"
|
| 39 |
+
url = "https://download.pytorch.org/whl/cpu"
|
| 40 |
+
explicit = true
|
| 41 |
+
|
| 42 |
+
[[tool.uv.index]]
|
| 43 |
+
name = "pytorch-cu124"
|
| 44 |
+
url = "https://download.pytorch.org/whl/cu124"
|
| 45 |
+
explicit = true
|
run_windows.ps1
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Param(
|
| 2 |
+
[string]$CONFIG_FILE = 'config.production.yaml'
|
| 3 |
+
)
|
| 4 |
+
|
| 5 |
+
if (-not (Test-Path .\.venv\Scripts\Activate.ps1)) {
|
| 6 |
+
Write-Host "Virtualenv not found. Run setup_windows.ps1 first." -ForegroundColor Red
|
| 7 |
+
exit 1
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.\.venv\Scripts\Activate.ps1
|
| 11 |
+
$env:CONFIG_FILE=$CONFIG_FILE
|
| 12 |
+
|
| 13 |
+
Write-Host "Starting the RWKV FastAPI app using $CONFIG_FILE..." -ForegroundColor Green
|
| 14 |
+
python app.py
|
setup_windows.ps1
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Param(
|
| 2 |
+
[switch]$gpu,
|
| 3 |
+
[string]$CONFIG_FILE = 'config.production.yaml',
|
| 4 |
+
[switch]$buildFrontend
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
Write-Host "Starting RWKV local setup for Windows..." -ForegroundColor Green
|
| 8 |
+
|
| 9 |
+
if (-not (Get-Command python -ErrorAction SilentlyContinue)) {
|
| 10 |
+
Write-Host "Python not found. Please install Python 3.10+ and add it to PATH." -ForegroundColor Red
|
| 11 |
+
exit 1
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
Write-Host "Creating virtual environment (./.venv) ..."
|
| 15 |
+
python -m venv .venv
|
| 16 |
+
.\.venv\Scripts\Activate.ps1
|
| 17 |
+
pip install --upgrade pip setuptools wheel
|
| 18 |
+
|
| 19 |
+
if ($gpu) {
|
| 20 |
+
Write-Host "GPU support requested. Installing GPU dependencies (cu124) ..." -ForegroundColor Yellow
|
| 21 |
+
pip install -e .[cu124]
|
| 22 |
+
} else {
|
| 23 |
+
Write-Host "Installing CPU-only dependencies ..." -ForegroundColor Yellow
|
| 24 |
+
pip install -e .[cpu]
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
Write-Host "Installing extra tooling (huggingface_hub for downloads) ..."
|
| 28 |
+
pip install huggingface-hub
|
| 29 |
+
pip install beautifulsoup4
|
| 30 |
+
|
| 31 |
+
Write-Host "Downloading models from config: $CONFIG_FILE" -ForegroundColor Green
|
| 32 |
+
python .\download_models.py --config $CONFIG_FILE --token $env:HF_TOKEN
|
| 33 |
+
|
| 34 |
+
# Ensure the config file is available as config.local.yaml that the app's default reads
|
| 35 |
+
if (Test-Path $CONFIG_FILE) {
|
| 36 |
+
Write-Host "Copying $CONFIG_FILE to config.local.yaml so the application uses it by default..." -ForegroundColor Green
|
| 37 |
+
Copy-Item -Force $CONFIG_FILE config.local.yaml
|
| 38 |
+
|
| 39 |
+
# If GPU not requested, set STRATEGY to CPU in config.local.yaml
|
| 40 |
+
if (-not $gpu) {
|
| 41 |
+
Write-Host "GPU not requested. Setting STRATEGY to 'cpu fp16' in config.local.yaml..." -ForegroundColor Yellow
|
| 42 |
+
try {
|
| 43 |
+
$yaml = (Get-Content config.local.yaml -Raw)
|
| 44 |
+
# Replace the STRATEGY line with CPU + fp16 precision
|
| 45 |
+
$yaml = $yaml -replace '(?m)^STRATEGY:.*$', 'STRATEGY: "cpu fp16"'
|
| 46 |
+
\r\"]*\"?", "STRATEGY: \"cpu\""
|
| 47 |
+
$yaml | Out-File -Encoding utf8 config.local.yaml -Force
|
| 48 |
+
} catch {
|
| 49 |
+
Write-Host "Warning: failed to modify config.local.yaml; please set STRATEGY manually" -ForegroundColor Yellow
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
if ($buildFrontend) {
|
| 55 |
+
if (-not (Get-Command pnpm -ErrorAction SilentlyContinue)) {
|
| 56 |
+
Write-Host "pnpm not found. Installing pnpm globally via npm..." -ForegroundColor Yellow
|
| 57 |
+
npm install -g pnpm
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
if (-not (Test-Path .\web-frontend)) {
|
| 61 |
+
Write-Host "Cloning web frontend repo..."
|
| 62 |
+
git clone https://github.com/SolomonLeon/web-rwkv-realweb.git web-frontend
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
Push-Location web-frontend
|
| 66 |
+
pnpm install
|
| 67 |
+
if ($env:MODELSCOPE_ENVIRONMENT -eq "studio") {
|
| 68 |
+
pnpm run build --mode target-rwkv-modelscope-space
|
| 69 |
+
} else {
|
| 70 |
+
pnpm run build --mode target-rwkv-hf-space
|
| 71 |
+
}
|
| 72 |
+
Pop-Location
|
| 73 |
+
|
| 74 |
+
# Copy dist to the project's dist-frontend
|
| 75 |
+
if (Test-Path .\web-frontend\dist) {
|
| 76 |
+
Remove-Item -Recurse -Force .\dist-frontend -ErrorAction SilentlyContinue
|
| 77 |
+
Copy-Item -Recurse .\web-frontend\dist .\dist-frontend
|
| 78 |
+
Write-Host "Frontend built and copied to ./dist-frontend" -ForegroundColor Green
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
Write-Host "Setup complete. Run the app with: \n$env:CONFIG_FILE='$CONFIG_FILE'\npython app.py" -ForegroundColor Cyan
|
tests/api_test.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests, json, time
|
| 2 |
+
|
| 3 |
+
BASE = 'http://127.0.0.1:7860/api/v1/chat/completions'
|
| 4 |
+
|
| 5 |
+
headers = {'Content-Type': 'application/json'}
|
| 6 |
+
|
| 7 |
+
print('Non-streaming example')
|
| 8 |
+
payload = {
|
| 9 |
+
'model': 'rwkv-latest',
|
| 10 |
+
'prompt': 'Who is the president of France today?',
|
| 11 |
+
'stream': False,
|
| 12 |
+
'max_tokens': 64,
|
| 13 |
+
'temperature': 0.2,
|
| 14 |
+
'include_usage': True,
|
| 15 |
+
}
|
| 16 |
+
try:
|
| 17 |
+
r = requests.post(BASE, json=payload, timeout=120)
|
| 18 |
+
print('Status', r.status_code)
|
| 19 |
+
try:
|
| 20 |
+
print(json.dumps(r.json(), indent=2))
|
| 21 |
+
except Exception:
|
| 22 |
+
print('Non-JSON response:', r.text[:1000])
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print('Error in non-stream request:', e)
|
| 25 |
+
|
| 26 |
+
print('\nTools: calc example')
|
| 27 |
+
payload = {
|
| 28 |
+
'model': 'rwkv-latest',
|
| 29 |
+
'prompt': 'Calculate 2+3*4 and explain the result.',
|
| 30 |
+
'stream': False,
|
| 31 |
+
'tools': [{'name': 'calc', 'args': {'expression': '2+3*4'}}],
|
| 32 |
+
'include_usage': True,
|
| 33 |
+
}
|
| 34 |
+
try:
|
| 35 |
+
r = requests.post(BASE, json=payload, timeout=120)
|
| 36 |
+
print('Status', r.status_code)
|
| 37 |
+
try:
|
| 38 |
+
print(json.dumps(r.json(), indent=2))
|
| 39 |
+
except Exception:
|
| 40 |
+
print('Non-JSON response:', r.text[:1000])
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print('Error in calc tool request:', e)
|
| 43 |
+
|
| 44 |
+
print('\nTools: web_search example')
|
| 45 |
+
payload = {
|
| 46 |
+
'model': 'rwkv-latest',
|
| 47 |
+
'prompt': 'Who is the current president of France?',
|
| 48 |
+
'stream': False,
|
| 49 |
+
'web_search': True,
|
| 50 |
+
'search_top_k': 2,
|
| 51 |
+
'include_usage': True,
|
| 52 |
+
}
|
| 53 |
+
try:
|
| 54 |
+
r = requests.post(BASE, json=payload, timeout=120)
|
| 55 |
+
print('Status', r.status_code)
|
| 56 |
+
try:
|
| 57 |
+
print(json.dumps(r.json(), indent=2))
|
| 58 |
+
except Exception:
|
| 59 |
+
print('Non-JSON response:', r.text[:1000])
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print('Error in web_search request:', e)
|
| 62 |
+
|
| 63 |
+
print('\nStreaming example (short)')
|
| 64 |
+
payload = {
|
| 65 |
+
'model': 'rwkv-latest:thinking',
|
| 66 |
+
'messages': [{'role': 'user', 'content': 'Explain Newton\'s first law in one sentence.'}],
|
| 67 |
+
'stream': True,
|
| 68 |
+
'max_tokens': 64,
|
| 69 |
+
'temperature': 0.2,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
r = requests.post(BASE, json=payload, headers=headers, stream=True, timeout=120)
|
| 74 |
+
print('Status', r.status_code)
|
| 75 |
+
if r.status_code == 200:
|
| 76 |
+
for line in r.iter_lines(decode_unicode=True):
|
| 77 |
+
if not line:
|
| 78 |
+
continue
|
| 79 |
+
print('SSE:', line)
|
| 80 |
+
if line.strip().endswith('[DONE]'):
|
| 81 |
+
break
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print('Error in streaming request:', e)
|
| 84 |
+
|
| 85 |
+
print('\nDone tests')
|
tests/run_local_exec.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
os.environ['MODELSCOPE_ENVIRONMENT'] = ''
|
| 5 |
+
|
| 6 |
+
from app import chat_completions, ChatCompletionRequest
|
| 7 |
+
|
| 8 |
+
async def run_once():
|
| 9 |
+
req = ChatCompletionRequest(model='rwkv-latest', prompt='Who is the president of France today?', stream=False, max_tokens=32, temperature=0.2, include_usage=True)
|
| 10 |
+
res = await chat_completions(req)
|
| 11 |
+
print(res)
|
| 12 |
+
|
| 13 |
+
if __name__ == '__main__':
|
| 14 |
+
asyncio.run(run_once())
|
utils.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re, os, threading, queue, requests
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from pydantic_settings import BaseSettings
|
| 5 |
+
|
| 6 |
+
from api_types import ChatMessage
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_think_response(full_response: str):
|
| 10 |
+
think_start = full_response.find("<think")
|
| 11 |
+
if think_start == -1:
|
| 12 |
+
return None, full_response.strip()
|
| 13 |
+
|
| 14 |
+
think_end = full_response.find("</think>")
|
| 15 |
+
if think_end == -1: # 未闭合的情况
|
| 16 |
+
reasoning = full_response[think_start:].strip()
|
| 17 |
+
content = ""
|
| 18 |
+
else:
|
| 19 |
+
reasoning = full_response[think_start : think_end + 9].strip() # +9包含完整标签
|
| 20 |
+
content = full_response[think_end + 9 :].strip()
|
| 21 |
+
|
| 22 |
+
# 清理标签保留内容
|
| 23 |
+
reasoning_content = reasoning.replace("<think", "").replace("</think>", "").strip()
|
| 24 |
+
return reasoning_content, content
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = False):
|
| 28 |
+
promptStrList = []
|
| 29 |
+
|
| 30 |
+
for message in messages:
|
| 31 |
+
content = message.content.strip()
|
| 32 |
+
content = re.sub(r"\n+", "\n", content)
|
| 33 |
+
promptStrList.append(
|
| 34 |
+
f"{message.role.strip().lower().capitalize()}: {content if message.role.strip().lower().capitalize()!='Assistant' or not removeThinkingContent else remove_nested_think_tags_stack(content)}"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return "\n\n".join(promptStrList)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def remove_nested_think_tags_stack(text):
|
| 41 |
+
stack = []
|
| 42 |
+
result = ""
|
| 43 |
+
i = 0
|
| 44 |
+
while i < len(text):
|
| 45 |
+
if text[i : i + 7] == "<think>":
|
| 46 |
+
stack.append("<think>")
|
| 47 |
+
i += 7
|
| 48 |
+
elif text[i : i + 8] == "</think>":
|
| 49 |
+
if stack and stack[-1] == "<think>":
|
| 50 |
+
stack.pop()
|
| 51 |
+
i += 8
|
| 52 |
+
else:
|
| 53 |
+
result += text[i : i + 8]
|
| 54 |
+
i += 8
|
| 55 |
+
elif not stack:
|
| 56 |
+
result += text[i]
|
| 57 |
+
i += 1
|
| 58 |
+
else:
|
| 59 |
+
i += 1
|
| 60 |
+
return result
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def format_bytes(size):
|
| 64 |
+
power = 2**10
|
| 65 |
+
n = 0
|
| 66 |
+
power_labels = {0: "", 1: "K", 2: "M", 3: "G", 4: "T"}
|
| 67 |
+
while size > power:
|
| 68 |
+
size /= power
|
| 69 |
+
n += 1
|
| 70 |
+
return f"{size:.4f}{power_labels[n]+'B'}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
LOGGER_QUEUE = queue.Queue(5)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def logger():
|
| 77 |
+
print("enable")
|
| 78 |
+
while True:
|
| 79 |
+
item = LOGGER_QUEUE.get()
|
| 80 |
+
try:
|
| 81 |
+
requests.post(
|
| 82 |
+
os.environ.get("LOG_PORT"),
|
| 83 |
+
headers={"Content-Type": "application/json"},
|
| 84 |
+
json=item,
|
| 85 |
+
)
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if os.environ.get("LOG_PORT"):
|
| 91 |
+
threading.Thread(target=logger).start()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def log(item):
|
| 95 |
+
LOGGER_QUEUE.put_nowait(item)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def web_search(query: str, top_k: int = 3) -> str:
|
| 99 |
+
"""Perform a simple web search via DuckDuckGo HTML and return top_k results as a combined string.
|
| 100 |
+
|
| 101 |
+
This is a lightweight fallback search that does not call external model services —
|
| 102 |
+
it queries a public search endpoint, parses titles/snippets/urls and returns them as
|
| 103 |
+
formatted text to be included into the model's prompt context.
|
| 104 |
+
"""
|
| 105 |
+
if not query or query.strip() == "":
|
| 106 |
+
return ""
|
| 107 |
+
try:
|
| 108 |
+
from bs4 import BeautifulSoup
|
| 109 |
+
except Exception:
|
| 110 |
+
return ""
|
| 111 |
+
try:
|
| 112 |
+
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
|
| 113 |
+
q = query.strip()
|
| 114 |
+
resp = requests.get("https://duckduckgo.com/html/", params={"q": q}, headers=headers, timeout=10)
|
| 115 |
+
soup = BeautifulSoup(resp.text, "html.parser")
|
| 116 |
+
# DuckDuckGo's html structure: results are in `div.result` containers.
|
| 117 |
+
results = []
|
| 118 |
+
for r in soup.find_all("div", class_="result", limit=top_k):
|
| 119 |
+
a = r.find("a", class_="result__a") or r.find("a", href=True)
|
| 120 |
+
title = a.get_text(strip=True) if a else ""
|
| 121 |
+
href = a.get("href") if a else ""
|
| 122 |
+
snippet = ""
|
| 123 |
+
s = r.find("a", class_="result__snippet") or r.find("div", class_="result__snippet")
|
| 124 |
+
if s:
|
| 125 |
+
snippet = s.get_text(strip=True)
|
| 126 |
+
results.append(f"{title} - {snippet} - {href}")
|
| 127 |
+
return "\n".join(results)
|
| 128 |
+
except Exception:
|
| 129 |
+
return ""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def calc(expr: str) -> str:
|
| 133 |
+
"""Safely evaluate a simple arithmetic expression and return the result as string.
|
| 134 |
+
|
| 135 |
+
This uses ast parsing to disallow attributes and only permit arithmetic operators.
|
| 136 |
+
"""
|
| 137 |
+
try:
|
| 138 |
+
import ast, operator as op
|
| 139 |
+
|
| 140 |
+
# supported operators
|
| 141 |
+
allowed_ops = {
|
| 142 |
+
ast.Add: op.add,
|
| 143 |
+
ast.Sub: op.sub,
|
| 144 |
+
ast.Mult: op.mul,
|
| 145 |
+
ast.Div: op.truediv,
|
| 146 |
+
ast.Pow: op.pow,
|
| 147 |
+
ast.BitXor: op.xor,
|
| 148 |
+
ast.USub: op.neg,
|
| 149 |
+
ast.Mod: op.mod,
|
| 150 |
+
ast.FloorDiv: op.floordiv,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
def _eval(node):
|
| 154 |
+
if isinstance(node, ast.Num): # <number>
|
| 155 |
+
return node.n
|
| 156 |
+
elif isinstance(node, ast.BinOp):
|
| 157 |
+
left = _eval(node.left)
|
| 158 |
+
right = _eval(node.right)
|
| 159 |
+
op_type = type(node.op)
|
| 160 |
+
if op_type in allowed_ops:
|
| 161 |
+
return allowed_ops[op_type](left, right)
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError("Unsupported operator")
|
| 164 |
+
elif isinstance(node, ast.UnaryOp):
|
| 165 |
+
operand = _eval(node.operand)
|
| 166 |
+
op_type = type(node.op)
|
| 167 |
+
if op_type in allowed_ops:
|
| 168 |
+
return allowed_ops[op_type](operand)
|
| 169 |
+
raise ValueError("Unsupported unary op")
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError("Unsupported expression type")
|
| 172 |
+
|
| 173 |
+
node = ast.parse(expr, mode='eval')
|
| 174 |
+
result = _eval(node.body)
|
| 175 |
+
return str(result)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
return f"ERROR: {e}"
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
verify_setup.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple local verification script to ensure the environment is prepared and the model is downloaded.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
def check_venv_python():
|
| 8 |
+
if sys.prefix == sys.base_prefix:
|
| 9 |
+
print("Not in a virtual environment; consider activating .venv")
|
| 10 |
+
else:
|
| 11 |
+
print(f"Virtualenv detected: {sys.prefix}")
|
| 12 |
+
|
| 13 |
+
def check_models_dir():
|
| 14 |
+
models_dir = "models"
|
| 15 |
+
if not os.path.exists(models_dir):
|
| 16 |
+
print("Models directory not found. Run download_models.py first.")
|
| 17 |
+
return False
|
| 18 |
+
files = [f for f in os.listdir(models_dir) if f.endswith('.pth')]
|
| 19 |
+
if not files:
|
| 20 |
+
print("No .pth files found in ./models. Run download_models.py to fetch model weights.")
|
| 21 |
+
return False
|
| 22 |
+
print(f"Found model files: {files}")
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
def check_dependencies():
|
| 26 |
+
try:
|
| 27 |
+
import importlib
|
| 28 |
+
packages = [
|
| 29 |
+
'fastapi', 'uvicorn', 'rwkv', 'huggingface_hub', 'pydantic', 'loguru'
|
| 30 |
+
]
|
| 31 |
+
missing = []
|
| 32 |
+
for p in packages:
|
| 33 |
+
if importlib.util.find_spec(p) is None:
|
| 34 |
+
missing.append(p)
|
| 35 |
+
if missing:
|
| 36 |
+
print(f"Missing packages: {missing}; install them in your venv")
|
| 37 |
+
return False
|
| 38 |
+
print("All key dependencies found.")
|
| 39 |
+
return True
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Dependency check failed: {e}")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
check_venv_python()
|
| 46 |
+
deps_ok = check_dependencies()
|
| 47 |
+
models_ok = check_models_dir()
|
| 48 |
+
if deps_ok and models_ok:
|
| 49 |
+
print("Environment appears configured. You can run: python app.py")
|
| 50 |
+
else:
|
| 51 |
+
print("Fix missing items and re-run verification.")
|
| 52 |
+
|
| 53 |
+
if __name__ == '__main__':
|
| 54 |
+
main()
|