Ksjsjjdj commited on
Commit
b219d99
·
verified ·
1 Parent(s): eaf7f79

Upload 34 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,47 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.gguf* filter=lfs diff=lfs merge=lfs -text
36
+ *.ggml filter=lfs diff=lfs merge=lfs -text
37
+ *.llamafile* filter=lfs diff=lfs merge=lfs -text
38
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
39
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
40
+ *.npy filter=lfs diff=lfs merge=lfs -text
41
+ *.npz filter=lfs diff=lfs merge=lfs -text
42
+ *.pickle filter=lfs diff=lfs merge=lfs -text
43
+ *.pkl filter=lfs diff=lfs merge=lfs -text
44
+ *.tar filter=lfs diff=lfs merge=lfs -text
45
+ *.wasm filter=lfs diff=lfs merge=lfs -text
46
+ *.zst filter=lfs diff=lfs merge=lfs -text
47
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .cache
13
+
14
+ *pth
15
+ *.pt
16
+ *.st
17
+ *local*
18
+
19
+ dist-frontend/
20
+
21
+ .vscode/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
Dockerfile ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM node:20-alpine AS FrontendBuilder
2
+
3
+ RUN apk update && apk upgrade && \
4
+ apk add --no-cache bash git openssh curl rust cargo
5
+
6
+ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
7
+ RUN npm install -g pnpm
8
+
9
+ ADD https://api.github.com/repos/SolomonLeon/web-rwkv-realweb/git/refs/heads/ version_1.json
10
+
11
+ WORKDIR /app
12
+ RUN git clone https://github.com/SolomonLeon/web-rwkv-realweb.git /app
13
+
14
+ WORKDIR /app/web-rwkv-wasm
15
+ RUN ["cargo", "install", "wasm-pack", "--locked"]
16
+
17
+ WORKDIR /app
18
+ ENV PATH=/root/.cargo/bin:$PATH
19
+ RUN pnpm install
20
+ RUN if [ "$MODELSCOPE_ENVIRONMENT" = "studio" ]; then \
21
+ pnpm run build --mode target-rwkv-modelscope-space; \
22
+ else \
23
+ pnpm run build --mode target-rwkv-hf-space; \
24
+ fi
25
+
26
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS Backend
27
+
28
+ RUN <<EOF
29
+ apt update
30
+ apt install --no-install-recommends -y \
31
+ build-essential \
32
+ git \
33
+ cuda-nvcc-12-4 \
34
+ cuda-cudart-dev-12-4 \
35
+ python3-dev \
36
+ python3-pip \
37
+ libpython3.10-dev
38
+ apt clean && rm -rf /var/lib/apt/lists/*
39
+ EOF
40
+
41
+
42
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
43
+
44
+ COPY . .
45
+
46
+ RUN useradd -m -u 1000 user
47
+ USER user
48
+
49
+ ENV HOME=/home/user \
50
+ PATH=/usr/local/cuda/bin:/home/user/.local/bin:$PATH \
51
+ LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" \
52
+ CXX=/usr/bin/g++ \
53
+ TORCH_CUDA_ARCH_LIST="7.5"
54
+ WORKDIR $HOME/app
55
+
56
+ COPY --chown=user . $HOME/app
57
+
58
+ COPY --chown=user --from=FrontendBuilder /app/dist $HOME/app/dist-frontend
59
+
60
+ RUN uv sync --frozen --extra cu124
61
+
62
+ CMD ["sh", "-c", "if [ \"$MODELSCOPE_ENVIRONMENT\" = \"studio\" ]; then CONFIG_FILE=\"./config.production-modelscope.yaml\"; else CONFIG_FILE=\"./config.production.yaml\"; fi; uv run --offline --frozen app.py --config_file \"$CONFIG_FILE\""]
README.md CHANGED
@@ -1,12 +1,89 @@
1
- ---
2
- title: Xd
3
- emoji: 🌖
4
- colorFrom: indigo
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- short_description: xd
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ title: RWKV HF Space
4
+ emoji: 🐦‍⬛
5
+ colorFrom: purple
6
+ colorTo: pink
7
+ sdk: docker
8
+ pinned: false
9
+
10
+ ---
11
+
12
+ # Simple RWKV OpenAI-Compatible API
13
+
14
+ ---
15
+
16
+ title: RWKV HF Space
17
+ emoji: 🐦‍⬛
18
+ colorFrom: purple
19
+ colorTo: pink
20
+ sdk: docker
21
+ pinned: false
22
+
23
+ ---
24
+
25
+ # Simple RWKV OpenAI-Compatible API
26
+
27
+ ## Quick Windows Setup (no Docker)
28
+
29
+ This repository was originally packaged with a Dockerfile. It now provides a `setup_windows.ps1` script that mirrors Dockerfile actions and sets up the service locally on Windows (installs Python dependencies, builds the frontend, and downloads the 0.1B model).
30
+
31
+ Prerequisites:
32
+ - Python 3.10+ installed and in PATH
33
+ - Node.js + npm (optional, required for building the frontend)
34
+ - (Optional) NVIDIA GPU and CUDA (for GPU runtime)
35
+
36
+ To setup locally on Windows (CPU-only):
37
+
38
+ ```powershell
39
+ .\setup_windows.ps1 -gpu:$false -buildFrontend:$true -CONFIG_FILE config.production.yaml
40
+ ```
41
+
42
+ If you have a compatible NVIDIA GPU and prefer to install GPU-enabled dependencies, run with the `-gpu` switch.
43
+
44
+ After setup, run the API:
45
+
46
+ ```powershell
47
+ #$env:CONFIG_FILE='config.production.yaml'
48
+ python app.py
49
+ ```
50
+
51
+ The default production config in `config.production.yaml` now contains a single model — the 0.1B model `rwkv7-g1a-0.1b-20250728-ctx4096` — set as default chat and reasoning model.
52
+
53
+ To download models defined in any config:
54
+
55
+ ```powershell
56
+ python download_models.py --config config.production.yaml
57
+ ```
58
+
59
+ This will store the downloaded .pth files under the `DOWNLOAD_MODEL_DIR` specified in the YAML (defaults to `./models`).
60
+
61
+ Advanced features:
62
+ - `reasoning` is performed in-process by the same model (no external reasoning model is used). Use a request model like `rwkv-latest:thinking` or set the reasoning suffix and the requested model will run reasoning in the same model.
63
+ - `web_search` functionality is available at the request level — set `web_search: true` and optionally `search_top_k` to inject search results from DuckDuckGo into the prompt. This is executed by the server and provided to the same model as context.
64
+ - `tools` are executed server-side and results injected into the prompt for the same model. Supported tools: `web_search` and `calc` (calculator). Example of `tools` usage:
65
+
66
+ ```json
67
+ {
68
+ "model": "rwkv-latest",
69
+ "prompt": "Calculate 2+3*4 and tell me the result",
70
+ "tools": [{"name": "calc", "args": {"expression": "2+3*4"}}]
71
+ }
72
+ ```
73
+
74
+ Example: POST with `web_search` and reasoning enabled
75
+
76
+ ```json
77
+ {
78
+ "model": "rwkv-latest:thinking",
79
+ "prompt": "Who is the current president of France?",
80
+ "max_tokens": 32,
81
+ "web_search": true,
82
+ "search_top_k": 3
83
+ }
84
+ ```
85
+
86
+ The server will perform a web search for the prompt, aggregate the top 3 results, and inject those into the prompt, then run the model with reasoning enabled — all using the same model instead of an external reasoning or search model.
87
+
88
+ Streaming behavior:
89
+ - The API streams responses token-by-token by default (`stream: true`) and persists a `state_name` for the generation if requested (or will generate one). Provide `state_name` to resume continuation from where the previous stream stopped. The server stores model state in memory under `(model, state_name)` so subsequent requests with the same `state_name` can continue generation from that exact point.
api_types.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any, Literal
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class ChatMessage(BaseModel):
6
+ role: str = Field()
7
+ content: str = Field()
8
+
9
+
10
+ class Logprob(BaseModel):
11
+ token: str
12
+ logprob: float
13
+ top_logprobs: Optional[List[Dict[str, Any]]] = None
14
+
15
+
16
+ class LogprobsContent(BaseModel):
17
+ content: Optional[List[Logprob]] = None
18
+ refusal: Optional[List[Logprob]] = None
19
+
20
+
21
+ class FunctionCall(BaseModel):
22
+ name: str
23
+ arguments: str
24
+
25
+
26
+ class ChatCompletionMessage(BaseModel):
27
+ role: Optional[str] = Field(
28
+ None, description="The role of the author of this message"
29
+ )
30
+ content: Optional[str] = Field(None, description="The contents of the message")
31
+ reasoning_content: Optional[str] = Field(
32
+ None, description="The reasoning contents of the message"
33
+ )
34
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
35
+ None, description="Tool calls generated by the model"
36
+ )
37
+
38
+
39
+ class PromptTokensDetails(BaseModel):
40
+ cached_tokens: int
41
+
42
+
43
+ class CompletionTokensDetails(BaseModel):
44
+ reasoning_tokens: int
45
+ accepted_prediction_tokens: int
46
+ rejected_prediction_tokens: int
47
+
48
+
49
+ class Usage(BaseModel):
50
+ prompt_tokens: int
51
+ completion_tokens: int
52
+ total_tokens: int
53
+ prompt_tokens_details: Optional[PromptTokensDetails]
54
+ # completion_tokens_details: CompletionTokensDetails
55
+
56
+
57
+ class ChatCompletionChoice(BaseModel):
58
+ index: int
59
+ message: Optional[ChatCompletionMessage] = None
60
+ delta: Optional[ChatCompletionMessage] = None
61
+ logprobs: Optional[LogprobsContent] = None
62
+ finish_reason: Optional[str] = Field(
63
+ ..., description="Reason for stopping: stop, length, content_filter, tool_calls"
64
+ )
65
+
66
+
67
+ class ChatCompletion(BaseModel):
68
+ id: str = Field(..., description="Unique identifier for the chat completion")
69
+ object: Literal["chat.completion"] = "chat.completion"
70
+ created: int = Field(..., description="Unix timestamp of creation")
71
+ model: str
72
+ choices: List[ChatCompletionChoice]
73
+ usage: Usage
74
+
75
+
76
+ class ChatCompletionChunk(BaseModel):
77
+ id: str = Field(..., description="Unique identifier for the chat completion")
78
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
79
+ created: int = Field(..., description="Unix timestamp of creation")
80
+ model: str
81
+ choices: List[ChatCompletionChoice]
82
+ usage: Optional[Usage]
app.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
4
+ from modelscope import patch_hub
5
+
6
+ patch_hub()
7
+
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
9
+
10
+
11
+ from config import CONFIG, ModelConfig
12
+ from utils import (
13
+ cleanMessages,
14
+ parse_think_response,
15
+ remove_nested_think_tags_stack,
16
+ format_bytes,
17
+ log,
18
+ )
19
+
20
+ import copy, types, gc, sys, re, time, collections, asyncio
21
+ from huggingface_hub import hf_hub_download
22
+ from loguru import logger
23
+ from rich import print
24
+
25
+ from snowflake import SnowflakeGenerator
26
+
27
+ CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
28
+
29
+ from typing import List, Optional, Union, Any, Dict
30
+ import uuid
31
+ from pydantic import BaseModel, Field, model_validator
32
+ from pydantic_settings import BaseSettings
33
+
34
+
35
+ import numpy as np
36
+ import torch
37
+
38
+
39
+ if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
40
+ logger.info(f"CUDA not found, fall back to cpu")
41
+ CONFIG.STRATEGY = "cpu fp16"
42
+ # Normalize STRATEGY to include precision if missing (e.g., 'cpu' -> 'cpu fp16')
43
+ _s = CONFIG.STRATEGY.lower()
44
+ if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s):
45
+ logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`")
46
+ CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16"
47
+
48
+
49
+ try:
50
+ from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
51
+ except Exception:
52
+ nvmlInit = None
53
+ nvmlDeviceGetHandleByIndex = None
54
+ nvmlDeviceGetMemoryInfo = None
55
+
56
+ if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None:
57
+ nvmlInit()
58
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
59
+
60
+
61
+ def logGPUState():
62
+ if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None:
63
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
64
+ logger.info(
65
+ f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
66
+ )
67
+
68
+
69
+ torch.backends.cudnn.benchmark = True
70
+ torch.backends.cudnn.allow_tf32 = True
71
+ torch.backends.cuda.matmul.allow_tf32 = True
72
+ os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
73
+ os.environ["RWKV_JIT_ON"] = "1"
74
+ os.environ["RWKV_CUDA_ON"] = (
75
+ "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
76
+ )
77
+
78
+ from rwkv.model import RWKV
79
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
80
+
81
+ from fastapi import FastAPI, HTTPException
82
+ from starlette.background import BackgroundTask
83
+ from fastapi.responses import StreamingResponse
84
+ from fastapi.middleware.cors import CORSMiddleware
85
+ from fastapi.staticfiles import StaticFiles
86
+ from fastapi.middleware.gzip import GZipMiddleware
87
+
88
+
89
+ from api_types import (
90
+ ChatMessage,
91
+ ChatCompletion,
92
+ ChatCompletionChunk,
93
+ Usage,
94
+ PromptTokensDetails,
95
+ ChatCompletionChoice,
96
+ ChatCompletionMessage,
97
+ )
98
+
99
+
100
+ class ModelStorage:
101
+ MODEL_CONFIG: Optional[ModelConfig] = None
102
+ model: Optional[RWKV] = None
103
+ pipeline: Optional[PIPELINE] = None
104
+
105
+
106
+ MODEL_STORAGE: Dict[str, ModelStorage] = {}
107
+
108
+ DEFALUT_MODEL_NAME = None
109
+ DEFAULT_REASONING_MODEL_NAME = None
110
+
111
+ # In-memory model state store to support streaming continuation/resume per state_name.
112
+ # Keys: (model_name, state_name) -> model_state object
113
+ STATE_STORE: Dict[tuple, Any] = {}
114
+
115
+ logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
116
+
117
+ logGPUState()
118
+
119
+ # Enforce single 0.1b model. If multiple models are present, select only the one
120
+ # that matches '0.1b' literally in the service name, to obey policy of single model.
121
+ filtered_models = [m for m in CONFIG.MODELS if '0.1b' in m.SERVICE_NAME]
122
+ if len(filtered_models) == 0:
123
+ # If no explicit 0.1b model detected, fall back to the first provided model but warn.
124
+ logger.warning("No '0.1b' model detected in config; using the first available model. To ensure single 0.1b use, include a model name with '0.1b'.")
125
+ CONFIG.MODELS = [CONFIG.MODELS[0]]
126
+ elif len(filtered_models) > 1:
127
+ logger.warning("Multiple '0.1b' models detected; selecting the first one as the single model.")
128
+ CONFIG.MODELS = [filtered_models[0]]
129
+ else:
130
+ CONFIG.MODELS = [filtered_models[0]]
131
+
132
+ for model_config in CONFIG.MODELS:
133
+ logger.info(f"Load Model - {model_config.SERVICE_NAME}")
134
+
135
+ if model_config.MODEL_FILE_PATH == None:
136
+ model_config.MODEL_FILE_PATH = hf_hub_download(
137
+ repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID),
138
+ filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME),
139
+ local_dir=str(model_config.DOWNLOAD_MODEL_DIR),
140
+ )
141
+ logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
142
+
143
+ if model_config.DEFAULT_CHAT:
144
+ if DEFALUT_MODEL_NAME != None:
145
+ logger.info(
146
+ f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
147
+ )
148
+ DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
149
+
150
+ if model_config.DEFAULT_REASONING:
151
+ if DEFAULT_REASONING_MODEL_NAME != None:
152
+ logger.info(
153
+ f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
154
+ )
155
+ DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
156
+
157
+ logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
158
+ print(model_config.DEFAULT_SAMPLER)
159
+
160
+ MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
161
+ MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
162
+ MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
163
+ model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
164
+ strategy=CONFIG.STRATEGY,
165
+ )
166
+ MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
167
+ MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
168
+ )
169
+ if "cuda" in CONFIG.STRATEGY:
170
+ torch.cuda.empty_cache()
171
+ gc.collect()
172
+ logGPUState()
173
+
174
+
175
+ logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
176
+ logger.info(
177
+ f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
178
+ )
179
+
180
+ if len(MODEL_STORAGE) == 1:
181
+ single_name = list(MODEL_STORAGE.keys())[0]
182
+ if DEFALUT_MODEL_NAME != single_name:
183
+ DEFALUT_MODEL_NAME = single_name
184
+ logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`")
185
+ if DEFAULT_REASONING_MODEL_NAME != single_name:
186
+ DEFAULT_REASONING_MODEL_NAME = single_name
187
+ logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`")
188
+
189
+
190
+ class ChatCompletionRequest(BaseModel):
191
+ model: str = Field(
192
+ default="rwkv-latest",
193
+ description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
194
+ )
195
+ messages: Optional[List[ChatMessage]] = Field(default=None)
196
+ prompt: Optional[str] = Field(default=None)
197
+ max_tokens: Optional[int] = Field(default=None)
198
+ temperature: Optional[float] = Field(default=None)
199
+ top_p: Optional[float] = Field(default=None)
200
+ presence_penalty: Optional[float] = Field(default=None)
201
+ count_penalty: Optional[float] = Field(default=None)
202
+ penalty_decay: Optional[float] = Field(default=None)
203
+ stream: Optional[bool] = Field(default=True, description="Whether to stream token-by-token responses by default")
204
+ state_name: Optional[str] = Field(default=None)
205
+ include_usage: Optional[bool] = Field(default=False)
206
+ stop: Optional[list[str]] = Field(["\n\n"])
207
+ stop_tokens: Optional[list[int]] = Field([0])
208
+ web_search: Optional[bool] = Field(default=False, description="Whether to perform a web search and append results to the prompt")
209
+ search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve")
210
+ tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})")
211
+
212
+ @model_validator(mode="before")
213
+ @classmethod
214
+ def validate_mutual_exclusivity(cls, data: Any) -> Any:
215
+ if not isinstance(data, dict):
216
+ return data
217
+
218
+ messages_provided = "messages" in data and data["messages"] != None
219
+ prompt_provided = "prompt" in data and data["prompt"] != None
220
+
221
+ if messages_provided and prompt_provided:
222
+ raise ValueError("messages and prompt cannot coexist. Choose one.")
223
+ if not messages_provided and not prompt_provided:
224
+ raise ValueError("Either messages or prompt must be provided.")
225
+ return data
226
+
227
+
228
+ app = FastAPI(title="RWKV OpenAI-Compatible API")
229
+
230
+ app.add_middleware(
231
+ CORSMiddleware,
232
+ allow_origins=["*"],
233
+ allow_credentials=True,
234
+ allow_methods=["*"],
235
+ allow_headers=["*"],
236
+ )
237
+ app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
238
+
239
+
240
+ async def runPrefill(
241
+ request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
242
+ ):
243
+ ctx = ctx.replace("\r\n", "\n")
244
+ out = None
245
+
246
+ ms = MODEL_STORAGE.get(request.model)
247
+ if not ms or not ms.pipeline or not ms.model:
248
+ raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
249
+ tokens = ms.pipeline.encode(ctx)
250
+ tokens = [int(x) for x in tokens]
251
+ model_tokens += tokens
252
+
253
+ while len(tokens) > 0:
254
+ out, model_state = ms.model.forward(
255
+ tokens[: CONFIG.CHUNK_LEN], model_state
256
+ )
257
+ tokens = tokens[CONFIG.CHUNK_LEN :]
258
+ await asyncio.sleep(0)
259
+
260
+ return out, model_tokens, model_state
261
+
262
+
263
+ def generate(
264
+ request: ChatCompletionRequest,
265
+ out,
266
+ model_tokens: List[int],
267
+ model_state,
268
+ max_tokens=2048,
269
+ ):
270
+ ms = MODEL_STORAGE.get(request.model)
271
+ if not ms or not ms.pipeline or not ms.model:
272
+ raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
273
+
274
+ temperature = request.temperature if request.temperature is not None else 0.2
275
+ top_p = request.top_p if request.top_p is not None else 0.9
276
+ alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0
277
+ alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0
278
+ penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5
279
+
280
+ args = PIPELINE_ARGS(
281
+ temperature=max(0.2, temperature),
282
+ top_p=top_p,
283
+ alpha_frequency=alpha_frequency,
284
+ alpha_presence=alpha_presence,
285
+ token_ban=[], # ban the generation of some tokens
286
+ token_stop=[0],
287
+ ) # stop generation whenever you see any token here
288
+
289
+ occurrence = {}
290
+ out_tokens: List[int] = []
291
+ out_last = 0
292
+
293
+ # Stream token-by-token; each chunk contains a single decoded token string.
294
+
295
+ for i in range(max_tokens):
296
+ for n in occurrence:
297
+ out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
298
+ # out[0] -= 1e10 # disable END_OF_TEXT
299
+
300
+ token = ms.pipeline.sample_logits(
301
+ out, temperature=args.temperature, top_p=args.top_p
302
+ )
303
+
304
+ if token == 0 and request.stop_tokens and token in request.stop_tokens:
305
+ yield {
306
+ "content": "",
307
+ "tokens": out_tokens[out_last:],
308
+ "finish_reason": "stop:token:0",
309
+ "state": model_state,
310
+ }
311
+
312
+ del out
313
+ gc.collect()
314
+ return
315
+
316
+ out, model_state = ms.model.forward([token], model_state)
317
+ model_tokens.append(token)
318
+ out_tokens.append(token)
319
+
320
+ if request.stop_tokens and token in request.stop_tokens:
321
+ yield {
322
+ "content": "",
323
+ "tokens": out_tokens[out_last:],
324
+ "finish_reason": f"stop:token:{token}",
325
+ "state": model_state,
326
+ }
327
+
328
+ del out
329
+ gc.collect()
330
+ return
331
+
332
+ for xxx in list(occurrence.keys()):
333
+ occurrence[xxx] *= penalty_decay
334
+ occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
335
+
336
+ # Decode token to text and yield it as a single-token chunk
337
+ decoded = ms.pipeline.decode([token])
338
+ # filter out replacement characters
339
+ if "\ufffd" in decoded:
340
+ continue
341
+
342
+ yield {
343
+ "content": decoded,
344
+ "tokens": [token],
345
+ "finish_reason": None,
346
+ "state": model_state,
347
+ }
348
+ out_last = i + 1
349
+
350
+ else:
351
+ yield {
352
+ "content": "",
353
+ "tokens": [],
354
+ "finish_reason": "length",
355
+ }
356
+
357
+
358
+ async def chatResponse(
359
+ request: ChatCompletionRequest,
360
+ model_state: Any,
361
+ completionId: str,
362
+ enableReasoning: bool,
363
+ ) -> ChatCompletion:
364
+ createTimestamp = time.time()
365
+
366
+ prompt = (
367
+ f"{cleanMessages(request.messages or [])}\n\nAssistant:{' <think' if enableReasoning else ''}"
368
+ if request.prompt == None
369
+ else request.prompt.strip()
370
+ )
371
+ # Process tools and web_search (tools executed server-side and results injected to prompt)
372
+ if request.tools:
373
+ try:
374
+ for tool in request.tools:
375
+ name = tool.get('name')
376
+ args = tool.get('args', {})
377
+ if name == 'web_search':
378
+ from utils import web_search
379
+
380
+ search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
381
+ search_top_k = int(args.get('top_k') or request.search_top_k or 3)
382
+ search_str = web_search(search_q, search_top_k)
383
+ if search_str:
384
+ prompt = (f"ToolResults:\n{search_str}\n\nUse these results to answer the prompt.\n\n" + prompt)
385
+ elif name == 'calc' or name == 'calculator':
386
+ from utils import calc
387
+
388
+ expr = args.get('expression')
389
+ if expr:
390
+ calc_res = calc(expr)
391
+ prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res}\n\nUse this result to answer the prompt.\n\n" + prompt)
392
+ else:
393
+ # Unsupported tool - ignore or log
394
+ logger.info(f"Unsupported tool requested: {name}")
395
+ except Exception as e:
396
+ logger.info(f"Tool processing error: {e}")
397
+ elif request.web_search:
398
+ try:
399
+ from utils import web_search
400
+
401
+ search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
402
+ search_res = web_search(search_q, int(request.search_top_k or 3))
403
+ if search_res:
404
+ prompt = f"WebSearchResults:\n{search_res}\n\n" + prompt
405
+ except Exception:
406
+ pass
407
+ logger.info(f"[REQ] {completionId} - prompt - {prompt}")
408
+
409
+ # Resume or prefill tokens/state
410
+ if request.state_name:
411
+ state_key = (request.model, request.state_name)
412
+ if state_key in STATE_STORE:
413
+ stored = STATE_STORE[state_key]
414
+ model_state = stored.get('state', model_state)
415
+ model_tokens = stored.get('model_tokens', [0])
416
+ out = None
417
+ else:
418
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
419
+ else:
420
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
421
+
422
+ prefillTime = time.time()
423
+ promptTokenCount = len(model_tokens)
424
+
425
+ fullResponse = " <think" if enableReasoning else ""
426
+ completionTokenCount = 0
427
+ finishReason = None
428
+
429
+ for chunk in generate(
430
+ request,
431
+ out,
432
+ model_tokens,
433
+ model_state,
434
+ max_tokens=(
435
+ 64000
436
+ if "max_tokens" not in request.model_fields_set and enableReasoning
437
+ else (request.max_tokens or 2048)
438
+ ),
439
+ ):
440
+ # chunk['content'] is now expected to be a single token's decoded text
441
+ fullResponse += chunk["content"]
442
+ # Check stop sequences (multi-token) after each token
443
+ for stop_words in request.stop or []:
444
+ if stop_words in fullResponse:
445
+ finishReason = f"stop:words:{stop_words}"
446
+ break
447
+ completionTokenCount += 1
448
+
449
+ if chunk["finish_reason"]:
450
+ finishReason = chunk["finish_reason"]
451
+ await asyncio.sleep(0)
452
+
453
+ genenrateTime = time.time()
454
+
455
+ responseLog = {
456
+ "content": fullResponse,
457
+ "finish": finishReason,
458
+ "prefill_len": promptTokenCount,
459
+ "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
460
+ "gen_len": completionTokenCount,
461
+ "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
462
+ }
463
+ logger.info(f"[RES] {completionId} - {responseLog}")
464
+
465
+ reasoning_content, content = parse_think_response(fullResponse)
466
+
467
+ response = ChatCompletion(
468
+ id=completionId,
469
+ created=int(createTimestamp),
470
+ model=request.model,
471
+ usage=Usage(
472
+ prompt_tokens=promptTokenCount,
473
+ completion_tokens=completionTokenCount,
474
+ total_tokens=promptTokenCount + completionTokenCount,
475
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
476
+ ),
477
+ choices=[
478
+ ChatCompletionChoice(
479
+ index=0,
480
+ message=ChatCompletionMessage(
481
+ role="Assistant",
482
+ content=content,
483
+ reasoning_content=reasoning_content if reasoning_content else None,
484
+ tool_calls=None,
485
+ ),
486
+ logprobs=None,
487
+ finish_reason=finishReason,
488
+ )
489
+ ],
490
+ )
491
+
492
+ # Save state if requested for future resumption
493
+ try:
494
+ if request.state_name:
495
+ STATE_STORE[(request.model, request.state_name)] = {
496
+ 'state': model_state,
497
+ 'model_tokens': model_tokens,
498
+ }
499
+ except Exception:
500
+ pass
501
+
502
+ return response
503
+
504
+
505
+ async def chatResponseStream(
506
+ request: ChatCompletionRequest,
507
+ model_state: Any,
508
+ completionId: str,
509
+ enableReasoning: bool,
510
+ ):
511
+ createTimestamp = int(time.time())
512
+
513
+ prompt = (
514
+ f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
515
+ if request.prompt == None
516
+ else request.prompt.strip()
517
+ )
518
+ # Process tools and web_search (tools executed server-side and results injected to prompt)
519
+ if request.tools:
520
+ try:
521
+ for tool in request.tools:
522
+ name = tool.get('name')
523
+ args = tool.get('args', {})
524
+ if name == 'web_search':
525
+ from utils import web_search
526
+
527
+ search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
528
+ search_top_k = int(args.get('top_k') or request.search_top_k or 3)
529
+ search_str = web_search(search_q, search_top_k)
530
+ if search_str:
531
+ prompt = (f"WebSearchResults:\n{search_str}\n\n" + prompt)
532
+ elif name == 'calc' or name == 'calculator':
533
+ from utils import calc
534
+
535
+ expr = args.get('expression')
536
+ if expr:
537
+ calc_res = calc(expr)
538
+ prompt = (f"CalcResult:{expr} = {calc_res}\n\n" + prompt)
539
+ else:
540
+ logger.info(f"Unsupported tool requested: {name}")
541
+ except Exception as e:
542
+ logger.info(f"Tool processing error: {e}")
543
+ elif request.web_search:
544
+ try:
545
+ from utils import web_search
546
+
547
+ search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
548
+ search_res = web_search(search_q, int(request.search_top_k or 3))
549
+ if search_res:
550
+ prompt = f"WebSearchResults:\n{search_res}\n\n" + prompt
551
+ except Exception:
552
+ pass
553
+
554
+ logger.info(f"[REQ] {completionId} - context\n```{prompt}```")
555
+
556
+ # Resume or prefill tokens/state
557
+ if request.state_name:
558
+ state_key = (request.model, request.state_name)
559
+ if state_key in STATE_STORE:
560
+ stored = STATE_STORE[state_key]
561
+ model_state = stored.get('state', model_state)
562
+ model_tokens = stored.get('model_tokens', [0])
563
+ out = None
564
+ else:
565
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
566
+ else:
567
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
568
+
569
+ prefillTime = time.time()
570
+ promptTokenCount = len(model_tokens)
571
+
572
+ completionTokenCount = 0
573
+ finishReason = None
574
+
575
+ response = ChatCompletionChunk(
576
+ id=completionId,
577
+ created=createTimestamp,
578
+ model=request.model,
579
+ usage=(
580
+ Usage(
581
+ prompt_tokens=promptTokenCount,
582
+ completion_tokens=completionTokenCount,
583
+ total_tokens=promptTokenCount + completionTokenCount,
584
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
585
+ )
586
+ if request.include_usage
587
+ else None
588
+ ),
589
+ choices=[
590
+ ChatCompletionChoice(
591
+ index=0,
592
+ delta=ChatCompletionMessage(
593
+ role="Assistant",
594
+ content="",
595
+ reasoning_content="" if enableReasoning else None,
596
+ tool_calls=None,
597
+ ),
598
+ logprobs=None,
599
+ finish_reason=finishReason,
600
+ )
601
+ ],
602
+ )
603
+ if response.choices and response.choices[0].delta is None:
604
+ response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
605
+ # Attach state_name in the initial chunk so client can save it to continue later
606
+ r_dict = response.model_dump()
607
+ r_dict['state_name'] = request.state_name
608
+ yield f"data: {r_dict}\n\n"
609
+
610
+ buffer = []
611
+
612
+ if enableReasoning:
613
+ buffer.append("<think")
614
+
615
+ streamConfig = {
616
+ "isChecking": False, # check whether is <think> tag
617
+ "fullTextCursor": 0,
618
+ "in_think": False,
619
+ "cacheStr": "",
620
+ }
621
+
622
+ for chunk in generate(
623
+ request,
624
+ out,
625
+ model_tokens,
626
+ model_state,
627
+ max_tokens=(
628
+ 64000
629
+ if "max_tokens" not in request.model_fields_set and enableReasoning
630
+ else (request.max_tokens or 2048)
631
+ ),
632
+ ):
633
+ completionTokenCount += 1
634
+ # Each token stream is delivered as a decoded character/bytes (maybe 1 or more chars)
635
+ chunkContent: str = chunk["content"]
636
+ buffer.append(chunkContent)
637
+
638
+ fullText = "".join(buffer)
639
+
640
+ if chunk["finish_reason"]:
641
+ finishReason = chunk["finish_reason"]
642
+
643
+ response = ChatCompletionChunk(
644
+ id=completionId,
645
+ created=createTimestamp,
646
+ model=request.model,
647
+ usage=(
648
+ Usage(
649
+ prompt_tokens=promptTokenCount,
650
+ completion_tokens=completionTokenCount,
651
+ total_tokens=promptTokenCount + completionTokenCount,
652
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
653
+ )
654
+ if request.include_usage
655
+ else None
656
+ ),
657
+ choices=[
658
+ ChatCompletionChoice(
659
+ index=0,
660
+ delta=ChatCompletionMessage(
661
+ role="Assistant",
662
+ content=None,
663
+ reasoning_content=None,
664
+ tool_calls=None,
665
+ ),
666
+ logprobs=None,
667
+ finish_reason=finishReason,
668
+ )
669
+ ],
670
+ )
671
+ if response.choices and response.choices[0].delta is None:
672
+ response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
673
+
674
+ markStart = fullText.find("<", streamConfig["fullTextCursor"])
675
+ if not streamConfig["isChecking"] and markStart != -1:
676
+ streamConfig["isChecking"] = True
677
+
678
+ if streamConfig["in_think"]:
679
+ delta = response.choices[0].delta
680
+ if delta is None:
681
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
682
+ response.choices[0].delta = delta
683
+ delta.reasoning_content = fullText[streamConfig["fullTextCursor"] : markStart]
684
+ else:
685
+ delta = response.choices[0].delta
686
+ if delta is None:
687
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
688
+ response.choices[0].delta = delta
689
+ delta.content = fullText[streamConfig["fullTextCursor"] : markStart]
690
+
691
+ streamConfig["cacheStr"] = ""
692
+ streamConfig["fullTextCursor"] = markStart
693
+
694
+ if streamConfig["isChecking"]:
695
+ streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
696
+ else:
697
+ if streamConfig["in_think"]:
698
+ delta = response.choices[0].delta
699
+ if delta is None:
700
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
701
+ response.choices[0].delta = delta
702
+ delta.reasoning_content = chunkContent
703
+ else:
704
+ delta = response.choices[0].delta
705
+ if delta is None:
706
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
707
+ response.choices[0].delta = delta
708
+ delta.content = chunkContent
709
+ streamConfig["fullTextCursor"] = len(fullText)
710
+
711
+ markEnd = fullText.find(">", streamConfig["fullTextCursor"])
712
+ if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
713
+ streamConfig["isChecking"] = False
714
+
715
+ if (
716
+ not streamConfig["in_think"]
717
+ and streamConfig["cacheStr"].find("<think>") != -1
718
+ ):
719
+ streamConfig["in_think"] = True
720
+
721
+ delta = response.choices[0].delta
722
+ if delta is None:
723
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
724
+ response.choices[0].delta = delta
725
+ delta.reasoning_content = (
726
+ delta.reasoning_content
727
+ if delta.reasoning_content != None
728
+ else "" + streamConfig["cacheStr"].replace("<think>", "")
729
+ )
730
+
731
+ elif (
732
+ streamConfig["in_think"]
733
+ and streamConfig["cacheStr"].find("</think>") != -1
734
+ ):
735
+ streamConfig["in_think"] = False
736
+
737
+ delta = response.choices[0].delta
738
+ if delta is None:
739
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
740
+ response.choices[0].delta = delta
741
+ delta.content = (
742
+ delta.content
743
+ if delta.content != None
744
+ else "" + streamConfig["cacheStr"].replace("</think>", "")
745
+ )
746
+ else:
747
+ if streamConfig["in_think"]:
748
+ delta = response.choices[0].delta
749
+ if delta is None:
750
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
751
+ response.choices[0].delta = delta
752
+ delta.reasoning_content = (
753
+ delta.reasoning_content
754
+ if delta.reasoning_content != None
755
+ else "" + streamConfig["cacheStr"]
756
+ )
757
+ else:
758
+ delta = response.choices[0].delta
759
+ if delta is None:
760
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
761
+ response.choices[0].delta = delta
762
+ delta.content = (
763
+ delta.content
764
+ if delta.content != None
765
+ else "" + streamConfig["cacheStr"]
766
+ )
767
+ streamConfig["fullTextCursor"] = len(fullText)
768
+
769
+ delta = response.choices[0].delta
770
+ if delta is None:
771
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
772
+ response.choices[0].delta = delta
773
+ if delta.content != None or delta.reasoning_content != None:
774
+ # Save model state frequently (after each token) to allow resuming
775
+ try:
776
+ if request.state_name:
777
+ STATE_STORE[(request.model, request.state_name)] = {
778
+ 'state': model_state,
779
+ 'model_tokens': model_tokens,
780
+ }
781
+ except Exception:
782
+ pass
783
+ yield f"data: {response.model_dump_json()}\n\n"
784
+ # check stop sequences and stop streaming if we see them
785
+ for stop_words in request.stop or []:
786
+ if stop_words in ''.join(buffer):
787
+ finishReason = f"stop:words:{stop_words}"
788
+ return
789
+
790
+ await asyncio.sleep(0)
791
+
792
+ del streamConfig
793
+ else:
794
+ for chunk in generate(request, out, model_tokens, model_state):
795
+ completionTokenCount += 1
796
+ buffer.append(chunk["content"])
797
+
798
+ if chunk["finish_reason"]:
799
+ finishReason = chunk["finish_reason"]
800
+
801
+ response = ChatCompletionChunk(
802
+ id=completionId,
803
+ created=createTimestamp,
804
+ model=request.model,
805
+ usage=(
806
+ Usage(
807
+ prompt_tokens=promptTokenCount,
808
+ completion_tokens=completionTokenCount,
809
+ total_tokens=promptTokenCount + completionTokenCount,
810
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
811
+ )
812
+ if request.include_usage
813
+ else None
814
+ ),
815
+ choices=[
816
+ ChatCompletionChoice(
817
+ index=0,
818
+ delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
819
+ logprobs=None,
820
+ finish_reason=finishReason,
821
+ )
822
+ ],
823
+ )
824
+
825
+ yield f"data: {response.model_dump_json()}\n\n"
826
+ await asyncio.sleep(0)
827
+
828
+ genenrateTime = time.time()
829
+
830
+ responseLog = {
831
+ "content": "".join(buffer),
832
+ "finish": finishReason,
833
+ "prefill_len": promptTokenCount,
834
+ "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
835
+ "gen_len": completionTokenCount,
836
+ "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
837
+ }
838
+ logger.info(f"[RES] {completionId} - {responseLog}")
839
+ if request.messages is None:
840
+ request.messages = []
841
+ request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"]))
842
+ log(
843
+ {
844
+ **request.model_dump(),
845
+ **responseLog,
846
+ "completionId": completionId,
847
+ "machineLabel": os.environ.get("MACHINE_LABEL"),
848
+ }
849
+ )
850
+
851
+ del buffer
852
+
853
+ yield "data: [DONE]\n\n"
854
+
855
+
856
+ @app.post("/api/v1/chat/completions")
857
+ async def chat_completions(request: ChatCompletionRequest):
858
+ completionId = str(next(CompletionIdGenerator))
859
+ logger.info(f"[REQ] {completionId} - {request.model_dump()}")
860
+
861
+ modelName = request.model.split(":")[0]
862
+ enableReasoning = ":thinking" in request.model
863
+
864
+ if "rwkv-latest" in request.model:
865
+ # Map to the default chat model in all cases. Do not redirect to a separate
866
+ # reasoning model when ':thinking' is used. The same model will be used
867
+ # and reasoning handled in-process by setting enableReasoning=True.
868
+ if DEFALUT_MODEL_NAME == None:
869
+ raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
870
+ ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME)
871
+ if not ms_def or not ms_def.MODEL_CONFIG:
872
+ raise HTTPException(500, "Default sampler config missing for default model")
873
+ defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER
874
+ request.model = DEFALUT_MODEL_NAME
875
+
876
+ elif modelName in MODEL_STORAGE:
877
+ ms_sel = MODEL_STORAGE.get(modelName)
878
+ if not ms_sel or not ms_sel.MODEL_CONFIG:
879
+ raise HTTPException(500, f"Default sampler config missing for model {modelName}")
880
+ defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER
881
+ request.model = modelName
882
+ else:
883
+ raise HTTPException(404, f"Can not find `{modelName}`")
884
+
885
+ async def chatResponseStreamDisconnect():
886
+ logGPUState()
887
+
888
+ # Load or initialize model_state and tokens based on state_name
889
+ model_state = None
890
+ model_tokens_for_resume = [0]
891
+ state_name = request.state_name
892
+ if state_name is None:
893
+ state_name = str(uuid.uuid4())
894
+ request.state_name = state_name
895
+ state_key = (request.model, state_name)
896
+ if state_key in STATE_STORE:
897
+ stored = STATE_STORE[state_key]
898
+ model_state = stored.get('state', None)
899
+ model_tokens_for_resume = stored.get('model_tokens', [0])
900
+ request_dict = request.model_dump()
901
+
902
+ for k, v in defaultSamplerConfig.model_dump().items():
903
+ if k in request_dict and request_dict[k] is None:
904
+ request_dict[k] = v
905
+ realRequest = ChatCompletionRequest(**request_dict)
906
+
907
+ logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
908
+
909
+ if request.stream:
910
+ r = StreamingResponse(
911
+ chatResponseStream(realRequest, model_state, completionId, enableReasoning),
912
+ media_type="text/event-stream",
913
+ background=BackgroundTask(chatResponseStreamDisconnect),
914
+ )
915
+ else:
916
+ r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
917
+ # Attach state_name to non-streaming response as additional metadata
918
+ try:
919
+ import json
920
+
921
+ if isinstance(r, ChatCompletion):
922
+ d = r.model_dump()
923
+ d['state_name'] = state_name
924
+ return d
925
+ except Exception:
926
+ pass
927
+
928
+ return r
929
+
930
+
931
+ if os.path.exists("dist-frontend"):
932
+ app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
933
+ else:
934
+ logger.info("dist-frontend not found; skipping static files mount")
935
+
936
+ if __name__ == "__main__":
937
+ import uvicorn
938
+
939
+ host = CONFIG.HOST or "127.0.0.1"
940
+ port = CONFIG.PORT or 7860
941
+ uvicorn.run(app, host=host, port=port)
app_stderr.log ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C:\Users\Administrator\Downloads\New folder (3)\RWKV\.venv\Lib\site-packages\torch\cuda\__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
2
+ import pynvml # type: ignore[import]
3
+ 2025-11-23 16:35:08.739 | INFO | __main__:<module>:104 - STRATEGY - cpu fp16
4
+ 2025-11-23 16:35:08.740 | INFO | __main__:<module>:109 - Load Model - rwkv7-g1a-0.1b-20250728-ctx4096
5
+ 2025-11-23 16:35:09.724 | INFO | __main__:<module>:117 - Load Model - Path - models\rwkv7-g1a-0.1b-20250728-ctx4096.pth
6
+ 2025-11-23 16:35:09.724 | INFO | __main__:<module>:133 - Load Model - Loading `rwkv7-g1a-0.1b-20250728-ctx4096`
7
+ 2025-11-23 16:35:15.073 | INFO | __main__:<module>:151 - Load Model - DEFALUT_MODEL_NAME is `rwkv7-g1a-0.1b-20250728-ctx4096`
8
+ 2025-11-23 16:35:15.074 | INFO | __main__:<module>:152 - Load Model - DEFAULT_REASONING_MODEL_NAME is `rwkv7-g1a-0.1b-20250728-ctx4096`
9
+ 2025-11-23 16:35:15.080 | INFO | __main__:<module>:746 - dist-frontend not found; skipping static files mount
10
+ INFO: Started server process [9328]
11
+ INFO: Waiting for application startup.
12
+ INFO: Application startup complete.
13
+ INFO: Uvicorn running on http://0.0.0.0:7860 (Press CTRL+C to quit)
14
+ 2025-11-23 16:35:51.067 | INFO | __main__:chat_completions:698 - [REQ] 7398519686318694400 - {'model': 'rwkv-latest', 'messages': None, 'prompt': 'Who is the current president of France?', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 3}
15
+ 2025-11-23 16:35:51.067 | INFO | __main__:chat_completions:729 - [REQ] 7398519686318694400 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'Who is the current president of France?', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 3}
16
+ 2025-11-23 16:35:53.728 | INFO | __main__:chatResponse:363 - [REQ] 7398519686318694400 - prompt - Who is the current president of France?
17
+ 2025-11-23 16:36:09.388 | INFO | __main__:chatResponse:402 - [RES] 7398519686318694400 - {'content': '\nThe current president of France is Emmanuel Macron.', 'finish': 'stop:words:\n\n', 'prefill_len': 9, 'prefill_tps': 1.36, 'gen_len': 6, 'gen_tps': 0.51}
18
+ 2025-11-23 16:36:52.165 | INFO | __main__:chat_completions:698 - [REQ] 7398519942582280192 - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096:thinking', 'messages': None, 'prompt': 'Summarize the first paragraph from the search about Python programming', 'max_tokens': 60, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
19
+ 2025-11-23 16:36:52.165 | INFO | __main__:chat_completions:729 - [REQ] 7398519942582280192 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'Summarize the first paragraph from the search about Python programming', 'max_tokens': 60, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
20
+ 2025-11-23 16:36:54.650 | INFO | __main__:chatResponse:363 - [REQ] 7398519942582280192 - prompt - Summarize the first paragraph from the search about Python programming
21
+ 2025-11-23 16:38:03.778 | INFO | __main__:chatResponse:402 - [RES] 7398519942582280192 - {'content': ' <think.\nThe first paragraph of the search is about Python programming. It talks about how to use Python for data analysis and machine learning. The second paragraph is about how to use Python for web development. It talks about how to use Python for creating websites and applications. The third', 'finish': 'length', 'prefill_len': 13, 'prefill_tps': 1.65, 'gen_len': 56, 'gen_tps': 0.88}
22
+ 2025-11-23 16:38:05.030 | INFO | __main__:chat_completions:698 - [REQ] 7398520248166686720 - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096:thinking', 'messages': None, 'prompt': 'Tell me a short summary of Python programming', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
23
+ 2025-11-23 16:38:05.033 | INFO | __main__:chat_completions:729 - [REQ] 7398520248166686720 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'Tell me a short summary of Python programming', 'max_tokens': 50, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': True, 'search_top_k': 2}
24
+ 2025-11-23 16:38:06.800 | INFO | __main__:chatResponse:363 - [REQ] 7398520248166686720 - prompt - Tell me a short summary of Python programming
25
+ 2025-11-23 16:38:24.585 | INFO | __main__:chatResponse:402 - [RES] 7398520248166686720 - {'content': ' <think and how it can be used to solve problems.', 'finish': 'stop:words:\n\n', 'prefill_len': 9, 'prefill_tps': 1.55, 'gen_len': 6, 'gen_tps': 0.44}
26
+ 2025-11-23 16:42:18.982 | INFO | __main__:chat_completions:698 - [REQ] 7398521313352130560 - {'model': 'rwkv-latest', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
27
+ 2025-11-23 16:42:18.982 | INFO | __main__:chat_completions:729 - [REQ] 7398521313352130560 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
28
+ 2025-11-23 16:42:18.982 | INFO | __main__:chatResponse:363 - [REQ] 7398521313352130560 - prompt - What is two plus three times four?
29
+ 2025-11-23 16:42:56.030 | INFO | __main__:chatResponse:402 - [RES] 7398521313352130560 - {'content': '\n100\nWhat is the difference between 0.9 and 0.8?\n0.2\nWhat is the sum of', 'finish': 'length', 'prefill_len': 9, 'prefill_tps': 2.17, 'gen_len': 28, 'gen_tps': 0.85}
30
+ 2025-11-23 16:44:08.178 | INFO | __main__:chat_completions:698 - [REQ] 7398521771353350144 - {'model': 'rwkv-latest', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
31
+ 2025-11-23 16:44:08.179 | INFO | __main__:chat_completions:729 - [REQ] 7398521771353350144 - Real - {'model': 'rwkv7-g1a-0.1b-20250728-ctx4096', 'messages': None, 'prompt': 'What is two plus three times four?', 'max_tokens': 32, 'temperature': None, 'top_p': None, 'presence_penalty': None, 'count_penalty': None, 'penalty_decay': None, 'stream': False, 'state_name': None, 'include_usage': False, 'stop': ['\n\n'], 'stop_tokens': [0], 'web_search': False, 'search_top_k': 3}
32
+ 2025-11-23 16:44:08.179 | INFO | __main__:chatResponse:363 - [REQ] 7398521771353350144 - prompt - What is two plus three times four?
33
+ 2025-11-23 16:44:45.828 | INFO | __main__:chatResponse:402 - [RES] 7398521771353350144 - {'content': '\nTwo plus three times four is eight.\nWhat is the sum of the digits of two-digit numbers?\nThe sum of', 'finish': 'length', 'prefill_len': 9, 'prefill_tps': 2.28, 'gen_len': 28, 'gen_tps': 0.83}
app_stdout.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### RWKV-7 "Goose" enabled ###
3
+
4
+ SamplerConfig(
5
+ max_tokens=4096,
6
+ temperature=1.0,
7
+ top_p=0.3,
8
+ presence_penalty=0.5,
9
+ count_penalty=0.5,
10
+ penalty_decay=0.996,
11
+ stop=['\n\n'],
12
+ stop_tokens=[0]
13
+ )
14
+ Loading models\rwkv7-g1a-0.1b-20250728-ctx4096 (cpu fp16)
15
+
16
+ INFO: 127.0.0.1:50012 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
17
+ INFO: 127.0.0.1:50128 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
18
+ INFO: 127.0.0.1:50128 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
19
+ INFO: 127.0.0.1:50763 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
20
+ INFO: 127.0.0.1:50973 - "POST /api/v1/chat/completions HTTP/1.1" 200 OK
21
+ INFO: 127.0.0.1:51134 - "POST /api/v1/chat/completions HTTP/1.1" 422 Unprocessable Entity
22
+ INFO: 127.0.0.1:51144 - "POST /api/v1/chat/completions HTTP/1.1" 422 Unprocessable Entity
config.local.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HOST: "0.0.0.0"
2
+ PORT: 7860
3
+ STRATEGY: "cpu fp16"
4
+ RWKV_CUDA_ON: False
5
+ CHUNK_LEN: 256
6
+ MODELS:
7
+ - SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
8
+ DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
9
+ DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
10
+ DOWNLOAD_MODEL_DIR: "./models"
11
+ REASONING: True
12
+ DEFAULT_CHAT: True
13
+ DEFAULT_REASONING: True
14
+ DEFAULT_SAMPLER:
15
+ max_tokens: 4096
16
+ temperature: 1.0
17
+ top_p: 0.3
18
+ presence_penalty: 0.5
19
+ count_penalty: 0.5
20
+ penalty_decay: 0.996
21
+ stop:
22
+ - "\n\n"
23
+ stop_tokens:
24
+ - 0
config.production-modelscope.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HOST: "0.0.0.0"
2
+ PORT: 7860
3
+ STRATEGY: "cuda fp16"
4
+ RWKV_CUDA_ON: True
5
+ CHUNK_LEN: 256
6
+ MODELS:
7
+ - SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
8
+ DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
9
+ DOWNLOAD_MODEL_REPO_ID: "RWKV/rwkv7-g1"
10
+ DOWNLOAD_MODEL_DIR: "./models"
11
+ REASONING: True
12
+ DEFAULT_CHAT: True
13
+ DEFAULT_REASONING: True
14
+ DEFAULT_SAMPLER:
15
+ max_tokens: 4096
16
+ temperature: 1.0
17
+ top_p: 0.3
18
+ presence_penalty: 0.5
19
+ count_penalty: 0.5
20
+ penalty_decay: 0.996
21
+ stop:
22
+ - "\n\n"
23
+ stop_tokens:
24
+ - 0
config.production.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HOST: "0.0.0.0"
2
+ PORT: 7860
3
+ STRATEGY: "cuda fp16"
4
+ RWKV_CUDA_ON: True
5
+ CHUNK_LEN: 256
6
+ MODELS:
7
+ - SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
8
+ DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
9
+ DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
10
+ DOWNLOAD_MODEL_DIR: "./models"
11
+ REASONING: True
12
+ DEFAULT_CHAT: True
13
+ DEFAULT_REASONING: True
14
+ DEFAULT_SAMPLER:
15
+ max_tokens: 4096
16
+ temperature: 1.0
17
+ top_p: 0.3
18
+ presence_penalty: 0.5
19
+ count_penalty: 0.5
20
+ penalty_decay: 0.996
21
+ stop:
22
+ - "\n\n"
23
+ stop_tokens:
24
+ - 0
config.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+ from typing import List, Optional, Union, Any
4
+
5
+ import sys
6
+
7
+
8
+ from pydantic_settings import BaseSettings
9
+
10
+
11
+ class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
12
+ CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
13
+
14
+
15
+ CLI_CONFIG = CliConfig()
16
+
17
+
18
+ class SamplerConfig(BaseModel):
19
+ """Default sampler configuration for each model."""
20
+
21
+ max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
22
+ temperature: float = Field(1.0, description="Sampling temperature.")
23
+ top_p: float = Field(0.3, description="Top-p sampling threshold.")
24
+ presence_penalty: float = Field(0.5, description="Presence penalty.")
25
+ count_penalty: float = Field(0.5, description="Count penalty.")
26
+ penalty_decay: float = Field(0.996, description="Penalty decay factor.")
27
+ stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
28
+ stop_tokens: List[int] = Field([0], description="List of stop tokens.")
29
+
30
+
31
+ class ModelConfig(BaseModel):
32
+ """Configuration for each individual model."""
33
+
34
+ SERVICE_NAME: str = Field(..., description="Service name of the model.")
35
+
36
+ MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
37
+
38
+ DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
39
+ None, description="Model name, should end with .pth"
40
+ )
41
+ DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
42
+ None, description="Model repository ID on Hugging Face Hub."
43
+ )
44
+ DOWNLOAD_MODEL_DIR: Optional[str] = Field(
45
+ "./models", description="Directory to download the model to."
46
+ )
47
+
48
+ REASONING: bool = Field(
49
+ False, description="Whether reasoning is enabled for this model."
50
+ )
51
+
52
+ DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
53
+ DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
54
+ DEFAULT_SAMPLER: SamplerConfig = Field(
55
+ SamplerConfig(), description="Default sampler configuration for this model."
56
+ )
57
+ VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
58
+
59
+
60
+ class RootConfig(BaseModel):
61
+ """Root configuration for the RWKV service."""
62
+
63
+ HOST: Optional[str] = Field(
64
+ "127.0.0.1", description="Host IP address to bind to."
65
+ ) # 注释掉可选的HOST和PORT
66
+ PORT: Optional[int] = Field(
67
+ 8000, description="Port number to listen on."
68
+ ) # 因为YAML示例中被注释掉了
69
+ STRATEGY: str = Field(
70
+ "cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
71
+ )
72
+ RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
73
+ CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
74
+ MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
75
+
76
+
77
+ import yaml
78
+
79
+ try:
80
+ with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
81
+ CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
82
+ except Exception as e:
83
+ print(f"Pydantic Model Validation Failed: {e}")
84
+ sys.exit(0)
cuda/gemm_fp16_cublas.cpp ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cublas_v2.h>
2
+ #include <cuda.h>
3
+ #include <cuda_fp16.h>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/extension.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <ATen/cuda/CUDAContext.h>
8
+
9
+ #define CUBLAS_CHECK(condition) \
10
+ for (cublasStatus_t _cublas_check_status = (condition); \
11
+ _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
12
+ throw std::runtime_error("cuBLAS error " + \
13
+ std::to_string(_cublas_check_status) + " at " + \
14
+ std::to_string(__LINE__));
15
+
16
+ #define CUDA_CHECK(condition) \
17
+ for (cudaError_t _cuda_check_status = (condition); \
18
+ _cuda_check_status != cudaSuccess;) \
19
+ throw std::runtime_error( \
20
+ "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
21
+ " at " + std::to_string(__LINE__));
22
+
23
+ /*
24
+ NOTE: blas gemm is column-major by default, but we need row-major output.
25
+ The data of row-major, transposed matrix is exactly the same as the
26
+ column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
27
+ */
28
+ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
29
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
30
+ const auto cuda_data_type = CUDA_R_16F;
31
+ const auto cuda_c_data_type =
32
+ c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
33
+ const auto compute_type = CUDA_R_32F;
34
+ const float sp_alpha = 1.f;
35
+ // swap a and b, and use CUBLAS_OP_N. see the notes above
36
+ std::swap(a, b);
37
+ const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
38
+ const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
39
+ // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
40
+ // negative axis is used because of the existence of batch matmul.
41
+ const int m = a.size(-1);
42
+ const int k = a.size(-2);
43
+ const int n = b.size(-2);
44
+ const int cublas_lda = m;
45
+ const int cublas_ldb = k;
46
+ const int cublas_ldc = m;
47
+ cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
48
+
49
+ #if CUDA_VERSION >= 11000
50
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
51
+ #else
52
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
53
+ #endif
54
+ const float sp_beta = 0.f;
55
+ if (a.sizes().size() == 2 && b.sizes().size() == 2) {
56
+ CUBLAS_CHECK(cublasGemmEx(
57
+ cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
58
+ a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
59
+ cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
60
+ compute_type, algo));
61
+ } else {
62
+ // batch matmul
63
+ assert(a.sizes().size() == 3 && b.sizes().size() == 3);
64
+
65
+ const long long int cublas_stride_a = m * k;
66
+ const long long int cublas_stride_b = k * n;
67
+ const long long int cublas_stride_c = m * n;
68
+ CUBLAS_CHECK(cublasGemmStridedBatchedEx(
69
+ cublas_handle, cublas_trans_a, cublas_trans_b, m,
70
+ n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
71
+ cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
72
+ &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
73
+ a.size(0), compute_type, algo));
74
+ }
75
+ }
cuda/operators.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #include <cuda_fp16.h>
5
+ #define MIN_VALUE (-1e38)
6
+ typedef at::Half fp16;
7
+ __half *cast(fp16 *ptr) {
8
+ return reinterpret_cast<__half *>(ptr);
9
+ }
10
+
11
+ template <typename F>
12
+ __global__ void kernel_wkv_forward(const int B, const int T, const int C,
13
+ const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
14
+ F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
15
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
16
+ const int _b = idx / C;
17
+ const int _c = idx % C;
18
+ const int _offset = _b * T * C + _c;
19
+ const int _state_offset = _b * C + _c;
20
+
21
+ float u = _u[_c];
22
+ float w = _w[_c];
23
+ const F *__restrict__ const k = _k + _offset;
24
+ const F *__restrict__ const v = _v + _offset;
25
+ F *__restrict__ const y = _y + _offset;
26
+
27
+ float aa = _aa[_state_offset];
28
+ float bb = _bb[_state_offset];
29
+ float pp = _pp[_state_offset];
30
+ for (int i = 0; i < T; i++) {
31
+ const int ii = i * C;
32
+ const float kk = float(k[ii]);
33
+ const float vv = float(v[ii]);
34
+ float ww = u + kk;
35
+ float p = max(pp, ww);
36
+ float e1 = exp(pp - p);
37
+ float e2 = exp(ww - p);
38
+ y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
39
+ ww = w + pp;
40
+ p = max(ww, kk);
41
+ e1 = exp(ww - p);
42
+ e2 = exp(kk - p);
43
+ aa = e1 * aa + e2 * vv;
44
+ bb = e1 * bb + e2;
45
+ pp = p;
46
+ }
47
+ _aa[_state_offset] = aa;
48
+ _bb[_state_offset] = bb;
49
+ _pp[_state_offset] = pp;
50
+ }
51
+
52
+ template <typename F>
53
+ void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
54
+ dim3 threadsPerBlock( min(C, 32) );
55
+ assert(B * C % threadsPerBlock.x == 0);
56
+ dim3 numBlocks(B * C / threadsPerBlock.x);
57
+ kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
58
+ }
59
+
60
+ template void cuda_wkv_forward<fp16>(
61
+ int B, int T, int C,
62
+ float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
63
+ float *aa, float *bb, float *pp);
64
+ template void cuda_wkv_forward<float>(
65
+ int B, int T, int C,
66
+ float *w, float *u, float *k, float *v, float *y,
67
+ float *aa, float *bb, float *pp);
68
+
69
+ __global__ void kernel_mm_seq_fp32i8(
70
+ const int B, const int N, const int M,
71
+ const float *__restrict__ const x, const int x_stride,
72
+ const uint8_t *__restrict__ const w, const int w_stride,
73
+ const float *__restrict__ const mx,
74
+ const float *__restrict__ const rx,
75
+ const float *__restrict__ const my,
76
+ const float *__restrict__ const ry,
77
+ float *__restrict__ const y, const int y_stride) {
78
+
79
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
80
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
81
+
82
+ if (i < B && k < M) {
83
+ float y_local = 0;
84
+ for (int j = 0; j < N; ++j) {
85
+ y_local += x[i * x_stride + j] * (
86
+ (float(w[j * w_stride + k]) + 0.5f)
87
+ * rx[k] * ry[j] + mx[k] + my[j]
88
+ );
89
+ }
90
+ y[i * y_stride + k] = y_local;
91
+ }
92
+ }
93
+
94
+ template <typename F>
95
+ void cuda_mm8_seq(int B, int N, int M,
96
+ F *x, int x_stride,
97
+ uint8_t *w, int w_stride,
98
+ F *mx, F *rx,
99
+ F *my, F *ry,
100
+ F *y, int y_stride);
101
+
102
+ template <>
103
+ void cuda_mm8_seq<float>(int B, int N, int M,
104
+ float *x, int x_stride,
105
+ uint8_t *w, int w_stride,
106
+ float *mx, float *rx,
107
+ float *my, float *ry,
108
+ float *y, int y_stride) {
109
+ dim3 blockSize(1, 128);
110
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
111
+ kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
112
+ B, N, M, x, x_stride, w, w_stride,
113
+ mx, rx, my, ry, y, y_stride);
114
+ }
115
+
116
+ __global__ void kernel_mm_seq_fp16i8(
117
+ const int B, const int N, const int M,
118
+ const __half *__restrict__ const x, const int x_stride,
119
+ const uint8_t *__restrict__ const w, const int w_stride,
120
+ const __half *__restrict__ const mx,
121
+ const __half *__restrict__ const rx,
122
+ const __half *__restrict__ const my,
123
+ const __half *__restrict__ const ry,
124
+ __half *__restrict__ const y, const int y_stride) {
125
+
126
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
127
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
128
+
129
+ if (i < B && k < M) {
130
+ float y_local = 0;
131
+ for (int j = 0; j < N; ++j) {
132
+ y_local += __half2float(x[i * x_stride + j]) * (
133
+ (float(w[j * w_stride + k]) + 0.5f)
134
+ * __half2float(rx[k]) * __half2float(ry[j])
135
+ + __half2float(mx[k]) + __half2float(my[j])
136
+ );
137
+ }
138
+ y[i * y_stride + k] = __float2half(y_local);
139
+ }
140
+ }
141
+
142
+ template <>
143
+ void cuda_mm8_seq<fp16>(int B, int N, int M,
144
+ fp16 *x, int x_stride,
145
+ uint8_t *w, int w_stride,
146
+ fp16 *mx, fp16 *rx,
147
+ fp16 *my, fp16 *ry,
148
+ fp16 *y, int y_stride) {
149
+ dim3 blockSize(1, 128);
150
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
151
+ kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
152
+ B, N, M, cast(x), x_stride, w, w_stride,
153
+ cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
154
+ }
155
+
156
+ #define MM8_ONE_JSPLIT 24
157
+ #define MM8_ONE_TILE 1024
158
+
159
+ __global__ void kernel_mm_one_fp32i8(
160
+ const int N, const int M,
161
+ const float *__restrict__ const x,
162
+ const uint8_t *__restrict__ const w, const int w_stride,
163
+ const float *__restrict__ const mx,
164
+ const float *__restrict__ const rx,
165
+ const float *__restrict__ const my,
166
+ const float *__restrict__ const ry,
167
+ float *__restrict__ const y) {
168
+
169
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
170
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
171
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
172
+
173
+ if (k < M) {
174
+ float y_local = 0;
175
+ for (int j = j0; j < j1; ++j) {
176
+ y_local += x[j] * (
177
+ (float(w[j * w_stride + k]) + 0.5f)
178
+ * rx[k] * ry[j] + mx[k] + my[j]
179
+ );
180
+ }
181
+ atomicAdd(&y[k], y_local);
182
+ }
183
+ }
184
+
185
+ template <typename F>
186
+ void cuda_mm8_one(int N, int M,
187
+ F *x,
188
+ uint8_t *w, int w_stride,
189
+ F *mx, F *rx,
190
+ F *my, F *ry,
191
+ float *y);
192
+
193
+ template <>
194
+ void cuda_mm8_one<float>(int N, int M,
195
+ float *x,
196
+ uint8_t *w, int w_stride,
197
+ float *mx, float *rx,
198
+ float *my, float *ry,
199
+ float *y) {
200
+ dim3 blockSize(1, MM8_ONE_TILE);
201
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
202
+ kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
203
+ N, M, x, w, w_stride,
204
+ mx, rx, my, ry, y);
205
+ }
206
+
207
+ __global__ void kernel_mm_one_fp16i8(
208
+ const int N, const int M,
209
+ const __half *__restrict__ const x,
210
+ const uint8_t *__restrict__ const w, const int w_stride,
211
+ const __half *__restrict__ const mx,
212
+ const __half *__restrict__ const rx,
213
+ const __half *__restrict__ const my,
214
+ const __half *__restrict__ const ry,
215
+ float *__restrict__ const y) {
216
+
217
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
218
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
219
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
220
+
221
+ if (k < M) {
222
+ float y_local = 0;
223
+ for (int j = j0; j < j1; ++j) {
224
+ y_local += __half2float(x[j]) * (
225
+ (float(w[j * w_stride + k]) + 0.5f)
226
+ * __half2float(rx[k]) * __half2float(ry[j])
227
+ + __half2float(mx[k]) + __half2float(my[j])
228
+ );
229
+ }
230
+ atomicAdd(&y[k], y_local);
231
+ }
232
+ }
233
+
234
+ template <>
235
+ void cuda_mm8_one<fp16>(int N, int M,
236
+ fp16 *x,
237
+ uint8_t *w, int w_stride,
238
+ fp16 *mx, fp16 *rx,
239
+ fp16 *my, fp16 *ry,
240
+ float *y) {
241
+ dim3 blockSize(1, MM8_ONE_TILE);
242
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
243
+ kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
244
+ N, M, cast(x), w, w_stride,
245
+ cast(mx), cast(rx), cast(my), cast(ry), y);
246
+ }
cuda/rwkv5.cu ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _w += h*_N_;
17
+ _u += h*_N_;
18
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
19
+
20
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
21
+
22
+ float state[_N_];
23
+ #pragma unroll
24
+ for (int j = 0; j < _N_; j++)
25
+ state[j] = _state[j];
26
+
27
+ __syncthreads();
28
+ u[i] = float(_u[i]);
29
+ w[i] = _w[i];
30
+ __syncthreads();
31
+
32
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
33
+ {
34
+ __syncthreads();
35
+ r[i] = float(_r[t]);
36
+ k[i] = float(_k[t]);
37
+ __syncthreads();
38
+
39
+ const float v = float(_v[t]);
40
+ float y = 0;
41
+
42
+ #pragma unroll
43
+ for (int j = 0; j < _N_; j+=4)
44
+ {
45
+ const float4& r_ = (float4&)(r[j]);
46
+ const float4& k_ = (float4&)(k[j]);
47
+ const float4& w_ = (float4&)(w[j]);
48
+ const float4& u_ = (float4&)(u[j]);
49
+ float4& s = (float4&)(state[j]);
50
+ float4 x;
51
+
52
+ x.x = k_.x * v;
53
+ x.y = k_.y * v;
54
+ x.z = k_.z * v;
55
+ x.w = k_.w * v;
56
+
57
+ y += r_.x * (u_.x * x.x + s.x);
58
+ y += r_.y * (u_.y * x.y + s.y);
59
+ y += r_.z * (u_.z * x.z + s.z);
60
+ y += r_.w * (u_.w * x.w + s.w);
61
+
62
+ s.x = s.x * w_.x + x.x;
63
+ s.y = s.y * w_.y + x.y;
64
+ s.z = s.z * w_.z + x.z;
65
+ s.w = s.w * w_.w + x.w;
66
+ }
67
+ _y[t] = F(y);
68
+ }
69
+ #pragma unroll
70
+ for (int j = 0; j < _N_; j++)
71
+ _state[j] = state[j];
72
+ }
73
+
74
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
75
+ {
76
+ assert(H*_N_ == C);
77
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
78
+ }
79
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
80
+ {
81
+ assert(H*_N_ == C);
82
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
83
+ }
84
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
85
+ {
86
+ assert(H*_N_ == C);
87
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
88
+ }
cuda/rwkv5_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv5, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
cuda/rwkv6.cu ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _u += h*_N_;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20
+
21
+ float state[_N_];
22
+ #pragma unroll
23
+ for (int j = 0; j < _N_; j++)
24
+ state[j] = _state[j];
25
+
26
+ __syncthreads();
27
+ u[i] = float(_u[i]);
28
+ __syncthreads();
29
+
30
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
31
+ {
32
+ __syncthreads();
33
+ w[i] = _w[t];
34
+ r[i] = float(_r[t]);
35
+ k[i] = float(_k[t]);
36
+ __syncthreads();
37
+
38
+ const float v = float(_v[t]);
39
+ float y = 0;
40
+
41
+ #pragma unroll
42
+ for (int j = 0; j < _N_; j+=4)
43
+ {
44
+ const float4& r_ = (float4&)(r[j]);
45
+ const float4& k_ = (float4&)(k[j]);
46
+ const float4& w_ = (float4&)(w[j]);
47
+ const float4& u_ = (float4&)(u[j]);
48
+ float4& s = (float4&)(state[j]);
49
+ float4 x;
50
+
51
+ x.x = k_.x * v;
52
+ x.y = k_.y * v;
53
+ x.z = k_.z * v;
54
+ x.w = k_.w * v;
55
+
56
+ y += r_.x * (u_.x * x.x + s.x);
57
+ y += r_.y * (u_.y * x.y + s.y);
58
+ y += r_.z * (u_.z * x.z + s.z);
59
+ y += r_.w * (u_.w * x.w + s.w);
60
+
61
+ s.x = s.x * w_.x + x.x;
62
+ s.y = s.y * w_.y + x.y;
63
+ s.z = s.z * w_.z + x.z;
64
+ s.w = s.w * w_.w + x.w;
65
+ }
66
+ _y[t] = F(y);
67
+ }
68
+ #pragma unroll
69
+ for (int j = 0; j < _N_; j++)
70
+ _state[j] = state[j];
71
+ }
72
+
73
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
74
+ {
75
+ assert(H*_N_ == C);
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
77
+ }
78
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
79
+ {
80
+ assert(H*_N_ == C);
81
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
82
+ }
83
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
84
+ {
85
+ assert(H*_N_ == C);
86
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
87
+ }
cuda/rwkv6_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv6, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
cuda/rwkv7.cu ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+
5
+ typedef at::Half fp16;
6
+ typedef at::BFloat16 bf16;
7
+ typedef float fp32;
8
+
9
+ template <typename F>
10
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H,
11
+ float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
12
+ F *__restrict__ const _y)
13
+ {
14
+ const int e = blockIdx.x / H;
15
+ const int h = blockIdx.x % H;
16
+ const int i = threadIdx.x;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ float state[_N_];
20
+ #pragma unroll
21
+ for (int j = 0; j < _N_; j++)
22
+ state[j] = _state[j];
23
+
24
+ __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
25
+
26
+ for (int _t = 0; _t < T; _t++)
27
+ {
28
+ const int t = e*T*C + h*_N_ + i + _t * C;
29
+ __syncthreads();
30
+ r[i] = float(_r[t]);
31
+ w[i] = __expf(-__expf(float(_w[t])));
32
+ k[i] = float(_k[t]);
33
+ a[i] = float(_a[t]);
34
+ b[i] = float(_b[t]);
35
+ __syncthreads();
36
+
37
+ float sa = 0;
38
+ #pragma unroll
39
+ for (int j = 0; j < _N_; j++)
40
+ {
41
+ sa += a[j] * state[j];
42
+ }
43
+
44
+ float vv = float(_v[t]);
45
+ float y = 0;
46
+ #pragma unroll
47
+ for (int j = 0; j < _N_; j++)
48
+ {
49
+ float& s = state[j];
50
+ s = s * w[j] + k[j] * vv + sa * b[j];
51
+ y += s * r[j];
52
+ }
53
+ _y[t] = F(y);
54
+ }
55
+ #pragma unroll
56
+ for (int j = 0; j < _N_; j++)
57
+ _state[j] = state[j];
58
+ }
59
+
60
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
61
+ {
62
+ assert(H*_N_ == C);
63
+ assert(B == 1); // only for B=1
64
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
65
+ }
66
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
67
+ {
68
+ assert(H*_N_ == C);
69
+ assert(B == 1); // only for B=1
70
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
71
+ }
72
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
73
+ {
74
+ assert(H*_N_ == C);
75
+ assert(B == 1); // only for B=1
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
77
+ }
cuda/rwkv7_op.cpp ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+
4
+ typedef at::Half fp16;
5
+ typedef at::BFloat16 bf16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
13
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
14
+ }
15
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
16
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
17
+ }
18
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
19
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
20
+ }
21
+
22
+ TORCH_LIBRARY(wkv7s, m) {
23
+ m.def("forward_bf16", forward_bf16);
24
+ m.def("forward_fp16", forward_fp16);
25
+ m.def("forward_fp32", forward_fp32);
26
+ }
cuda/wrapper.cpp ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <iostream>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ typedef at::Half fp16;
7
+
8
+ template <typename F>
9
+ void cuda_wkv_forward(int B, int T, int C,
10
+ float *w, float *u, F *k, F *v, F *y,
11
+ float *aa, float *bb, float *pp);
12
+ template <typename F>
13
+ void cuda_mm8_seq(int B, int N, int M,
14
+ F *x, int x_stride,
15
+ uint8_t *w, int w_stride,
16
+ F *mx, F *rx,
17
+ F *my, F *ry,
18
+ F *y, int y_stride);
19
+ template <typename F>
20
+ void cuda_mm8_one(int N, int M,
21
+ F *x,
22
+ uint8_t *w, int w_stride,
23
+ F *mx, F *rx,
24
+ F *my, F *ry,
25
+ float *y);
26
+
27
+ void wkv_forward(int64_t B, int64_t T, int64_t C,
28
+ torch::Tensor &w, torch::Tensor &u,
29
+ torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
30
+ torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
32
+ switch (k.scalar_type()) {
33
+ case c10::ScalarType::Half:
34
+ cuda_wkv_forward(B, T, C,
35
+ w.data_ptr<float>(), u.data_ptr<float>(),
36
+ k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
37
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
38
+ break;
39
+ case c10::ScalarType::Float:
40
+ cuda_wkv_forward(B, T, C,
41
+ w.data_ptr<float>(), u.data_ptr<float>(),
42
+ k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
43
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
44
+ break;
45
+ default:
46
+ assert(false && "Only FP16 and FP32 are currently supported");
47
+ }
48
+ }
49
+
50
+ void mm8_seq(int64_t B, int64_t N, int64_t M,
51
+ torch::Tensor &x, torch::Tensor &w,
52
+ torch::Tensor &mx, torch::Tensor &rx,
53
+ torch::Tensor &my, torch::Tensor &ry,
54
+ torch::Tensor &y) {
55
+ assert(x.stride(1) == 1);
56
+ assert(w.stride(1) == 1);
57
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
58
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
59
+ assert(y.stride(1) == 1);
60
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
61
+ switch (x.scalar_type()) {
62
+ case c10::ScalarType::Half:
63
+ cuda_mm8_seq(
64
+ B, N, M,
65
+ x.data_ptr<fp16>(), x.stride(0),
66
+ w.data_ptr<uint8_t>(), w.stride(0),
67
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
68
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
69
+ y.data_ptr<fp16>(), y.stride(0));
70
+ break;
71
+ case c10::ScalarType::Float:
72
+ cuda_mm8_seq(
73
+ B, N, M,
74
+ x.data_ptr<float>(), x.stride(0),
75
+ w.data_ptr<uint8_t>(), w.stride(0),
76
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
77
+ my.data_ptr<float>(), ry.data_ptr<float>(),
78
+ y.data_ptr<float>(), y.stride(0));
79
+ break;
80
+ default:
81
+ assert(false && "Only FP16 and FP32 are currently supported");
82
+ }
83
+ }
84
+ void mm8_one(int64_t N, int64_t M,
85
+ torch::Tensor &x, torch::Tensor &w,
86
+ torch::Tensor &mx, torch::Tensor &rx,
87
+ torch::Tensor &my, torch::Tensor &ry,
88
+ torch::Tensor &y) {
89
+ assert(x.stride(0) == 1);
90
+ assert(w.stride(1) == 1);
91
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
92
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
93
+ assert(y.stride(0) == 1);
94
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
95
+ switch (x.scalar_type()) {
96
+ case c10::ScalarType::Half:
97
+ cuda_mm8_one(
98
+ N, M,
99
+ x.data_ptr<fp16>(),
100
+ w.data_ptr<uint8_t>(), w.stride(0),
101
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
102
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
103
+ y.data_ptr<float>());
104
+ break;
105
+ case c10::ScalarType::Float:
106
+ cuda_mm8_one(
107
+ N, M,
108
+ x.data_ptr<float>(),
109
+ w.data_ptr<uint8_t>(), w.stride(0),
110
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
111
+ my.data_ptr<float>(), ry.data_ptr<float>(),
112
+ y.data_ptr<float>());
113
+ break;
114
+ default:
115
+ assert(false && "Only FP16 and FP32 are currently supported");
116
+ }
117
+ }
118
+
119
+ using torch::Tensor;
120
+
121
+ #ifndef DISABLE_CUBLAS_GEMM
122
+ void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
123
+ #endif
124
+
125
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
126
+ m.def("wkv_forward", &wkv_forward, "wkv forward");
127
+ m.def("mm8_seq", &mm8_seq, "mm8 seq");
128
+ m.def("mm8_one", &mm8_one, "mm8 one");
129
+ #ifndef DISABLE_CUBLAS_GEMM
130
+ m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
131
+ #endif
132
+ }
133
+
134
+ TORCH_LIBRARY(rwkv, m) {
135
+ m.def("wkv_forward", wkv_forward);
136
+ m.def("mm8_seq", mm8_seq);
137
+ m.def("mm8_one", mm8_one);
138
+ #ifndef DISABLE_CUBLAS_GEMM
139
+ m.def("gemm_fp16_cublas", gemm_fp16_cublas);
140
+ #endif
141
+ }
download_models.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download model weights listed in a config YAML (replicates Dockerfile download behavior without Docker).
4
+
5
+ Usage:
6
+ python download_models.py --config config.production.yaml
7
+
8
+ This script uses huggingface_hub.hf_hub_download to download specified .pth files to the
9
+ model's DOWNLOAD_MODEL_DIR (or ./models by default).
10
+ """
11
+ import argparse
12
+ import os
13
+ import yaml
14
+ import time
15
+ from huggingface_hub import hf_hub_download
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--config", default="config.production.yaml")
21
+ parser.add_argument("--token", default=None, help="Hugging Face token (optional)")
22
+ args = parser.parse_args()
23
+
24
+ with open(args.config, "r", encoding="utf-8") as f:
25
+ cfg = yaml.safe_load(f.read())
26
+
27
+ models = cfg.get("MODELS", [])
28
+ if len(models) == 0:
29
+ print("No models found in config. Nothing to download.")
30
+ return
31
+
32
+ for m in models:
33
+ repo_id = m.get("DOWNLOAD_MODEL_REPO_ID")
34
+ filename = m.get("DOWNLOAD_MODEL_FILE_NAME")
35
+ local_dir = m.get("DOWNLOAD_MODEL_DIR", "./models")
36
+
37
+ if repo_id is None or filename is None:
38
+ print(f"Skipping model with incomplete download info: {m}")
39
+ continue
40
+
41
+ os.makedirs(local_dir, exist_ok=True)
42
+ print(f"Downloading {filename} from repo {repo_id} into {local_dir} ...")
43
+ os.makedirs(local_dir, exist_ok=True)
44
+
45
+ # Add retry logic
46
+ max_attempts = 5
47
+ for attempt in range(1, max_attempts + 1):
48
+ try:
49
+ path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir, token=args.token)
50
+ print(f"Downloaded file to {path}")
51
+ break
52
+ except Exception as e:
53
+ print(f"Attempt {attempt} failed to download {filename} from {repo_id}: {e}")
54
+ if attempt < max_attempts:
55
+ print(f"Retrying in {attempt*5} seconds...")
56
+ time.sleep(attempt * 5)
57
+ else:
58
+ print(f"Failed after {max_attempts} attempts. Skipping {filename}.")
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
models/.cache/huggingface/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
models/.cache/huggingface/download/rwkv7-g1a-0.1b-20250728-ctx4096.pth.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 8c8cdf8c605dc7dfdccb676b9d0c482ba002f710
2
+ 964f01cc4673273bbcf1b9c3cdc243d58af97bffeab51cb20c752eeaf048a3c6
3
+ 1763947179.4323187
models/rwkv7-g1a-0.1b-20250728-ctx4096.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:964f01cc4673273bbcf1b9c3cdc243d58af97bffeab51cb20c752eeaf048a3c6
3
+ size 382223868
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rwkv-hf-space"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "fastapi[standard]>=0.115.11",
9
+ "huggingface-hub>=0.29.1",
10
+ "loguru>=0.7.3",
11
+ "ninja>=1.11.1.3",
12
+ "numpy>=2.2.3",
13
+ "pydantic>=2.10.6",
14
+ "pydantic-settings>=2.8.1",
15
+ "pynvml>=12.0.0",
16
+ "rich>=13.9.4",
17
+ "rwkv>=0.8.30",
18
+ "setuptools>=75.8.2",
19
+ "snowflake-id>=1.0.2",
20
+ "modelscope>=1.23.0",
21
+ "transformers",
22
+ ]
23
+
24
+ [project.optional-dependencies]
25
+ cpu = ["torch>=2.6.0"]
26
+ cu124 = ["torch>=2.6.0"]
27
+
28
+ [tool.uv]
29
+ conflicts = [[{ extra = "cpu" }, { extra = "cu124" }, { extra = "cu113" }]]
30
+
31
+ [tool.uv.sources]
32
+ torch = [
33
+ { index = "pytorch-cpu", extra = "cpu" },
34
+ { index = "pytorch-cu124", extra = "cu124" },
35
+ ]
36
+
37
+ [[tool.uv.index]]
38
+ name = "pytorch-cpu"
39
+ url = "https://download.pytorch.org/whl/cpu"
40
+ explicit = true
41
+
42
+ [[tool.uv.index]]
43
+ name = "pytorch-cu124"
44
+ url = "https://download.pytorch.org/whl/cu124"
45
+ explicit = true
run_windows.ps1 ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Param(
2
+ [string]$CONFIG_FILE = 'config.production.yaml'
3
+ )
4
+
5
+ if (-not (Test-Path .\.venv\Scripts\Activate.ps1)) {
6
+ Write-Host "Virtualenv not found. Run setup_windows.ps1 first." -ForegroundColor Red
7
+ exit 1
8
+ }
9
+
10
+ .\.venv\Scripts\Activate.ps1
11
+ $env:CONFIG_FILE=$CONFIG_FILE
12
+
13
+ Write-Host "Starting the RWKV FastAPI app using $CONFIG_FILE..." -ForegroundColor Green
14
+ python app.py
setup_windows.ps1 ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Param(
2
+ [switch]$gpu,
3
+ [string]$CONFIG_FILE = 'config.production.yaml',
4
+ [switch]$buildFrontend
5
+ )
6
+
7
+ Write-Host "Starting RWKV local setup for Windows..." -ForegroundColor Green
8
+
9
+ if (-not (Get-Command python -ErrorAction SilentlyContinue)) {
10
+ Write-Host "Python not found. Please install Python 3.10+ and add it to PATH." -ForegroundColor Red
11
+ exit 1
12
+ }
13
+
14
+ Write-Host "Creating virtual environment (./.venv) ..."
15
+ python -m venv .venv
16
+ .\.venv\Scripts\Activate.ps1
17
+ pip install --upgrade pip setuptools wheel
18
+
19
+ if ($gpu) {
20
+ Write-Host "GPU support requested. Installing GPU dependencies (cu124) ..." -ForegroundColor Yellow
21
+ pip install -e .[cu124]
22
+ } else {
23
+ Write-Host "Installing CPU-only dependencies ..." -ForegroundColor Yellow
24
+ pip install -e .[cpu]
25
+ }
26
+
27
+ Write-Host "Installing extra tooling (huggingface_hub for downloads) ..."
28
+ pip install huggingface-hub
29
+ pip install beautifulsoup4
30
+
31
+ Write-Host "Downloading models from config: $CONFIG_FILE" -ForegroundColor Green
32
+ python .\download_models.py --config $CONFIG_FILE --token $env:HF_TOKEN
33
+
34
+ # Ensure the config file is available as config.local.yaml that the app's default reads
35
+ if (Test-Path $CONFIG_FILE) {
36
+ Write-Host "Copying $CONFIG_FILE to config.local.yaml so the application uses it by default..." -ForegroundColor Green
37
+ Copy-Item -Force $CONFIG_FILE config.local.yaml
38
+
39
+ # If GPU not requested, set STRATEGY to CPU in config.local.yaml
40
+ if (-not $gpu) {
41
+ Write-Host "GPU not requested. Setting STRATEGY to 'cpu fp16' in config.local.yaml..." -ForegroundColor Yellow
42
+ try {
43
+ $yaml = (Get-Content config.local.yaml -Raw)
44
+ # Replace the STRATEGY line with CPU + fp16 precision
45
+ $yaml = $yaml -replace '(?m)^STRATEGY:.*$', 'STRATEGY: "cpu fp16"'
46
+ \r\"]*\"?", "STRATEGY: \"cpu\""
47
+ $yaml | Out-File -Encoding utf8 config.local.yaml -Force
48
+ } catch {
49
+ Write-Host "Warning: failed to modify config.local.yaml; please set STRATEGY manually" -ForegroundColor Yellow
50
+ }
51
+ }
52
+ }
53
+
54
+ if ($buildFrontend) {
55
+ if (-not (Get-Command pnpm -ErrorAction SilentlyContinue)) {
56
+ Write-Host "pnpm not found. Installing pnpm globally via npm..." -ForegroundColor Yellow
57
+ npm install -g pnpm
58
+ }
59
+
60
+ if (-not (Test-Path .\web-frontend)) {
61
+ Write-Host "Cloning web frontend repo..."
62
+ git clone https://github.com/SolomonLeon/web-rwkv-realweb.git web-frontend
63
+ }
64
+
65
+ Push-Location web-frontend
66
+ pnpm install
67
+ if ($env:MODELSCOPE_ENVIRONMENT -eq "studio") {
68
+ pnpm run build --mode target-rwkv-modelscope-space
69
+ } else {
70
+ pnpm run build --mode target-rwkv-hf-space
71
+ }
72
+ Pop-Location
73
+
74
+ # Copy dist to the project's dist-frontend
75
+ if (Test-Path .\web-frontend\dist) {
76
+ Remove-Item -Recurse -Force .\dist-frontend -ErrorAction SilentlyContinue
77
+ Copy-Item -Recurse .\web-frontend\dist .\dist-frontend
78
+ Write-Host "Frontend built and copied to ./dist-frontend" -ForegroundColor Green
79
+ }
80
+ }
81
+
82
+ Write-Host "Setup complete. Run the app with: \n$env:CONFIG_FILE='$CONFIG_FILE'\npython app.py" -ForegroundColor Cyan
tests/api_test.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, json, time
2
+
3
+ BASE = 'http://127.0.0.1:7860/api/v1/chat/completions'
4
+
5
+ headers = {'Content-Type': 'application/json'}
6
+
7
+ print('Non-streaming example')
8
+ payload = {
9
+ 'model': 'rwkv-latest',
10
+ 'prompt': 'Who is the president of France today?',
11
+ 'stream': False,
12
+ 'max_tokens': 64,
13
+ 'temperature': 0.2,
14
+ 'include_usage': True,
15
+ }
16
+ try:
17
+ r = requests.post(BASE, json=payload, timeout=120)
18
+ print('Status', r.status_code)
19
+ try:
20
+ print(json.dumps(r.json(), indent=2))
21
+ except Exception:
22
+ print('Non-JSON response:', r.text[:1000])
23
+ except Exception as e:
24
+ print('Error in non-stream request:', e)
25
+
26
+ print('\nTools: calc example')
27
+ payload = {
28
+ 'model': 'rwkv-latest',
29
+ 'prompt': 'Calculate 2+3*4 and explain the result.',
30
+ 'stream': False,
31
+ 'tools': [{'name': 'calc', 'args': {'expression': '2+3*4'}}],
32
+ 'include_usage': True,
33
+ }
34
+ try:
35
+ r = requests.post(BASE, json=payload, timeout=120)
36
+ print('Status', r.status_code)
37
+ try:
38
+ print(json.dumps(r.json(), indent=2))
39
+ except Exception:
40
+ print('Non-JSON response:', r.text[:1000])
41
+ except Exception as e:
42
+ print('Error in calc tool request:', e)
43
+
44
+ print('\nTools: web_search example')
45
+ payload = {
46
+ 'model': 'rwkv-latest',
47
+ 'prompt': 'Who is the current president of France?',
48
+ 'stream': False,
49
+ 'web_search': True,
50
+ 'search_top_k': 2,
51
+ 'include_usage': True,
52
+ }
53
+ try:
54
+ r = requests.post(BASE, json=payload, timeout=120)
55
+ print('Status', r.status_code)
56
+ try:
57
+ print(json.dumps(r.json(), indent=2))
58
+ except Exception:
59
+ print('Non-JSON response:', r.text[:1000])
60
+ except Exception as e:
61
+ print('Error in web_search request:', e)
62
+
63
+ print('\nStreaming example (short)')
64
+ payload = {
65
+ 'model': 'rwkv-latest:thinking',
66
+ 'messages': [{'role': 'user', 'content': 'Explain Newton\'s first law in one sentence.'}],
67
+ 'stream': True,
68
+ 'max_tokens': 64,
69
+ 'temperature': 0.2,
70
+ }
71
+
72
+ try:
73
+ r = requests.post(BASE, json=payload, headers=headers, stream=True, timeout=120)
74
+ print('Status', r.status_code)
75
+ if r.status_code == 200:
76
+ for line in r.iter_lines(decode_unicode=True):
77
+ if not line:
78
+ continue
79
+ print('SSE:', line)
80
+ if line.strip().endswith('[DONE]'):
81
+ break
82
+ except Exception as e:
83
+ print('Error in streaming request:', e)
84
+
85
+ print('\nDone tests')
tests/run_local_exec.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ os.environ['MODELSCOPE_ENVIRONMENT'] = ''
5
+
6
+ from app import chat_completions, ChatCompletionRequest
7
+
8
+ async def run_once():
9
+ req = ChatCompletionRequest(model='rwkv-latest', prompt='Who is the president of France today?', stream=False, max_tokens=32, temperature=0.2, include_usage=True)
10
+ res = await chat_completions(req)
11
+ print(res)
12
+
13
+ if __name__ == '__main__':
14
+ asyncio.run(run_once())
utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, os, threading, queue, requests
2
+ from typing import List, Optional, Union
3
+ from pydantic import BaseModel, Field
4
+ from pydantic_settings import BaseSettings
5
+
6
+ from api_types import ChatMessage
7
+
8
+
9
+ def parse_think_response(full_response: str):
10
+ think_start = full_response.find("<think")
11
+ if think_start == -1:
12
+ return None, full_response.strip()
13
+
14
+ think_end = full_response.find("</think>")
15
+ if think_end == -1: # 未闭合的情况
16
+ reasoning = full_response[think_start:].strip()
17
+ content = ""
18
+ else:
19
+ reasoning = full_response[think_start : think_end + 9].strip() # +9包含完整标签
20
+ content = full_response[think_end + 9 :].strip()
21
+
22
+ # 清理标签保留内容
23
+ reasoning_content = reasoning.replace("<think", "").replace("</think>", "").strip()
24
+ return reasoning_content, content
25
+
26
+
27
+ def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = False):
28
+ promptStrList = []
29
+
30
+ for message in messages:
31
+ content = message.content.strip()
32
+ content = re.sub(r"\n+", "\n", content)
33
+ promptStrList.append(
34
+ f"{message.role.strip().lower().capitalize()}: {content if message.role.strip().lower().capitalize()!='Assistant' or not removeThinkingContent else remove_nested_think_tags_stack(content)}"
35
+ )
36
+
37
+ return "\n\n".join(promptStrList)
38
+
39
+
40
+ def remove_nested_think_tags_stack(text):
41
+ stack = []
42
+ result = ""
43
+ i = 0
44
+ while i < len(text):
45
+ if text[i : i + 7] == "<think>":
46
+ stack.append("<think>")
47
+ i += 7
48
+ elif text[i : i + 8] == "</think>":
49
+ if stack and stack[-1] == "<think>":
50
+ stack.pop()
51
+ i += 8
52
+ else:
53
+ result += text[i : i + 8]
54
+ i += 8
55
+ elif not stack:
56
+ result += text[i]
57
+ i += 1
58
+ else:
59
+ i += 1
60
+ return result
61
+
62
+
63
+ def format_bytes(size):
64
+ power = 2**10
65
+ n = 0
66
+ power_labels = {0: "", 1: "K", 2: "M", 3: "G", 4: "T"}
67
+ while size > power:
68
+ size /= power
69
+ n += 1
70
+ return f"{size:.4f}{power_labels[n]+'B'}"
71
+
72
+
73
+ LOGGER_QUEUE = queue.Queue(5)
74
+
75
+
76
+ def logger():
77
+ print("enable")
78
+ while True:
79
+ item = LOGGER_QUEUE.get()
80
+ try:
81
+ requests.post(
82
+ os.environ.get("LOG_PORT"),
83
+ headers={"Content-Type": "application/json"},
84
+ json=item,
85
+ )
86
+ except Exception:
87
+ pass
88
+
89
+
90
+ if os.environ.get("LOG_PORT"):
91
+ threading.Thread(target=logger).start()
92
+
93
+
94
+ def log(item):
95
+ LOGGER_QUEUE.put_nowait(item)
96
+
97
+
98
+ def web_search(query: str, top_k: int = 3) -> str:
99
+ """Perform a simple web search via DuckDuckGo HTML and return top_k results as a combined string.
100
+
101
+ This is a lightweight fallback search that does not call external model services —
102
+ it queries a public search endpoint, parses titles/snippets/urls and returns them as
103
+ formatted text to be included into the model's prompt context.
104
+ """
105
+ if not query or query.strip() == "":
106
+ return ""
107
+ try:
108
+ from bs4 import BeautifulSoup
109
+ except Exception:
110
+ return ""
111
+ try:
112
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
113
+ q = query.strip()
114
+ resp = requests.get("https://duckduckgo.com/html/", params={"q": q}, headers=headers, timeout=10)
115
+ soup = BeautifulSoup(resp.text, "html.parser")
116
+ # DuckDuckGo's html structure: results are in `div.result` containers.
117
+ results = []
118
+ for r in soup.find_all("div", class_="result", limit=top_k):
119
+ a = r.find("a", class_="result__a") or r.find("a", href=True)
120
+ title = a.get_text(strip=True) if a else ""
121
+ href = a.get("href") if a else ""
122
+ snippet = ""
123
+ s = r.find("a", class_="result__snippet") or r.find("div", class_="result__snippet")
124
+ if s:
125
+ snippet = s.get_text(strip=True)
126
+ results.append(f"{title} - {snippet} - {href}")
127
+ return "\n".join(results)
128
+ except Exception:
129
+ return ""
130
+
131
+
132
+ def calc(expr: str) -> str:
133
+ """Safely evaluate a simple arithmetic expression and return the result as string.
134
+
135
+ This uses ast parsing to disallow attributes and only permit arithmetic operators.
136
+ """
137
+ try:
138
+ import ast, operator as op
139
+
140
+ # supported operators
141
+ allowed_ops = {
142
+ ast.Add: op.add,
143
+ ast.Sub: op.sub,
144
+ ast.Mult: op.mul,
145
+ ast.Div: op.truediv,
146
+ ast.Pow: op.pow,
147
+ ast.BitXor: op.xor,
148
+ ast.USub: op.neg,
149
+ ast.Mod: op.mod,
150
+ ast.FloorDiv: op.floordiv,
151
+ }
152
+
153
+ def _eval(node):
154
+ if isinstance(node, ast.Num): # <number>
155
+ return node.n
156
+ elif isinstance(node, ast.BinOp):
157
+ left = _eval(node.left)
158
+ right = _eval(node.right)
159
+ op_type = type(node.op)
160
+ if op_type in allowed_ops:
161
+ return allowed_ops[op_type](left, right)
162
+ else:
163
+ raise ValueError("Unsupported operator")
164
+ elif isinstance(node, ast.UnaryOp):
165
+ operand = _eval(node.operand)
166
+ op_type = type(node.op)
167
+ if op_type in allowed_ops:
168
+ return allowed_ops[op_type](operand)
169
+ raise ValueError("Unsupported unary op")
170
+ else:
171
+ raise ValueError("Unsupported expression type")
172
+
173
+ node = ast.parse(expr, mode='eval')
174
+ result = _eval(node.body)
175
+ return str(result)
176
+ except Exception as e:
177
+ return f"ERROR: {e}"
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
verify_setup.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple local verification script to ensure the environment is prepared and the model is downloaded.
3
+ """
4
+ import os
5
+ import sys
6
+
7
+ def check_venv_python():
8
+ if sys.prefix == sys.base_prefix:
9
+ print("Not in a virtual environment; consider activating .venv")
10
+ else:
11
+ print(f"Virtualenv detected: {sys.prefix}")
12
+
13
+ def check_models_dir():
14
+ models_dir = "models"
15
+ if not os.path.exists(models_dir):
16
+ print("Models directory not found. Run download_models.py first.")
17
+ return False
18
+ files = [f for f in os.listdir(models_dir) if f.endswith('.pth')]
19
+ if not files:
20
+ print("No .pth files found in ./models. Run download_models.py to fetch model weights.")
21
+ return False
22
+ print(f"Found model files: {files}")
23
+ return True
24
+
25
+ def check_dependencies():
26
+ try:
27
+ import importlib
28
+ packages = [
29
+ 'fastapi', 'uvicorn', 'rwkv', 'huggingface_hub', 'pydantic', 'loguru'
30
+ ]
31
+ missing = []
32
+ for p in packages:
33
+ if importlib.util.find_spec(p) is None:
34
+ missing.append(p)
35
+ if missing:
36
+ print(f"Missing packages: {missing}; install them in your venv")
37
+ return False
38
+ print("All key dependencies found.")
39
+ return True
40
+ except Exception as e:
41
+ print(f"Dependency check failed: {e}")
42
+ return False
43
+
44
+ def main():
45
+ check_venv_python()
46
+ deps_ok = check_dependencies()
47
+ models_ok = check_models_dir()
48
+ if deps_ok and models_ok:
49
+ print("Environment appears configured. You can run: python app.py")
50
+ else:
51
+ print("Fix missing items and re-run verification.")
52
+
53
+ if __name__ == '__main__':
54
+ main()