Spaces:
Sleeping
Sleeping
Commit
·
adcb9bd
0
Parent(s):
Initial implementation of ZeroGPU OpenCode Provider
Browse files- OpenAI-compatible /v1/chat/completions endpoint
- Pass-through model selection (any HF model ID)
- ZeroGPU H200 inference with automatic fallback to HF Serverless
- HF Token authentication required
- SSE streaming support
- Automatic INT4 quantization for 70B+ models
- .env.template +22 -0
- .gitignore +60 -0
- CLAUDE.md +186 -0
- README.md +171 -0
- app.py +524 -0
- config.py +159 -0
- models.py +335 -0
- openai_compat.py +269 -0
- requirements.txt +32 -0
- tests/__init__.py +1 -0
- tests/conftest.py +116 -0
- tests/test_models.py +150 -0
- tests/test_openai_compat.py +263 -0
.env.template
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace ZeroGPU Space - Environment Variables
|
| 2 |
+
# Copy to .env and fill in values
|
| 3 |
+
|
| 4 |
+
# HuggingFace Token (for gated models access)
|
| 5 |
+
# When deployed to HF Space, the Space's own token is used automatically
|
| 6 |
+
# This is mainly for local development with gated models
|
| 7 |
+
HF_TOKEN=
|
| 8 |
+
|
| 9 |
+
# Fallback Configuration
|
| 10 |
+
# Enable HF Serverless Inference API fallback when ZeroGPU quota exhausted
|
| 11 |
+
FALLBACK_ENABLED=true
|
| 12 |
+
|
| 13 |
+
# Logging
|
| 14 |
+
LOG_LEVEL=INFO
|
| 15 |
+
|
| 16 |
+
# Model Loading
|
| 17 |
+
# Default quantization for large models (none, int8, int4)
|
| 18 |
+
DEFAULT_QUANTIZATION=none
|
| 19 |
+
|
| 20 |
+
# Maximum model size to load without quantization (in billions of parameters)
|
| 21 |
+
# Models larger than this will automatically use INT4 quantization
|
| 22 |
+
AUTO_QUANTIZE_THRESHOLD_B=34
|
.gitignore
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.env
|
| 25 |
+
.venv
|
| 26 |
+
env/
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
|
| 30 |
+
# IDE
|
| 31 |
+
.idea/
|
| 32 |
+
.vscode/
|
| 33 |
+
*.swp
|
| 34 |
+
*.swo
|
| 35 |
+
*~
|
| 36 |
+
|
| 37 |
+
# Testing
|
| 38 |
+
.pytest_cache/
|
| 39 |
+
.coverage
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
|
| 44 |
+
# Gradio
|
| 45 |
+
flagged/
|
| 46 |
+
|
| 47 |
+
# HuggingFace
|
| 48 |
+
*.bin
|
| 49 |
+
*.safetensors
|
| 50 |
+
*.pt
|
| 51 |
+
*.pth
|
| 52 |
+
.cache/
|
| 53 |
+
|
| 54 |
+
# Logs
|
| 55 |
+
*.log
|
| 56 |
+
logs/
|
| 57 |
+
|
| 58 |
+
# OS
|
| 59 |
+
.DS_Store
|
| 60 |
+
Thumbs.db
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
HuggingFace ZeroGPU Space serving as an OpenAI-compatible inference provider for opencode. Deployed at `serenichron/opencode-zerogpu`.
|
| 8 |
+
|
| 9 |
+
**Key Features:**
|
| 10 |
+
- OpenAI-compatible `/v1/chat/completions` endpoint
|
| 11 |
+
- Pass-through model selection (any HF model ID)
|
| 12 |
+
- ZeroGPU H200 inference with HF Serverless fallback
|
| 13 |
+
- HF Token authentication required
|
| 14 |
+
- SSE streaming support
|
| 15 |
+
|
| 16 |
+
## Architecture
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
┌─────────────┐ ┌──────────────────────────────────────────────┐
|
| 20 |
+
│ opencode │────▶│ serenichron/opencode-zerogpu (HF Space) │
|
| 21 |
+
│ (client) │ │ │
|
| 22 |
+
└─────────────┘ │ ┌────────────────────────────────────────┐ │
|
| 23 |
+
│ │ app.py (Gradio + FastAPI mount) │ │
|
| 24 |
+
│ │ └─ /v1/chat/completions │ │
|
| 25 |
+
│ │ └─ auth_middleware (HF token) │ │
|
| 26 |
+
│ │ └─ inference_router │ │
|
| 27 |
+
│ │ ├─ ZeroGPU (@spaces.GPU) │ │
|
| 28 |
+
│ │ └─ HF Serverless (fallback) │ │
|
| 29 |
+
│ └────────────────────────────────────────┘ │
|
| 30 |
+
│ │
|
| 31 |
+
│ ┌──────────────┐ ┌───────────────────────┐ │
|
| 32 |
+
│ │ models.py │ │ openai_compat.py │ │
|
| 33 |
+
│ │ - load/unload│ │ - request/response │ │
|
| 34 |
+
│ │ - quantize │ │ - streaming format │ │
|
| 35 |
+
│ └──────────────┘ └───────────────────────┘ │
|
| 36 |
+
└──────────────────────────────────────────────┘
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Development Commands
|
| 40 |
+
|
| 41 |
+
### Local Development (CPU/Mock Mode)
|
| 42 |
+
```bash
|
| 43 |
+
# Install dependencies
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
|
| 46 |
+
# Run locally (ZeroGPU decorator no-ops)
|
| 47 |
+
python app.py
|
| 48 |
+
|
| 49 |
+
# Run with specific port
|
| 50 |
+
gradio app.py --server-port 7860
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Testing
|
| 54 |
+
```bash
|
| 55 |
+
# Run all tests
|
| 56 |
+
pytest tests/ -v
|
| 57 |
+
|
| 58 |
+
# Run specific test file
|
| 59 |
+
pytest tests/test_openai_compat.py -v
|
| 60 |
+
|
| 61 |
+
# Run with coverage
|
| 62 |
+
pytest tests/ --cov=. --cov-report=term-missing
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### API Testing
|
| 66 |
+
```bash
|
| 67 |
+
# Test chat completions endpoint
|
| 68 |
+
curl -X POST http://localhost:7860/v1/chat/completions \
|
| 69 |
+
-H "Content-Type: application/json" \
|
| 70 |
+
-H "Authorization: Bearer $HF_TOKEN" \
|
| 71 |
+
-d '{
|
| 72 |
+
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
| 73 |
+
"messages": [{"role": "user", "content": "Hello!"}],
|
| 74 |
+
"stream": true
|
| 75 |
+
}'
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Deployment
|
| 79 |
+
```bash
|
| 80 |
+
# Push to HuggingFace Space (after git remote setup)
|
| 81 |
+
git push hf main
|
| 82 |
+
|
| 83 |
+
# Or use HF CLI
|
| 84 |
+
huggingface-cli upload serenichron/opencode-zerogpu . --repo-type space
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## Key Files
|
| 88 |
+
|
| 89 |
+
| File | Purpose |
|
| 90 |
+
|------|---------|
|
| 91 |
+
| `app.py` | Main Gradio app with FastAPI mount for OpenAI endpoints |
|
| 92 |
+
| `models.py` | Model loading, unloading, quantization, caching |
|
| 93 |
+
| `openai_compat.py` | OpenAI request/response format conversion |
|
| 94 |
+
| `config.py` | Environment variables, settings, quota tracking |
|
| 95 |
+
| `README.md` | HF Space config (YAML frontmatter) + documentation |
|
| 96 |
+
|
| 97 |
+
## ZeroGPU Patterns
|
| 98 |
+
|
| 99 |
+
### GPU Decorator Usage
|
| 100 |
+
```python
|
| 101 |
+
import spaces
|
| 102 |
+
|
| 103 |
+
# Standard inference (60s default)
|
| 104 |
+
@spaces.GPU
|
| 105 |
+
def generate(prompt, model_id):
|
| 106 |
+
...
|
| 107 |
+
|
| 108 |
+
# Extended duration for large models
|
| 109 |
+
@spaces.GPU(duration=120)
|
| 110 |
+
def generate_large(prompt, model_id):
|
| 111 |
+
...
|
| 112 |
+
|
| 113 |
+
# Dynamic duration based on input
|
| 114 |
+
def calc_duration(prompt, max_tokens):
|
| 115 |
+
return min(120, max_tokens // 10)
|
| 116 |
+
|
| 117 |
+
@spaces.GPU(duration=calc_duration)
|
| 118 |
+
def generate_dynamic(prompt, max_tokens):
|
| 119 |
+
...
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Model Loading Pattern
|
| 123 |
+
```python
|
| 124 |
+
import gc
|
| 125 |
+
import torch
|
| 126 |
+
|
| 127 |
+
current_model = None
|
| 128 |
+
current_model_id = None
|
| 129 |
+
|
| 130 |
+
@spaces.GPU
|
| 131 |
+
def load_and_generate(model_id, prompt):
|
| 132 |
+
global current_model, current_model_id
|
| 133 |
+
|
| 134 |
+
if model_id != current_model_id:
|
| 135 |
+
# Cleanup previous model
|
| 136 |
+
if current_model:
|
| 137 |
+
del current_model
|
| 138 |
+
gc.collect()
|
| 139 |
+
torch.cuda.empty_cache()
|
| 140 |
+
|
| 141 |
+
# Load new model
|
| 142 |
+
current_model = AutoModelForCausalLM.from_pretrained(
|
| 143 |
+
model_id,
|
| 144 |
+
torch_dtype=torch.bfloat16,
|
| 145 |
+
device_map="auto"
|
| 146 |
+
)
|
| 147 |
+
current_model_id = model_id
|
| 148 |
+
|
| 149 |
+
return generate(current_model, prompt)
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
## Important Constraints
|
| 153 |
+
|
| 154 |
+
1. **ZeroGPU Compatibility**
|
| 155 |
+
- `torch.compile` NOT supported - use PyTorch AoT instead
|
| 156 |
+
- Gradio SDK only (no Streamlit)
|
| 157 |
+
- GPU allocated only during `@spaces.GPU` decorated functions
|
| 158 |
+
|
| 159 |
+
2. **Memory Management**
|
| 160 |
+
- H200 provides ~70GB VRAM
|
| 161 |
+
- 70B models require INT4 quantization
|
| 162 |
+
- Always cleanup with `gc.collect()` and `torch.cuda.empty_cache()`
|
| 163 |
+
|
| 164 |
+
3. **Quota Awareness**
|
| 165 |
+
- PRO plan: 25 min/day H200 compute
|
| 166 |
+
- Track usage, fall back to HF Serverless when exhausted
|
| 167 |
+
- Shorter `duration` = higher queue priority
|
| 168 |
+
|
| 169 |
+
4. **Authentication**
|
| 170 |
+
- All API requests require `Authorization: Bearer hf_...` header
|
| 171 |
+
- Validate tokens via HuggingFace Hub API
|
| 172 |
+
|
| 173 |
+
## Environment Variables
|
| 174 |
+
|
| 175 |
+
| Variable | Required | Description |
|
| 176 |
+
|----------|----------|-------------|
|
| 177 |
+
| `HF_TOKEN` | No* | Token for accessing gated models (* Space has its own token) |
|
| 178 |
+
| `FALLBACK_ENABLED` | No | Enable HF Serverless fallback (default: true) |
|
| 179 |
+
| `LOG_LEVEL` | No | Logging verbosity (default: INFO) |
|
| 180 |
+
|
| 181 |
+
## Testing Strategy
|
| 182 |
+
|
| 183 |
+
1. **Unit Tests**: Model loading, OpenAI format conversion
|
| 184 |
+
2. **Integration Tests**: Full API request/response cycle
|
| 185 |
+
3. **Local Testing**: CPU-only mode (decorator no-ops)
|
| 186 |
+
4. **Live Testing**: Deploy to Space, test via opencode
|
README.md
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: OpenCode ZeroGPU Provider
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
hardware: zero-a10g
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# OpenCode ZeroGPU Provider
|
| 15 |
+
|
| 16 |
+
OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode), powered by HuggingFace ZeroGPU (NVIDIA H200).
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **OpenAI-compatible API** - Drop-in replacement for OpenAI endpoints
|
| 21 |
+
- **Pass-through model selection** - Use any HuggingFace model ID
|
| 22 |
+
- **ZeroGPU H200 inference** - 25 min/day of H200 GPU compute (PRO plan)
|
| 23 |
+
- **Automatic fallback** - Falls back to HF Serverless when quota exhausted
|
| 24 |
+
- **SSE streaming** - Real-time token streaming support
|
| 25 |
+
- **Authentication** - Requires valid HuggingFace token
|
| 26 |
+
|
| 27 |
+
## API Endpoint
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
POST /v1/chat/completions
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Request Format
|
| 34 |
+
|
| 35 |
+
```json
|
| 36 |
+
{
|
| 37 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
| 38 |
+
"messages": [
|
| 39 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 40 |
+
{"role": "user", "content": "Hello!"}
|
| 41 |
+
],
|
| 42 |
+
"temperature": 0.7,
|
| 43 |
+
"max_tokens": 512,
|
| 44 |
+
"stream": true
|
| 45 |
+
}
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Headers
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
Authorization: Bearer hf_YOUR_TOKEN
|
| 52 |
+
Content-Type: application/json
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Usage with opencode
|
| 56 |
+
|
| 57 |
+
Configure in `~/.config/opencode/opencode.json`:
|
| 58 |
+
|
| 59 |
+
```json
|
| 60 |
+
{
|
| 61 |
+
"providers": {
|
| 62 |
+
"zerogpu": {
|
| 63 |
+
"npm": "@ai-sdk/openai-compatible",
|
| 64 |
+
"options": {
|
| 65 |
+
"baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
|
| 66 |
+
"headers": {
|
| 67 |
+
"Authorization": "Bearer hf_YOUR_TOKEN"
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
"models": {
|
| 71 |
+
"llama-8b": {
|
| 72 |
+
"name": "meta-llama/Llama-3.1-8B-Instruct"
|
| 73 |
+
},
|
| 74 |
+
"mistral-7b": {
|
| 75 |
+
"name": "mistralai/Mistral-7B-Instruct-v0.3"
|
| 76 |
+
},
|
| 77 |
+
"qwen-7b": {
|
| 78 |
+
"name": "Qwen/Qwen2.5-7B-Instruct"
|
| 79 |
+
},
|
| 80 |
+
"qwen-14b": {
|
| 81 |
+
"name": "Qwen/Qwen2.5-14B-Instruct"
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Then use `/models` in opencode to select a zerogpu model.
|
| 90 |
+
|
| 91 |
+
## Supported Models
|
| 92 |
+
|
| 93 |
+
Any HuggingFace model that fits in ~70GB VRAM. Examples:
|
| 94 |
+
|
| 95 |
+
| Model | Size | Quantization |
|
| 96 |
+
|-------|------|--------------|
|
| 97 |
+
| `meta-llama/Llama-3.1-8B-Instruct` | 8B | None |
|
| 98 |
+
| `mistralai/Mistral-7B-Instruct-v0.3` | 7B | None |
|
| 99 |
+
| `Qwen/Qwen2.5-7B-Instruct` | 7B | None |
|
| 100 |
+
| `Qwen/Qwen2.5-14B-Instruct` | 14B | None |
|
| 101 |
+
| `Qwen/Qwen2.5-32B-Instruct` | 32B | None |
|
| 102 |
+
| `meta-llama/Llama-3.1-70B-Instruct` | 70B | INT4 (auto) |
|
| 103 |
+
|
| 104 |
+
Models larger than 34B are automatically quantized to INT4.
|
| 105 |
+
|
| 106 |
+
## VRAM Guidelines
|
| 107 |
+
|
| 108 |
+
| Model Size | FP16 VRAM | INT8 VRAM | INT4 VRAM |
|
| 109 |
+
|------------|-----------|-----------|-----------|
|
| 110 |
+
| 7B | ~14GB | ~7GB | ~3.5GB |
|
| 111 |
+
| 13B | ~26GB | ~13GB | ~6.5GB |
|
| 112 |
+
| 34B | ~68GB | ~34GB | ~17GB |
|
| 113 |
+
| 70B | ~140GB | ~70GB | ~35GB |
|
| 114 |
+
|
| 115 |
+
*70B models require INT4 quantization. Add ~20% overhead for KV cache.*
|
| 116 |
+
|
| 117 |
+
## Quota Information
|
| 118 |
+
|
| 119 |
+
- **PRO plan**: 25 minutes/day of H200 GPU compute
|
| 120 |
+
- **Priority**: PRO users get highest queue priority
|
| 121 |
+
- **Fallback**: When quota exhausted, falls back to HF Serverless Inference API
|
| 122 |
+
|
| 123 |
+
## API Endpoints
|
| 124 |
+
|
| 125 |
+
| Endpoint | Method | Description |
|
| 126 |
+
|----------|--------|-------------|
|
| 127 |
+
| `/v1/chat/completions` | POST | Chat completion (OpenAI-compatible) |
|
| 128 |
+
| `/v1/models` | GET | List loaded models |
|
| 129 |
+
| `/health` | GET | Health check and quota status |
|
| 130 |
+
|
| 131 |
+
## Local Development
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
# Clone the repo
|
| 135 |
+
git clone https://huggingface.co/spaces/serenichron/opencode-zerogpu
|
| 136 |
+
|
| 137 |
+
# Install dependencies
|
| 138 |
+
pip install -r requirements.txt
|
| 139 |
+
|
| 140 |
+
# Run locally (ZeroGPU decorator no-ops)
|
| 141 |
+
python app.py
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## Testing
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
# Run tests
|
| 148 |
+
pytest tests/ -v
|
| 149 |
+
|
| 150 |
+
# Test the API locally
|
| 151 |
+
curl -X POST http://localhost:7860/v1/chat/completions \
|
| 152 |
+
-H "Content-Type: application/json" \
|
| 153 |
+
-H "Authorization: Bearer $HF_TOKEN" \
|
| 154 |
+
-d '{
|
| 155 |
+
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
| 156 |
+
"messages": [{"role": "user", "content": "Hello!"}],
|
| 157 |
+
"stream": false
|
| 158 |
+
}'
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## Environment Variables
|
| 162 |
+
|
| 163 |
+
| Variable | Required | Description |
|
| 164 |
+
|----------|----------|-------------|
|
| 165 |
+
| `HF_TOKEN` | No* | Token for gated models (* Space uses its own token) |
|
| 166 |
+
| `FALLBACK_ENABLED` | No | Enable HF Serverless fallback (default: true) |
|
| 167 |
+
| `LOG_LEVEL` | No | Logging verbosity (default: INFO) |
|
| 168 |
+
|
| 169 |
+
## License
|
| 170 |
+
|
| 171 |
+
MIT
|
app.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace ZeroGPU Space - OpenAI-compatible inference provider for opencode.
|
| 3 |
+
|
| 4 |
+
This Gradio app provides:
|
| 5 |
+
- OpenAI-compatible /v1/chat/completions endpoint
|
| 6 |
+
- Pass-through model selection (any HF model ID)
|
| 7 |
+
- ZeroGPU H200 inference with HF Serverless fallback
|
| 8 |
+
- HF Token authentication
|
| 9 |
+
- SSE streaming support
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import time
|
| 14 |
+
from contextlib import asynccontextmanager
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import httpx
|
| 19 |
+
from fastapi import FastAPI, Header, HTTPException, Request
|
| 20 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 21 |
+
from huggingface_hub import HfApi
|
| 22 |
+
|
| 23 |
+
# Import spaces conditionally (no-op when not on ZeroGPU)
|
| 24 |
+
try:
|
| 25 |
+
import spaces
|
| 26 |
+
ZEROGPU_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
ZEROGPU_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
# Create a no-op decorator for local development
|
| 31 |
+
class spaces:
|
| 32 |
+
@staticmethod
|
| 33 |
+
def GPU(fn=None, duration=60):
|
| 34 |
+
if fn is None:
|
| 35 |
+
return lambda f: f
|
| 36 |
+
return fn
|
| 37 |
+
|
| 38 |
+
from config import get_config, get_quota_tracker
|
| 39 |
+
from models import (
|
| 40 |
+
apply_chat_template,
|
| 41 |
+
generate_text,
|
| 42 |
+
generate_text_stream,
|
| 43 |
+
get_current_model,
|
| 44 |
+
)
|
| 45 |
+
from openai_compat import (
|
| 46 |
+
ChatCompletionRequest,
|
| 47 |
+
InferenceParams,
|
| 48 |
+
create_chat_response,
|
| 49 |
+
create_error_response,
|
| 50 |
+
estimate_tokens,
|
| 51 |
+
stream_response_generator,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
config = get_config()
|
| 57 |
+
quota_tracker = get_quota_tracker()
|
| 58 |
+
|
| 59 |
+
# HuggingFace API for token validation
|
| 60 |
+
hf_api = HfApi()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# --- Authentication ---
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def validate_hf_token(token: str) -> bool:
|
| 67 |
+
"""Validate a HuggingFace token by checking with the API."""
|
| 68 |
+
if not token or not token.startswith("hf_"):
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
hf_api.whoami(token=token)
|
| 73 |
+
return True
|
| 74 |
+
except Exception:
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def extract_token(authorization: Optional[str]) -> Optional[str]:
|
| 79 |
+
"""Extract the token from the Authorization header."""
|
| 80 |
+
if not authorization:
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
if authorization.startswith("Bearer "):
|
| 84 |
+
return authorization[7:]
|
| 85 |
+
|
| 86 |
+
return authorization
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# --- ZeroGPU Inference ---
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@spaces.GPU(duration=120)
|
| 93 |
+
def zerogpu_generate(
|
| 94 |
+
model_id: str,
|
| 95 |
+
prompt: str,
|
| 96 |
+
max_new_tokens: int,
|
| 97 |
+
temperature: float,
|
| 98 |
+
top_p: float,
|
| 99 |
+
stop_sequences: Optional[list[str]],
|
| 100 |
+
) -> str:
|
| 101 |
+
"""Generate text using ZeroGPU (H200 GPU)."""
|
| 102 |
+
start_time = time.time()
|
| 103 |
+
|
| 104 |
+
result = generate_text(
|
| 105 |
+
model_id=model_id,
|
| 106 |
+
prompt=prompt,
|
| 107 |
+
max_new_tokens=max_new_tokens,
|
| 108 |
+
temperature=temperature,
|
| 109 |
+
top_p=top_p,
|
| 110 |
+
stop_sequences=stop_sequences,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Track quota usage
|
| 114 |
+
duration = time.time() - start_time
|
| 115 |
+
quota_tracker.add_usage(duration)
|
| 116 |
+
|
| 117 |
+
return result
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@spaces.GPU(duration=120)
|
| 121 |
+
def zerogpu_generate_stream(
|
| 122 |
+
model_id: str,
|
| 123 |
+
prompt: str,
|
| 124 |
+
max_new_tokens: int,
|
| 125 |
+
temperature: float,
|
| 126 |
+
top_p: float,
|
| 127 |
+
stop_sequences: Optional[list[str]],
|
| 128 |
+
):
|
| 129 |
+
"""Generate text with streaming using ZeroGPU (H200 GPU)."""
|
| 130 |
+
start_time = time.time()
|
| 131 |
+
|
| 132 |
+
for token in generate_text_stream(
|
| 133 |
+
model_id=model_id,
|
| 134 |
+
prompt=prompt,
|
| 135 |
+
max_new_tokens=max_new_tokens,
|
| 136 |
+
temperature=temperature,
|
| 137 |
+
top_p=top_p,
|
| 138 |
+
stop_sequences=stop_sequences,
|
| 139 |
+
):
|
| 140 |
+
yield token
|
| 141 |
+
|
| 142 |
+
# Track quota usage
|
| 143 |
+
duration = time.time() - start_time
|
| 144 |
+
quota_tracker.add_usage(duration)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# --- HF Serverless Fallback ---
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
async def serverless_generate(
|
| 151 |
+
model_id: str,
|
| 152 |
+
prompt: str,
|
| 153 |
+
max_new_tokens: int,
|
| 154 |
+
temperature: float,
|
| 155 |
+
top_p: float,
|
| 156 |
+
token: str,
|
| 157 |
+
) -> str:
|
| 158 |
+
"""Generate text using HuggingFace Serverless Inference API."""
|
| 159 |
+
url = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 160 |
+
|
| 161 |
+
payload = {
|
| 162 |
+
"inputs": prompt,
|
| 163 |
+
"parameters": {
|
| 164 |
+
"max_new_tokens": max_new_tokens,
|
| 165 |
+
"temperature": temperature,
|
| 166 |
+
"top_p": top_p,
|
| 167 |
+
"return_full_text": False,
|
| 168 |
+
},
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
async with httpx.AsyncClient() as client:
|
| 172 |
+
response = await client.post(
|
| 173 |
+
url,
|
| 174 |
+
json=payload,
|
| 175 |
+
headers={"Authorization": f"Bearer {token}"},
|
| 176 |
+
timeout=120.0,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if response.status_code != 200:
|
| 180 |
+
raise HTTPException(
|
| 181 |
+
status_code=response.status_code,
|
| 182 |
+
detail=f"HF Serverless error: {response.text}",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
result = response.json()
|
| 186 |
+
|
| 187 |
+
# Handle different response formats
|
| 188 |
+
if isinstance(result, list) and len(result) > 0:
|
| 189 |
+
if "generated_text" in result[0]:
|
| 190 |
+
return result[0]["generated_text"]
|
| 191 |
+
|
| 192 |
+
raise HTTPException(
|
| 193 |
+
status_code=500,
|
| 194 |
+
detail=f"Unexpected response format from HF Serverless: {result}",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# --- FastAPI App ---
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@asynccontextmanager
|
| 202 |
+
async def lifespan(app: FastAPI):
|
| 203 |
+
"""Application lifespan events."""
|
| 204 |
+
logger.info("Starting ZeroGPU OpenCode Provider")
|
| 205 |
+
logger.info(f"ZeroGPU available: {ZEROGPU_AVAILABLE}")
|
| 206 |
+
logger.info(f"Fallback enabled: {config.fallback_enabled}")
|
| 207 |
+
yield
|
| 208 |
+
logger.info("Shutting down ZeroGPU OpenCode Provider")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
api = FastAPI(
|
| 212 |
+
title="ZeroGPU OpenCode Provider",
|
| 213 |
+
description="OpenAI-compatible API for HuggingFace models on ZeroGPU",
|
| 214 |
+
version="1.0.0",
|
| 215 |
+
lifespan=lifespan,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@api.post("/v1/chat/completions")
|
| 220 |
+
async def chat_completions(
|
| 221 |
+
request: ChatCompletionRequest,
|
| 222 |
+
authorization: Optional[str] = Header(None),
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
OpenAI-compatible chat completions endpoint.
|
| 226 |
+
|
| 227 |
+
Supports both streaming and non-streaming responses.
|
| 228 |
+
"""
|
| 229 |
+
# Validate authentication
|
| 230 |
+
token = extract_token(authorization)
|
| 231 |
+
if not token or not validate_hf_token(token):
|
| 232 |
+
return JSONResponse(
|
| 233 |
+
status_code=401,
|
| 234 |
+
content=create_error_response(
|
| 235 |
+
message="Invalid or missing HuggingFace token",
|
| 236 |
+
error_type="authentication_error",
|
| 237 |
+
code="invalid_api_key",
|
| 238 |
+
).model_dump(),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Extract inference parameters
|
| 242 |
+
params = InferenceParams.from_request(request)
|
| 243 |
+
|
| 244 |
+
# Apply chat template
|
| 245 |
+
try:
|
| 246 |
+
prompt = apply_chat_template(params.model_id, params.messages)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.error(f"Failed to apply chat template: {e}")
|
| 249 |
+
return JSONResponse(
|
| 250 |
+
status_code=400,
|
| 251 |
+
content=create_error_response(
|
| 252 |
+
message=f"Failed to load model or apply chat template: {str(e)}",
|
| 253 |
+
error_type="invalid_request_error",
|
| 254 |
+
param="model",
|
| 255 |
+
).model_dump(),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
prompt_tokens = estimate_tokens(prompt)
|
| 259 |
+
|
| 260 |
+
# Determine inference method
|
| 261 |
+
use_zerogpu = ZEROGPU_AVAILABLE and not quota_tracker.quota_exhausted
|
| 262 |
+
|
| 263 |
+
if not use_zerogpu and not config.fallback_enabled:
|
| 264 |
+
return JSONResponse(
|
| 265 |
+
status_code=503,
|
| 266 |
+
content=create_error_response(
|
| 267 |
+
message="ZeroGPU quota exhausted and fallback is disabled",
|
| 268 |
+
error_type="server_error",
|
| 269 |
+
code="quota_exhausted",
|
| 270 |
+
).model_dump(),
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
if params.stream:
|
| 275 |
+
# Streaming response
|
| 276 |
+
if use_zerogpu:
|
| 277 |
+
token_gen = zerogpu_generate_stream(
|
| 278 |
+
model_id=params.model_id,
|
| 279 |
+
prompt=prompt,
|
| 280 |
+
max_new_tokens=params.max_new_tokens,
|
| 281 |
+
temperature=params.temperature,
|
| 282 |
+
top_p=params.top_p,
|
| 283 |
+
stop_sequences=params.stop_sequences,
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
# Fallback doesn't support streaming, so generate full response
|
| 287 |
+
# and simulate streaming
|
| 288 |
+
logger.info("Using HF Serverless fallback (no streaming)")
|
| 289 |
+
full_response = await serverless_generate(
|
| 290 |
+
model_id=params.model_id,
|
| 291 |
+
prompt=prompt,
|
| 292 |
+
max_new_tokens=params.max_new_tokens,
|
| 293 |
+
temperature=params.temperature,
|
| 294 |
+
top_p=params.top_p,
|
| 295 |
+
token=token,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def simulate_stream():
|
| 299 |
+
# Yield the full response as a single chunk
|
| 300 |
+
yield full_response
|
| 301 |
+
|
| 302 |
+
token_gen = simulate_stream()
|
| 303 |
+
|
| 304 |
+
return StreamingResponse(
|
| 305 |
+
stream_response_generator(params.model_id, token_gen),
|
| 306 |
+
media_type="text/event-stream",
|
| 307 |
+
headers={
|
| 308 |
+
"Cache-Control": "no-cache",
|
| 309 |
+
"Connection": "keep-alive",
|
| 310 |
+
"X-Accel-Buffering": "no",
|
| 311 |
+
},
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
# Non-streaming response
|
| 315 |
+
if use_zerogpu:
|
| 316 |
+
response_text = zerogpu_generate(
|
| 317 |
+
model_id=params.model_id,
|
| 318 |
+
prompt=prompt,
|
| 319 |
+
max_new_tokens=params.max_new_tokens,
|
| 320 |
+
temperature=params.temperature,
|
| 321 |
+
top_p=params.top_p,
|
| 322 |
+
stop_sequences=params.stop_sequences,
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
logger.info("Using HF Serverless fallback")
|
| 326 |
+
response_text = await serverless_generate(
|
| 327 |
+
model_id=params.model_id,
|
| 328 |
+
prompt=prompt,
|
| 329 |
+
max_new_tokens=params.max_new_tokens,
|
| 330 |
+
temperature=params.temperature,
|
| 331 |
+
top_p=params.top_p,
|
| 332 |
+
token=token,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
completion_tokens = estimate_tokens(response_text)
|
| 336 |
+
|
| 337 |
+
return create_chat_response(
|
| 338 |
+
model=params.model_id,
|
| 339 |
+
content=response_text,
|
| 340 |
+
prompt_tokens=prompt_tokens,
|
| 341 |
+
completion_tokens=completion_tokens,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
except Exception as e:
|
| 345 |
+
logger.exception(f"Inference error: {e}")
|
| 346 |
+
return JSONResponse(
|
| 347 |
+
status_code=500,
|
| 348 |
+
content=create_error_response(
|
| 349 |
+
message=f"Inference failed: {str(e)}",
|
| 350 |
+
error_type="server_error",
|
| 351 |
+
).model_dump(),
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@api.get("/v1/models")
|
| 356 |
+
async def list_models(authorization: Optional[str] = Header(None)):
|
| 357 |
+
"""List available models (returns info about current model if loaded)."""
|
| 358 |
+
token = extract_token(authorization)
|
| 359 |
+
if not token or not validate_hf_token(token):
|
| 360 |
+
return JSONResponse(
|
| 361 |
+
status_code=401,
|
| 362 |
+
content=create_error_response(
|
| 363 |
+
message="Invalid or missing HuggingFace token",
|
| 364 |
+
error_type="authentication_error",
|
| 365 |
+
code="invalid_api_key",
|
| 366 |
+
).model_dump(),
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
current = get_current_model()
|
| 370 |
+
models = []
|
| 371 |
+
|
| 372 |
+
if current:
|
| 373 |
+
models.append(
|
| 374 |
+
{
|
| 375 |
+
"id": current.model_id,
|
| 376 |
+
"object": "model",
|
| 377 |
+
"created": int(time.time()),
|
| 378 |
+
"owned_by": "huggingface",
|
| 379 |
+
}
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return {"object": "list", "data": models}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
@api.get("/health")
|
| 386 |
+
async def health_check():
|
| 387 |
+
"""Health check endpoint."""
|
| 388 |
+
return {
|
| 389 |
+
"status": "healthy",
|
| 390 |
+
"zerogpu_available": ZEROGPU_AVAILABLE,
|
| 391 |
+
"quota_remaining_minutes": quota_tracker.remaining_minutes(),
|
| 392 |
+
"fallback_enabled": config.fallback_enabled,
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# --- Gradio Interface ---
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def gradio_chat(
|
| 400 |
+
message: str,
|
| 401 |
+
history: list[list[str]],
|
| 402 |
+
model_id: str,
|
| 403 |
+
temperature: float,
|
| 404 |
+
max_tokens: int,
|
| 405 |
+
):
|
| 406 |
+
"""Gradio chat interface handler."""
|
| 407 |
+
# Build messages from history
|
| 408 |
+
messages = []
|
| 409 |
+
for user_msg, assistant_msg in history:
|
| 410 |
+
messages.append({"role": "user", "content": user_msg})
|
| 411 |
+
if assistant_msg:
|
| 412 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 413 |
+
messages.append({"role": "user", "content": message})
|
| 414 |
+
|
| 415 |
+
# Apply chat template
|
| 416 |
+
prompt = apply_chat_template(model_id, messages)
|
| 417 |
+
|
| 418 |
+
# Generate response (streaming)
|
| 419 |
+
response = ""
|
| 420 |
+
for token in zerogpu_generate_stream(
|
| 421 |
+
model_id=model_id,
|
| 422 |
+
prompt=prompt,
|
| 423 |
+
max_new_tokens=max_tokens,
|
| 424 |
+
temperature=temperature,
|
| 425 |
+
top_p=0.95,
|
| 426 |
+
stop_sequences=None,
|
| 427 |
+
):
|
| 428 |
+
response += token
|
| 429 |
+
yield response
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# Gradio Blocks interface
|
| 433 |
+
with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
|
| 434 |
+
gr.Markdown(
|
| 435 |
+
"""
|
| 436 |
+
# ZeroGPU OpenCode Provider
|
| 437 |
+
|
| 438 |
+
OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
|
| 439 |
+
|
| 440 |
+
**API Endpoint:** `/v1/chat/completions`
|
| 441 |
+
|
| 442 |
+
## Usage with opencode
|
| 443 |
+
|
| 444 |
+
Configure in `~/.config/opencode/opencode.json`:
|
| 445 |
+
|
| 446 |
+
```json
|
| 447 |
+
{
|
| 448 |
+
"providers": {
|
| 449 |
+
"zerogpu": {
|
| 450 |
+
"npm": "@ai-sdk/openai-compatible",
|
| 451 |
+
"options": {
|
| 452 |
+
"baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
|
| 453 |
+
"headers": {
|
| 454 |
+
"Authorization": "Bearer hf_YOUR_TOKEN"
|
| 455 |
+
}
|
| 456 |
+
},
|
| 457 |
+
"models": {
|
| 458 |
+
"llama-8b": {
|
| 459 |
+
"name": "meta-llama/Llama-3.1-8B-Instruct"
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
```
|
| 466 |
+
|
| 467 |
+
---
|
| 468 |
+
"""
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
with gr.Row():
|
| 472 |
+
with gr.Column(scale=1):
|
| 473 |
+
model_dropdown = gr.Dropdown(
|
| 474 |
+
label="Model",
|
| 475 |
+
choices=[
|
| 476 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
| 477 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 478 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 479 |
+
"Qwen/Qwen2.5-14B-Instruct",
|
| 480 |
+
],
|
| 481 |
+
value="meta-llama/Llama-3.1-8B-Instruct",
|
| 482 |
+
allow_custom_value=True,
|
| 483 |
+
)
|
| 484 |
+
temperature_slider = gr.Slider(
|
| 485 |
+
label="Temperature",
|
| 486 |
+
minimum=0.0,
|
| 487 |
+
maximum=2.0,
|
| 488 |
+
value=0.7,
|
| 489 |
+
step=0.1,
|
| 490 |
+
)
|
| 491 |
+
max_tokens_slider = gr.Slider(
|
| 492 |
+
label="Max Tokens",
|
| 493 |
+
minimum=64,
|
| 494 |
+
maximum=4096,
|
| 495 |
+
value=512,
|
| 496 |
+
step=64,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
gr.Markdown(
|
| 500 |
+
f"""
|
| 501 |
+
### Status
|
| 502 |
+
- **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
|
| 503 |
+
- **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
|
| 504 |
+
"""
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
with gr.Column(scale=3):
|
| 508 |
+
chatbot = gr.ChatInterface(
|
| 509 |
+
fn=gradio_chat,
|
| 510 |
+
additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
|
| 511 |
+
title="",
|
| 512 |
+
examples=[
|
| 513 |
+
["Hello! How are you?"],
|
| 514 |
+
["Explain quantum computing in simple terms."],
|
| 515 |
+
["Write a Python function to calculate fibonacci numbers."],
|
| 516 |
+
],
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# Mount FastAPI to Gradio
|
| 520 |
+
demo = gr.mount_gradio_app(demo, api, path="/")
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
if __name__ == "__main__":
|
| 524 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
config.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration and environment handling for ZeroGPU Space."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Config:
|
| 16 |
+
"""Application configuration loaded from environment."""
|
| 17 |
+
|
| 18 |
+
# HuggingFace token for gated models
|
| 19 |
+
hf_token: Optional[str] = field(default_factory=lambda: os.getenv("HF_TOKEN"))
|
| 20 |
+
|
| 21 |
+
# Fallback to HF Serverless when ZeroGPU quota exhausted
|
| 22 |
+
fallback_enabled: bool = field(
|
| 23 |
+
default_factory=lambda: os.getenv("FALLBACK_ENABLED", "true").lower() == "true"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Logging level
|
| 27 |
+
log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO"))
|
| 28 |
+
|
| 29 |
+
# Quantization settings
|
| 30 |
+
default_quantization: str = field(
|
| 31 |
+
default_factory=lambda: os.getenv("DEFAULT_QUANTIZATION", "none")
|
| 32 |
+
)
|
| 33 |
+
auto_quantize_threshold_b: int = field(
|
| 34 |
+
default_factory=lambda: int(os.getenv("AUTO_QUANTIZE_THRESHOLD_B", "34"))
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def __post_init__(self):
|
| 38 |
+
"""Configure logging after initialization."""
|
| 39 |
+
logging.basicConfig(
|
| 40 |
+
level=getattr(logging, self.log_level.upper(), logging.INFO),
|
| 41 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class QuotaTracker:
|
| 47 |
+
"""Track ZeroGPU quota usage for the current session."""
|
| 48 |
+
|
| 49 |
+
# Total seconds used in current day
|
| 50 |
+
seconds_used: float = 0.0
|
| 51 |
+
|
| 52 |
+
# Daily quota in seconds (PRO plan: 25 min = 1500 sec)
|
| 53 |
+
daily_quota_seconds: float = 1500.0
|
| 54 |
+
|
| 55 |
+
# Whether quota is exhausted
|
| 56 |
+
quota_exhausted: bool = False
|
| 57 |
+
|
| 58 |
+
def add_usage(self, seconds: float) -> None:
|
| 59 |
+
"""Record GPU usage time."""
|
| 60 |
+
self.seconds_used += seconds
|
| 61 |
+
if self.seconds_used >= self.daily_quota_seconds:
|
| 62 |
+
self.quota_exhausted = True
|
| 63 |
+
logger.warning(
|
| 64 |
+
f"ZeroGPU quota exhausted: {self.seconds_used:.1f}s / {self.daily_quota_seconds:.1f}s"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def remaining_seconds(self) -> float:
|
| 68 |
+
"""Get remaining quota in seconds."""
|
| 69 |
+
return max(0, self.daily_quota_seconds - self.seconds_used)
|
| 70 |
+
|
| 71 |
+
def remaining_minutes(self) -> float:
|
| 72 |
+
"""Get remaining quota in minutes."""
|
| 73 |
+
return self.remaining_seconds() / 60.0
|
| 74 |
+
|
| 75 |
+
def reset(self) -> None:
|
| 76 |
+
"""Reset quota (called at day boundary)."""
|
| 77 |
+
self.seconds_used = 0.0
|
| 78 |
+
self.quota_exhausted = False
|
| 79 |
+
logger.info("ZeroGPU quota reset")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Global configuration instance
|
| 83 |
+
config = Config()
|
| 84 |
+
|
| 85 |
+
# Global quota tracker
|
| 86 |
+
quota_tracker = QuotaTracker()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_config() -> Config:
|
| 90 |
+
"""Get the global configuration instance."""
|
| 91 |
+
return config
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_quota_tracker() -> QuotaTracker:
|
| 95 |
+
"""Get the global quota tracker instance."""
|
| 96 |
+
return quota_tracker
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Model size estimates (parameters in billions)
|
| 100 |
+
MODEL_SIZE_ESTIMATES = {
|
| 101 |
+
# Llama family
|
| 102 |
+
"meta-llama/Llama-3.1-8B-Instruct": 8,
|
| 103 |
+
"meta-llama/Llama-3.1-70B-Instruct": 70,
|
| 104 |
+
"meta-llama/Llama-3.2-1B-Instruct": 1,
|
| 105 |
+
"meta-llama/Llama-3.2-3B-Instruct": 3,
|
| 106 |
+
|
| 107 |
+
# Mistral family
|
| 108 |
+
"mistralai/Mistral-7B-Instruct-v0.3": 7,
|
| 109 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": 47, # MoE effective
|
| 110 |
+
|
| 111 |
+
# Qwen family
|
| 112 |
+
"Qwen/Qwen2.5-7B-Instruct": 7,
|
| 113 |
+
"Qwen/Qwen2.5-14B-Instruct": 14,
|
| 114 |
+
"Qwen/Qwen2.5-32B-Instruct": 32,
|
| 115 |
+
"Qwen/Qwen2.5-72B-Instruct": 72,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def estimate_model_size(model_id: str) -> Optional[int]:
|
| 120 |
+
"""
|
| 121 |
+
Estimate model size in billions of parameters from model ID.
|
| 122 |
+
|
| 123 |
+
Returns None if size cannot be determined.
|
| 124 |
+
"""
|
| 125 |
+
# Check known models first
|
| 126 |
+
if model_id in MODEL_SIZE_ESTIMATES:
|
| 127 |
+
return MODEL_SIZE_ESTIMATES[model_id]
|
| 128 |
+
|
| 129 |
+
# Try to extract size from model name (e.g., "7B", "70B", "14B")
|
| 130 |
+
import re
|
| 131 |
+
match = re.search(r"(\d+)B", model_id, re.IGNORECASE)
|
| 132 |
+
if match:
|
| 133 |
+
return int(match.group(1))
|
| 134 |
+
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def should_quantize(model_id: str) -> str:
|
| 139 |
+
"""
|
| 140 |
+
Determine if a model should be quantized and which method to use.
|
| 141 |
+
|
| 142 |
+
Returns: "none", "int8", or "int4"
|
| 143 |
+
"""
|
| 144 |
+
if config.default_quantization != "none":
|
| 145 |
+
return config.default_quantization
|
| 146 |
+
|
| 147 |
+
size = estimate_model_size(model_id)
|
| 148 |
+
if size is None:
|
| 149 |
+
# Unknown size, don't auto-quantize
|
| 150 |
+
return "none"
|
| 151 |
+
|
| 152 |
+
if size > 65:
|
| 153 |
+
# 70B+ models need INT4 to fit in 70GB VRAM
|
| 154 |
+
return "int4"
|
| 155 |
+
elif size > config.auto_quantize_threshold_b:
|
| 156 |
+
# Large models get INT8
|
| 157 |
+
return "int8"
|
| 158 |
+
|
| 159 |
+
return "none"
|
models.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading, caching, and memory management for ZeroGPU inference."""
|
| 2 |
+
|
| 3 |
+
import gc
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional, Generator, Any
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoModelForCausalLM,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
BitsAndBytesConfig,
|
| 13 |
+
TextIteratorStreamer,
|
| 14 |
+
)
|
| 15 |
+
from threading import Thread
|
| 16 |
+
|
| 17 |
+
from config import get_config, should_quantize
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LoadedModel:
|
| 24 |
+
"""Container for a loaded model and its tokenizer."""
|
| 25 |
+
|
| 26 |
+
model_id: str
|
| 27 |
+
model: Any
|
| 28 |
+
tokenizer: Any
|
| 29 |
+
quantization: str = "none"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Global model cache (single model at a time due to memory constraints)
|
| 33 |
+
_current_model: Optional[LoadedModel] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_quantization_config(quantization: str) -> Optional[BitsAndBytesConfig]:
|
| 37 |
+
"""Get BitsAndBytes configuration for the specified quantization level."""
|
| 38 |
+
if quantization == "int8":
|
| 39 |
+
return BitsAndBytesConfig(load_in_8bit=True)
|
| 40 |
+
elif quantization == "int4":
|
| 41 |
+
return BitsAndBytesConfig(
|
| 42 |
+
load_in_4bit=True,
|
| 43 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 44 |
+
bnb_4bit_use_double_quant=True,
|
| 45 |
+
bnb_4bit_quant_type="nf4",
|
| 46 |
+
)
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def clear_gpu_memory() -> None:
|
| 51 |
+
"""Clear GPU memory by running garbage collection and emptying CUDA cache."""
|
| 52 |
+
gc.collect()
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
torch.cuda.empty_cache()
|
| 55 |
+
torch.cuda.synchronize()
|
| 56 |
+
logger.debug("GPU memory cleared")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def unload_model() -> None:
|
| 60 |
+
"""Unload the currently loaded model and free memory."""
|
| 61 |
+
global _current_model
|
| 62 |
+
|
| 63 |
+
if _current_model is not None:
|
| 64 |
+
logger.info(f"Unloading model: {_current_model.model_id}")
|
| 65 |
+
del _current_model.model
|
| 66 |
+
del _current_model.tokenizer
|
| 67 |
+
_current_model = None
|
| 68 |
+
clear_gpu_memory()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_model(
|
| 72 |
+
model_id: str,
|
| 73 |
+
quantization: Optional[str] = None,
|
| 74 |
+
force_reload: bool = False,
|
| 75 |
+
) -> LoadedModel:
|
| 76 |
+
"""
|
| 77 |
+
Load a model from HuggingFace Hub.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_id: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
| 81 |
+
quantization: Force specific quantization ("none", "int8", "int4")
|
| 82 |
+
If None, auto-determine based on model size
|
| 83 |
+
force_reload: If True, reload even if already loaded
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
LoadedModel with model and tokenizer ready for inference
|
| 87 |
+
"""
|
| 88 |
+
global _current_model
|
| 89 |
+
|
| 90 |
+
# Check if already loaded
|
| 91 |
+
if not force_reload and _current_model is not None:
|
| 92 |
+
if _current_model.model_id == model_id:
|
| 93 |
+
logger.debug(f"Model already loaded: {model_id}")
|
| 94 |
+
return _current_model
|
| 95 |
+
|
| 96 |
+
# Determine quantization
|
| 97 |
+
if quantization is None:
|
| 98 |
+
quantization = should_quantize(model_id)
|
| 99 |
+
|
| 100 |
+
logger.info(f"Loading model: {model_id} (quantization: {quantization})")
|
| 101 |
+
|
| 102 |
+
# Unload current model first
|
| 103 |
+
unload_model()
|
| 104 |
+
|
| 105 |
+
config = get_config()
|
| 106 |
+
|
| 107 |
+
# Load tokenizer
|
| 108 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 109 |
+
model_id,
|
| 110 |
+
token=config.hf_token,
|
| 111 |
+
trust_remote_code=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Ensure tokenizer has pad token
|
| 115 |
+
if tokenizer.pad_token is None:
|
| 116 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 117 |
+
|
| 118 |
+
# Load model with appropriate configuration
|
| 119 |
+
quant_config = get_quantization_config(quantization)
|
| 120 |
+
|
| 121 |
+
model_kwargs = {
|
| 122 |
+
"token": config.hf_token,
|
| 123 |
+
"trust_remote_code": True,
|
| 124 |
+
"device_map": "auto",
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if quant_config is not None:
|
| 128 |
+
model_kwargs["quantization_config"] = quant_config
|
| 129 |
+
else:
|
| 130 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 131 |
+
|
| 132 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
|
| 133 |
+
|
| 134 |
+
_current_model = LoadedModel(
|
| 135 |
+
model_id=model_id,
|
| 136 |
+
model=model,
|
| 137 |
+
tokenizer=tokenizer,
|
| 138 |
+
quantization=quantization,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
logger.info(f"Model loaded successfully: {model_id}")
|
| 142 |
+
return _current_model
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_current_model() -> Optional[LoadedModel]:
|
| 146 |
+
"""Get the currently loaded model, if any."""
|
| 147 |
+
return _current_model
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def generate_text(
|
| 151 |
+
model_id: str,
|
| 152 |
+
prompt: str,
|
| 153 |
+
max_new_tokens: int = 512,
|
| 154 |
+
temperature: float = 0.7,
|
| 155 |
+
top_p: float = 0.95,
|
| 156 |
+
top_k: int = 50,
|
| 157 |
+
repetition_penalty: float = 1.1,
|
| 158 |
+
stop_sequences: Optional[list[str]] = None,
|
| 159 |
+
) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Generate text using the specified model.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
model_id: HuggingFace model ID
|
| 165 |
+
prompt: Input prompt (already formatted with chat template)
|
| 166 |
+
max_new_tokens: Maximum tokens to generate
|
| 167 |
+
temperature: Sampling temperature
|
| 168 |
+
top_p: Nucleus sampling probability
|
| 169 |
+
top_k: Top-k sampling parameter
|
| 170 |
+
repetition_penalty: Penalty for repeating tokens
|
| 171 |
+
stop_sequences: Additional stop sequences
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Generated text (without the input prompt)
|
| 175 |
+
"""
|
| 176 |
+
loaded = load_model(model_id)
|
| 177 |
+
|
| 178 |
+
inputs = loaded.tokenizer(
|
| 179 |
+
prompt,
|
| 180 |
+
return_tensors="pt",
|
| 181 |
+
truncation=True,
|
| 182 |
+
max_length=loaded.tokenizer.model_max_length - max_new_tokens,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if torch.cuda.is_available():
|
| 186 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 187 |
+
|
| 188 |
+
# Build generation config
|
| 189 |
+
gen_kwargs = {
|
| 190 |
+
"max_new_tokens": max_new_tokens,
|
| 191 |
+
"temperature": temperature,
|
| 192 |
+
"top_p": top_p,
|
| 193 |
+
"top_k": top_k,
|
| 194 |
+
"repetition_penalty": repetition_penalty,
|
| 195 |
+
"do_sample": temperature > 0,
|
| 196 |
+
"pad_token_id": loaded.tokenizer.pad_token_id,
|
| 197 |
+
"eos_token_id": loaded.tokenizer.eos_token_id,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
with torch.no_grad():
|
| 201 |
+
outputs = loaded.model.generate(**inputs, **gen_kwargs)
|
| 202 |
+
|
| 203 |
+
# Decode only the new tokens
|
| 204 |
+
input_length = inputs["input_ids"].shape[1]
|
| 205 |
+
generated_tokens = outputs[0][input_length:]
|
| 206 |
+
response = loaded.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 207 |
+
|
| 208 |
+
# Handle stop sequences
|
| 209 |
+
if stop_sequences:
|
| 210 |
+
for stop_seq in stop_sequences:
|
| 211 |
+
if stop_seq in response:
|
| 212 |
+
response = response.split(stop_seq)[0]
|
| 213 |
+
|
| 214 |
+
return response
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def generate_text_stream(
|
| 218 |
+
model_id: str,
|
| 219 |
+
prompt: str,
|
| 220 |
+
max_new_tokens: int = 512,
|
| 221 |
+
temperature: float = 0.7,
|
| 222 |
+
top_p: float = 0.95,
|
| 223 |
+
top_k: int = 50,
|
| 224 |
+
repetition_penalty: float = 1.1,
|
| 225 |
+
stop_sequences: Optional[list[str]] = None,
|
| 226 |
+
) -> Generator[str, None, None]:
|
| 227 |
+
"""
|
| 228 |
+
Generate text using streaming output.
|
| 229 |
+
|
| 230 |
+
Yields tokens as they are generated.
|
| 231 |
+
"""
|
| 232 |
+
loaded = load_model(model_id)
|
| 233 |
+
|
| 234 |
+
inputs = loaded.tokenizer(
|
| 235 |
+
prompt,
|
| 236 |
+
return_tensors="pt",
|
| 237 |
+
truncation=True,
|
| 238 |
+
max_length=loaded.tokenizer.model_max_length - max_new_tokens,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if torch.cuda.is_available():
|
| 242 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 243 |
+
|
| 244 |
+
# Create streamer
|
| 245 |
+
streamer = TextIteratorStreamer(
|
| 246 |
+
loaded.tokenizer,
|
| 247 |
+
skip_prompt=True,
|
| 248 |
+
skip_special_tokens=True,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Build generation config
|
| 252 |
+
gen_kwargs = {
|
| 253 |
+
**inputs,
|
| 254 |
+
"max_new_tokens": max_new_tokens,
|
| 255 |
+
"temperature": temperature,
|
| 256 |
+
"top_p": top_p,
|
| 257 |
+
"top_k": top_k,
|
| 258 |
+
"repetition_penalty": repetition_penalty,
|
| 259 |
+
"do_sample": temperature > 0,
|
| 260 |
+
"pad_token_id": loaded.tokenizer.pad_token_id,
|
| 261 |
+
"eos_token_id": loaded.tokenizer.eos_token_id,
|
| 262 |
+
"streamer": streamer,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# Run generation in separate thread
|
| 266 |
+
thread = Thread(target=loaded.model.generate, kwargs=gen_kwargs)
|
| 267 |
+
thread.start()
|
| 268 |
+
|
| 269 |
+
# Stream tokens
|
| 270 |
+
accumulated = ""
|
| 271 |
+
for token in streamer:
|
| 272 |
+
accumulated += token
|
| 273 |
+
|
| 274 |
+
# Check for stop sequences
|
| 275 |
+
should_stop = False
|
| 276 |
+
if stop_sequences:
|
| 277 |
+
for stop_seq in stop_sequences:
|
| 278 |
+
if stop_seq in accumulated:
|
| 279 |
+
# Yield everything before the stop sequence
|
| 280 |
+
before_stop = accumulated.split(stop_seq)[0]
|
| 281 |
+
if before_stop:
|
| 282 |
+
yield before_stop[len(accumulated) - len(token):]
|
| 283 |
+
should_stop = True
|
| 284 |
+
break
|
| 285 |
+
|
| 286 |
+
if should_stop:
|
| 287 |
+
break
|
| 288 |
+
|
| 289 |
+
yield token
|
| 290 |
+
|
| 291 |
+
thread.join()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def apply_chat_template(
|
| 295 |
+
model_id: str,
|
| 296 |
+
messages: list[dict[str, str]],
|
| 297 |
+
add_generation_prompt: bool = True,
|
| 298 |
+
) -> str:
|
| 299 |
+
"""
|
| 300 |
+
Apply the model's chat template to format messages.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
model_id: HuggingFace model ID
|
| 304 |
+
messages: List of message dicts with "role" and "content"
|
| 305 |
+
add_generation_prompt: Whether to add the generation prompt
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Formatted prompt string
|
| 309 |
+
"""
|
| 310 |
+
loaded = load_model(model_id)
|
| 311 |
+
|
| 312 |
+
# Check if tokenizer has chat template
|
| 313 |
+
if hasattr(loaded.tokenizer, "apply_chat_template"):
|
| 314 |
+
return loaded.tokenizer.apply_chat_template(
|
| 315 |
+
messages,
|
| 316 |
+
tokenize=False,
|
| 317 |
+
add_generation_prompt=add_generation_prompt,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Fallback: simple formatting
|
| 321 |
+
prompt_parts = []
|
| 322 |
+
for msg in messages:
|
| 323 |
+
role = msg["role"]
|
| 324 |
+
content = msg["content"]
|
| 325 |
+
if role == "system":
|
| 326 |
+
prompt_parts.append(f"System: {content}\n")
|
| 327 |
+
elif role == "user":
|
| 328 |
+
prompt_parts.append(f"User: {content}\n")
|
| 329 |
+
elif role == "assistant":
|
| 330 |
+
prompt_parts.append(f"Assistant: {content}\n")
|
| 331 |
+
|
| 332 |
+
if add_generation_prompt:
|
| 333 |
+
prompt_parts.append("Assistant:")
|
| 334 |
+
|
| 335 |
+
return "".join(prompt_parts)
|
openai_compat.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenAI-compatible API request/response format handling."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import uuid
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Optional, Generator, Literal
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# --- Request Models ---
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ChatMessage(BaseModel):
|
| 19 |
+
"""A single message in the conversation."""
|
| 20 |
+
|
| 21 |
+
role: Literal["system", "user", "assistant"]
|
| 22 |
+
content: str
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ChatCompletionRequest(BaseModel):
|
| 26 |
+
"""OpenAI-compatible chat completion request."""
|
| 27 |
+
|
| 28 |
+
model: str = Field(..., description="HuggingFace model ID")
|
| 29 |
+
messages: list[ChatMessage]
|
| 30 |
+
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
| 31 |
+
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
|
| 32 |
+
max_tokens: Optional[int] = Field(default=512, ge=1, le=8192)
|
| 33 |
+
stream: bool = False
|
| 34 |
+
stop: Optional[list[str]] = None
|
| 35 |
+
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
|
| 36 |
+
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
|
| 37 |
+
n: int = Field(default=1, ge=1, le=1) # Only support n=1 for now
|
| 38 |
+
user: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# --- Response Models ---
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ChatCompletionChoice(BaseModel):
|
| 45 |
+
"""A single completion choice."""
|
| 46 |
+
|
| 47 |
+
index: int
|
| 48 |
+
message: ChatMessage
|
| 49 |
+
finish_reason: Literal["stop", "length", "content_filter"] = "stop"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ChatCompletionUsage(BaseModel):
|
| 53 |
+
"""Token usage statistics."""
|
| 54 |
+
|
| 55 |
+
prompt_tokens: int
|
| 56 |
+
completion_tokens: int
|
| 57 |
+
total_tokens: int
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ChatCompletionResponse(BaseModel):
|
| 61 |
+
"""OpenAI-compatible chat completion response."""
|
| 62 |
+
|
| 63 |
+
id: str
|
| 64 |
+
object: str = "chat.completion"
|
| 65 |
+
created: int
|
| 66 |
+
model: str
|
| 67 |
+
choices: list[ChatCompletionChoice]
|
| 68 |
+
usage: ChatCompletionUsage
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# --- Streaming Response Models ---
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class DeltaMessage(BaseModel):
|
| 75 |
+
"""Delta content for streaming responses."""
|
| 76 |
+
|
| 77 |
+
role: Optional[str] = None
|
| 78 |
+
content: Optional[str] = None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class StreamChoice(BaseModel):
|
| 82 |
+
"""A single streaming choice."""
|
| 83 |
+
|
| 84 |
+
index: int
|
| 85 |
+
delta: DeltaMessage
|
| 86 |
+
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ChatCompletionChunk(BaseModel):
|
| 90 |
+
"""OpenAI-compatible streaming chunk."""
|
| 91 |
+
|
| 92 |
+
id: str
|
| 93 |
+
object: str = "chat.completion.chunk"
|
| 94 |
+
created: int
|
| 95 |
+
model: str
|
| 96 |
+
choices: list[StreamChoice]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# --- Helper Functions ---
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def generate_completion_id() -> str:
|
| 103 |
+
"""Generate a unique completion ID."""
|
| 104 |
+
return f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_chat_response(
|
| 108 |
+
model: str,
|
| 109 |
+
content: str,
|
| 110 |
+
prompt_tokens: int = 0,
|
| 111 |
+
completion_tokens: int = 0,
|
| 112 |
+
finish_reason: str = "stop",
|
| 113 |
+
) -> ChatCompletionResponse:
|
| 114 |
+
"""Create a complete chat completion response."""
|
| 115 |
+
return ChatCompletionResponse(
|
| 116 |
+
id=generate_completion_id(),
|
| 117 |
+
created=int(time.time()),
|
| 118 |
+
model=model,
|
| 119 |
+
choices=[
|
| 120 |
+
ChatCompletionChoice(
|
| 121 |
+
index=0,
|
| 122 |
+
message=ChatMessage(role="assistant", content=content),
|
| 123 |
+
finish_reason=finish_reason,
|
| 124 |
+
)
|
| 125 |
+
],
|
| 126 |
+
usage=ChatCompletionUsage(
|
| 127 |
+
prompt_tokens=prompt_tokens,
|
| 128 |
+
completion_tokens=completion_tokens,
|
| 129 |
+
total_tokens=prompt_tokens + completion_tokens,
|
| 130 |
+
),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def create_stream_chunk(
|
| 135 |
+
completion_id: str,
|
| 136 |
+
model: str,
|
| 137 |
+
content: Optional[str] = None,
|
| 138 |
+
role: Optional[str] = None,
|
| 139 |
+
finish_reason: Optional[str] = None,
|
| 140 |
+
) -> ChatCompletionChunk:
|
| 141 |
+
"""Create a single streaming chunk."""
|
| 142 |
+
return ChatCompletionChunk(
|
| 143 |
+
id=completion_id,
|
| 144 |
+
created=int(time.time()),
|
| 145 |
+
model=model,
|
| 146 |
+
choices=[
|
| 147 |
+
StreamChoice(
|
| 148 |
+
index=0,
|
| 149 |
+
delta=DeltaMessage(role=role, content=content),
|
| 150 |
+
finish_reason=finish_reason,
|
| 151 |
+
)
|
| 152 |
+
],
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def stream_response_generator(
|
| 157 |
+
model: str,
|
| 158 |
+
token_generator: Generator[str, None, None],
|
| 159 |
+
) -> Generator[str, None, None]:
|
| 160 |
+
"""
|
| 161 |
+
Convert a token generator to SSE-formatted streaming response.
|
| 162 |
+
|
| 163 |
+
Yields SSE-formatted strings ready for HTTP streaming.
|
| 164 |
+
"""
|
| 165 |
+
completion_id = generate_completion_id()
|
| 166 |
+
|
| 167 |
+
# First chunk: role
|
| 168 |
+
first_chunk = create_stream_chunk(
|
| 169 |
+
completion_id=completion_id,
|
| 170 |
+
model=model,
|
| 171 |
+
role="assistant",
|
| 172 |
+
)
|
| 173 |
+
yield f"data: {first_chunk.model_dump_json()}\n\n"
|
| 174 |
+
|
| 175 |
+
# Content chunks
|
| 176 |
+
for token in token_generator:
|
| 177 |
+
chunk = create_stream_chunk(
|
| 178 |
+
completion_id=completion_id,
|
| 179 |
+
model=model,
|
| 180 |
+
content=token,
|
| 181 |
+
)
|
| 182 |
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
| 183 |
+
|
| 184 |
+
# Final chunk: finish reason
|
| 185 |
+
final_chunk = create_stream_chunk(
|
| 186 |
+
completion_id=completion_id,
|
| 187 |
+
model=model,
|
| 188 |
+
finish_reason="stop",
|
| 189 |
+
)
|
| 190 |
+
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
| 191 |
+
|
| 192 |
+
# End marker
|
| 193 |
+
yield "data: [DONE]\n\n"
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def messages_to_dicts(messages: list[ChatMessage]) -> list[dict[str, str]]:
|
| 197 |
+
"""Convert Pydantic ChatMessage objects to simple dicts."""
|
| 198 |
+
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def estimate_tokens(text: str) -> int:
|
| 202 |
+
"""
|
| 203 |
+
Rough token count estimation.
|
| 204 |
+
|
| 205 |
+
This is a simple approximation - actual token count depends on the tokenizer.
|
| 206 |
+
Rule of thumb: ~4 characters per token for English text.
|
| 207 |
+
"""
|
| 208 |
+
return max(1, len(text) // 4)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@dataclass
|
| 212 |
+
class InferenceParams:
|
| 213 |
+
"""Extracted inference parameters from request."""
|
| 214 |
+
|
| 215 |
+
model_id: str
|
| 216 |
+
messages: list[dict[str, str]]
|
| 217 |
+
max_new_tokens: int
|
| 218 |
+
temperature: float
|
| 219 |
+
top_p: float
|
| 220 |
+
stop_sequences: Optional[list[str]]
|
| 221 |
+
stream: bool
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def from_request(cls, request: ChatCompletionRequest) -> "InferenceParams":
|
| 225 |
+
"""Extract inference parameters from an OpenAI-compatible request."""
|
| 226 |
+
return cls(
|
| 227 |
+
model_id=request.model,
|
| 228 |
+
messages=messages_to_dicts(request.messages),
|
| 229 |
+
max_new_tokens=request.max_tokens or 512,
|
| 230 |
+
temperature=request.temperature,
|
| 231 |
+
top_p=request.top_p,
|
| 232 |
+
stop_sequences=request.stop,
|
| 233 |
+
stream=request.stream,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# --- Error Responses ---
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class ErrorDetail(BaseModel):
|
| 241 |
+
"""Error detail for API error responses."""
|
| 242 |
+
|
| 243 |
+
message: str
|
| 244 |
+
type: str
|
| 245 |
+
param: Optional[str] = None
|
| 246 |
+
code: Optional[str] = None
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ErrorResponse(BaseModel):
|
| 250 |
+
"""OpenAI-compatible error response."""
|
| 251 |
+
|
| 252 |
+
error: ErrorDetail
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def create_error_response(
|
| 256 |
+
message: str,
|
| 257 |
+
error_type: str = "invalid_request_error",
|
| 258 |
+
param: Optional[str] = None,
|
| 259 |
+
code: Optional[str] = None,
|
| 260 |
+
) -> ErrorResponse:
|
| 261 |
+
"""Create an error response."""
|
| 262 |
+
return ErrorResponse(
|
| 263 |
+
error=ErrorDetail(
|
| 264 |
+
message=message,
|
| 265 |
+
type=error_type,
|
| 266 |
+
param=param,
|
| 267 |
+
code=code,
|
| 268 |
+
)
|
| 269 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace ZeroGPU Space - OpenCode Provider
|
| 2 |
+
# For ZeroGPU H200 inference with OpenAI-compatible API
|
| 3 |
+
|
| 4 |
+
# Core Framework
|
| 5 |
+
gradio>=4.44.0
|
| 6 |
+
spaces>=0.30.0
|
| 7 |
+
|
| 8 |
+
# ML/Inference
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
transformers>=4.45.0
|
| 11 |
+
accelerate>=0.34.0
|
| 12 |
+
bitsandbytes>=0.44.0
|
| 13 |
+
safetensors>=0.4.0
|
| 14 |
+
|
| 15 |
+
# HuggingFace Integration
|
| 16 |
+
huggingface-hub>=0.25.0
|
| 17 |
+
|
| 18 |
+
# API
|
| 19 |
+
fastapi>=0.115.0
|
| 20 |
+
uvicorn>=0.30.0
|
| 21 |
+
pydantic>=2.0.0
|
| 22 |
+
httpx>=0.27.0
|
| 23 |
+
sse-starlette>=2.1.0
|
| 24 |
+
|
| 25 |
+
# Utilities
|
| 26 |
+
python-dotenv>=1.0.0
|
| 27 |
+
typing-extensions>=4.12.0
|
| 28 |
+
|
| 29 |
+
# Testing (dev)
|
| 30 |
+
pytest>=8.0.0
|
| 31 |
+
pytest-asyncio>=0.24.0
|
| 32 |
+
pytest-cov>=5.0.0
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Tests package for ZeroGPU OpenCode Provider
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test fixtures for ZeroGPU OpenCode Provider tests."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from unittest.mock import MagicMock, patch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@pytest.fixture
|
| 8 |
+
def mock_tokenizer():
|
| 9 |
+
"""Create a mock tokenizer for testing."""
|
| 10 |
+
tokenizer = MagicMock()
|
| 11 |
+
tokenizer.pad_token = None
|
| 12 |
+
tokenizer.eos_token = "</s>"
|
| 13 |
+
tokenizer.pad_token_id = 0
|
| 14 |
+
tokenizer.eos_token_id = 2
|
| 15 |
+
tokenizer.model_max_length = 4096
|
| 16 |
+
|
| 17 |
+
def mock_apply_chat_template(messages, tokenize=False, add_generation_prompt=True):
|
| 18 |
+
parts = []
|
| 19 |
+
for msg in messages:
|
| 20 |
+
role = msg.get("role", msg.role if hasattr(msg, "role") else "user")
|
| 21 |
+
content = msg.get("content", msg.content if hasattr(msg, "content") else "")
|
| 22 |
+
if role == "system":
|
| 23 |
+
parts.append(f"<|system|>{content}</s>")
|
| 24 |
+
elif role == "user":
|
| 25 |
+
parts.append(f"<|user|>{content}</s>")
|
| 26 |
+
elif role == "assistant":
|
| 27 |
+
parts.append(f"<|assistant|>{content}</s>")
|
| 28 |
+
if add_generation_prompt:
|
| 29 |
+
parts.append("<|assistant|>")
|
| 30 |
+
return "".join(parts)
|
| 31 |
+
|
| 32 |
+
tokenizer.apply_chat_template = mock_apply_chat_template
|
| 33 |
+
|
| 34 |
+
def mock_call(text, return_tensors=None, truncation=True, max_length=None):
|
| 35 |
+
import torch
|
| 36 |
+
# Simple mock: return input_ids based on text length
|
| 37 |
+
token_count = max(1, len(text) // 4)
|
| 38 |
+
return {
|
| 39 |
+
"input_ids": torch.ones((1, token_count), dtype=torch.long),
|
| 40 |
+
"attention_mask": torch.ones((1, token_count), dtype=torch.long),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
tokenizer.__call__ = mock_call
|
| 44 |
+
tokenizer.return_value = mock_call("test")
|
| 45 |
+
|
| 46 |
+
def mock_decode(tokens, skip_special_tokens=True):
|
| 47 |
+
return "This is a test response."
|
| 48 |
+
|
| 49 |
+
tokenizer.decode = mock_decode
|
| 50 |
+
|
| 51 |
+
return tokenizer
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@pytest.fixture
|
| 55 |
+
def mock_model():
|
| 56 |
+
"""Create a mock model for testing."""
|
| 57 |
+
import torch
|
| 58 |
+
|
| 59 |
+
model = MagicMock()
|
| 60 |
+
|
| 61 |
+
def mock_generate(**kwargs):
|
| 62 |
+
input_ids = kwargs.get("input_ids", torch.ones((1, 10), dtype=torch.long))
|
| 63 |
+
input_length = input_ids.shape[1]
|
| 64 |
+
# Generate some tokens
|
| 65 |
+
generated = torch.ones((1, input_length + 20), dtype=torch.long)
|
| 66 |
+
return generated
|
| 67 |
+
|
| 68 |
+
model.generate = mock_generate
|
| 69 |
+
model.device = "cpu"
|
| 70 |
+
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@pytest.fixture
|
| 75 |
+
def sample_messages():
|
| 76 |
+
"""Sample chat messages for testing."""
|
| 77 |
+
return [
|
| 78 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 79 |
+
{"role": "user", "content": "Hello!"},
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@pytest.fixture
|
| 84 |
+
def sample_request_data():
|
| 85 |
+
"""Sample request data for OpenAI-compatible endpoint."""
|
| 86 |
+
return {
|
| 87 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
| 88 |
+
"messages": [
|
| 89 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 90 |
+
{"role": "user", "content": "Hello!"},
|
| 91 |
+
],
|
| 92 |
+
"temperature": 0.7,
|
| 93 |
+
"max_tokens": 512,
|
| 94 |
+
"stream": False,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@pytest.fixture
|
| 99 |
+
def sample_streaming_request_data():
|
| 100 |
+
"""Sample streaming request data."""
|
| 101 |
+
return {
|
| 102 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
| 103 |
+
"messages": [
|
| 104 |
+
{"role": "user", "content": "Tell me a joke."},
|
| 105 |
+
],
|
| 106 |
+
"temperature": 0.7,
|
| 107 |
+
"max_tokens": 256,
|
| 108 |
+
"stream": True,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@pytest.fixture(autouse=True)
|
| 113 |
+
def mock_torch_cuda():
|
| 114 |
+
"""Mock CUDA availability for tests."""
|
| 115 |
+
with patch("torch.cuda.is_available", return_value=False):
|
| 116 |
+
yield
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for model loading and inference."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from unittest.mock import patch, MagicMock
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Add parent directory to path for imports
|
| 9 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
|
| 11 |
+
from config import estimate_model_size, should_quantize
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestModelSizeEstimation:
|
| 15 |
+
"""Test model size estimation logic."""
|
| 16 |
+
|
| 17 |
+
def test_known_model_size(self):
|
| 18 |
+
"""Test size estimation for known models."""
|
| 19 |
+
assert estimate_model_size("meta-llama/Llama-3.1-8B-Instruct") == 8
|
| 20 |
+
assert estimate_model_size("meta-llama/Llama-3.1-70B-Instruct") == 70
|
| 21 |
+
assert estimate_model_size("mistralai/Mistral-7B-Instruct-v0.3") == 7
|
| 22 |
+
|
| 23 |
+
def test_extract_size_from_name(self):
|
| 24 |
+
"""Test size extraction from model name pattern."""
|
| 25 |
+
assert estimate_model_size("some-org/CustomModel-13B") == 13
|
| 26 |
+
assert estimate_model_size("another/model-2B-test") == 2
|
| 27 |
+
assert estimate_model_size("org/Model-32B-Instruct") == 32
|
| 28 |
+
|
| 29 |
+
def test_unknown_model_size(self):
|
| 30 |
+
"""Test handling of models with unknown size."""
|
| 31 |
+
assert estimate_model_size("unknown/model-without-size") is None
|
| 32 |
+
assert estimate_model_size("org/mystery-model") is None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestQuantizationDecision:
|
| 36 |
+
"""Test automatic quantization decisions."""
|
| 37 |
+
|
| 38 |
+
def test_small_model_no_quantization(self):
|
| 39 |
+
"""Small models should not be quantized."""
|
| 40 |
+
assert should_quantize("meta-llama/Llama-3.1-8B-Instruct") == "none"
|
| 41 |
+
assert should_quantize("mistralai/Mistral-7B-Instruct-v0.3") == "none"
|
| 42 |
+
|
| 43 |
+
def test_large_model_int4_quantization(self):
|
| 44 |
+
"""70B+ models should use INT4."""
|
| 45 |
+
assert should_quantize("meta-llama/Llama-3.1-70B-Instruct") == "int4"
|
| 46 |
+
assert should_quantize("Qwen/Qwen2.5-72B-Instruct") == "int4"
|
| 47 |
+
|
| 48 |
+
def test_unknown_model_no_quantization(self):
|
| 49 |
+
"""Unknown models should not be auto-quantized."""
|
| 50 |
+
assert should_quantize("unknown/mystery-model") == "none"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TestModelLoading:
|
| 54 |
+
"""Test model loading functionality."""
|
| 55 |
+
|
| 56 |
+
@patch("models.AutoModelForCausalLM")
|
| 57 |
+
@patch("models.AutoTokenizer")
|
| 58 |
+
def test_load_model_creates_loaded_model(
|
| 59 |
+
self, mock_tokenizer_class, mock_model_class, mock_tokenizer, mock_model
|
| 60 |
+
):
|
| 61 |
+
"""Test that load_model returns a LoadedModel instance."""
|
| 62 |
+
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| 63 |
+
mock_model_class.from_pretrained.return_value = mock_model
|
| 64 |
+
|
| 65 |
+
from models import load_model, unload_model
|
| 66 |
+
|
| 67 |
+
# Ensure clean state
|
| 68 |
+
unload_model()
|
| 69 |
+
|
| 70 |
+
loaded = load_model("test-model/test-7B")
|
| 71 |
+
|
| 72 |
+
assert loaded.model_id == "test-model/test-7B"
|
| 73 |
+
assert loaded.model is not None
|
| 74 |
+
assert loaded.tokenizer is not None
|
| 75 |
+
|
| 76 |
+
@patch("models.AutoModelForCausalLM")
|
| 77 |
+
@patch("models.AutoTokenizer")
|
| 78 |
+
def test_load_model_caches_result(
|
| 79 |
+
self, mock_tokenizer_class, mock_model_class, mock_tokenizer, mock_model
|
| 80 |
+
):
|
| 81 |
+
"""Test that loading the same model twice uses cache."""
|
| 82 |
+
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| 83 |
+
mock_model_class.from_pretrained.return_value = mock_model
|
| 84 |
+
|
| 85 |
+
from models import load_model, unload_model
|
| 86 |
+
|
| 87 |
+
# Ensure clean state
|
| 88 |
+
unload_model()
|
| 89 |
+
|
| 90 |
+
# First load
|
| 91 |
+
load_model("test-model/test-7B")
|
| 92 |
+
first_call_count = mock_model_class.from_pretrained.call_count
|
| 93 |
+
|
| 94 |
+
# Second load (should use cache)
|
| 95 |
+
load_model("test-model/test-7B")
|
| 96 |
+
second_call_count = mock_model_class.from_pretrained.call_count
|
| 97 |
+
|
| 98 |
+
# Should not have called from_pretrained again
|
| 99 |
+
assert first_call_count == second_call_count
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class TestChatTemplate:
|
| 103 |
+
"""Test chat template application."""
|
| 104 |
+
|
| 105 |
+
@patch("models.load_model")
|
| 106 |
+
def test_apply_chat_template_with_tokenizer_method(self, mock_load_model, mock_tokenizer):
|
| 107 |
+
"""Test chat template when tokenizer has apply_chat_template."""
|
| 108 |
+
from models import apply_chat_template, LoadedModel
|
| 109 |
+
|
| 110 |
+
mock_load_model.return_value = LoadedModel(
|
| 111 |
+
model_id="test-model",
|
| 112 |
+
model=MagicMock(),
|
| 113 |
+
tokenizer=mock_tokenizer,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
messages = [
|
| 117 |
+
{"role": "user", "content": "Hello!"},
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
result = apply_chat_template("test-model", messages)
|
| 121 |
+
|
| 122 |
+
assert "<|user|>" in result
|
| 123 |
+
assert "Hello!" in result
|
| 124 |
+
assert "<|assistant|>" in result # Generation prompt
|
| 125 |
+
|
| 126 |
+
@patch("models.load_model")
|
| 127 |
+
def test_apply_chat_template_fallback(self, mock_load_model):
|
| 128 |
+
"""Test fallback formatting when tokenizer lacks apply_chat_template."""
|
| 129 |
+
from models import apply_chat_template, LoadedModel
|
| 130 |
+
|
| 131 |
+
# Tokenizer without apply_chat_template
|
| 132 |
+
simple_tokenizer = MagicMock()
|
| 133 |
+
del simple_tokenizer.apply_chat_template
|
| 134 |
+
|
| 135 |
+
mock_load_model.return_value = LoadedModel(
|
| 136 |
+
model_id="test-model",
|
| 137 |
+
model=MagicMock(),
|
| 138 |
+
tokenizer=simple_tokenizer,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
messages = [
|
| 142 |
+
{"role": "system", "content": "You are helpful."},
|
| 143 |
+
{"role": "user", "content": "Hi!"},
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
result = apply_chat_template("test-model", messages)
|
| 147 |
+
|
| 148 |
+
assert "System:" in result
|
| 149 |
+
assert "User:" in result
|
| 150 |
+
assert "Assistant:" in result
|
tests/test_openai_compat.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for OpenAI-compatible API format handling."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import pytest
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Add parent directory to path for imports
|
| 9 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
|
| 11 |
+
from openai_compat import (
|
| 12 |
+
ChatCompletionRequest,
|
| 13 |
+
ChatMessage,
|
| 14 |
+
InferenceParams,
|
| 15 |
+
create_chat_response,
|
| 16 |
+
create_error_response,
|
| 17 |
+
create_stream_chunk,
|
| 18 |
+
estimate_tokens,
|
| 19 |
+
generate_completion_id,
|
| 20 |
+
messages_to_dicts,
|
| 21 |
+
stream_response_generator,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TestChatCompletionRequest:
|
| 26 |
+
"""Test request parsing."""
|
| 27 |
+
|
| 28 |
+
def test_parse_basic_request(self, sample_request_data):
|
| 29 |
+
"""Test parsing a basic chat completion request."""
|
| 30 |
+
request = ChatCompletionRequest(**sample_request_data)
|
| 31 |
+
|
| 32 |
+
assert request.model == "meta-llama/Llama-3.1-8B-Instruct"
|
| 33 |
+
assert len(request.messages) == 2
|
| 34 |
+
assert request.messages[0].role == "system"
|
| 35 |
+
assert request.messages[1].role == "user"
|
| 36 |
+
assert request.temperature == 0.7
|
| 37 |
+
assert request.max_tokens == 512
|
| 38 |
+
assert request.stream is False
|
| 39 |
+
|
| 40 |
+
def test_parse_streaming_request(self, sample_streaming_request_data):
|
| 41 |
+
"""Test parsing a streaming request."""
|
| 42 |
+
request = ChatCompletionRequest(**sample_streaming_request_data)
|
| 43 |
+
|
| 44 |
+
assert request.stream is True
|
| 45 |
+
assert request.max_tokens == 256
|
| 46 |
+
|
| 47 |
+
def test_default_values(self):
|
| 48 |
+
"""Test that defaults are applied correctly."""
|
| 49 |
+
minimal_request = {
|
| 50 |
+
"model": "test-model",
|
| 51 |
+
"messages": [{"role": "user", "content": "Hi"}],
|
| 52 |
+
}
|
| 53 |
+
request = ChatCompletionRequest(**minimal_request)
|
| 54 |
+
|
| 55 |
+
assert request.temperature == 0.7
|
| 56 |
+
assert request.top_p == 0.95
|
| 57 |
+
assert request.max_tokens == 512
|
| 58 |
+
assert request.stream is False
|
| 59 |
+
assert request.stop is None
|
| 60 |
+
|
| 61 |
+
def test_validation_temperature_bounds(self):
|
| 62 |
+
"""Test temperature validation."""
|
| 63 |
+
with pytest.raises(ValueError):
|
| 64 |
+
ChatCompletionRequest(
|
| 65 |
+
model="test",
|
| 66 |
+
messages=[{"role": "user", "content": "Hi"}],
|
| 67 |
+
temperature=-0.5,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
with pytest.raises(ValueError):
|
| 71 |
+
ChatCompletionRequest(
|
| 72 |
+
model="test",
|
| 73 |
+
messages=[{"role": "user", "content": "Hi"}],
|
| 74 |
+
temperature=2.5,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class TestChatCompletionResponse:
|
| 79 |
+
"""Test response generation."""
|
| 80 |
+
|
| 81 |
+
def test_create_basic_response(self):
|
| 82 |
+
"""Test creating a basic chat response."""
|
| 83 |
+
response = create_chat_response(
|
| 84 |
+
model="test-model",
|
| 85 |
+
content="Hello! How can I help you?",
|
| 86 |
+
prompt_tokens=10,
|
| 87 |
+
completion_tokens=8,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
assert response.model == "test-model"
|
| 91 |
+
assert response.object == "chat.completion"
|
| 92 |
+
assert len(response.choices) == 1
|
| 93 |
+
assert response.choices[0].message.role == "assistant"
|
| 94 |
+
assert response.choices[0].message.content == "Hello! How can I help you?"
|
| 95 |
+
assert response.choices[0].finish_reason == "stop"
|
| 96 |
+
assert response.usage.prompt_tokens == 10
|
| 97 |
+
assert response.usage.completion_tokens == 8
|
| 98 |
+
assert response.usage.total_tokens == 18
|
| 99 |
+
|
| 100 |
+
def test_response_has_unique_id(self):
|
| 101 |
+
"""Test that each response has a unique ID."""
|
| 102 |
+
response1 = create_chat_response(model="test", content="Hi")
|
| 103 |
+
response2 = create_chat_response(model="test", content="Hi")
|
| 104 |
+
|
| 105 |
+
assert response1.id != response2.id
|
| 106 |
+
assert response1.id.startswith("chatcmpl-")
|
| 107 |
+
|
| 108 |
+
def test_response_serialization(self):
|
| 109 |
+
"""Test that response can be serialized to JSON."""
|
| 110 |
+
response = create_chat_response(
|
| 111 |
+
model="test-model",
|
| 112 |
+
content="Test",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
json_str = response.model_dump_json()
|
| 116 |
+
parsed = json.loads(json_str)
|
| 117 |
+
|
| 118 |
+
assert parsed["model"] == "test-model"
|
| 119 |
+
assert parsed["choices"][0]["message"]["content"] == "Test"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TestStreamingResponse:
|
| 123 |
+
"""Test streaming response format."""
|
| 124 |
+
|
| 125 |
+
def test_create_stream_chunk_with_content(self):
|
| 126 |
+
"""Test creating a streaming chunk with content."""
|
| 127 |
+
chunk = create_stream_chunk(
|
| 128 |
+
completion_id="test-id",
|
| 129 |
+
model="test-model",
|
| 130 |
+
content="Hello",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
assert chunk.id == "test-id"
|
| 134 |
+
assert chunk.object == "chat.completion.chunk"
|
| 135 |
+
assert chunk.choices[0].delta.content == "Hello"
|
| 136 |
+
assert chunk.choices[0].finish_reason is None
|
| 137 |
+
|
| 138 |
+
def test_create_stream_chunk_with_role(self):
|
| 139 |
+
"""Test creating a streaming chunk with role (first chunk)."""
|
| 140 |
+
chunk = create_stream_chunk(
|
| 141 |
+
completion_id="test-id",
|
| 142 |
+
model="test-model",
|
| 143 |
+
role="assistant",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
assert chunk.choices[0].delta.role == "assistant"
|
| 147 |
+
assert chunk.choices[0].delta.content is None
|
| 148 |
+
|
| 149 |
+
def test_create_stream_chunk_with_finish_reason(self):
|
| 150 |
+
"""Test creating a final streaming chunk."""
|
| 151 |
+
chunk = create_stream_chunk(
|
| 152 |
+
completion_id="test-id",
|
| 153 |
+
model="test-model",
|
| 154 |
+
finish_reason="stop",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
assert chunk.choices[0].finish_reason == "stop"
|
| 158 |
+
|
| 159 |
+
def test_stream_response_generator(self):
|
| 160 |
+
"""Test the full streaming response generator."""
|
| 161 |
+
def token_gen():
|
| 162 |
+
yield "Hello"
|
| 163 |
+
yield " World"
|
| 164 |
+
yield "!"
|
| 165 |
+
|
| 166 |
+
chunks = list(stream_response_generator("test-model", token_gen()))
|
| 167 |
+
|
| 168 |
+
# Should have: role chunk, 3 content chunks, finish chunk, [DONE]
|
| 169 |
+
assert len(chunks) == 6
|
| 170 |
+
|
| 171 |
+
# First chunk has role
|
| 172 |
+
first_data = json.loads(chunks[0].replace("data: ", "").strip())
|
| 173 |
+
assert first_data["choices"][0]["delta"]["role"] == "assistant"
|
| 174 |
+
|
| 175 |
+
# Content chunks
|
| 176 |
+
second_data = json.loads(chunks[1].replace("data: ", "").strip())
|
| 177 |
+
assert second_data["choices"][0]["delta"]["content"] == "Hello"
|
| 178 |
+
|
| 179 |
+
# Last data chunk has finish reason
|
| 180 |
+
last_data = json.loads(chunks[4].replace("data: ", "").strip())
|
| 181 |
+
assert last_data["choices"][0]["finish_reason"] == "stop"
|
| 182 |
+
|
| 183 |
+
# Very last is [DONE]
|
| 184 |
+
assert chunks[5] == "data: [DONE]\n\n"
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class TestInferenceParams:
|
| 188 |
+
"""Test parameter extraction."""
|
| 189 |
+
|
| 190 |
+
def test_extract_params_from_request(self, sample_request_data):
|
| 191 |
+
"""Test extracting inference parameters from request."""
|
| 192 |
+
request = ChatCompletionRequest(**sample_request_data)
|
| 193 |
+
params = InferenceParams.from_request(request)
|
| 194 |
+
|
| 195 |
+
assert params.model_id == "meta-llama/Llama-3.1-8B-Instruct"
|
| 196 |
+
assert len(params.messages) == 2
|
| 197 |
+
assert params.max_new_tokens == 512
|
| 198 |
+
assert params.temperature == 0.7
|
| 199 |
+
assert params.stream is False
|
| 200 |
+
|
| 201 |
+
def test_messages_to_dicts(self):
|
| 202 |
+
"""Test converting ChatMessage objects to dicts."""
|
| 203 |
+
messages = [
|
| 204 |
+
ChatMessage(role="user", content="Hello"),
|
| 205 |
+
ChatMessage(role="assistant", content="Hi there!"),
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
dicts = messages_to_dicts(messages)
|
| 209 |
+
|
| 210 |
+
assert dicts == [
|
| 211 |
+
{"role": "user", "content": "Hello"},
|
| 212 |
+
{"role": "assistant", "content": "Hi there!"},
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class TestErrorResponse:
|
| 217 |
+
"""Test error response format."""
|
| 218 |
+
|
| 219 |
+
def test_create_error_response(self):
|
| 220 |
+
"""Test creating an error response."""
|
| 221 |
+
error = create_error_response(
|
| 222 |
+
message="Model not found",
|
| 223 |
+
error_type="invalid_request_error",
|
| 224 |
+
param="model",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
assert error.error.message == "Model not found"
|
| 228 |
+
assert error.error.type == "invalid_request_error"
|
| 229 |
+
assert error.error.param == "model"
|
| 230 |
+
|
| 231 |
+
def test_error_response_serialization(self):
|
| 232 |
+
"""Test error response JSON serialization."""
|
| 233 |
+
error = create_error_response(
|
| 234 |
+
message="Test error",
|
| 235 |
+
error_type="server_error",
|
| 236 |
+
code="internal_error",
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
parsed = json.loads(error.model_dump_json())
|
| 240 |
+
|
| 241 |
+
assert parsed["error"]["message"] == "Test error"
|
| 242 |
+
assert parsed["error"]["type"] == "server_error"
|
| 243 |
+
assert parsed["error"]["code"] == "internal_error"
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class TestUtilityFunctions:
|
| 247 |
+
"""Test utility functions."""
|
| 248 |
+
|
| 249 |
+
def test_generate_completion_id_format(self):
|
| 250 |
+
"""Test completion ID format."""
|
| 251 |
+
id1 = generate_completion_id()
|
| 252 |
+
id2 = generate_completion_id()
|
| 253 |
+
|
| 254 |
+
assert id1.startswith("chatcmpl-")
|
| 255 |
+
assert len(id1) == len("chatcmpl-") + 24
|
| 256 |
+
assert id1 != id2 # Should be unique
|
| 257 |
+
|
| 258 |
+
def test_estimate_tokens(self):
|
| 259 |
+
"""Test rough token estimation."""
|
| 260 |
+
# ~4 chars per token
|
| 261 |
+
assert estimate_tokens("Hello World!") == 3 # 12 chars / 4 = 3
|
| 262 |
+
assert estimate_tokens("A") == 1 # Min 1
|
| 263 |
+
assert estimate_tokens("This is a longer piece of text.") == 8 # 32 / 4 = 8
|