serenichron commited on
Commit
adcb9bd
·
0 Parent(s):

Initial implementation of ZeroGPU OpenCode Provider

Browse files

- OpenAI-compatible /v1/chat/completions endpoint
- Pass-through model selection (any HF model ID)
- ZeroGPU H200 inference with automatic fallback to HF Serverless
- HF Token authentication required
- SSE streaming support
- Automatic INT4 quantization for 70B+ models

Files changed (13) hide show
  1. .env.template +22 -0
  2. .gitignore +60 -0
  3. CLAUDE.md +186 -0
  4. README.md +171 -0
  5. app.py +524 -0
  6. config.py +159 -0
  7. models.py +335 -0
  8. openai_compat.py +269 -0
  9. requirements.txt +32 -0
  10. tests/__init__.py +1 -0
  11. tests/conftest.py +116 -0
  12. tests/test_models.py +150 -0
  13. tests/test_openai_compat.py +263 -0
.env.template ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace ZeroGPU Space - Environment Variables
2
+ # Copy to .env and fill in values
3
+
4
+ # HuggingFace Token (for gated models access)
5
+ # When deployed to HF Space, the Space's own token is used automatically
6
+ # This is mainly for local development with gated models
7
+ HF_TOKEN=
8
+
9
+ # Fallback Configuration
10
+ # Enable HF Serverless Inference API fallback when ZeroGPU quota exhausted
11
+ FALLBACK_ENABLED=true
12
+
13
+ # Logging
14
+ LOG_LEVEL=INFO
15
+
16
+ # Model Loading
17
+ # Default quantization for large models (none, int8, int4)
18
+ DEFAULT_QUANTIZATION=none
19
+
20
+ # Maximum model size to load without quantization (in billions of parameters)
21
+ # Models larger than this will automatically use INT4 quantization
22
+ AUTO_QUANTIZE_THRESHOLD_B=34
.gitignore ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+
30
+ # IDE
31
+ .idea/
32
+ .vscode/
33
+ *.swp
34
+ *.swo
35
+ *~
36
+
37
+ # Testing
38
+ .pytest_cache/
39
+ .coverage
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+
44
+ # Gradio
45
+ flagged/
46
+
47
+ # HuggingFace
48
+ *.bin
49
+ *.safetensors
50
+ *.pt
51
+ *.pth
52
+ .cache/
53
+
54
+ # Logs
55
+ *.log
56
+ logs/
57
+
58
+ # OS
59
+ .DS_Store
60
+ Thumbs.db
CLAUDE.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ HuggingFace ZeroGPU Space serving as an OpenAI-compatible inference provider for opencode. Deployed at `serenichron/opencode-zerogpu`.
8
+
9
+ **Key Features:**
10
+ - OpenAI-compatible `/v1/chat/completions` endpoint
11
+ - Pass-through model selection (any HF model ID)
12
+ - ZeroGPU H200 inference with HF Serverless fallback
13
+ - HF Token authentication required
14
+ - SSE streaming support
15
+
16
+ ## Architecture
17
+
18
+ ```
19
+ ┌─────────────┐ ┌──────────────────────────────────────────────┐
20
+ │ opencode │────▶│ serenichron/opencode-zerogpu (HF Space) │
21
+ │ (client) │ │ │
22
+ └─────────────┘ │ ┌────────────────────────────────────────┐ │
23
+ │ │ app.py (Gradio + FastAPI mount) │ │
24
+ │ │ └─ /v1/chat/completions │ │
25
+ │ │ └─ auth_middleware (HF token) │ │
26
+ │ │ └─ inference_router │ │
27
+ │ │ ├─ ZeroGPU (@spaces.GPU) │ │
28
+ │ │ └─ HF Serverless (fallback) │ │
29
+ │ └────────────────────────────────────────┘ │
30
+ │ │
31
+ │ ┌──────────────┐ ┌───────────────────────┐ │
32
+ │ │ models.py │ │ openai_compat.py │ │
33
+ │ │ - load/unload│ │ - request/response │ │
34
+ │ │ - quantize │ │ - streaming format │ │
35
+ │ └──────────────┘ └───────────────────────┘ │
36
+ └──────────────────────────────────────────────┘
37
+ ```
38
+
39
+ ## Development Commands
40
+
41
+ ### Local Development (CPU/Mock Mode)
42
+ ```bash
43
+ # Install dependencies
44
+ pip install -r requirements.txt
45
+
46
+ # Run locally (ZeroGPU decorator no-ops)
47
+ python app.py
48
+
49
+ # Run with specific port
50
+ gradio app.py --server-port 7860
51
+ ```
52
+
53
+ ### Testing
54
+ ```bash
55
+ # Run all tests
56
+ pytest tests/ -v
57
+
58
+ # Run specific test file
59
+ pytest tests/test_openai_compat.py -v
60
+
61
+ # Run with coverage
62
+ pytest tests/ --cov=. --cov-report=term-missing
63
+ ```
64
+
65
+ ### API Testing
66
+ ```bash
67
+ # Test chat completions endpoint
68
+ curl -X POST http://localhost:7860/v1/chat/completions \
69
+ -H "Content-Type: application/json" \
70
+ -H "Authorization: Bearer $HF_TOKEN" \
71
+ -d '{
72
+ "model": "mistralai/Mistral-7B-Instruct-v0.3",
73
+ "messages": [{"role": "user", "content": "Hello!"}],
74
+ "stream": true
75
+ }'
76
+ ```
77
+
78
+ ### Deployment
79
+ ```bash
80
+ # Push to HuggingFace Space (after git remote setup)
81
+ git push hf main
82
+
83
+ # Or use HF CLI
84
+ huggingface-cli upload serenichron/opencode-zerogpu . --repo-type space
85
+ ```
86
+
87
+ ## Key Files
88
+
89
+ | File | Purpose |
90
+ |------|---------|
91
+ | `app.py` | Main Gradio app with FastAPI mount for OpenAI endpoints |
92
+ | `models.py` | Model loading, unloading, quantization, caching |
93
+ | `openai_compat.py` | OpenAI request/response format conversion |
94
+ | `config.py` | Environment variables, settings, quota tracking |
95
+ | `README.md` | HF Space config (YAML frontmatter) + documentation |
96
+
97
+ ## ZeroGPU Patterns
98
+
99
+ ### GPU Decorator Usage
100
+ ```python
101
+ import spaces
102
+
103
+ # Standard inference (60s default)
104
+ @spaces.GPU
105
+ def generate(prompt, model_id):
106
+ ...
107
+
108
+ # Extended duration for large models
109
+ @spaces.GPU(duration=120)
110
+ def generate_large(prompt, model_id):
111
+ ...
112
+
113
+ # Dynamic duration based on input
114
+ def calc_duration(prompt, max_tokens):
115
+ return min(120, max_tokens // 10)
116
+
117
+ @spaces.GPU(duration=calc_duration)
118
+ def generate_dynamic(prompt, max_tokens):
119
+ ...
120
+ ```
121
+
122
+ ### Model Loading Pattern
123
+ ```python
124
+ import gc
125
+ import torch
126
+
127
+ current_model = None
128
+ current_model_id = None
129
+
130
+ @spaces.GPU
131
+ def load_and_generate(model_id, prompt):
132
+ global current_model, current_model_id
133
+
134
+ if model_id != current_model_id:
135
+ # Cleanup previous model
136
+ if current_model:
137
+ del current_model
138
+ gc.collect()
139
+ torch.cuda.empty_cache()
140
+
141
+ # Load new model
142
+ current_model = AutoModelForCausalLM.from_pretrained(
143
+ model_id,
144
+ torch_dtype=torch.bfloat16,
145
+ device_map="auto"
146
+ )
147
+ current_model_id = model_id
148
+
149
+ return generate(current_model, prompt)
150
+ ```
151
+
152
+ ## Important Constraints
153
+
154
+ 1. **ZeroGPU Compatibility**
155
+ - `torch.compile` NOT supported - use PyTorch AoT instead
156
+ - Gradio SDK only (no Streamlit)
157
+ - GPU allocated only during `@spaces.GPU` decorated functions
158
+
159
+ 2. **Memory Management**
160
+ - H200 provides ~70GB VRAM
161
+ - 70B models require INT4 quantization
162
+ - Always cleanup with `gc.collect()` and `torch.cuda.empty_cache()`
163
+
164
+ 3. **Quota Awareness**
165
+ - PRO plan: 25 min/day H200 compute
166
+ - Track usage, fall back to HF Serverless when exhausted
167
+ - Shorter `duration` = higher queue priority
168
+
169
+ 4. **Authentication**
170
+ - All API requests require `Authorization: Bearer hf_...` header
171
+ - Validate tokens via HuggingFace Hub API
172
+
173
+ ## Environment Variables
174
+
175
+ | Variable | Required | Description |
176
+ |----------|----------|-------------|
177
+ | `HF_TOKEN` | No* | Token for accessing gated models (* Space has its own token) |
178
+ | `FALLBACK_ENABLED` | No | Enable HF Serverless fallback (default: true) |
179
+ | `LOG_LEVEL` | No | Logging verbosity (default: INFO) |
180
+
181
+ ## Testing Strategy
182
+
183
+ 1. **Unit Tests**: Model loading, OpenAI format conversion
184
+ 2. **Integration Tests**: Full API request/response cycle
185
+ 3. **Local Testing**: CPU-only mode (decorator no-ops)
186
+ 4. **Live Testing**: Deploy to Space, test via opencode
README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OpenCode ZeroGPU Provider
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ hardware: zero-a10g
12
+ ---
13
+
14
+ # OpenCode ZeroGPU Provider
15
+
16
+ OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode), powered by HuggingFace ZeroGPU (NVIDIA H200).
17
+
18
+ ## Features
19
+
20
+ - **OpenAI-compatible API** - Drop-in replacement for OpenAI endpoints
21
+ - **Pass-through model selection** - Use any HuggingFace model ID
22
+ - **ZeroGPU H200 inference** - 25 min/day of H200 GPU compute (PRO plan)
23
+ - **Automatic fallback** - Falls back to HF Serverless when quota exhausted
24
+ - **SSE streaming** - Real-time token streaming support
25
+ - **Authentication** - Requires valid HuggingFace token
26
+
27
+ ## API Endpoint
28
+
29
+ ```
30
+ POST /v1/chat/completions
31
+ ```
32
+
33
+ ### Request Format
34
+
35
+ ```json
36
+ {
37
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
38
+ "messages": [
39
+ {"role": "system", "content": "You are a helpful assistant."},
40
+ {"role": "user", "content": "Hello!"}
41
+ ],
42
+ "temperature": 0.7,
43
+ "max_tokens": 512,
44
+ "stream": true
45
+ }
46
+ ```
47
+
48
+ ### Headers
49
+
50
+ ```
51
+ Authorization: Bearer hf_YOUR_TOKEN
52
+ Content-Type: application/json
53
+ ```
54
+
55
+ ## Usage with opencode
56
+
57
+ Configure in `~/.config/opencode/opencode.json`:
58
+
59
+ ```json
60
+ {
61
+ "providers": {
62
+ "zerogpu": {
63
+ "npm": "@ai-sdk/openai-compatible",
64
+ "options": {
65
+ "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
66
+ "headers": {
67
+ "Authorization": "Bearer hf_YOUR_TOKEN"
68
+ }
69
+ },
70
+ "models": {
71
+ "llama-8b": {
72
+ "name": "meta-llama/Llama-3.1-8B-Instruct"
73
+ },
74
+ "mistral-7b": {
75
+ "name": "mistralai/Mistral-7B-Instruct-v0.3"
76
+ },
77
+ "qwen-7b": {
78
+ "name": "Qwen/Qwen2.5-7B-Instruct"
79
+ },
80
+ "qwen-14b": {
81
+ "name": "Qwen/Qwen2.5-14B-Instruct"
82
+ }
83
+ }
84
+ }
85
+ }
86
+ }
87
+ ```
88
+
89
+ Then use `/models` in opencode to select a zerogpu model.
90
+
91
+ ## Supported Models
92
+
93
+ Any HuggingFace model that fits in ~70GB VRAM. Examples:
94
+
95
+ | Model | Size | Quantization |
96
+ |-------|------|--------------|
97
+ | `meta-llama/Llama-3.1-8B-Instruct` | 8B | None |
98
+ | `mistralai/Mistral-7B-Instruct-v0.3` | 7B | None |
99
+ | `Qwen/Qwen2.5-7B-Instruct` | 7B | None |
100
+ | `Qwen/Qwen2.5-14B-Instruct` | 14B | None |
101
+ | `Qwen/Qwen2.5-32B-Instruct` | 32B | None |
102
+ | `meta-llama/Llama-3.1-70B-Instruct` | 70B | INT4 (auto) |
103
+
104
+ Models larger than 34B are automatically quantized to INT4.
105
+
106
+ ## VRAM Guidelines
107
+
108
+ | Model Size | FP16 VRAM | INT8 VRAM | INT4 VRAM |
109
+ |------------|-----------|-----------|-----------|
110
+ | 7B | ~14GB | ~7GB | ~3.5GB |
111
+ | 13B | ~26GB | ~13GB | ~6.5GB |
112
+ | 34B | ~68GB | ~34GB | ~17GB |
113
+ | 70B | ~140GB | ~70GB | ~35GB |
114
+
115
+ *70B models require INT4 quantization. Add ~20% overhead for KV cache.*
116
+
117
+ ## Quota Information
118
+
119
+ - **PRO plan**: 25 minutes/day of H200 GPU compute
120
+ - **Priority**: PRO users get highest queue priority
121
+ - **Fallback**: When quota exhausted, falls back to HF Serverless Inference API
122
+
123
+ ## API Endpoints
124
+
125
+ | Endpoint | Method | Description |
126
+ |----------|--------|-------------|
127
+ | `/v1/chat/completions` | POST | Chat completion (OpenAI-compatible) |
128
+ | `/v1/models` | GET | List loaded models |
129
+ | `/health` | GET | Health check and quota status |
130
+
131
+ ## Local Development
132
+
133
+ ```bash
134
+ # Clone the repo
135
+ git clone https://huggingface.co/spaces/serenichron/opencode-zerogpu
136
+
137
+ # Install dependencies
138
+ pip install -r requirements.txt
139
+
140
+ # Run locally (ZeroGPU decorator no-ops)
141
+ python app.py
142
+ ```
143
+
144
+ ## Testing
145
+
146
+ ```bash
147
+ # Run tests
148
+ pytest tests/ -v
149
+
150
+ # Test the API locally
151
+ curl -X POST http://localhost:7860/v1/chat/completions \
152
+ -H "Content-Type: application/json" \
153
+ -H "Authorization: Bearer $HF_TOKEN" \
154
+ -d '{
155
+ "model": "mistralai/Mistral-7B-Instruct-v0.3",
156
+ "messages": [{"role": "user", "content": "Hello!"}],
157
+ "stream": false
158
+ }'
159
+ ```
160
+
161
+ ## Environment Variables
162
+
163
+ | Variable | Required | Description |
164
+ |----------|----------|-------------|
165
+ | `HF_TOKEN` | No* | Token for gated models (* Space uses its own token) |
166
+ | `FALLBACK_ENABLED` | No | Enable HF Serverless fallback (default: true) |
167
+ | `LOG_LEVEL` | No | Logging verbosity (default: INFO) |
168
+
169
+ ## License
170
+
171
+ MIT
app.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace ZeroGPU Space - OpenAI-compatible inference provider for opencode.
3
+
4
+ This Gradio app provides:
5
+ - OpenAI-compatible /v1/chat/completions endpoint
6
+ - Pass-through model selection (any HF model ID)
7
+ - ZeroGPU H200 inference with HF Serverless fallback
8
+ - HF Token authentication
9
+ - SSE streaming support
10
+ """
11
+
12
+ import logging
13
+ import time
14
+ from contextlib import asynccontextmanager
15
+ from typing import Optional
16
+
17
+ import gradio as gr
18
+ import httpx
19
+ from fastapi import FastAPI, Header, HTTPException, Request
20
+ from fastapi.responses import StreamingResponse, JSONResponse
21
+ from huggingface_hub import HfApi
22
+
23
+ # Import spaces conditionally (no-op when not on ZeroGPU)
24
+ try:
25
+ import spaces
26
+ ZEROGPU_AVAILABLE = True
27
+ except ImportError:
28
+ ZEROGPU_AVAILABLE = False
29
+
30
+ # Create a no-op decorator for local development
31
+ class spaces:
32
+ @staticmethod
33
+ def GPU(fn=None, duration=60):
34
+ if fn is None:
35
+ return lambda f: f
36
+ return fn
37
+
38
+ from config import get_config, get_quota_tracker
39
+ from models import (
40
+ apply_chat_template,
41
+ generate_text,
42
+ generate_text_stream,
43
+ get_current_model,
44
+ )
45
+ from openai_compat import (
46
+ ChatCompletionRequest,
47
+ InferenceParams,
48
+ create_chat_response,
49
+ create_error_response,
50
+ estimate_tokens,
51
+ stream_response_generator,
52
+ )
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ config = get_config()
57
+ quota_tracker = get_quota_tracker()
58
+
59
+ # HuggingFace API for token validation
60
+ hf_api = HfApi()
61
+
62
+
63
+ # --- Authentication ---
64
+
65
+
66
+ def validate_hf_token(token: str) -> bool:
67
+ """Validate a HuggingFace token by checking with the API."""
68
+ if not token or not token.startswith("hf_"):
69
+ return False
70
+
71
+ try:
72
+ hf_api.whoami(token=token)
73
+ return True
74
+ except Exception:
75
+ return False
76
+
77
+
78
+ def extract_token(authorization: Optional[str]) -> Optional[str]:
79
+ """Extract the token from the Authorization header."""
80
+ if not authorization:
81
+ return None
82
+
83
+ if authorization.startswith("Bearer "):
84
+ return authorization[7:]
85
+
86
+ return authorization
87
+
88
+
89
+ # --- ZeroGPU Inference ---
90
+
91
+
92
+ @spaces.GPU(duration=120)
93
+ def zerogpu_generate(
94
+ model_id: str,
95
+ prompt: str,
96
+ max_new_tokens: int,
97
+ temperature: float,
98
+ top_p: float,
99
+ stop_sequences: Optional[list[str]],
100
+ ) -> str:
101
+ """Generate text using ZeroGPU (H200 GPU)."""
102
+ start_time = time.time()
103
+
104
+ result = generate_text(
105
+ model_id=model_id,
106
+ prompt=prompt,
107
+ max_new_tokens=max_new_tokens,
108
+ temperature=temperature,
109
+ top_p=top_p,
110
+ stop_sequences=stop_sequences,
111
+ )
112
+
113
+ # Track quota usage
114
+ duration = time.time() - start_time
115
+ quota_tracker.add_usage(duration)
116
+
117
+ return result
118
+
119
+
120
+ @spaces.GPU(duration=120)
121
+ def zerogpu_generate_stream(
122
+ model_id: str,
123
+ prompt: str,
124
+ max_new_tokens: int,
125
+ temperature: float,
126
+ top_p: float,
127
+ stop_sequences: Optional[list[str]],
128
+ ):
129
+ """Generate text with streaming using ZeroGPU (H200 GPU)."""
130
+ start_time = time.time()
131
+
132
+ for token in generate_text_stream(
133
+ model_id=model_id,
134
+ prompt=prompt,
135
+ max_new_tokens=max_new_tokens,
136
+ temperature=temperature,
137
+ top_p=top_p,
138
+ stop_sequences=stop_sequences,
139
+ ):
140
+ yield token
141
+
142
+ # Track quota usage
143
+ duration = time.time() - start_time
144
+ quota_tracker.add_usage(duration)
145
+
146
+
147
+ # --- HF Serverless Fallback ---
148
+
149
+
150
+ async def serverless_generate(
151
+ model_id: str,
152
+ prompt: str,
153
+ max_new_tokens: int,
154
+ temperature: float,
155
+ top_p: float,
156
+ token: str,
157
+ ) -> str:
158
+ """Generate text using HuggingFace Serverless Inference API."""
159
+ url = f"https://api-inference.huggingface.co/models/{model_id}"
160
+
161
+ payload = {
162
+ "inputs": prompt,
163
+ "parameters": {
164
+ "max_new_tokens": max_new_tokens,
165
+ "temperature": temperature,
166
+ "top_p": top_p,
167
+ "return_full_text": False,
168
+ },
169
+ }
170
+
171
+ async with httpx.AsyncClient() as client:
172
+ response = await client.post(
173
+ url,
174
+ json=payload,
175
+ headers={"Authorization": f"Bearer {token}"},
176
+ timeout=120.0,
177
+ )
178
+
179
+ if response.status_code != 200:
180
+ raise HTTPException(
181
+ status_code=response.status_code,
182
+ detail=f"HF Serverless error: {response.text}",
183
+ )
184
+
185
+ result = response.json()
186
+
187
+ # Handle different response formats
188
+ if isinstance(result, list) and len(result) > 0:
189
+ if "generated_text" in result[0]:
190
+ return result[0]["generated_text"]
191
+
192
+ raise HTTPException(
193
+ status_code=500,
194
+ detail=f"Unexpected response format from HF Serverless: {result}",
195
+ )
196
+
197
+
198
+ # --- FastAPI App ---
199
+
200
+
201
+ @asynccontextmanager
202
+ async def lifespan(app: FastAPI):
203
+ """Application lifespan events."""
204
+ logger.info("Starting ZeroGPU OpenCode Provider")
205
+ logger.info(f"ZeroGPU available: {ZEROGPU_AVAILABLE}")
206
+ logger.info(f"Fallback enabled: {config.fallback_enabled}")
207
+ yield
208
+ logger.info("Shutting down ZeroGPU OpenCode Provider")
209
+
210
+
211
+ api = FastAPI(
212
+ title="ZeroGPU OpenCode Provider",
213
+ description="OpenAI-compatible API for HuggingFace models on ZeroGPU",
214
+ version="1.0.0",
215
+ lifespan=lifespan,
216
+ )
217
+
218
+
219
+ @api.post("/v1/chat/completions")
220
+ async def chat_completions(
221
+ request: ChatCompletionRequest,
222
+ authorization: Optional[str] = Header(None),
223
+ ):
224
+ """
225
+ OpenAI-compatible chat completions endpoint.
226
+
227
+ Supports both streaming and non-streaming responses.
228
+ """
229
+ # Validate authentication
230
+ token = extract_token(authorization)
231
+ if not token or not validate_hf_token(token):
232
+ return JSONResponse(
233
+ status_code=401,
234
+ content=create_error_response(
235
+ message="Invalid or missing HuggingFace token",
236
+ error_type="authentication_error",
237
+ code="invalid_api_key",
238
+ ).model_dump(),
239
+ )
240
+
241
+ # Extract inference parameters
242
+ params = InferenceParams.from_request(request)
243
+
244
+ # Apply chat template
245
+ try:
246
+ prompt = apply_chat_template(params.model_id, params.messages)
247
+ except Exception as e:
248
+ logger.error(f"Failed to apply chat template: {e}")
249
+ return JSONResponse(
250
+ status_code=400,
251
+ content=create_error_response(
252
+ message=f"Failed to load model or apply chat template: {str(e)}",
253
+ error_type="invalid_request_error",
254
+ param="model",
255
+ ).model_dump(),
256
+ )
257
+
258
+ prompt_tokens = estimate_tokens(prompt)
259
+
260
+ # Determine inference method
261
+ use_zerogpu = ZEROGPU_AVAILABLE and not quota_tracker.quota_exhausted
262
+
263
+ if not use_zerogpu and not config.fallback_enabled:
264
+ return JSONResponse(
265
+ status_code=503,
266
+ content=create_error_response(
267
+ message="ZeroGPU quota exhausted and fallback is disabled",
268
+ error_type="server_error",
269
+ code="quota_exhausted",
270
+ ).model_dump(),
271
+ )
272
+
273
+ try:
274
+ if params.stream:
275
+ # Streaming response
276
+ if use_zerogpu:
277
+ token_gen = zerogpu_generate_stream(
278
+ model_id=params.model_id,
279
+ prompt=prompt,
280
+ max_new_tokens=params.max_new_tokens,
281
+ temperature=params.temperature,
282
+ top_p=params.top_p,
283
+ stop_sequences=params.stop_sequences,
284
+ )
285
+ else:
286
+ # Fallback doesn't support streaming, so generate full response
287
+ # and simulate streaming
288
+ logger.info("Using HF Serverless fallback (no streaming)")
289
+ full_response = await serverless_generate(
290
+ model_id=params.model_id,
291
+ prompt=prompt,
292
+ max_new_tokens=params.max_new_tokens,
293
+ temperature=params.temperature,
294
+ top_p=params.top_p,
295
+ token=token,
296
+ )
297
+
298
+ def simulate_stream():
299
+ # Yield the full response as a single chunk
300
+ yield full_response
301
+
302
+ token_gen = simulate_stream()
303
+
304
+ return StreamingResponse(
305
+ stream_response_generator(params.model_id, token_gen),
306
+ media_type="text/event-stream",
307
+ headers={
308
+ "Cache-Control": "no-cache",
309
+ "Connection": "keep-alive",
310
+ "X-Accel-Buffering": "no",
311
+ },
312
+ )
313
+ else:
314
+ # Non-streaming response
315
+ if use_zerogpu:
316
+ response_text = zerogpu_generate(
317
+ model_id=params.model_id,
318
+ prompt=prompt,
319
+ max_new_tokens=params.max_new_tokens,
320
+ temperature=params.temperature,
321
+ top_p=params.top_p,
322
+ stop_sequences=params.stop_sequences,
323
+ )
324
+ else:
325
+ logger.info("Using HF Serverless fallback")
326
+ response_text = await serverless_generate(
327
+ model_id=params.model_id,
328
+ prompt=prompt,
329
+ max_new_tokens=params.max_new_tokens,
330
+ temperature=params.temperature,
331
+ top_p=params.top_p,
332
+ token=token,
333
+ )
334
+
335
+ completion_tokens = estimate_tokens(response_text)
336
+
337
+ return create_chat_response(
338
+ model=params.model_id,
339
+ content=response_text,
340
+ prompt_tokens=prompt_tokens,
341
+ completion_tokens=completion_tokens,
342
+ )
343
+
344
+ except Exception as e:
345
+ logger.exception(f"Inference error: {e}")
346
+ return JSONResponse(
347
+ status_code=500,
348
+ content=create_error_response(
349
+ message=f"Inference failed: {str(e)}",
350
+ error_type="server_error",
351
+ ).model_dump(),
352
+ )
353
+
354
+
355
+ @api.get("/v1/models")
356
+ async def list_models(authorization: Optional[str] = Header(None)):
357
+ """List available models (returns info about current model if loaded)."""
358
+ token = extract_token(authorization)
359
+ if not token or not validate_hf_token(token):
360
+ return JSONResponse(
361
+ status_code=401,
362
+ content=create_error_response(
363
+ message="Invalid or missing HuggingFace token",
364
+ error_type="authentication_error",
365
+ code="invalid_api_key",
366
+ ).model_dump(),
367
+ )
368
+
369
+ current = get_current_model()
370
+ models = []
371
+
372
+ if current:
373
+ models.append(
374
+ {
375
+ "id": current.model_id,
376
+ "object": "model",
377
+ "created": int(time.time()),
378
+ "owned_by": "huggingface",
379
+ }
380
+ )
381
+
382
+ return {"object": "list", "data": models}
383
+
384
+
385
+ @api.get("/health")
386
+ async def health_check():
387
+ """Health check endpoint."""
388
+ return {
389
+ "status": "healthy",
390
+ "zerogpu_available": ZEROGPU_AVAILABLE,
391
+ "quota_remaining_minutes": quota_tracker.remaining_minutes(),
392
+ "fallback_enabled": config.fallback_enabled,
393
+ }
394
+
395
+
396
+ # --- Gradio Interface ---
397
+
398
+
399
+ def gradio_chat(
400
+ message: str,
401
+ history: list[list[str]],
402
+ model_id: str,
403
+ temperature: float,
404
+ max_tokens: int,
405
+ ):
406
+ """Gradio chat interface handler."""
407
+ # Build messages from history
408
+ messages = []
409
+ for user_msg, assistant_msg in history:
410
+ messages.append({"role": "user", "content": user_msg})
411
+ if assistant_msg:
412
+ messages.append({"role": "assistant", "content": assistant_msg})
413
+ messages.append({"role": "user", "content": message})
414
+
415
+ # Apply chat template
416
+ prompt = apply_chat_template(model_id, messages)
417
+
418
+ # Generate response (streaming)
419
+ response = ""
420
+ for token in zerogpu_generate_stream(
421
+ model_id=model_id,
422
+ prompt=prompt,
423
+ max_new_tokens=max_tokens,
424
+ temperature=temperature,
425
+ top_p=0.95,
426
+ stop_sequences=None,
427
+ ):
428
+ response += token
429
+ yield response
430
+
431
+
432
+ # Gradio Blocks interface
433
+ with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
434
+ gr.Markdown(
435
+ """
436
+ # ZeroGPU OpenCode Provider
437
+
438
+ OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
439
+
440
+ **API Endpoint:** `/v1/chat/completions`
441
+
442
+ ## Usage with opencode
443
+
444
+ Configure in `~/.config/opencode/opencode.json`:
445
+
446
+ ```json
447
+ {
448
+ "providers": {
449
+ "zerogpu": {
450
+ "npm": "@ai-sdk/openai-compatible",
451
+ "options": {
452
+ "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
453
+ "headers": {
454
+ "Authorization": "Bearer hf_YOUR_TOKEN"
455
+ }
456
+ },
457
+ "models": {
458
+ "llama-8b": {
459
+ "name": "meta-llama/Llama-3.1-8B-Instruct"
460
+ }
461
+ }
462
+ }
463
+ }
464
+ }
465
+ ```
466
+
467
+ ---
468
+ """
469
+ )
470
+
471
+ with gr.Row():
472
+ with gr.Column(scale=1):
473
+ model_dropdown = gr.Dropdown(
474
+ label="Model",
475
+ choices=[
476
+ "meta-llama/Llama-3.1-8B-Instruct",
477
+ "mistralai/Mistral-7B-Instruct-v0.3",
478
+ "Qwen/Qwen2.5-7B-Instruct",
479
+ "Qwen/Qwen2.5-14B-Instruct",
480
+ ],
481
+ value="meta-llama/Llama-3.1-8B-Instruct",
482
+ allow_custom_value=True,
483
+ )
484
+ temperature_slider = gr.Slider(
485
+ label="Temperature",
486
+ minimum=0.0,
487
+ maximum=2.0,
488
+ value=0.7,
489
+ step=0.1,
490
+ )
491
+ max_tokens_slider = gr.Slider(
492
+ label="Max Tokens",
493
+ minimum=64,
494
+ maximum=4096,
495
+ value=512,
496
+ step=64,
497
+ )
498
+
499
+ gr.Markdown(
500
+ f"""
501
+ ### Status
502
+ - **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
503
+ - **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
504
+ """
505
+ )
506
+
507
+ with gr.Column(scale=3):
508
+ chatbot = gr.ChatInterface(
509
+ fn=gradio_chat,
510
+ additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
511
+ title="",
512
+ examples=[
513
+ ["Hello! How are you?"],
514
+ ["Explain quantum computing in simple terms."],
515
+ ["Write a Python function to calculate fibonacci numbers."],
516
+ ],
517
+ )
518
+
519
+ # Mount FastAPI to Gradio
520
+ demo = gr.mount_gradio_app(demo, api, path="/")
521
+
522
+
523
+ if __name__ == "__main__":
524
+ demo.launch(server_name="0.0.0.0", server_port=7860)
config.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration and environment handling for ZeroGPU Space."""
2
+
3
+ import os
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class Config:
16
+ """Application configuration loaded from environment."""
17
+
18
+ # HuggingFace token for gated models
19
+ hf_token: Optional[str] = field(default_factory=lambda: os.getenv("HF_TOKEN"))
20
+
21
+ # Fallback to HF Serverless when ZeroGPU quota exhausted
22
+ fallback_enabled: bool = field(
23
+ default_factory=lambda: os.getenv("FALLBACK_ENABLED", "true").lower() == "true"
24
+ )
25
+
26
+ # Logging level
27
+ log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO"))
28
+
29
+ # Quantization settings
30
+ default_quantization: str = field(
31
+ default_factory=lambda: os.getenv("DEFAULT_QUANTIZATION", "none")
32
+ )
33
+ auto_quantize_threshold_b: int = field(
34
+ default_factory=lambda: int(os.getenv("AUTO_QUANTIZE_THRESHOLD_B", "34"))
35
+ )
36
+
37
+ def __post_init__(self):
38
+ """Configure logging after initialization."""
39
+ logging.basicConfig(
40
+ level=getattr(logging, self.log_level.upper(), logging.INFO),
41
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class QuotaTracker:
47
+ """Track ZeroGPU quota usage for the current session."""
48
+
49
+ # Total seconds used in current day
50
+ seconds_used: float = 0.0
51
+
52
+ # Daily quota in seconds (PRO plan: 25 min = 1500 sec)
53
+ daily_quota_seconds: float = 1500.0
54
+
55
+ # Whether quota is exhausted
56
+ quota_exhausted: bool = False
57
+
58
+ def add_usage(self, seconds: float) -> None:
59
+ """Record GPU usage time."""
60
+ self.seconds_used += seconds
61
+ if self.seconds_used >= self.daily_quota_seconds:
62
+ self.quota_exhausted = True
63
+ logger.warning(
64
+ f"ZeroGPU quota exhausted: {self.seconds_used:.1f}s / {self.daily_quota_seconds:.1f}s"
65
+ )
66
+
67
+ def remaining_seconds(self) -> float:
68
+ """Get remaining quota in seconds."""
69
+ return max(0, self.daily_quota_seconds - self.seconds_used)
70
+
71
+ def remaining_minutes(self) -> float:
72
+ """Get remaining quota in minutes."""
73
+ return self.remaining_seconds() / 60.0
74
+
75
+ def reset(self) -> None:
76
+ """Reset quota (called at day boundary)."""
77
+ self.seconds_used = 0.0
78
+ self.quota_exhausted = False
79
+ logger.info("ZeroGPU quota reset")
80
+
81
+
82
+ # Global configuration instance
83
+ config = Config()
84
+
85
+ # Global quota tracker
86
+ quota_tracker = QuotaTracker()
87
+
88
+
89
+ def get_config() -> Config:
90
+ """Get the global configuration instance."""
91
+ return config
92
+
93
+
94
+ def get_quota_tracker() -> QuotaTracker:
95
+ """Get the global quota tracker instance."""
96
+ return quota_tracker
97
+
98
+
99
+ # Model size estimates (parameters in billions)
100
+ MODEL_SIZE_ESTIMATES = {
101
+ # Llama family
102
+ "meta-llama/Llama-3.1-8B-Instruct": 8,
103
+ "meta-llama/Llama-3.1-70B-Instruct": 70,
104
+ "meta-llama/Llama-3.2-1B-Instruct": 1,
105
+ "meta-llama/Llama-3.2-3B-Instruct": 3,
106
+
107
+ # Mistral family
108
+ "mistralai/Mistral-7B-Instruct-v0.3": 7,
109
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": 47, # MoE effective
110
+
111
+ # Qwen family
112
+ "Qwen/Qwen2.5-7B-Instruct": 7,
113
+ "Qwen/Qwen2.5-14B-Instruct": 14,
114
+ "Qwen/Qwen2.5-32B-Instruct": 32,
115
+ "Qwen/Qwen2.5-72B-Instruct": 72,
116
+ }
117
+
118
+
119
+ def estimate_model_size(model_id: str) -> Optional[int]:
120
+ """
121
+ Estimate model size in billions of parameters from model ID.
122
+
123
+ Returns None if size cannot be determined.
124
+ """
125
+ # Check known models first
126
+ if model_id in MODEL_SIZE_ESTIMATES:
127
+ return MODEL_SIZE_ESTIMATES[model_id]
128
+
129
+ # Try to extract size from model name (e.g., "7B", "70B", "14B")
130
+ import re
131
+ match = re.search(r"(\d+)B", model_id, re.IGNORECASE)
132
+ if match:
133
+ return int(match.group(1))
134
+
135
+ return None
136
+
137
+
138
+ def should_quantize(model_id: str) -> str:
139
+ """
140
+ Determine if a model should be quantized and which method to use.
141
+
142
+ Returns: "none", "int8", or "int4"
143
+ """
144
+ if config.default_quantization != "none":
145
+ return config.default_quantization
146
+
147
+ size = estimate_model_size(model_id)
148
+ if size is None:
149
+ # Unknown size, don't auto-quantize
150
+ return "none"
151
+
152
+ if size > 65:
153
+ # 70B+ models need INT4 to fit in 70GB VRAM
154
+ return "int4"
155
+ elif size > config.auto_quantize_threshold_b:
156
+ # Large models get INT8
157
+ return "int8"
158
+
159
+ return "none"
models.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading, caching, and memory management for ZeroGPU inference."""
2
+
3
+ import gc
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional, Generator, Any
7
+
8
+ import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ BitsAndBytesConfig,
13
+ TextIteratorStreamer,
14
+ )
15
+ from threading import Thread
16
+
17
+ from config import get_config, should_quantize
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class LoadedModel:
24
+ """Container for a loaded model and its tokenizer."""
25
+
26
+ model_id: str
27
+ model: Any
28
+ tokenizer: Any
29
+ quantization: str = "none"
30
+
31
+
32
+ # Global model cache (single model at a time due to memory constraints)
33
+ _current_model: Optional[LoadedModel] = None
34
+
35
+
36
+ def get_quantization_config(quantization: str) -> Optional[BitsAndBytesConfig]:
37
+ """Get BitsAndBytes configuration for the specified quantization level."""
38
+ if quantization == "int8":
39
+ return BitsAndBytesConfig(load_in_8bit=True)
40
+ elif quantization == "int4":
41
+ return BitsAndBytesConfig(
42
+ load_in_4bit=True,
43
+ bnb_4bit_compute_dtype=torch.bfloat16,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4",
46
+ )
47
+ return None
48
+
49
+
50
+ def clear_gpu_memory() -> None:
51
+ """Clear GPU memory by running garbage collection and emptying CUDA cache."""
52
+ gc.collect()
53
+ if torch.cuda.is_available():
54
+ torch.cuda.empty_cache()
55
+ torch.cuda.synchronize()
56
+ logger.debug("GPU memory cleared")
57
+
58
+
59
+ def unload_model() -> None:
60
+ """Unload the currently loaded model and free memory."""
61
+ global _current_model
62
+
63
+ if _current_model is not None:
64
+ logger.info(f"Unloading model: {_current_model.model_id}")
65
+ del _current_model.model
66
+ del _current_model.tokenizer
67
+ _current_model = None
68
+ clear_gpu_memory()
69
+
70
+
71
+ def load_model(
72
+ model_id: str,
73
+ quantization: Optional[str] = None,
74
+ force_reload: bool = False,
75
+ ) -> LoadedModel:
76
+ """
77
+ Load a model from HuggingFace Hub.
78
+
79
+ Args:
80
+ model_id: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
81
+ quantization: Force specific quantization ("none", "int8", "int4")
82
+ If None, auto-determine based on model size
83
+ force_reload: If True, reload even if already loaded
84
+
85
+ Returns:
86
+ LoadedModel with model and tokenizer ready for inference
87
+ """
88
+ global _current_model
89
+
90
+ # Check if already loaded
91
+ if not force_reload and _current_model is not None:
92
+ if _current_model.model_id == model_id:
93
+ logger.debug(f"Model already loaded: {model_id}")
94
+ return _current_model
95
+
96
+ # Determine quantization
97
+ if quantization is None:
98
+ quantization = should_quantize(model_id)
99
+
100
+ logger.info(f"Loading model: {model_id} (quantization: {quantization})")
101
+
102
+ # Unload current model first
103
+ unload_model()
104
+
105
+ config = get_config()
106
+
107
+ # Load tokenizer
108
+ tokenizer = AutoTokenizer.from_pretrained(
109
+ model_id,
110
+ token=config.hf_token,
111
+ trust_remote_code=True,
112
+ )
113
+
114
+ # Ensure tokenizer has pad token
115
+ if tokenizer.pad_token is None:
116
+ tokenizer.pad_token = tokenizer.eos_token
117
+
118
+ # Load model with appropriate configuration
119
+ quant_config = get_quantization_config(quantization)
120
+
121
+ model_kwargs = {
122
+ "token": config.hf_token,
123
+ "trust_remote_code": True,
124
+ "device_map": "auto",
125
+ }
126
+
127
+ if quant_config is not None:
128
+ model_kwargs["quantization_config"] = quant_config
129
+ else:
130
+ model_kwargs["torch_dtype"] = torch.bfloat16
131
+
132
+ model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
133
+
134
+ _current_model = LoadedModel(
135
+ model_id=model_id,
136
+ model=model,
137
+ tokenizer=tokenizer,
138
+ quantization=quantization,
139
+ )
140
+
141
+ logger.info(f"Model loaded successfully: {model_id}")
142
+ return _current_model
143
+
144
+
145
+ def get_current_model() -> Optional[LoadedModel]:
146
+ """Get the currently loaded model, if any."""
147
+ return _current_model
148
+
149
+
150
+ def generate_text(
151
+ model_id: str,
152
+ prompt: str,
153
+ max_new_tokens: int = 512,
154
+ temperature: float = 0.7,
155
+ top_p: float = 0.95,
156
+ top_k: int = 50,
157
+ repetition_penalty: float = 1.1,
158
+ stop_sequences: Optional[list[str]] = None,
159
+ ) -> str:
160
+ """
161
+ Generate text using the specified model.
162
+
163
+ Args:
164
+ model_id: HuggingFace model ID
165
+ prompt: Input prompt (already formatted with chat template)
166
+ max_new_tokens: Maximum tokens to generate
167
+ temperature: Sampling temperature
168
+ top_p: Nucleus sampling probability
169
+ top_k: Top-k sampling parameter
170
+ repetition_penalty: Penalty for repeating tokens
171
+ stop_sequences: Additional stop sequences
172
+
173
+ Returns:
174
+ Generated text (without the input prompt)
175
+ """
176
+ loaded = load_model(model_id)
177
+
178
+ inputs = loaded.tokenizer(
179
+ prompt,
180
+ return_tensors="pt",
181
+ truncation=True,
182
+ max_length=loaded.tokenizer.model_max_length - max_new_tokens,
183
+ )
184
+
185
+ if torch.cuda.is_available():
186
+ inputs = {k: v.cuda() for k, v in inputs.items()}
187
+
188
+ # Build generation config
189
+ gen_kwargs = {
190
+ "max_new_tokens": max_new_tokens,
191
+ "temperature": temperature,
192
+ "top_p": top_p,
193
+ "top_k": top_k,
194
+ "repetition_penalty": repetition_penalty,
195
+ "do_sample": temperature > 0,
196
+ "pad_token_id": loaded.tokenizer.pad_token_id,
197
+ "eos_token_id": loaded.tokenizer.eos_token_id,
198
+ }
199
+
200
+ with torch.no_grad():
201
+ outputs = loaded.model.generate(**inputs, **gen_kwargs)
202
+
203
+ # Decode only the new tokens
204
+ input_length = inputs["input_ids"].shape[1]
205
+ generated_tokens = outputs[0][input_length:]
206
+ response = loaded.tokenizer.decode(generated_tokens, skip_special_tokens=True)
207
+
208
+ # Handle stop sequences
209
+ if stop_sequences:
210
+ for stop_seq in stop_sequences:
211
+ if stop_seq in response:
212
+ response = response.split(stop_seq)[0]
213
+
214
+ return response
215
+
216
+
217
+ def generate_text_stream(
218
+ model_id: str,
219
+ prompt: str,
220
+ max_new_tokens: int = 512,
221
+ temperature: float = 0.7,
222
+ top_p: float = 0.95,
223
+ top_k: int = 50,
224
+ repetition_penalty: float = 1.1,
225
+ stop_sequences: Optional[list[str]] = None,
226
+ ) -> Generator[str, None, None]:
227
+ """
228
+ Generate text using streaming output.
229
+
230
+ Yields tokens as they are generated.
231
+ """
232
+ loaded = load_model(model_id)
233
+
234
+ inputs = loaded.tokenizer(
235
+ prompt,
236
+ return_tensors="pt",
237
+ truncation=True,
238
+ max_length=loaded.tokenizer.model_max_length - max_new_tokens,
239
+ )
240
+
241
+ if torch.cuda.is_available():
242
+ inputs = {k: v.cuda() for k, v in inputs.items()}
243
+
244
+ # Create streamer
245
+ streamer = TextIteratorStreamer(
246
+ loaded.tokenizer,
247
+ skip_prompt=True,
248
+ skip_special_tokens=True,
249
+ )
250
+
251
+ # Build generation config
252
+ gen_kwargs = {
253
+ **inputs,
254
+ "max_new_tokens": max_new_tokens,
255
+ "temperature": temperature,
256
+ "top_p": top_p,
257
+ "top_k": top_k,
258
+ "repetition_penalty": repetition_penalty,
259
+ "do_sample": temperature > 0,
260
+ "pad_token_id": loaded.tokenizer.pad_token_id,
261
+ "eos_token_id": loaded.tokenizer.eos_token_id,
262
+ "streamer": streamer,
263
+ }
264
+
265
+ # Run generation in separate thread
266
+ thread = Thread(target=loaded.model.generate, kwargs=gen_kwargs)
267
+ thread.start()
268
+
269
+ # Stream tokens
270
+ accumulated = ""
271
+ for token in streamer:
272
+ accumulated += token
273
+
274
+ # Check for stop sequences
275
+ should_stop = False
276
+ if stop_sequences:
277
+ for stop_seq in stop_sequences:
278
+ if stop_seq in accumulated:
279
+ # Yield everything before the stop sequence
280
+ before_stop = accumulated.split(stop_seq)[0]
281
+ if before_stop:
282
+ yield before_stop[len(accumulated) - len(token):]
283
+ should_stop = True
284
+ break
285
+
286
+ if should_stop:
287
+ break
288
+
289
+ yield token
290
+
291
+ thread.join()
292
+
293
+
294
+ def apply_chat_template(
295
+ model_id: str,
296
+ messages: list[dict[str, str]],
297
+ add_generation_prompt: bool = True,
298
+ ) -> str:
299
+ """
300
+ Apply the model's chat template to format messages.
301
+
302
+ Args:
303
+ model_id: HuggingFace model ID
304
+ messages: List of message dicts with "role" and "content"
305
+ add_generation_prompt: Whether to add the generation prompt
306
+
307
+ Returns:
308
+ Formatted prompt string
309
+ """
310
+ loaded = load_model(model_id)
311
+
312
+ # Check if tokenizer has chat template
313
+ if hasattr(loaded.tokenizer, "apply_chat_template"):
314
+ return loaded.tokenizer.apply_chat_template(
315
+ messages,
316
+ tokenize=False,
317
+ add_generation_prompt=add_generation_prompt,
318
+ )
319
+
320
+ # Fallback: simple formatting
321
+ prompt_parts = []
322
+ for msg in messages:
323
+ role = msg["role"]
324
+ content = msg["content"]
325
+ if role == "system":
326
+ prompt_parts.append(f"System: {content}\n")
327
+ elif role == "user":
328
+ prompt_parts.append(f"User: {content}\n")
329
+ elif role == "assistant":
330
+ prompt_parts.append(f"Assistant: {content}\n")
331
+
332
+ if add_generation_prompt:
333
+ prompt_parts.append("Assistant:")
334
+
335
+ return "".join(prompt_parts)
openai_compat.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI-compatible API request/response format handling."""
2
+
3
+ import time
4
+ import uuid
5
+ import json
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional, Generator, Literal
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ # --- Request Models ---
16
+
17
+
18
+ class ChatMessage(BaseModel):
19
+ """A single message in the conversation."""
20
+
21
+ role: Literal["system", "user", "assistant"]
22
+ content: str
23
+
24
+
25
+ class ChatCompletionRequest(BaseModel):
26
+ """OpenAI-compatible chat completion request."""
27
+
28
+ model: str = Field(..., description="HuggingFace model ID")
29
+ messages: list[ChatMessage]
30
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
31
+ top_p: float = Field(default=0.95, ge=0.0, le=1.0)
32
+ max_tokens: Optional[int] = Field(default=512, ge=1, le=8192)
33
+ stream: bool = False
34
+ stop: Optional[list[str]] = None
35
+ presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
36
+ frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
37
+ n: int = Field(default=1, ge=1, le=1) # Only support n=1 for now
38
+ user: Optional[str] = None
39
+
40
+
41
+ # --- Response Models ---
42
+
43
+
44
+ class ChatCompletionChoice(BaseModel):
45
+ """A single completion choice."""
46
+
47
+ index: int
48
+ message: ChatMessage
49
+ finish_reason: Literal["stop", "length", "content_filter"] = "stop"
50
+
51
+
52
+ class ChatCompletionUsage(BaseModel):
53
+ """Token usage statistics."""
54
+
55
+ prompt_tokens: int
56
+ completion_tokens: int
57
+ total_tokens: int
58
+
59
+
60
+ class ChatCompletionResponse(BaseModel):
61
+ """OpenAI-compatible chat completion response."""
62
+
63
+ id: str
64
+ object: str = "chat.completion"
65
+ created: int
66
+ model: str
67
+ choices: list[ChatCompletionChoice]
68
+ usage: ChatCompletionUsage
69
+
70
+
71
+ # --- Streaming Response Models ---
72
+
73
+
74
+ class DeltaMessage(BaseModel):
75
+ """Delta content for streaming responses."""
76
+
77
+ role: Optional[str] = None
78
+ content: Optional[str] = None
79
+
80
+
81
+ class StreamChoice(BaseModel):
82
+ """A single streaming choice."""
83
+
84
+ index: int
85
+ delta: DeltaMessage
86
+ finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
87
+
88
+
89
+ class ChatCompletionChunk(BaseModel):
90
+ """OpenAI-compatible streaming chunk."""
91
+
92
+ id: str
93
+ object: str = "chat.completion.chunk"
94
+ created: int
95
+ model: str
96
+ choices: list[StreamChoice]
97
+
98
+
99
+ # --- Helper Functions ---
100
+
101
+
102
+ def generate_completion_id() -> str:
103
+ """Generate a unique completion ID."""
104
+ return f"chatcmpl-{uuid.uuid4().hex[:24]}"
105
+
106
+
107
+ def create_chat_response(
108
+ model: str,
109
+ content: str,
110
+ prompt_tokens: int = 0,
111
+ completion_tokens: int = 0,
112
+ finish_reason: str = "stop",
113
+ ) -> ChatCompletionResponse:
114
+ """Create a complete chat completion response."""
115
+ return ChatCompletionResponse(
116
+ id=generate_completion_id(),
117
+ created=int(time.time()),
118
+ model=model,
119
+ choices=[
120
+ ChatCompletionChoice(
121
+ index=0,
122
+ message=ChatMessage(role="assistant", content=content),
123
+ finish_reason=finish_reason,
124
+ )
125
+ ],
126
+ usage=ChatCompletionUsage(
127
+ prompt_tokens=prompt_tokens,
128
+ completion_tokens=completion_tokens,
129
+ total_tokens=prompt_tokens + completion_tokens,
130
+ ),
131
+ )
132
+
133
+
134
+ def create_stream_chunk(
135
+ completion_id: str,
136
+ model: str,
137
+ content: Optional[str] = None,
138
+ role: Optional[str] = None,
139
+ finish_reason: Optional[str] = None,
140
+ ) -> ChatCompletionChunk:
141
+ """Create a single streaming chunk."""
142
+ return ChatCompletionChunk(
143
+ id=completion_id,
144
+ created=int(time.time()),
145
+ model=model,
146
+ choices=[
147
+ StreamChoice(
148
+ index=0,
149
+ delta=DeltaMessage(role=role, content=content),
150
+ finish_reason=finish_reason,
151
+ )
152
+ ],
153
+ )
154
+
155
+
156
+ def stream_response_generator(
157
+ model: str,
158
+ token_generator: Generator[str, None, None],
159
+ ) -> Generator[str, None, None]:
160
+ """
161
+ Convert a token generator to SSE-formatted streaming response.
162
+
163
+ Yields SSE-formatted strings ready for HTTP streaming.
164
+ """
165
+ completion_id = generate_completion_id()
166
+
167
+ # First chunk: role
168
+ first_chunk = create_stream_chunk(
169
+ completion_id=completion_id,
170
+ model=model,
171
+ role="assistant",
172
+ )
173
+ yield f"data: {first_chunk.model_dump_json()}\n\n"
174
+
175
+ # Content chunks
176
+ for token in token_generator:
177
+ chunk = create_stream_chunk(
178
+ completion_id=completion_id,
179
+ model=model,
180
+ content=token,
181
+ )
182
+ yield f"data: {chunk.model_dump_json()}\n\n"
183
+
184
+ # Final chunk: finish reason
185
+ final_chunk = create_stream_chunk(
186
+ completion_id=completion_id,
187
+ model=model,
188
+ finish_reason="stop",
189
+ )
190
+ yield f"data: {final_chunk.model_dump_json()}\n\n"
191
+
192
+ # End marker
193
+ yield "data: [DONE]\n\n"
194
+
195
+
196
+ def messages_to_dicts(messages: list[ChatMessage]) -> list[dict[str, str]]:
197
+ """Convert Pydantic ChatMessage objects to simple dicts."""
198
+ return [{"role": msg.role, "content": msg.content} for msg in messages]
199
+
200
+
201
+ def estimate_tokens(text: str) -> int:
202
+ """
203
+ Rough token count estimation.
204
+
205
+ This is a simple approximation - actual token count depends on the tokenizer.
206
+ Rule of thumb: ~4 characters per token for English text.
207
+ """
208
+ return max(1, len(text) // 4)
209
+
210
+
211
+ @dataclass
212
+ class InferenceParams:
213
+ """Extracted inference parameters from request."""
214
+
215
+ model_id: str
216
+ messages: list[dict[str, str]]
217
+ max_new_tokens: int
218
+ temperature: float
219
+ top_p: float
220
+ stop_sequences: Optional[list[str]]
221
+ stream: bool
222
+
223
+ @classmethod
224
+ def from_request(cls, request: ChatCompletionRequest) -> "InferenceParams":
225
+ """Extract inference parameters from an OpenAI-compatible request."""
226
+ return cls(
227
+ model_id=request.model,
228
+ messages=messages_to_dicts(request.messages),
229
+ max_new_tokens=request.max_tokens or 512,
230
+ temperature=request.temperature,
231
+ top_p=request.top_p,
232
+ stop_sequences=request.stop,
233
+ stream=request.stream,
234
+ )
235
+
236
+
237
+ # --- Error Responses ---
238
+
239
+
240
+ class ErrorDetail(BaseModel):
241
+ """Error detail for API error responses."""
242
+
243
+ message: str
244
+ type: str
245
+ param: Optional[str] = None
246
+ code: Optional[str] = None
247
+
248
+
249
+ class ErrorResponse(BaseModel):
250
+ """OpenAI-compatible error response."""
251
+
252
+ error: ErrorDetail
253
+
254
+
255
+ def create_error_response(
256
+ message: str,
257
+ error_type: str = "invalid_request_error",
258
+ param: Optional[str] = None,
259
+ code: Optional[str] = None,
260
+ ) -> ErrorResponse:
261
+ """Create an error response."""
262
+ return ErrorResponse(
263
+ error=ErrorDetail(
264
+ message=message,
265
+ type=error_type,
266
+ param=param,
267
+ code=code,
268
+ )
269
+ )
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace ZeroGPU Space - OpenCode Provider
2
+ # For ZeroGPU H200 inference with OpenAI-compatible API
3
+
4
+ # Core Framework
5
+ gradio>=4.44.0
6
+ spaces>=0.30.0
7
+
8
+ # ML/Inference
9
+ torch>=2.0.0
10
+ transformers>=4.45.0
11
+ accelerate>=0.34.0
12
+ bitsandbytes>=0.44.0
13
+ safetensors>=0.4.0
14
+
15
+ # HuggingFace Integration
16
+ huggingface-hub>=0.25.0
17
+
18
+ # API
19
+ fastapi>=0.115.0
20
+ uvicorn>=0.30.0
21
+ pydantic>=2.0.0
22
+ httpx>=0.27.0
23
+ sse-starlette>=2.1.0
24
+
25
+ # Utilities
26
+ python-dotenv>=1.0.0
27
+ typing-extensions>=4.12.0
28
+
29
+ # Testing (dev)
30
+ pytest>=8.0.0
31
+ pytest-asyncio>=0.24.0
32
+ pytest-cov>=5.0.0
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Tests package for ZeroGPU OpenCode Provider
tests/conftest.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test fixtures for ZeroGPU OpenCode Provider tests."""
2
+
3
+ import pytest
4
+ from unittest.mock import MagicMock, patch
5
+
6
+
7
+ @pytest.fixture
8
+ def mock_tokenizer():
9
+ """Create a mock tokenizer for testing."""
10
+ tokenizer = MagicMock()
11
+ tokenizer.pad_token = None
12
+ tokenizer.eos_token = "</s>"
13
+ tokenizer.pad_token_id = 0
14
+ tokenizer.eos_token_id = 2
15
+ tokenizer.model_max_length = 4096
16
+
17
+ def mock_apply_chat_template(messages, tokenize=False, add_generation_prompt=True):
18
+ parts = []
19
+ for msg in messages:
20
+ role = msg.get("role", msg.role if hasattr(msg, "role") else "user")
21
+ content = msg.get("content", msg.content if hasattr(msg, "content") else "")
22
+ if role == "system":
23
+ parts.append(f"<|system|>{content}</s>")
24
+ elif role == "user":
25
+ parts.append(f"<|user|>{content}</s>")
26
+ elif role == "assistant":
27
+ parts.append(f"<|assistant|>{content}</s>")
28
+ if add_generation_prompt:
29
+ parts.append("<|assistant|>")
30
+ return "".join(parts)
31
+
32
+ tokenizer.apply_chat_template = mock_apply_chat_template
33
+
34
+ def mock_call(text, return_tensors=None, truncation=True, max_length=None):
35
+ import torch
36
+ # Simple mock: return input_ids based on text length
37
+ token_count = max(1, len(text) // 4)
38
+ return {
39
+ "input_ids": torch.ones((1, token_count), dtype=torch.long),
40
+ "attention_mask": torch.ones((1, token_count), dtype=torch.long),
41
+ }
42
+
43
+ tokenizer.__call__ = mock_call
44
+ tokenizer.return_value = mock_call("test")
45
+
46
+ def mock_decode(tokens, skip_special_tokens=True):
47
+ return "This is a test response."
48
+
49
+ tokenizer.decode = mock_decode
50
+
51
+ return tokenizer
52
+
53
+
54
+ @pytest.fixture
55
+ def mock_model():
56
+ """Create a mock model for testing."""
57
+ import torch
58
+
59
+ model = MagicMock()
60
+
61
+ def mock_generate(**kwargs):
62
+ input_ids = kwargs.get("input_ids", torch.ones((1, 10), dtype=torch.long))
63
+ input_length = input_ids.shape[1]
64
+ # Generate some tokens
65
+ generated = torch.ones((1, input_length + 20), dtype=torch.long)
66
+ return generated
67
+
68
+ model.generate = mock_generate
69
+ model.device = "cpu"
70
+
71
+ return model
72
+
73
+
74
+ @pytest.fixture
75
+ def sample_messages():
76
+ """Sample chat messages for testing."""
77
+ return [
78
+ {"role": "system", "content": "You are a helpful assistant."},
79
+ {"role": "user", "content": "Hello!"},
80
+ ]
81
+
82
+
83
+ @pytest.fixture
84
+ def sample_request_data():
85
+ """Sample request data for OpenAI-compatible endpoint."""
86
+ return {
87
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
88
+ "messages": [
89
+ {"role": "system", "content": "You are a helpful assistant."},
90
+ {"role": "user", "content": "Hello!"},
91
+ ],
92
+ "temperature": 0.7,
93
+ "max_tokens": 512,
94
+ "stream": False,
95
+ }
96
+
97
+
98
+ @pytest.fixture
99
+ def sample_streaming_request_data():
100
+ """Sample streaming request data."""
101
+ return {
102
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
103
+ "messages": [
104
+ {"role": "user", "content": "Tell me a joke."},
105
+ ],
106
+ "temperature": 0.7,
107
+ "max_tokens": 256,
108
+ "stream": True,
109
+ }
110
+
111
+
112
+ @pytest.fixture(autouse=True)
113
+ def mock_torch_cuda():
114
+ """Mock CUDA availability for tests."""
115
+ with patch("torch.cuda.is_available", return_value=False):
116
+ yield
tests/test_models.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for model loading and inference."""
2
+
3
+ import pytest
4
+ from unittest.mock import patch, MagicMock
5
+ import sys
6
+ import os
7
+
8
+ # Add parent directory to path for imports
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+
11
+ from config import estimate_model_size, should_quantize
12
+
13
+
14
+ class TestModelSizeEstimation:
15
+ """Test model size estimation logic."""
16
+
17
+ def test_known_model_size(self):
18
+ """Test size estimation for known models."""
19
+ assert estimate_model_size("meta-llama/Llama-3.1-8B-Instruct") == 8
20
+ assert estimate_model_size("meta-llama/Llama-3.1-70B-Instruct") == 70
21
+ assert estimate_model_size("mistralai/Mistral-7B-Instruct-v0.3") == 7
22
+
23
+ def test_extract_size_from_name(self):
24
+ """Test size extraction from model name pattern."""
25
+ assert estimate_model_size("some-org/CustomModel-13B") == 13
26
+ assert estimate_model_size("another/model-2B-test") == 2
27
+ assert estimate_model_size("org/Model-32B-Instruct") == 32
28
+
29
+ def test_unknown_model_size(self):
30
+ """Test handling of models with unknown size."""
31
+ assert estimate_model_size("unknown/model-without-size") is None
32
+ assert estimate_model_size("org/mystery-model") is None
33
+
34
+
35
+ class TestQuantizationDecision:
36
+ """Test automatic quantization decisions."""
37
+
38
+ def test_small_model_no_quantization(self):
39
+ """Small models should not be quantized."""
40
+ assert should_quantize("meta-llama/Llama-3.1-8B-Instruct") == "none"
41
+ assert should_quantize("mistralai/Mistral-7B-Instruct-v0.3") == "none"
42
+
43
+ def test_large_model_int4_quantization(self):
44
+ """70B+ models should use INT4."""
45
+ assert should_quantize("meta-llama/Llama-3.1-70B-Instruct") == "int4"
46
+ assert should_quantize("Qwen/Qwen2.5-72B-Instruct") == "int4"
47
+
48
+ def test_unknown_model_no_quantization(self):
49
+ """Unknown models should not be auto-quantized."""
50
+ assert should_quantize("unknown/mystery-model") == "none"
51
+
52
+
53
+ class TestModelLoading:
54
+ """Test model loading functionality."""
55
+
56
+ @patch("models.AutoModelForCausalLM")
57
+ @patch("models.AutoTokenizer")
58
+ def test_load_model_creates_loaded_model(
59
+ self, mock_tokenizer_class, mock_model_class, mock_tokenizer, mock_model
60
+ ):
61
+ """Test that load_model returns a LoadedModel instance."""
62
+ mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
63
+ mock_model_class.from_pretrained.return_value = mock_model
64
+
65
+ from models import load_model, unload_model
66
+
67
+ # Ensure clean state
68
+ unload_model()
69
+
70
+ loaded = load_model("test-model/test-7B")
71
+
72
+ assert loaded.model_id == "test-model/test-7B"
73
+ assert loaded.model is not None
74
+ assert loaded.tokenizer is not None
75
+
76
+ @patch("models.AutoModelForCausalLM")
77
+ @patch("models.AutoTokenizer")
78
+ def test_load_model_caches_result(
79
+ self, mock_tokenizer_class, mock_model_class, mock_tokenizer, mock_model
80
+ ):
81
+ """Test that loading the same model twice uses cache."""
82
+ mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
83
+ mock_model_class.from_pretrained.return_value = mock_model
84
+
85
+ from models import load_model, unload_model
86
+
87
+ # Ensure clean state
88
+ unload_model()
89
+
90
+ # First load
91
+ load_model("test-model/test-7B")
92
+ first_call_count = mock_model_class.from_pretrained.call_count
93
+
94
+ # Second load (should use cache)
95
+ load_model("test-model/test-7B")
96
+ second_call_count = mock_model_class.from_pretrained.call_count
97
+
98
+ # Should not have called from_pretrained again
99
+ assert first_call_count == second_call_count
100
+
101
+
102
+ class TestChatTemplate:
103
+ """Test chat template application."""
104
+
105
+ @patch("models.load_model")
106
+ def test_apply_chat_template_with_tokenizer_method(self, mock_load_model, mock_tokenizer):
107
+ """Test chat template when tokenizer has apply_chat_template."""
108
+ from models import apply_chat_template, LoadedModel
109
+
110
+ mock_load_model.return_value = LoadedModel(
111
+ model_id="test-model",
112
+ model=MagicMock(),
113
+ tokenizer=mock_tokenizer,
114
+ )
115
+
116
+ messages = [
117
+ {"role": "user", "content": "Hello!"},
118
+ ]
119
+
120
+ result = apply_chat_template("test-model", messages)
121
+
122
+ assert "<|user|>" in result
123
+ assert "Hello!" in result
124
+ assert "<|assistant|>" in result # Generation prompt
125
+
126
+ @patch("models.load_model")
127
+ def test_apply_chat_template_fallback(self, mock_load_model):
128
+ """Test fallback formatting when tokenizer lacks apply_chat_template."""
129
+ from models import apply_chat_template, LoadedModel
130
+
131
+ # Tokenizer without apply_chat_template
132
+ simple_tokenizer = MagicMock()
133
+ del simple_tokenizer.apply_chat_template
134
+
135
+ mock_load_model.return_value = LoadedModel(
136
+ model_id="test-model",
137
+ model=MagicMock(),
138
+ tokenizer=simple_tokenizer,
139
+ )
140
+
141
+ messages = [
142
+ {"role": "system", "content": "You are helpful."},
143
+ {"role": "user", "content": "Hi!"},
144
+ ]
145
+
146
+ result = apply_chat_template("test-model", messages)
147
+
148
+ assert "System:" in result
149
+ assert "User:" in result
150
+ assert "Assistant:" in result
tests/test_openai_compat.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for OpenAI-compatible API format handling."""
2
+
3
+ import json
4
+ import pytest
5
+ import sys
6
+ import os
7
+
8
+ # Add parent directory to path for imports
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+
11
+ from openai_compat import (
12
+ ChatCompletionRequest,
13
+ ChatMessage,
14
+ InferenceParams,
15
+ create_chat_response,
16
+ create_error_response,
17
+ create_stream_chunk,
18
+ estimate_tokens,
19
+ generate_completion_id,
20
+ messages_to_dicts,
21
+ stream_response_generator,
22
+ )
23
+
24
+
25
+ class TestChatCompletionRequest:
26
+ """Test request parsing."""
27
+
28
+ def test_parse_basic_request(self, sample_request_data):
29
+ """Test parsing a basic chat completion request."""
30
+ request = ChatCompletionRequest(**sample_request_data)
31
+
32
+ assert request.model == "meta-llama/Llama-3.1-8B-Instruct"
33
+ assert len(request.messages) == 2
34
+ assert request.messages[0].role == "system"
35
+ assert request.messages[1].role == "user"
36
+ assert request.temperature == 0.7
37
+ assert request.max_tokens == 512
38
+ assert request.stream is False
39
+
40
+ def test_parse_streaming_request(self, sample_streaming_request_data):
41
+ """Test parsing a streaming request."""
42
+ request = ChatCompletionRequest(**sample_streaming_request_data)
43
+
44
+ assert request.stream is True
45
+ assert request.max_tokens == 256
46
+
47
+ def test_default_values(self):
48
+ """Test that defaults are applied correctly."""
49
+ minimal_request = {
50
+ "model": "test-model",
51
+ "messages": [{"role": "user", "content": "Hi"}],
52
+ }
53
+ request = ChatCompletionRequest(**minimal_request)
54
+
55
+ assert request.temperature == 0.7
56
+ assert request.top_p == 0.95
57
+ assert request.max_tokens == 512
58
+ assert request.stream is False
59
+ assert request.stop is None
60
+
61
+ def test_validation_temperature_bounds(self):
62
+ """Test temperature validation."""
63
+ with pytest.raises(ValueError):
64
+ ChatCompletionRequest(
65
+ model="test",
66
+ messages=[{"role": "user", "content": "Hi"}],
67
+ temperature=-0.5,
68
+ )
69
+
70
+ with pytest.raises(ValueError):
71
+ ChatCompletionRequest(
72
+ model="test",
73
+ messages=[{"role": "user", "content": "Hi"}],
74
+ temperature=2.5,
75
+ )
76
+
77
+
78
+ class TestChatCompletionResponse:
79
+ """Test response generation."""
80
+
81
+ def test_create_basic_response(self):
82
+ """Test creating a basic chat response."""
83
+ response = create_chat_response(
84
+ model="test-model",
85
+ content="Hello! How can I help you?",
86
+ prompt_tokens=10,
87
+ completion_tokens=8,
88
+ )
89
+
90
+ assert response.model == "test-model"
91
+ assert response.object == "chat.completion"
92
+ assert len(response.choices) == 1
93
+ assert response.choices[0].message.role == "assistant"
94
+ assert response.choices[0].message.content == "Hello! How can I help you?"
95
+ assert response.choices[0].finish_reason == "stop"
96
+ assert response.usage.prompt_tokens == 10
97
+ assert response.usage.completion_tokens == 8
98
+ assert response.usage.total_tokens == 18
99
+
100
+ def test_response_has_unique_id(self):
101
+ """Test that each response has a unique ID."""
102
+ response1 = create_chat_response(model="test", content="Hi")
103
+ response2 = create_chat_response(model="test", content="Hi")
104
+
105
+ assert response1.id != response2.id
106
+ assert response1.id.startswith("chatcmpl-")
107
+
108
+ def test_response_serialization(self):
109
+ """Test that response can be serialized to JSON."""
110
+ response = create_chat_response(
111
+ model="test-model",
112
+ content="Test",
113
+ )
114
+
115
+ json_str = response.model_dump_json()
116
+ parsed = json.loads(json_str)
117
+
118
+ assert parsed["model"] == "test-model"
119
+ assert parsed["choices"][0]["message"]["content"] == "Test"
120
+
121
+
122
+ class TestStreamingResponse:
123
+ """Test streaming response format."""
124
+
125
+ def test_create_stream_chunk_with_content(self):
126
+ """Test creating a streaming chunk with content."""
127
+ chunk = create_stream_chunk(
128
+ completion_id="test-id",
129
+ model="test-model",
130
+ content="Hello",
131
+ )
132
+
133
+ assert chunk.id == "test-id"
134
+ assert chunk.object == "chat.completion.chunk"
135
+ assert chunk.choices[0].delta.content == "Hello"
136
+ assert chunk.choices[0].finish_reason is None
137
+
138
+ def test_create_stream_chunk_with_role(self):
139
+ """Test creating a streaming chunk with role (first chunk)."""
140
+ chunk = create_stream_chunk(
141
+ completion_id="test-id",
142
+ model="test-model",
143
+ role="assistant",
144
+ )
145
+
146
+ assert chunk.choices[0].delta.role == "assistant"
147
+ assert chunk.choices[0].delta.content is None
148
+
149
+ def test_create_stream_chunk_with_finish_reason(self):
150
+ """Test creating a final streaming chunk."""
151
+ chunk = create_stream_chunk(
152
+ completion_id="test-id",
153
+ model="test-model",
154
+ finish_reason="stop",
155
+ )
156
+
157
+ assert chunk.choices[0].finish_reason == "stop"
158
+
159
+ def test_stream_response_generator(self):
160
+ """Test the full streaming response generator."""
161
+ def token_gen():
162
+ yield "Hello"
163
+ yield " World"
164
+ yield "!"
165
+
166
+ chunks = list(stream_response_generator("test-model", token_gen()))
167
+
168
+ # Should have: role chunk, 3 content chunks, finish chunk, [DONE]
169
+ assert len(chunks) == 6
170
+
171
+ # First chunk has role
172
+ first_data = json.loads(chunks[0].replace("data: ", "").strip())
173
+ assert first_data["choices"][0]["delta"]["role"] == "assistant"
174
+
175
+ # Content chunks
176
+ second_data = json.loads(chunks[1].replace("data: ", "").strip())
177
+ assert second_data["choices"][0]["delta"]["content"] == "Hello"
178
+
179
+ # Last data chunk has finish reason
180
+ last_data = json.loads(chunks[4].replace("data: ", "").strip())
181
+ assert last_data["choices"][0]["finish_reason"] == "stop"
182
+
183
+ # Very last is [DONE]
184
+ assert chunks[5] == "data: [DONE]\n\n"
185
+
186
+
187
+ class TestInferenceParams:
188
+ """Test parameter extraction."""
189
+
190
+ def test_extract_params_from_request(self, sample_request_data):
191
+ """Test extracting inference parameters from request."""
192
+ request = ChatCompletionRequest(**sample_request_data)
193
+ params = InferenceParams.from_request(request)
194
+
195
+ assert params.model_id == "meta-llama/Llama-3.1-8B-Instruct"
196
+ assert len(params.messages) == 2
197
+ assert params.max_new_tokens == 512
198
+ assert params.temperature == 0.7
199
+ assert params.stream is False
200
+
201
+ def test_messages_to_dicts(self):
202
+ """Test converting ChatMessage objects to dicts."""
203
+ messages = [
204
+ ChatMessage(role="user", content="Hello"),
205
+ ChatMessage(role="assistant", content="Hi there!"),
206
+ ]
207
+
208
+ dicts = messages_to_dicts(messages)
209
+
210
+ assert dicts == [
211
+ {"role": "user", "content": "Hello"},
212
+ {"role": "assistant", "content": "Hi there!"},
213
+ ]
214
+
215
+
216
+ class TestErrorResponse:
217
+ """Test error response format."""
218
+
219
+ def test_create_error_response(self):
220
+ """Test creating an error response."""
221
+ error = create_error_response(
222
+ message="Model not found",
223
+ error_type="invalid_request_error",
224
+ param="model",
225
+ )
226
+
227
+ assert error.error.message == "Model not found"
228
+ assert error.error.type == "invalid_request_error"
229
+ assert error.error.param == "model"
230
+
231
+ def test_error_response_serialization(self):
232
+ """Test error response JSON serialization."""
233
+ error = create_error_response(
234
+ message="Test error",
235
+ error_type="server_error",
236
+ code="internal_error",
237
+ )
238
+
239
+ parsed = json.loads(error.model_dump_json())
240
+
241
+ assert parsed["error"]["message"] == "Test error"
242
+ assert parsed["error"]["type"] == "server_error"
243
+ assert parsed["error"]["code"] == "internal_error"
244
+
245
+
246
+ class TestUtilityFunctions:
247
+ """Test utility functions."""
248
+
249
+ def test_generate_completion_id_format(self):
250
+ """Test completion ID format."""
251
+ id1 = generate_completion_id()
252
+ id2 = generate_completion_id()
253
+
254
+ assert id1.startswith("chatcmpl-")
255
+ assert len(id1) == len("chatcmpl-") + 24
256
+ assert id1 != id2 # Should be unique
257
+
258
+ def test_estimate_tokens(self):
259
+ """Test rough token estimation."""
260
+ # ~4 chars per token
261
+ assert estimate_tokens("Hello World!") == 3 # 12 chars / 4 = 3
262
+ assert estimate_tokens("A") == 1 # Min 1
263
+ assert estimate_tokens("This is a longer piece of text.") == 8 # 32 / 4 = 8