File size: 4,677 Bytes
6268841 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | # GLM-4.6V / GLM-4.5V Usage
## Launch commands for SGLang
Below are suggested launch commands tailored for different hardware / precision modes
### FP8 (quantised) mode
For high memory-efficiency and latency optimized deployments (e.g., on H100, H200) where FP8 checkpoint is supported:
```bash
python3 -m sglang.launch_server \
--model-path zai-org/GLM-4.6V-FP8 \
--tp 2 \
--ep 2 \
--host 0.0.0.0 \
--port 30000 \
--keep-mm-feature-on-device
```
### Non-FP8 (BF16 / full precision) mode
For deployments on A100/H100 where BF16 is used (or FP8 snapshot not used):
```bash
python3 -m sglang.launch_server \
--model-path zai-org/GLM-4.6V \
--tp 4 \
--ep 4 \
--host 0.0.0.0 \
--port 30000
```
## Hardware-specific notes / recommendations
- On H100 with FP8: Use the FP8 checkpoint for best memory efficiency.
- On A100 / H100 with BF16 (non-FP8): It’s recommended to use `--mm-max-concurrent-calls` to control parallel throughput and GPU memory usage during image/video inference.
- On H200 & B200: The model can be run “out of the box”, supporting full context length plus concurrent image + video processing.
## Sending Image/Video Requests
### Image input:
```python
import requests
url = f"http://localhost:30000/v1/chat/completions"
data = {
"model": "zai-org/GLM-4.6V",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(url, json=data)
print(response.text)
```
### Video Input:
```python
import requests
url = f"http://localhost:30000/v1/chat/completions"
data = {
"model": "zai-org/GLM-4.6V",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s happening in this video?"},
{
"type": "video_url",
"video_url": {
"url": "https://github.com/sgl-project/sgl-test-files/raw/refs/heads/main/videos/jobs_presenting_ipod.mp4"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(url, json=data)
print(response.text)
```
## Important Server Parameters and Flags
When launching the model server for **multimodal support**, you can use the following command-line arguments to fine-tune performance and behavior:
- `--mm-attention-backend`: Specify multimodal attention backend. Eg. `fa3`(Flash Attention 3)
- `--mm-max-concurrent-calls <value>`: Specifies the **maximum number of concurrent asynchronous multimodal data processing calls** allowed on the server. Use this to control parallel throughput and GPU memory usage during image/video inference.
- `--mm-per-request-timeout <seconds>`: Defines the **timeout duration (in seconds)** for each multimodal request. If a request exceeds this time limit (e.g., for very large video inputs), it will be automatically terminated.
- `--keep-mm-feature-on-device`: Instructs the server to **retain multimodal feature tensors on the GPU** after processing. This avoids device-to-host (D2H) memory copies and improves performance for repeated or high-frequency inference workloads.
- `--mm-enable-dp-encoder`: Placing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput.
- `SGLANG_USE_CUDA_IPC_TRANSPORT=1`: Shared memory pool based CUDA IPC for multi-modal data transport. For significantly improving e2e latency.
### Example usage with the above optimizations:
```bash
SGLANG_USE_CUDA_IPC_TRANSPORT=1 \
SGLANG_VLM_CACHE_SIZE_MB=0 \
python -m sglang.launch_server \
--model-path zai-org/GLM-4.6V \
--host 0.0.0.0 \
--port 30000 \
--trust-remote-code \
--tp-size 8 \
--enable-cache-report \
--log-level info \
--max-running-requests 64 \
--mem-fraction-static 0.65 \
--chunked-prefill-size 8192 \
--attention-backend fa3 \
--mm-attention-backend fa3 \
--mm-enable-dp-encoder \
--enable-metrics
```
### Thinking Budget for GLM-4.5V / GLM-4.6V
In SGLang, we can implement thinking budget with `CustomLogitProcessor`.
Launch a server with the `--enable-custom-logit-processor` flag. Then, use `Glm4MoeThinkingBudgetLogitProcessor` in the request, similar to the `GLM-4.6` example in [glm45.md](./glm45.md).
|