Initial Commit
Browse files- .DS_Store +0 -0
- CONFIGURATION_GUIDE.md +270 -0
- Dockerfile +67 -0
- HUGGINGFACE_DEPLOYMENT.md +295 -0
- README.md +215 -10
- app/__init__.py +1 -0
- app/api/__init__.py +1 -0
- app/api/v1/__init__.py +1 -0
- app/api/v1/endpoints.py +363 -0
- app/core/__init__.py +1 -0
- app/core/config.py +178 -0
- app/core/logging.py +85 -0
- app/main.py +187 -0
- app/models/__init__.py +1 -0
- app/models/schemas.py +215 -0
- app/services/__init__.py +1 -0
- app/services/chat_manager.py +398 -0
- app/services/model_backends/__init__.py +1 -0
- app/services/model_backends/anthropic_api.py +319 -0
- app/services/model_backends/base.py +222 -0
- app/services/model_backends/google_api.py +304 -0
- app/services/model_backends/hf_api.py +303 -0
- app/services/model_backends/local_hf.py +330 -0
- app/services/model_backends/minimax_api.py +341 -0
- app/services/model_backends/openai_api.py +291 -0
- app/services/model_manager.py +382 -0
- app/services/session_manager.py +400 -0
- app/utils/__init__.py +1 -0
- app/utils/helpers.py +309 -0
- examples/test_backends.py +292 -0
- requirements.txt +28 -0
- setup_huggingface.sh +242 -0
- tests/__init__.py +1 -0
- tests/test_api.py +313 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
CONFIGURATION_GUIDE.md
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# π§ Sema Chat API Configuration Guide
|
| 2 |
+
|
| 3 |
+
## π― **MiniMax Integration**
|
| 4 |
+
|
| 5 |
+
### Configuration
|
| 6 |
+
```bash
|
| 7 |
+
MODEL_TYPE=minimax
|
| 8 |
+
MODEL_NAME=MiniMax-M1
|
| 9 |
+
MINIMAX_API_KEY=your_minimax_api_key
|
| 10 |
+
MINIMAX_API_URL=https://api.minimax.chat/v1/text/chatcompletion_v2
|
| 11 |
+
MINIMAX_MODEL_VERSION=abab6.5s-chat
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
### Features
|
| 15 |
+
- β
**Reasoning Capabilities**: Shows model's thinking process
|
| 16 |
+
- β
**Streaming Support**: Real-time response generation
|
| 17 |
+
- β
**Custom API Integration**: Direct integration with MiniMax API
|
| 18 |
+
- β
**Reasoning Content**: Displays both reasoning and final response
|
| 19 |
+
|
| 20 |
+
### Example Usage
|
| 21 |
+
```bash
|
| 22 |
+
curl -X POST "http://localhost:7860/api/v1/chat" \
|
| 23 |
+
-H "Content-Type: application/json" \
|
| 24 |
+
-d '{
|
| 25 |
+
"message": "Solve this math problem: 2x + 5 = 15",
|
| 26 |
+
"session_id": "minimax-test"
|
| 27 |
+
}'
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
**Response includes reasoning:**
|
| 31 |
+
```json
|
| 32 |
+
{
|
| 33 |
+
"message": "[Reasoning: I need to solve for x. First, subtract 5 from both sides: 2x = 10. Then divide by 2: x = 5]\n\nTo solve 2x + 5 = 15:\n1. Subtract 5 from both sides: 2x = 10\n2. Divide by 2: x = 5\n\nTherefore, x = 5.",
|
| 34 |
+
"session_id": "minimax-test",
|
| 35 |
+
"model_name": "MiniMax-M1"
|
| 36 |
+
}
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## π₯ **Gemma Integration**
|
| 42 |
+
|
| 43 |
+
### Option 1: Local Gemma (Free Tier)
|
| 44 |
+
```bash
|
| 45 |
+
MODEL_TYPE=local
|
| 46 |
+
MODEL_NAME=google/gemma-2b-it
|
| 47 |
+
DEVICE=auto
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Option 2: Gemma via HuggingFace API
|
| 51 |
+
```bash
|
| 52 |
+
MODEL_TYPE=hf_api
|
| 53 |
+
MODEL_NAME=google/gemma-2b-it
|
| 54 |
+
HF_API_TOKEN=your_hf_token
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Option 3: Gemma via Google AI Studio
|
| 58 |
+
```bash
|
| 59 |
+
MODEL_TYPE=google
|
| 60 |
+
MODEL_NAME=gemma-2-9b-it
|
| 61 |
+
GOOGLE_API_KEY=your_google_api_key
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Available Gemma Models
|
| 65 |
+
- **gemma-2-2b-it** (2B parameters, instruction-tuned)
|
| 66 |
+
- **gemma-2-9b-it** (9B parameters, instruction-tuned)
|
| 67 |
+
- **gemma-2-27b-it** (27B parameters, instruction-tuned)
|
| 68 |
+
- **gemini-1.5-flash** (Fast, efficient)
|
| 69 |
+
- **gemini-1.5-pro** (Most capable)
|
| 70 |
+
|
| 71 |
+
### Example Usage
|
| 72 |
+
```bash
|
| 73 |
+
curl -X POST "http://localhost:7860/api/v1/chat" \
|
| 74 |
+
-H "Content-Type: application/json" \
|
| 75 |
+
-d '{
|
| 76 |
+
"message": "Explain quantum computing in simple terms",
|
| 77 |
+
"session_id": "gemma-test",
|
| 78 |
+
"temperature": 0.7
|
| 79 |
+
}'
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## π **Complete Backend Comparison**
|
| 85 |
+
|
| 86 |
+
| Backend | Cost | Setup | Streaming | Special Features |
|
| 87 |
+
|---------|------|-------|-----------|------------------|
|
| 88 |
+
| **Local** | Free | Medium | β
| Offline, Private |
|
| 89 |
+
| **HF API** | Free/Paid | Easy | β
| Many models |
|
| 90 |
+
| **OpenAI** | Paid | Easy | β
| High quality |
|
| 91 |
+
| **Anthropic** | Paid | Easy | β
| Long context |
|
| 92 |
+
| **MiniMax** | Paid | Easy | β
| Reasoning |
|
| 93 |
+
| **Google** | Free/Paid | Easy | β
| Multimodal |
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## π§ **Configuration Examples**
|
| 98 |
+
|
| 99 |
+
### Free Tier Setup (HuggingFace Spaces)
|
| 100 |
+
```bash
|
| 101 |
+
# Best for free deployment
|
| 102 |
+
MODEL_TYPE=local
|
| 103 |
+
MODEL_NAME=TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
| 104 |
+
DEVICE=cpu
|
| 105 |
+
MAX_NEW_TOKENS=256
|
| 106 |
+
TEMPERATURE=0.7
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### Production Setup (API-based)
|
| 110 |
+
```bash
|
| 111 |
+
# Best for production with fallbacks
|
| 112 |
+
MODEL_TYPE=openai
|
| 113 |
+
MODEL_NAME=gpt-3.5-turbo
|
| 114 |
+
OPENAI_API_KEY=your_key
|
| 115 |
+
|
| 116 |
+
# Fallback configuration
|
| 117 |
+
FALLBACK_MODEL_TYPE=hf_api
|
| 118 |
+
FALLBACK_MODEL_NAME=microsoft/DialoGPT-medium
|
| 119 |
+
HF_API_TOKEN=your_token
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Research Setup (Multiple Models)
|
| 123 |
+
```bash
|
| 124 |
+
# Primary: Latest Gemini
|
| 125 |
+
MODEL_TYPE=google
|
| 126 |
+
MODEL_NAME=gemini-1.5-pro
|
| 127 |
+
GOOGLE_API_KEY=your_key
|
| 128 |
+
|
| 129 |
+
# For reasoning tasks
|
| 130 |
+
REASONING_MODEL_TYPE=minimax
|
| 131 |
+
REASONING_MODEL_NAME=MiniMax-M1
|
| 132 |
+
MINIMAX_API_KEY=your_key
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## π― **Model Selection Guide**
|
| 138 |
+
|
| 139 |
+
### For **Free Deployment** (HuggingFace Spaces):
|
| 140 |
+
1. **TinyLlama/TinyLlama-1.1B-Chat-v1.0** - Smallest, fastest
|
| 141 |
+
2. **microsoft/DialoGPT-medium** - Better conversations
|
| 142 |
+
3. **Qwen/Qwen2.5-0.5B-Instruct** - Good instruction following
|
| 143 |
+
|
| 144 |
+
### For **Reasoning Tasks**:
|
| 145 |
+
1. **MiniMax M1** - Shows thinking process
|
| 146 |
+
2. **Claude-3 Opus** - Deep reasoning
|
| 147 |
+
3. **GPT-4** - Complex problem solving
|
| 148 |
+
|
| 149 |
+
### For **Conversations**:
|
| 150 |
+
1. **Claude-3 Haiku** - Natural, fast
|
| 151 |
+
2. **GPT-3.5-turbo** - Balanced cost/quality
|
| 152 |
+
3. **Gemini-1.5-flash** - Fast, capable
|
| 153 |
+
|
| 154 |
+
### For **Multilingual**:
|
| 155 |
+
1. **Gemma-2-9b-it** - Good multilingual
|
| 156 |
+
2. **GPT-4** - Excellent multilingual
|
| 157 |
+
3. **Local models** - Depends on training
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## π **Dynamic Model Switching**
|
| 162 |
+
|
| 163 |
+
The API supports runtime model switching:
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
# Switch to MiniMax for reasoning
|
| 167 |
+
POST /api/v1/model/switch
|
| 168 |
+
{
|
| 169 |
+
"model_type": "minimax",
|
| 170 |
+
"model_name": "MiniMax-M1"
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Switch back to fast model
|
| 174 |
+
POST /api/v1/model/switch
|
| 175 |
+
{
|
| 176 |
+
"model_type": "google",
|
| 177 |
+
"model_name": "gemini-1.5-flash"
|
| 178 |
+
}
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## π§ͺ **Testing Your Setup**
|
| 184 |
+
|
| 185 |
+
### Test All Backends
|
| 186 |
+
```bash
|
| 187 |
+
python examples/test_backends.py
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
### Test Specific Backend
|
| 191 |
+
```bash
|
| 192 |
+
# Test MiniMax
|
| 193 |
+
MINIMAX_API_KEY=your_key python -c "
|
| 194 |
+
import asyncio
|
| 195 |
+
from app.services.model_backends.minimax_api import MiniMaxAPIBackend
|
| 196 |
+
from app.models.schemas import ChatMessage
|
| 197 |
+
|
| 198 |
+
async def test():
|
| 199 |
+
backend = MiniMaxAPIBackend('MiniMax-M1', api_key='your_key', api_url='your_url')
|
| 200 |
+
await backend.load_model()
|
| 201 |
+
messages = [ChatMessage(role='user', content='Hello')]
|
| 202 |
+
response = await backend.generate_response(messages)
|
| 203 |
+
print(response.message)
|
| 204 |
+
|
| 205 |
+
asyncio.run(test())
|
| 206 |
+
"
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
### Test Gemma
|
| 210 |
+
```bash
|
| 211 |
+
# Test local Gemma
|
| 212 |
+
MODEL_TYPE=local MODEL_NAME=google/gemma-2b-it python tests/test_api.py
|
| 213 |
+
|
| 214 |
+
# Test Gemma via Google AI
|
| 215 |
+
MODEL_TYPE=google MODEL_NAME=gemma-2-9b-it GOOGLE_API_KEY=your_key python tests/test_api.py
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
## π **Deployment Examples**
|
| 221 |
+
|
| 222 |
+
### HuggingFace Spaces (Free)
|
| 223 |
+
```yaml
|
| 224 |
+
# In your Space settings
|
| 225 |
+
MODEL_TYPE: local
|
| 226 |
+
MODEL_NAME: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
| 227 |
+
DEVICE: cpu
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
### HuggingFace Spaces (With API)
|
| 231 |
+
```yaml
|
| 232 |
+
# In your Space settings
|
| 233 |
+
MODEL_TYPE: google
|
| 234 |
+
MODEL_NAME: gemma-2-9b-it
|
| 235 |
+
GOOGLE_API_KEY: your_secret_key
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
### Docker Deployment
|
| 239 |
+
```bash
|
| 240 |
+
docker run -e MODEL_TYPE=minimax \
|
| 241 |
+
-e MINIMAX_API_KEY=your_key \
|
| 242 |
+
-e MINIMAX_API_URL=your_url \
|
| 243 |
+
-p 8000:7860 \
|
| 244 |
+
sema-chat-api
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## π‘ **Pro Tips**
|
| 250 |
+
|
| 251 |
+
1. **Start Small**: Begin with TinyLlama for testing
|
| 252 |
+
2. **Use APIs for Production**: More reliable than local models
|
| 253 |
+
3. **Enable Streaming**: Better user experience
|
| 254 |
+
4. **Monitor Usage**: Track API costs and limits
|
| 255 |
+
5. **Have Fallbacks**: Configure multiple backends
|
| 256 |
+
6. **Test Thoroughly**: Use the provided test scripts
|
| 257 |
+
|
| 258 |
+
---
|
| 259 |
+
|
| 260 |
+
## π **Getting API Keys**
|
| 261 |
+
|
| 262 |
+
- **HuggingFace**: https://huggingface.co/settings/tokens
|
| 263 |
+
- **OpenAI**: https://platform.openai.com/api-keys
|
| 264 |
+
- **Anthropic**: https://console.anthropic.com/
|
| 265 |
+
- **Google AI**: https://aistudio.google.com/
|
| 266 |
+
- **MiniMax**: Contact MiniMax for API access
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
**Your architecture is now ready for both MiniMax and Gemma! π**
|
Dockerfile
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sema Chat API Dockerfile
|
| 2 |
+
# Multi-stage build for optimized production image
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim as builder
|
| 5 |
+
|
| 6 |
+
# Set environment variables
|
| 7 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 8 |
+
PYTHONUNBUFFERED=1 \
|
| 9 |
+
PIP_NO_CACHE_DIR=1 \
|
| 10 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 11 |
+
|
| 12 |
+
# Install system dependencies
|
| 13 |
+
RUN apt-get update && apt-get install -y \
|
| 14 |
+
build-essential \
|
| 15 |
+
curl \
|
| 16 |
+
git \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Create and activate virtual environment
|
| 20 |
+
RUN python -m venv /opt/venv
|
| 21 |
+
ENV PATH="/opt/venv/bin:$PATH"
|
| 22 |
+
|
| 23 |
+
# Copy requirements and install Python dependencies
|
| 24 |
+
COPY requirements.txt .
|
| 25 |
+
RUN pip install --upgrade pip && \
|
| 26 |
+
pip install -r requirements.txt
|
| 27 |
+
|
| 28 |
+
# Production stage
|
| 29 |
+
FROM python:3.11-slim
|
| 30 |
+
|
| 31 |
+
# Set environment variables
|
| 32 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 33 |
+
PYTHONUNBUFFERED=1 \
|
| 34 |
+
PATH="/opt/venv/bin:$PATH" \
|
| 35 |
+
PYTHONPATH="/app"
|
| 36 |
+
|
| 37 |
+
# Install runtime dependencies
|
| 38 |
+
RUN apt-get update && apt-get install -y \
|
| 39 |
+
curl \
|
| 40 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 41 |
+
|
| 42 |
+
# Copy virtual environment from builder stage
|
| 43 |
+
COPY --from=builder /opt/venv /opt/venv
|
| 44 |
+
|
| 45 |
+
# Create app directory and user
|
| 46 |
+
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
| 47 |
+
WORKDIR /app
|
| 48 |
+
|
| 49 |
+
# Copy application code
|
| 50 |
+
COPY . .
|
| 51 |
+
|
| 52 |
+
# Create necessary directories
|
| 53 |
+
RUN mkdir -p logs && \
|
| 54 |
+
chown -R appuser:appuser /app
|
| 55 |
+
|
| 56 |
+
# Switch to non-root user
|
| 57 |
+
USER appuser
|
| 58 |
+
|
| 59 |
+
# Expose port
|
| 60 |
+
EXPOSE 7860
|
| 61 |
+
|
| 62 |
+
# Health check
|
| 63 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 64 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 65 |
+
|
| 66 |
+
# Default command
|
| 67 |
+
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
HUGGINGFACE_DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# π HuggingFace Spaces Deployment Guide
|
| 2 |
+
|
| 3 |
+
## π **Quick Setup for Gemma**
|
| 4 |
+
|
| 5 |
+
### Step 1: Create Your HuggingFace Space
|
| 6 |
+
1. Go to [HuggingFace Spaces](https://huggingface.co/spaces)
|
| 7 |
+
2. Click **"Create new Space"**
|
| 8 |
+
3. Choose:
|
| 9 |
+
- **Space name**: `your-username/sema-chat-gemma`
|
| 10 |
+
- **License**: MIT
|
| 11 |
+
- **Space SDK**: Docker
|
| 12 |
+
- **Space hardware**: CPU basic (free) or T4 small (paid)
|
| 13 |
+
|
| 14 |
+
### Step 2: Clone and Upload Files
|
| 15 |
+
```bash
|
| 16 |
+
# Clone your new space
|
| 17 |
+
git clone https://huggingface.co/spaces/your-username/sema-chat-gemma
|
| 18 |
+
cd sema-chat-gemma
|
| 19 |
+
|
| 20 |
+
# Copy all files from backend/sema-chat/
|
| 21 |
+
cp -r /path/to/sema/backend/sema-chat/* .
|
| 22 |
+
|
| 23 |
+
# Add and commit
|
| 24 |
+
git add .
|
| 25 |
+
git commit -m "Initial Sema Chat API with Gemma support"
|
| 26 |
+
git push
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Step 3: Configure Environment Variables
|
| 30 |
+
In your Space settings, add these environment variables:
|
| 31 |
+
|
| 32 |
+
#### **Option A: Local Gemma (Free Tier)**
|
| 33 |
+
```
|
| 34 |
+
MODEL_TYPE=local
|
| 35 |
+
MODEL_NAME=google/gemma-2b-it
|
| 36 |
+
DEVICE=cpu
|
| 37 |
+
TEMPERATURE=0.7
|
| 38 |
+
MAX_NEW_TOKENS=256
|
| 39 |
+
DEBUG=false
|
| 40 |
+
ENVIRONMENT=production
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
#### **Option B: Gemma via Google AI Studio (Recommended)**
|
| 44 |
+
```
|
| 45 |
+
MODEL_TYPE=google
|
| 46 |
+
MODEL_NAME=gemma-2-9b-it
|
| 47 |
+
GOOGLE_API_KEY=your_google_api_key_here
|
| 48 |
+
TEMPERATURE=0.7
|
| 49 |
+
MAX_NEW_TOKENS=512
|
| 50 |
+
DEBUG=false
|
| 51 |
+
ENVIRONMENT=production
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
#### **Option C: Gemma via HuggingFace API**
|
| 55 |
+
```
|
| 56 |
+
MODEL_TYPE=hf_api
|
| 57 |
+
MODEL_NAME=google/gemma-2b-it
|
| 58 |
+
HF_API_TOKEN=your_hf_token_here
|
| 59 |
+
TEMPERATURE=0.7
|
| 60 |
+
MAX_NEW_TOKENS=512
|
| 61 |
+
DEBUG=false
|
| 62 |
+
ENVIRONMENT=production
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## π **Getting API Keys**
|
| 68 |
+
|
| 69 |
+
### Google AI Studio API Key (Recommended)
|
| 70 |
+
1. Go to [Google AI Studio](https://aistudio.google.com/)
|
| 71 |
+
2. Sign in with your Google account
|
| 72 |
+
3. Click **"Get API Key"**
|
| 73 |
+
4. Create a new API key
|
| 74 |
+
5. Copy the key and add it to your Space settings
|
| 75 |
+
|
| 76 |
+
### HuggingFace API Token (Alternative)
|
| 77 |
+
1. Go to [HuggingFace Settings](https://huggingface.co/settings/tokens)
|
| 78 |
+
2. Click **"New token"**
|
| 79 |
+
3. Choose **"Read"** access
|
| 80 |
+
4. Copy the token and add it to your Space settings
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## π **Required Files Structure**
|
| 85 |
+
|
| 86 |
+
Make sure your Space has these files:
|
| 87 |
+
```
|
| 88 |
+
your-space/
|
| 89 |
+
βββ app/ # Main application code
|
| 90 |
+
βββ requirements.txt # Python dependencies
|
| 91 |
+
βββ Dockerfile # Container configuration
|
| 92 |
+
βββ README.md # Space documentation
|
| 93 |
+
βββ .gitignore # Git ignore file
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## π³ **Dockerfile Configuration**
|
| 99 |
+
|
| 100 |
+
Your Dockerfile should be:
|
| 101 |
+
```dockerfile
|
| 102 |
+
FROM python:3.11-slim
|
| 103 |
+
|
| 104 |
+
# Set environment variables
|
| 105 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 106 |
+
PYTHONUNBUFFERED=1 \
|
| 107 |
+
PYTHONPATH="/app"
|
| 108 |
+
|
| 109 |
+
# Install system dependencies
|
| 110 |
+
RUN apt-get update && apt-get install -y \
|
| 111 |
+
curl \
|
| 112 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 113 |
+
|
| 114 |
+
# Set working directory
|
| 115 |
+
WORKDIR /app
|
| 116 |
+
|
| 117 |
+
# Copy requirements and install dependencies
|
| 118 |
+
COPY requirements.txt .
|
| 119 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 120 |
+
|
| 121 |
+
# Copy application code
|
| 122 |
+
COPY . .
|
| 123 |
+
|
| 124 |
+
# Create non-root user
|
| 125 |
+
RUN useradd -m -u 1000 user
|
| 126 |
+
USER user
|
| 127 |
+
|
| 128 |
+
# Expose port 7860 (HuggingFace Spaces standard)
|
| 129 |
+
EXPOSE 7860
|
| 130 |
+
|
| 131 |
+
# Health check
|
| 132 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 133 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 134 |
+
|
| 135 |
+
# Start the application
|
| 136 |
+
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## π― **Recommended Configuration for First Version**
|
| 142 |
+
|
| 143 |
+
For your first deployment, I recommend **Google AI Studio** with Gemma:
|
| 144 |
+
|
| 145 |
+
### Environment Variables:
|
| 146 |
+
```
|
| 147 |
+
MODEL_TYPE=google
|
| 148 |
+
MODEL_NAME=gemma-2-9b-it
|
| 149 |
+
GOOGLE_API_KEY=your_api_key_here
|
| 150 |
+
TEMPERATURE=0.7
|
| 151 |
+
MAX_NEW_TOKENS=512
|
| 152 |
+
DEBUG=false
|
| 153 |
+
ENVIRONMENT=production
|
| 154 |
+
ENABLE_STREAMING=true
|
| 155 |
+
RATE_LIMIT=30
|
| 156 |
+
SESSION_TIMEOUT=30
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Why This Setup?
|
| 160 |
+
- β
**Fast deployment** - No model download needed
|
| 161 |
+
- β
**Reliable** - Google's infrastructure
|
| 162 |
+
- β
**Cost-effective** - Free tier available
|
| 163 |
+
- β
**Good performance** - Gemma 2 9B is capable
|
| 164 |
+
- β
**Streaming support** - Real-time responses
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## π§ͺ **Testing Your Deployment**
|
| 169 |
+
|
| 170 |
+
### 1. Check Health
|
| 171 |
+
```bash
|
| 172 |
+
curl https://your-username-sema-chat-gemma.hf.space/health
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### 2. Test Chat
|
| 176 |
+
```bash
|
| 177 |
+
curl -X POST "https://your-username-sema-chat-gemma.hf.space/api/v1/chat" \
|
| 178 |
+
-H "Content-Type: application/json" \
|
| 179 |
+
-d '{
|
| 180 |
+
"message": "Hello! Can you introduce yourself?",
|
| 181 |
+
"session_id": "test-session"
|
| 182 |
+
}'
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
### 3. Test Streaming
|
| 186 |
+
```bash
|
| 187 |
+
curl -N -H "Accept: text/event-stream" \
|
| 188 |
+
"https://your-username-sema-chat-gemma.hf.space/api/v1/chat/stream?message=Tell%20me%20about%20AI&session_id=test"
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### 4. Access Swagger UI
|
| 192 |
+
Visit: `https://your-username-sema-chat-gemma.hf.space/`
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
## π§ **Troubleshooting**
|
| 197 |
+
|
| 198 |
+
### Common Issues:
|
| 199 |
+
|
| 200 |
+
#### 1. **Space Won't Start**
|
| 201 |
+
- Check logs in Space settings
|
| 202 |
+
- Verify all required files are present
|
| 203 |
+
- Check Dockerfile syntax
|
| 204 |
+
|
| 205 |
+
#### 2. **Model Loading Fails**
|
| 206 |
+
- Verify API key is correct
|
| 207 |
+
- Check model name spelling
|
| 208 |
+
- Try a smaller model first
|
| 209 |
+
|
| 210 |
+
#### 3. **API Errors**
|
| 211 |
+
- Check environment variables
|
| 212 |
+
- Verify network connectivity
|
| 213 |
+
- Review application logs
|
| 214 |
+
|
| 215 |
+
#### 4. **Slow Responses**
|
| 216 |
+
- Use smaller model (gemma-2-2b-it)
|
| 217 |
+
- Reduce MAX_NEW_TOKENS
|
| 218 |
+
- Enable streaming for better UX
|
| 219 |
+
|
| 220 |
+
### Debug Commands:
|
| 221 |
+
```bash
|
| 222 |
+
# Check environment variables
|
| 223 |
+
curl https://your-space.hf.space/api/v1/model/info
|
| 224 |
+
|
| 225 |
+
# Check detailed health
|
| 226 |
+
curl https://your-space.hf.space/api/v1/health
|
| 227 |
+
|
| 228 |
+
# View logs in Space settings
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## π **Step-by-Step Deployment**
|
| 234 |
+
|
| 235 |
+
### 1. Prepare Your Space
|
| 236 |
+
```bash
|
| 237 |
+
# Create and clone your space
|
| 238 |
+
git clone https://huggingface.co/spaces/your-username/sema-chat-gemma
|
| 239 |
+
cd sema-chat-gemma
|
| 240 |
+
|
| 241 |
+
# Copy files
|
| 242 |
+
cp -r ../sema/backend/sema-chat/* .
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
### 2. Set Environment Variables
|
| 246 |
+
Go to your Space settings and add:
|
| 247 |
+
```
|
| 248 |
+
MODEL_TYPE=google
|
| 249 |
+
MODEL_NAME=gemma-2-9b-it
|
| 250 |
+
GOOGLE_API_KEY=your_key_here
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
### 3. Deploy
|
| 254 |
+
```bash
|
| 255 |
+
git add .
|
| 256 |
+
git commit -m "Deploy Sema Chat with Gemma"
|
| 257 |
+
git push
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
### 4. Wait for Build
|
| 261 |
+
- Space will automatically build (5-10 minutes)
|
| 262 |
+
- Check build logs for any errors
|
| 263 |
+
- Once running, test the endpoints
|
| 264 |
+
|
| 265 |
+
### 5. Share Your Space
|
| 266 |
+
Your API will be available at:
|
| 267 |
+
`https://your-username-sema-chat-gemma.hf.space/`
|
| 268 |
+
|
| 269 |
+
---
|
| 270 |
+
|
| 271 |
+
## π‘ **Pro Tips**
|
| 272 |
+
|
| 273 |
+
1. **Start with Google AI Studio** - Easiest setup
|
| 274 |
+
2. **Use environment variables** - Never hardcode API keys
|
| 275 |
+
3. **Enable streaming** - Better user experience
|
| 276 |
+
4. **Monitor usage** - Check API quotas
|
| 277 |
+
5. **Test thoroughly** - Use the provided test scripts
|
| 278 |
+
6. **Document your API** - Swagger UI is auto-generated
|
| 279 |
+
|
| 280 |
+
---
|
| 281 |
+
|
| 282 |
+
## π **You're Ready!**
|
| 283 |
+
|
| 284 |
+
With this setup, you'll have a production-ready chatbot API with:
|
| 285 |
+
- β
Gemma 2 9B model via Google AI Studio
|
| 286 |
+
- β
Streaming responses
|
| 287 |
+
- β
Session management
|
| 288 |
+
- β
Rate limiting
|
| 289 |
+
- β
Health monitoring
|
| 290 |
+
- β
Interactive Swagger UI
|
| 291 |
+
|
| 292 |
+
**Your Space URL will be:**
|
| 293 |
+
`https://your-username-sema-chat-gemma.hf.space/`
|
| 294 |
+
|
| 295 |
+
Happy deploying! π
|
README.md
CHANGED
|
@@ -1,12 +1,217 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
+
# Sema Chat API π¬
|
| 2 |
+
|
| 3 |
+
Modern chatbot API with streaming capabilities, flexible model backends, and production-ready features. Built with FastAPI and designed for rapid GenAI advancements.
|
| 4 |
+
|
| 5 |
+
## π Quick Start with Gemma
|
| 6 |
+
|
| 7 |
+
### Option 1: Automated HuggingFace Spaces Deployment
|
| 8 |
+
```bash
|
| 9 |
+
cd backend/sema-chat
|
| 10 |
+
./setup_huggingface.sh
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
### Option 2: Manual Local Setup
|
| 14 |
+
```bash
|
| 15 |
+
cd backend/sema-chat
|
| 16 |
+
pip install -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy and configure environment
|
| 19 |
+
cp .env.example .env
|
| 20 |
+
|
| 21 |
+
# For Gemma via Google AI Studio (Recommended)
|
| 22 |
+
# Edit .env:
|
| 23 |
+
MODEL_TYPE=google
|
| 24 |
+
MODEL_NAME=gemma-2-9b-it
|
| 25 |
+
GOOGLE_API_KEY=your_google_api_key
|
| 26 |
+
|
| 27 |
+
# Run the API
|
| 28 |
+
uvicorn app.main:app --reload --host 0.0.0.0 --port 7860
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Option 3: Local Gemma (Free, No API Key)
|
| 32 |
+
```bash
|
| 33 |
+
# Edit .env:
|
| 34 |
+
MODEL_TYPE=local
|
| 35 |
+
MODEL_NAME=google/gemma-2b-it
|
| 36 |
+
DEVICE=cpu
|
| 37 |
+
|
| 38 |
+
# Run (will download model on first run)
|
| 39 |
+
uvicorn app.main:app --reload --host 0.0.0.0 --port 7860
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## π Access Your API
|
| 43 |
+
|
| 44 |
+
Once running, access:
|
| 45 |
+
- **Swagger UI**: http://localhost:7860/
|
| 46 |
+
- **Health Check**: http://localhost:7860/api/v1/health
|
| 47 |
+
- **Chat Endpoint**: http://localhost:7860/api/v1/chat
|
| 48 |
+
|
| 49 |
+
## π§ͺ Quick Test
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# Test chat
|
| 53 |
+
curl -X POST "http://localhost:7860/api/v1/chat" \
|
| 54 |
+
-H "Content-Type: application/json" \
|
| 55 |
+
-d '{
|
| 56 |
+
"message": "Hello! Can you introduce yourself?",
|
| 57 |
+
"session_id": "test-session"
|
| 58 |
+
}'
|
| 59 |
+
|
| 60 |
+
# Test streaming
|
| 61 |
+
curl -N -H "Accept: text/event-stream" \
|
| 62 |
+
"http://localhost:7860/api/v1/chat/stream?message=Tell%20me%20about%20AI&session_id=test"
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## π― Features
|
| 66 |
+
|
| 67 |
+
### Core Capabilities
|
| 68 |
+
- β
**Real-time Streaming**: Server-Sent Events and WebSocket support
|
| 69 |
+
- β
**Multiple Model Backends**: Local, HuggingFace API, OpenAI, Anthropic, Google AI, MiniMax
|
| 70 |
+
- β
**Session Management**: Persistent conversation contexts
|
| 71 |
+
- β
**Rate Limiting**: Built-in protection with configurable limits
|
| 72 |
+
- β
**Health Monitoring**: Comprehensive health checks and metrics
|
| 73 |
+
|
| 74 |
+
### Supported Models
|
| 75 |
+
- **Local**: TinyLlama, DialoGPT, Gemma, Qwen
|
| 76 |
+
- **Google AI**: Gemma-2-9b-it, Gemini-1.5-flash, Gemini-1.5-pro
|
| 77 |
+
- **OpenAI**: GPT-3.5-turbo, GPT-4, GPT-4-turbo
|
| 78 |
+
- **Anthropic**: Claude-3-haiku, Claude-3-sonnet, Claude-3-opus
|
| 79 |
+
- **HuggingFace API**: Any model via Inference API
|
| 80 |
+
- **MiniMax**: M1 model with reasoning capabilities
|
| 81 |
+
|
| 82 |
+
## π§ Configuration
|
| 83 |
+
|
| 84 |
+
### Environment Variables
|
| 85 |
+
```bash
|
| 86 |
+
# Model Backend (local, google, openai, anthropic, hf_api, minimax)
|
| 87 |
+
MODEL_TYPE=google
|
| 88 |
+
MODEL_NAME=gemma-2-9b-it
|
| 89 |
+
|
| 90 |
+
# API Keys (as needed)
|
| 91 |
+
GOOGLE_API_KEY=your_key
|
| 92 |
+
OPENAI_API_KEY=your_key
|
| 93 |
+
ANTHROPIC_API_KEY=your_key
|
| 94 |
+
HF_API_TOKEN=your_token
|
| 95 |
+
MINIMAX_API_KEY=your_key
|
| 96 |
+
|
| 97 |
+
# Generation Settings
|
| 98 |
+
TEMPERATURE=0.7
|
| 99 |
+
MAX_NEW_TOKENS=512
|
| 100 |
+
TOP_P=0.9
|
| 101 |
+
|
| 102 |
+
# Server Settings
|
| 103 |
+
HOST=0.0.0.0
|
| 104 |
+
PORT=7860
|
| 105 |
+
DEBUG=false
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## π Documentation
|
| 109 |
+
|
| 110 |
+
- **[Configuration Guide](CONFIGURATION_GUIDE.md)** - Detailed setup for all backends
|
| 111 |
+
- **[HuggingFace Deployment](HUGGINGFACE_DEPLOYMENT.md)** - Step-by-step deployment guide
|
| 112 |
+
- **[API Documentation](http://localhost:7860/)** - Interactive Swagger UI
|
| 113 |
+
|
| 114 |
+
## π§ͺ Testing
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
# Run comprehensive tests
|
| 118 |
+
python tests/test_api.py
|
| 119 |
+
|
| 120 |
+
# Test different backends
|
| 121 |
+
python examples/test_backends.py
|
| 122 |
+
|
| 123 |
+
# Test specific backend
|
| 124 |
+
python examples/test_backends.py --backend google
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## π Deployment
|
| 128 |
+
|
| 129 |
+
### HuggingFace Spaces (Recommended)
|
| 130 |
+
1. Run the setup script: `./setup_huggingface.sh`
|
| 131 |
+
2. Create your Space on HuggingFace
|
| 132 |
+
3. Push the generated code
|
| 133 |
+
4. Set environment variables in Space settings
|
| 134 |
+
5. Your API will be live at: `https://username-spacename.hf.space/`
|
| 135 |
+
|
| 136 |
+
### Docker
|
| 137 |
+
```bash
|
| 138 |
+
docker build -t sema-chat-api .
|
| 139 |
+
docker run -e MODEL_TYPE=google \
|
| 140 |
+
-e GOOGLE_API_KEY=your_key \
|
| 141 |
+
-p 7860:7860 \
|
| 142 |
+
sema-chat-api
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## π API Endpoints
|
| 146 |
+
|
| 147 |
+
### Chat
|
| 148 |
+
- **`POST /api/v1/chat`** - Send chat message
|
| 149 |
+
- **`GET /api/v1/chat/stream`** - Streaming chat (SSE)
|
| 150 |
+
- **`WebSocket /api/v1/chat/ws`** - Real-time WebSocket chat
|
| 151 |
+
|
| 152 |
+
### Sessions
|
| 153 |
+
- **`GET /api/v1/sessions/{id}`** - Get conversation history
|
| 154 |
+
- **`DELETE /api/v1/sessions/{id}`** - Clear conversation
|
| 155 |
+
- **`GET /api/v1/sessions`** - List active sessions
|
| 156 |
+
|
| 157 |
+
### System
|
| 158 |
+
- **`GET /api/v1/health`** - Comprehensive health check
|
| 159 |
+
- **`GET /api/v1/model/info`** - Current model information
|
| 160 |
+
- **`GET /api/v1/status`** - Basic status
|
| 161 |
+
|
| 162 |
+
## π‘ Why This Architecture?
|
| 163 |
+
|
| 164 |
+
1. **Future-Proof**: Modular design adapts to rapid GenAI advancements
|
| 165 |
+
2. **Flexible**: Switch between local models and APIs with environment variables
|
| 166 |
+
3. **Production-Ready**: Rate limiting, monitoring, error handling built-in
|
| 167 |
+
4. **Cost-Effective**: Start free with local models, scale with APIs
|
| 168 |
+
5. **Developer-Friendly**: Comprehensive docs, tests, and examples
|
| 169 |
+
|
| 170 |
+
## π οΈ Development
|
| 171 |
+
|
| 172 |
+
### Project Structure
|
| 173 |
+
```
|
| 174 |
+
app/
|
| 175 |
+
βββ main.py # FastAPI application
|
| 176 |
+
βββ api/v1/endpoints.py # API routes
|
| 177 |
+
βββ core/
|
| 178 |
+
β βββ config.py # Environment-based configuration
|
| 179 |
+
β βββ logging.py # Structured logging
|
| 180 |
+
βββ models/schemas.py # Pydantic request/response models
|
| 181 |
+
βββ services/
|
| 182 |
+
β βββ chat_manager.py # Chat orchestration
|
| 183 |
+
β βββ model_manager.py # Backend selection
|
| 184 |
+
β βββ session_manager.py # Conversation management
|
| 185 |
+
β βββ model_backends/ # Model implementations
|
| 186 |
+
βββ utils/helpers.py # Utility functions
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
### Adding New Backends
|
| 190 |
+
1. Create new backend in `app/services/model_backends/`
|
| 191 |
+
2. Inherit from `ModelBackend` base class
|
| 192 |
+
3. Implement required methods
|
| 193 |
+
4. Add to `ModelManager._create_backend()`
|
| 194 |
+
5. Update configuration and documentation
|
| 195 |
+
|
| 196 |
+
## π€ Contributing
|
| 197 |
+
|
| 198 |
+
1. Fork the repository
|
| 199 |
+
2. Create a feature branch
|
| 200 |
+
3. Add tests for new functionality
|
| 201 |
+
4. Ensure all tests pass
|
| 202 |
+
5. Submit a pull request
|
| 203 |
+
|
| 204 |
+
## π License
|
| 205 |
+
|
| 206 |
+
MIT License - see LICENSE file for details.
|
| 207 |
+
|
| 208 |
+
## π Acknowledgments
|
| 209 |
+
|
| 210 |
+
- **HuggingFace** for model hosting and Spaces platform
|
| 211 |
+
- **Google** for Gemma models and AI Studio
|
| 212 |
+
- **FastAPI** for the excellent web framework
|
| 213 |
+
- **OpenAI, Anthropic, MiniMax** for their APIs
|
| 214 |
+
|
| 215 |
---
|
| 216 |
|
| 217 |
+
**Ready to chat? Deploy your Sema Chat API today! ππ¬**
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Sema Chat API Package
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API package
|
app/api/v1/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API v1 package
|
app/api/v1/endpoints.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API v1 endpoints for Sema Chat API
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import List, Optional, Dict, Any
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter, HTTPException, Request, Query, WebSocket, WebSocketDisconnect
|
| 12 |
+
from fastapi.responses import StreamingResponse
|
| 13 |
+
from sse_starlette.sse import EventSourceResponse
|
| 14 |
+
from slowapi import Limiter
|
| 15 |
+
from slowapi.util import get_remote_address
|
| 16 |
+
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
| 17 |
+
from fastapi.responses import Response
|
| 18 |
+
|
| 19 |
+
from ...models.schemas import (
|
| 20 |
+
ChatRequest, ChatResponse, StreamChunk, ConversationHistory,
|
| 21 |
+
HealthResponse, ErrorResponse, ModelInfo, SessionInfo
|
| 22 |
+
)
|
| 23 |
+
from ...services.chat_manager import get_chat_manager
|
| 24 |
+
from ...services.model_manager import get_model_manager
|
| 25 |
+
from ...services.session_manager import get_session_manager
|
| 26 |
+
from ...services.model_backends.base import ModelBackendError, ModelNotLoadedError, GenerationError
|
| 27 |
+
from ...core.config import settings
|
| 28 |
+
from ...core.logging import get_logger
|
| 29 |
+
|
| 30 |
+
# Initialize router and rate limiter
|
| 31 |
+
router = APIRouter()
|
| 32 |
+
limiter = Limiter(key_func=get_remote_address)
|
| 33 |
+
logger = get_logger()
|
| 34 |
+
|
| 35 |
+
# WebSocket connection manager
|
| 36 |
+
class ConnectionManager:
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.active_connections: List[WebSocket] = []
|
| 39 |
+
|
| 40 |
+
async def connect(self, websocket: WebSocket):
|
| 41 |
+
await websocket.accept()
|
| 42 |
+
self.active_connections.append(websocket)
|
| 43 |
+
|
| 44 |
+
def disconnect(self, websocket: WebSocket):
|
| 45 |
+
self.active_connections.remove(websocket)
|
| 46 |
+
|
| 47 |
+
async def send_personal_message(self, message: str, websocket: WebSocket):
|
| 48 |
+
await websocket.send_text(message)
|
| 49 |
+
|
| 50 |
+
manager = ConnectionManager()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@router.post("/chat", response_model=ChatResponse)
|
| 54 |
+
@limiter.limit(f"{settings.rate_limit}/minute")
|
| 55 |
+
async def chat(request: ChatRequest, req: Request):
|
| 56 |
+
"""
|
| 57 |
+
Send a chat message and get a complete response
|
| 58 |
+
"""
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
chat_manager = await get_chat_manager()
|
| 63 |
+
response = await chat_manager.process_chat_request(request)
|
| 64 |
+
|
| 65 |
+
# Add timing information
|
| 66 |
+
total_time = time.time() - start_time
|
| 67 |
+
response.generation_time = getattr(response, 'generation_time', total_time)
|
| 68 |
+
|
| 69 |
+
logger.info("chat_request_completed",
|
| 70 |
+
session_id=request.session_id,
|
| 71 |
+
message_length=len(request.message),
|
| 72 |
+
response_length=len(response.message),
|
| 73 |
+
total_time=total_time)
|
| 74 |
+
|
| 75 |
+
return response
|
| 76 |
+
|
| 77 |
+
except ModelNotLoadedError as e:
|
| 78 |
+
logger.error("model_not_loaded", error=str(e), session_id=request.session_id)
|
| 79 |
+
raise HTTPException(status_code=503, detail="Model not available")
|
| 80 |
+
|
| 81 |
+
except GenerationError as e:
|
| 82 |
+
logger.error("generation_error", error=str(e), session_id=request.session_id)
|
| 83 |
+
raise HTTPException(status_code=500, detail="Failed to generate response")
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error("chat_request_failed", error=str(e), session_id=request.session_id)
|
| 87 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@router.get("/chat/stream")
|
| 91 |
+
@limiter.limit(f"{settings.rate_limit}/minute")
|
| 92 |
+
async def chat_stream(
|
| 93 |
+
message: str = Query(..., description="Chat message"),
|
| 94 |
+
session_id: str = Query(..., description="Session ID"),
|
| 95 |
+
system_prompt: Optional[str] = Query(None, description="Custom system prompt"),
|
| 96 |
+
temperature: Optional[float] = Query(None, ge=0.0, le=1.0, description="Temperature"),
|
| 97 |
+
max_tokens: Optional[int] = Query(None, ge=1, le=2048, description="Max tokens"),
|
| 98 |
+
req: Request = None
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Send a chat message and get a streaming response via Server-Sent Events
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
# Create chat request
|
| 105 |
+
chat_request = ChatRequest(
|
| 106 |
+
message=message,
|
| 107 |
+
session_id=session_id,
|
| 108 |
+
system_prompt=system_prompt,
|
| 109 |
+
temperature=temperature,
|
| 110 |
+
max_tokens=max_tokens,
|
| 111 |
+
stream=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
chat_manager = await get_chat_manager()
|
| 115 |
+
|
| 116 |
+
async def event_generator():
|
| 117 |
+
try:
|
| 118 |
+
async for chunk in chat_manager.process_streaming_chat_request(chat_request):
|
| 119 |
+
# Format as SSE event
|
| 120 |
+
chunk_data = {
|
| 121 |
+
"content": chunk.content,
|
| 122 |
+
"session_id": chunk.session_id,
|
| 123 |
+
"message_id": chunk.message_id,
|
| 124 |
+
"chunk_id": chunk.chunk_id,
|
| 125 |
+
"is_final": chunk.is_final,
|
| 126 |
+
"timestamp": chunk.timestamp.isoformat()
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
yield {
|
| 130 |
+
"event": "chunk",
|
| 131 |
+
"data": chunk_data
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
if chunk.is_final:
|
| 135 |
+
yield {
|
| 136 |
+
"event": "done",
|
| 137 |
+
"data": {"message": "Stream completed"}
|
| 138 |
+
}
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error("streaming_error", error=str(e), session_id=session_id)
|
| 143 |
+
yield {
|
| 144 |
+
"event": "error",
|
| 145 |
+
"data": {"error": str(e)}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return EventSourceResponse(event_generator())
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error("stream_setup_failed", error=str(e), session_id=session_id)
|
| 152 |
+
raise HTTPException(status_code=500, detail="Failed to setup stream")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@router.websocket("/chat/ws")
|
| 156 |
+
async def websocket_chat(websocket: WebSocket):
|
| 157 |
+
"""
|
| 158 |
+
WebSocket endpoint for real-time chat
|
| 159 |
+
"""
|
| 160 |
+
await manager.connect(websocket)
|
| 161 |
+
session_id = None
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
while True:
|
| 165 |
+
# Receive message from client
|
| 166 |
+
data = await websocket.receive_json()
|
| 167 |
+
|
| 168 |
+
# Extract request data
|
| 169 |
+
message = data.get("message")
|
| 170 |
+
session_id = data.get("session_id")
|
| 171 |
+
system_prompt = data.get("system_prompt")
|
| 172 |
+
temperature = data.get("temperature")
|
| 173 |
+
max_tokens = data.get("max_tokens")
|
| 174 |
+
|
| 175 |
+
if not message or not session_id:
|
| 176 |
+
await websocket.send_json({
|
| 177 |
+
"error": "Message and session_id are required"
|
| 178 |
+
})
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Create chat request
|
| 182 |
+
chat_request = ChatRequest(
|
| 183 |
+
message=message,
|
| 184 |
+
session_id=session_id,
|
| 185 |
+
system_prompt=system_prompt,
|
| 186 |
+
temperature=temperature,
|
| 187 |
+
max_tokens=max_tokens,
|
| 188 |
+
stream=True
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Process streaming request
|
| 192 |
+
chat_manager = await get_chat_manager()
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
async for chunk in chat_manager.process_streaming_chat_request(chat_request):
|
| 196 |
+
await websocket.send_json({
|
| 197 |
+
"type": "chunk",
|
| 198 |
+
"content": chunk.content,
|
| 199 |
+
"session_id": chunk.session_id,
|
| 200 |
+
"message_id": chunk.message_id,
|
| 201 |
+
"chunk_id": chunk.chunk_id,
|
| 202 |
+
"is_final": chunk.is_final,
|
| 203 |
+
"timestamp": chunk.timestamp.isoformat()
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
if chunk.is_final:
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error("websocket_generation_error", error=str(e), session_id=session_id)
|
| 211 |
+
await websocket.send_json({
|
| 212 |
+
"type": "error",
|
| 213 |
+
"error": str(e)
|
| 214 |
+
})
|
| 215 |
+
|
| 216 |
+
except WebSocketDisconnect:
|
| 217 |
+
manager.disconnect(websocket)
|
| 218 |
+
logger.info("websocket_disconnected", session_id=session_id)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error("websocket_error", error=str(e), session_id=session_id)
|
| 221 |
+
manager.disconnect(websocket)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@router.get("/sessions/{session_id}", response_model=ConversationHistory)
|
| 225 |
+
async def get_session(session_id: str):
|
| 226 |
+
"""
|
| 227 |
+
Get conversation history for a session
|
| 228 |
+
"""
|
| 229 |
+
try:
|
| 230 |
+
chat_manager = await get_chat_manager()
|
| 231 |
+
history = await chat_manager.get_conversation_history(session_id)
|
| 232 |
+
|
| 233 |
+
if not history:
|
| 234 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 235 |
+
|
| 236 |
+
return history
|
| 237 |
+
|
| 238 |
+
except HTTPException:
|
| 239 |
+
raise
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.error("get_session_failed", error=str(e), session_id=session_id)
|
| 242 |
+
raise HTTPException(status_code=500, detail="Failed to get session")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@router.delete("/sessions/{session_id}")
|
| 246 |
+
async def clear_session(session_id: str):
|
| 247 |
+
"""
|
| 248 |
+
Clear conversation history for a session
|
| 249 |
+
"""
|
| 250 |
+
try:
|
| 251 |
+
chat_manager = await get_chat_manager()
|
| 252 |
+
success = await chat_manager.clear_conversation(session_id)
|
| 253 |
+
|
| 254 |
+
if not success:
|
| 255 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 256 |
+
|
| 257 |
+
return {"message": "Session cleared successfully"}
|
| 258 |
+
|
| 259 |
+
except HTTPException:
|
| 260 |
+
raise
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error("clear_session_failed", error=str(e), session_id=session_id)
|
| 263 |
+
raise HTTPException(status_code=500, detail="Failed to clear session")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@router.get("/sessions", response_model=List[Dict[str, Any]])
|
| 267 |
+
async def get_active_sessions():
|
| 268 |
+
"""
|
| 269 |
+
Get list of active chat sessions
|
| 270 |
+
"""
|
| 271 |
+
try:
|
| 272 |
+
chat_manager = await get_chat_manager()
|
| 273 |
+
sessions = await chat_manager.get_active_sessions()
|
| 274 |
+
return sessions
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
logger.error("get_active_sessions_failed", error=str(e))
|
| 278 |
+
raise HTTPException(status_code=500, detail="Failed to get active sessions")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@router.get("/model/info", response_model=ModelInfo)
|
| 282 |
+
async def get_model_info():
|
| 283 |
+
"""
|
| 284 |
+
Get information about the current model
|
| 285 |
+
"""
|
| 286 |
+
try:
|
| 287 |
+
model_manager = await get_model_manager()
|
| 288 |
+
info = model_manager.get_model_info()
|
| 289 |
+
|
| 290 |
+
return ModelInfo(
|
| 291 |
+
name=info["name"],
|
| 292 |
+
type=info["type"],
|
| 293 |
+
loaded=info["loaded"],
|
| 294 |
+
parameters=info.get("parameters"),
|
| 295 |
+
capabilities=info.get("capabilities", [])
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.error("get_model_info_failed", error=str(e))
|
| 300 |
+
raise HTTPException(status_code=500, detail="Failed to get model info")
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@router.get("/health", response_model=HealthResponse)
|
| 304 |
+
async def health_check():
|
| 305 |
+
"""
|
| 306 |
+
Comprehensive health check endpoint
|
| 307 |
+
"""
|
| 308 |
+
try:
|
| 309 |
+
chat_manager = await get_chat_manager()
|
| 310 |
+
health_data = await chat_manager.health_check()
|
| 311 |
+
|
| 312 |
+
# Extract key information
|
| 313 |
+
overall_status = health_data.get("overall", {})
|
| 314 |
+
model_info = health_data.get("model_manager", {})
|
| 315 |
+
session_info = health_data.get("session_manager", {})
|
| 316 |
+
|
| 317 |
+
return HealthResponse(
|
| 318 |
+
status=overall_status.get("status", "unknown"),
|
| 319 |
+
version=settings.app_version,
|
| 320 |
+
model_type=settings.model_type,
|
| 321 |
+
model_name=settings.model_name,
|
| 322 |
+
model_loaded=model_info.get("status") == "healthy",
|
| 323 |
+
uptime=time.time(), # Simplified uptime
|
| 324 |
+
active_sessions=session_info.get("active_sessions", 0),
|
| 325 |
+
timestamp=datetime.utcnow()
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
logger.error("health_check_failed", error=str(e))
|
| 330 |
+
return HealthResponse(
|
| 331 |
+
status="unhealthy",
|
| 332 |
+
version=settings.app_version,
|
| 333 |
+
model_type=settings.model_type,
|
| 334 |
+
model_name=settings.model_name,
|
| 335 |
+
model_loaded=False,
|
| 336 |
+
uptime=time.time(),
|
| 337 |
+
active_sessions=0,
|
| 338 |
+
timestamp=datetime.utcnow()
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@router.get("/status")
|
| 343 |
+
async def status():
|
| 344 |
+
"""
|
| 345 |
+
Simple status endpoint
|
| 346 |
+
"""
|
| 347 |
+
return {
|
| 348 |
+
"status": "ok",
|
| 349 |
+
"service": "sema-chat-api",
|
| 350 |
+
"version": settings.app_version,
|
| 351 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@router.get("/metrics")
|
| 356 |
+
async def metrics():
|
| 357 |
+
"""
|
| 358 |
+
Prometheus metrics endpoint
|
| 359 |
+
"""
|
| 360 |
+
if not settings.enable_metrics:
|
| 361 |
+
raise HTTPException(status_code=404, detail="Metrics not enabled")
|
| 362 |
+
|
| 363 |
+
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Core configuration and utilities
|
app/core/config.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for Sema Chat API
|
| 3 |
+
Environment-driven settings for flexible model backends
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from pydantic import BaseSettings, Field
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables from .env file
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Settings(BaseSettings):
|
| 16 |
+
"""Application settings with environment variable support"""
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# APPLICATION SETTINGS
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
app_name: str = Field(default="Sema Chat API", env="APP_NAME")
|
| 23 |
+
app_version: str = Field(default="1.0.0", env="APP_VERSION")
|
| 24 |
+
environment: str = Field(default="development", env="ENVIRONMENT")
|
| 25 |
+
debug: bool = Field(default=True, env="DEBUG")
|
| 26 |
+
|
| 27 |
+
# =============================================================================
|
| 28 |
+
# SERVER SETTINGS
|
| 29 |
+
# =============================================================================
|
| 30 |
+
|
| 31 |
+
host: str = Field(default="0.0.0.0", env="HOST")
|
| 32 |
+
port: int = Field(default=7860, env="PORT")
|
| 33 |
+
cors_origins: List[str] = Field(default=["*"], env="CORS_ORIGINS")
|
| 34 |
+
|
| 35 |
+
# =============================================================================
|
| 36 |
+
# MODEL CONFIGURATION
|
| 37 |
+
# =============================================================================
|
| 38 |
+
|
| 39 |
+
model_type: str = Field(default="local", env="MODEL_TYPE")
|
| 40 |
+
model_name: str = Field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", env="MODEL_NAME")
|
| 41 |
+
|
| 42 |
+
# Local model settings
|
| 43 |
+
device: str = Field(default="auto", env="DEVICE")
|
| 44 |
+
max_length: int = Field(default=2048, env="MAX_LENGTH")
|
| 45 |
+
temperature: float = Field(default=0.7, env="TEMPERATURE")
|
| 46 |
+
top_p: float = Field(default=0.9, env="TOP_P")
|
| 47 |
+
top_k: int = Field(default=50, env="TOP_K")
|
| 48 |
+
max_new_tokens: int = Field(default=512, env="MAX_NEW_TOKENS")
|
| 49 |
+
|
| 50 |
+
# =============================================================================
|
| 51 |
+
# API KEYS AND TOKENS
|
| 52 |
+
# =============================================================================
|
| 53 |
+
|
| 54 |
+
# HuggingFace
|
| 55 |
+
hf_api_token: Optional[str] = Field(default=None, env="HF_API_TOKEN")
|
| 56 |
+
hf_inference_url: str = Field(
|
| 57 |
+
default="https://api-inference.huggingface.co/models/",
|
| 58 |
+
env="HF_INFERENCE_URL"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# OpenAI
|
| 62 |
+
openai_api_key: Optional[str] = Field(default=None, env="OPENAI_API_KEY")
|
| 63 |
+
openai_org_id: Optional[str] = Field(default=None, env="OPENAI_ORG_ID")
|
| 64 |
+
|
| 65 |
+
# Anthropic
|
| 66 |
+
anthropic_api_key: Optional[str] = Field(default=None, env="ANTHROPIC_API_KEY")
|
| 67 |
+
|
| 68 |
+
# MiniMax
|
| 69 |
+
minimax_api_key: Optional[str] = Field(default=None, env="MINIMAX_API_KEY")
|
| 70 |
+
minimax_api_url: Optional[str] = Field(default=None, env="MINIMAX_API_URL")
|
| 71 |
+
minimax_model_version: Optional[str] = Field(default=None, env="MINIMAX_MODEL_VERSION")
|
| 72 |
+
|
| 73 |
+
# Google AI Studio
|
| 74 |
+
google_api_key: Optional[str] = Field(default=None, env="GOOGLE_API_KEY")
|
| 75 |
+
|
| 76 |
+
# =============================================================================
|
| 77 |
+
# RATE LIMITING AND PERFORMANCE
|
| 78 |
+
# =============================================================================
|
| 79 |
+
|
| 80 |
+
rate_limit: int = Field(default=60, env="RATE_LIMIT") # requests per minute
|
| 81 |
+
max_concurrent_streams: int = Field(default=10, env="MAX_CONCURRENT_STREAMS")
|
| 82 |
+
stream_delay: float = Field(default=0.01, env="STREAM_DELAY")
|
| 83 |
+
|
| 84 |
+
# =============================================================================
|
| 85 |
+
# SESSION MANAGEMENT
|
| 86 |
+
# =============================================================================
|
| 87 |
+
|
| 88 |
+
session_timeout: int = Field(default=30, env="SESSION_TIMEOUT") # minutes
|
| 89 |
+
max_sessions_per_user: int = Field(default=5, env="MAX_SESSIONS_PER_USER")
|
| 90 |
+
max_messages_per_session: int = Field(default=100, env="MAX_MESSAGES_PER_SESSION")
|
| 91 |
+
|
| 92 |
+
# =============================================================================
|
| 93 |
+
# STREAMING SETTINGS
|
| 94 |
+
# =============================================================================
|
| 95 |
+
|
| 96 |
+
enable_streaming: bool = Field(default=True, env="ENABLE_STREAMING")
|
| 97 |
+
|
| 98 |
+
# =============================================================================
|
| 99 |
+
# LOGGING AND MONITORING
|
| 100 |
+
# =============================================================================
|
| 101 |
+
|
| 102 |
+
log_level: str = Field(default="INFO", env="LOG_LEVEL")
|
| 103 |
+
structured_logging: bool = Field(default=True, env="STRUCTURED_LOGGING")
|
| 104 |
+
log_file: Optional[str] = Field(default=None, env="LOG_FILE")
|
| 105 |
+
|
| 106 |
+
enable_metrics: bool = Field(default=True, env="ENABLE_METRICS")
|
| 107 |
+
metrics_path: str = Field(default="/metrics", env="METRICS_PATH")
|
| 108 |
+
|
| 109 |
+
# =============================================================================
|
| 110 |
+
# EXTERNAL SERVICES
|
| 111 |
+
# =============================================================================
|
| 112 |
+
|
| 113 |
+
redis_url: Optional[str] = Field(default=None, env="REDIS_URL")
|
| 114 |
+
|
| 115 |
+
# =============================================================================
|
| 116 |
+
# SECURITY
|
| 117 |
+
# =============================================================================
|
| 118 |
+
|
| 119 |
+
api_key: Optional[str] = Field(default=None, env="API_KEY")
|
| 120 |
+
jwt_secret: Optional[str] = Field(default=None, env="JWT_SECRET")
|
| 121 |
+
|
| 122 |
+
# =============================================================================
|
| 123 |
+
# SYSTEM PROMPTS
|
| 124 |
+
# =============================================================================
|
| 125 |
+
|
| 126 |
+
system_prompt: str = Field(
|
| 127 |
+
default="You are a helpful, harmless, and honest AI assistant. Respond in a friendly and professional manner.",
|
| 128 |
+
env="SYSTEM_PROMPT"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
system_prompt_chat: Optional[str] = Field(default=None, env="SYSTEM_PROMPT_CHAT")
|
| 132 |
+
system_prompt_code: Optional[str] = Field(default=None, env="SYSTEM_PROMPT_CODE")
|
| 133 |
+
system_prompt_creative: Optional[str] = Field(default=None, env="SYSTEM_PROMPT_CREATIVE")
|
| 134 |
+
|
| 135 |
+
class Config:
|
| 136 |
+
env_file = ".env"
|
| 137 |
+
case_sensitive = False
|
| 138 |
+
|
| 139 |
+
def get_system_prompt(self, prompt_type: str = "default") -> str:
|
| 140 |
+
"""Get system prompt based on type"""
|
| 141 |
+
if prompt_type == "chat" and self.system_prompt_chat:
|
| 142 |
+
return self.system_prompt_chat
|
| 143 |
+
elif prompt_type == "code" and self.system_prompt_code:
|
| 144 |
+
return self.system_prompt_code
|
| 145 |
+
elif prompt_type == "creative" and self.system_prompt_creative:
|
| 146 |
+
return self.system_prompt_creative
|
| 147 |
+
return self.system_prompt
|
| 148 |
+
|
| 149 |
+
def is_local_model(self) -> bool:
|
| 150 |
+
"""Check if using local model backend"""
|
| 151 |
+
return self.model_type.lower() == "local"
|
| 152 |
+
|
| 153 |
+
def is_api_model(self) -> bool:
|
| 154 |
+
"""Check if using API-based model backend"""
|
| 155 |
+
return self.model_type.lower() in ["hf_api", "openai", "anthropic"]
|
| 156 |
+
|
| 157 |
+
def validate_model_config(self) -> bool:
|
| 158 |
+
"""Validate model configuration based on type"""
|
| 159 |
+
if self.model_type == "hf_api" and not self.hf_api_token:
|
| 160 |
+
return False
|
| 161 |
+
elif self.model_type == "openai" and not self.openai_api_key:
|
| 162 |
+
return False
|
| 163 |
+
elif self.model_type == "anthropic" and not self.anthropic_api_key:
|
| 164 |
+
return False
|
| 165 |
+
elif self.model_type == "minimax" and (not self.minimax_api_key or not self.minimax_api_url):
|
| 166 |
+
return False
|
| 167 |
+
elif self.model_type == "google" and not self.google_api_key:
|
| 168 |
+
return False
|
| 169 |
+
return True
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Global settings instance
|
| 173 |
+
settings = Settings()
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_settings() -> Settings:
|
| 177 |
+
"""Get application settings"""
|
| 178 |
+
return settings
|
app/core/logging.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Structured logging configuration for Sema Chat API
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
import structlog
|
| 9 |
+
from .config import settings
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def configure_logging():
|
| 13 |
+
"""Configure structured logging for the application"""
|
| 14 |
+
|
| 15 |
+
# Configure structlog
|
| 16 |
+
structlog.configure(
|
| 17 |
+
processors=[
|
| 18 |
+
structlog.stdlib.filter_by_level,
|
| 19 |
+
structlog.stdlib.add_logger_name,
|
| 20 |
+
structlog.stdlib.add_log_level,
|
| 21 |
+
structlog.stdlib.PositionalArgumentsFormatter(),
|
| 22 |
+
structlog.processors.TimeStamper(fmt="iso"),
|
| 23 |
+
structlog.processors.StackInfoRenderer(),
|
| 24 |
+
structlog.processors.format_exc_info,
|
| 25 |
+
structlog.processors.UnicodeDecoder(),
|
| 26 |
+
structlog.processors.JSONRenderer() if settings.structured_logging else structlog.dev.ConsoleRenderer(),
|
| 27 |
+
],
|
| 28 |
+
context_class=dict,
|
| 29 |
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
| 30 |
+
wrapper_class=structlog.stdlib.BoundLogger,
|
| 31 |
+
cache_logger_on_first_use=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Configure standard logging
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
format="%(message)s",
|
| 37 |
+
stream=sys.stdout,
|
| 38 |
+
level=getattr(logging, settings.log_level.upper()),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Configure file logging if specified
|
| 42 |
+
if settings.log_file:
|
| 43 |
+
file_handler = logging.FileHandler(settings.log_file)
|
| 44 |
+
file_handler.setLevel(getattr(logging, settings.log_level.upper()))
|
| 45 |
+
|
| 46 |
+
if settings.structured_logging:
|
| 47 |
+
file_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 48 |
+
else:
|
| 49 |
+
file_handler.setFormatter(
|
| 50 |
+
logging.Formatter(
|
| 51 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
logging.getLogger().addHandler(file_handler)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_logger(name: str = None) -> structlog.BoundLogger:
|
| 59 |
+
"""Get a structured logger instance"""
|
| 60 |
+
return structlog.get_logger(name)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LoggerMixin:
|
| 64 |
+
"""Mixin class to add logging capabilities to any class"""
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def logger(self) -> structlog.BoundLogger:
|
| 68 |
+
"""Get logger for this class"""
|
| 69 |
+
return get_logger(self.__class__.__name__)
|
| 70 |
+
|
| 71 |
+
def log_info(self, message: str, **kwargs: Any):
|
| 72 |
+
"""Log info message with context"""
|
| 73 |
+
self.logger.info(message, **kwargs)
|
| 74 |
+
|
| 75 |
+
def log_error(self, message: str, **kwargs: Any):
|
| 76 |
+
"""Log error message with context"""
|
| 77 |
+
self.logger.error(message, **kwargs)
|
| 78 |
+
|
| 79 |
+
def log_warning(self, message: str, **kwargs: Any):
|
| 80 |
+
"""Log warning message with context"""
|
| 81 |
+
self.logger.warning(message, **kwargs)
|
| 82 |
+
|
| 83 |
+
def log_debug(self, message: str, **kwargs: Any):
|
| 84 |
+
"""Log debug message with context"""
|
| 85 |
+
self.logger.debug(message, **kwargs)
|
app/main.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sema Chat API - Main Application
|
| 3 |
+
Modern chatbot API with streaming capabilities and flexible model backends
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, Request
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
| 9 |
+
from fastapi.responses import RedirectResponse
|
| 10 |
+
from slowapi import _rate_limit_exceeded_handler
|
| 11 |
+
from slowapi.errors import RateLimitExceeded
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
from .core.config import settings
|
| 15 |
+
from .core.logging import configure_logging, get_logger
|
| 16 |
+
from .services.chat_manager import initialize_chat_manager, shutdown_chat_manager
|
| 17 |
+
from .api.v1.endpoints import router as v1_router, limiter
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
configure_logging()
|
| 21 |
+
logger = get_logger()
|
| 22 |
+
|
| 23 |
+
# Application startup time
|
| 24 |
+
startup_time = time.time()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_application() -> FastAPI:
|
| 28 |
+
"""Create and configure the FastAPI application"""
|
| 29 |
+
|
| 30 |
+
# Create FastAPI app
|
| 31 |
+
app = FastAPI(
|
| 32 |
+
title=settings.app_name,
|
| 33 |
+
description="Modern chatbot API with streaming capabilities and flexible model backends",
|
| 34 |
+
version=settings.app_version,
|
| 35 |
+
docs_url="/docs" if settings.debug else "/", # Swagger UI at root for HF Spaces
|
| 36 |
+
redoc_url="/redoc",
|
| 37 |
+
openapi_url="/openapi.json"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Add rate limiting
|
| 41 |
+
app.state.limiter = limiter
|
| 42 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 43 |
+
|
| 44 |
+
# Add CORS middleware
|
| 45 |
+
app.add_middleware(
|
| 46 |
+
CORSMiddleware,
|
| 47 |
+
allow_origins=settings.cors_origins,
|
| 48 |
+
allow_credentials=True,
|
| 49 |
+
allow_methods=["*"],
|
| 50 |
+
allow_headers=["*"],
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Add trusted host middleware for production
|
| 54 |
+
if settings.environment == "production":
|
| 55 |
+
app.add_middleware(
|
| 56 |
+
TrustedHostMiddleware,
|
| 57 |
+
allowed_hosts=["*"] # Configure appropriately for production
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Add request timing middleware
|
| 61 |
+
@app.middleware("http")
|
| 62 |
+
async def add_process_time_header(request: Request, call_next):
|
| 63 |
+
start_time = time.time()
|
| 64 |
+
response = await call_next(request)
|
| 65 |
+
process_time = time.time() - start_time
|
| 66 |
+
response.headers["X-Process-Time"] = str(process_time)
|
| 67 |
+
response.headers["X-Request-ID"] = str(id(request))
|
| 68 |
+
return response
|
| 69 |
+
|
| 70 |
+
# Include API routes
|
| 71 |
+
app.include_router(v1_router, prefix="/api/v1", tags=["Chat API v1"])
|
| 72 |
+
|
| 73 |
+
# Root redirect for HuggingFace Spaces
|
| 74 |
+
@app.get("/", include_in_schema=False)
|
| 75 |
+
async def root():
|
| 76 |
+
"""Redirect root to docs for HuggingFace Spaces"""
|
| 77 |
+
return RedirectResponse(url="/docs")
|
| 78 |
+
|
| 79 |
+
return app
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Create the application instance
|
| 83 |
+
app = create_application()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@app.on_event("startup")
|
| 87 |
+
async def startup_event():
|
| 88 |
+
"""Initialize the application on startup"""
|
| 89 |
+
logger.info("application_startup",
|
| 90 |
+
version=settings.app_version,
|
| 91 |
+
environment=settings.environment,
|
| 92 |
+
model_type=settings.model_type,
|
| 93 |
+
model_name=settings.model_name)
|
| 94 |
+
|
| 95 |
+
print(f"\nπ Starting {settings.app_name} v{settings.app_version}")
|
| 96 |
+
print(f"π Environment: {settings.environment}")
|
| 97 |
+
print(f"π€ Model Backend: {settings.model_type}")
|
| 98 |
+
print(f"π― Model: {settings.model_name}")
|
| 99 |
+
print("π Initializing chat services...")
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
# Initialize chat manager (which initializes model and session managers)
|
| 103 |
+
success = await initialize_chat_manager()
|
| 104 |
+
|
| 105 |
+
if success:
|
| 106 |
+
logger.info("chat_services_initialized")
|
| 107 |
+
print("β
Chat services initialized successfully")
|
| 108 |
+
print(f"π API Documentation: http://localhost:7860/docs")
|
| 109 |
+
print(f"π‘ WebSocket Chat: ws://localhost:7860/api/v1/chat/ws")
|
| 110 |
+
print(f"π Streaming Chat: http://localhost:7860/api/v1/chat/stream")
|
| 111 |
+
print(f"π¬ Regular Chat: http://localhost:7860/api/v1/chat")
|
| 112 |
+
print(f"β€οΈ Health Check: http://localhost:7860/api/v1/health")
|
| 113 |
+
|
| 114 |
+
if settings.enable_metrics:
|
| 115 |
+
print(f"π Metrics: http://localhost:7860/api/v1/metrics")
|
| 116 |
+
|
| 117 |
+
print("\nπ Sema Chat API is ready for conversations!")
|
| 118 |
+
print("=" * 60)
|
| 119 |
+
else:
|
| 120 |
+
logger.error("chat_services_initialization_failed")
|
| 121 |
+
print("β Failed to initialize chat services")
|
| 122 |
+
print("π§ Please check your configuration and try again")
|
| 123 |
+
raise RuntimeError("Chat services initialization failed")
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error("startup_failed", error=str(e))
|
| 127 |
+
print(f"π₯ Startup failed: {e}")
|
| 128 |
+
raise
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@app.on_event("shutdown")
|
| 132 |
+
async def shutdown_event():
|
| 133 |
+
"""Cleanup on application shutdown"""
|
| 134 |
+
logger.info("application_shutdown")
|
| 135 |
+
print("\nπ Shutting down Sema Chat API...")
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
await shutdown_chat_manager()
|
| 139 |
+
print("β
Chat services shutdown complete")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error("shutdown_failed", error=str(e))
|
| 142 |
+
print(f"β οΈ Shutdown warning: {e}")
|
| 143 |
+
|
| 144 |
+
print("π Goodbye!\n")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# Health check endpoint at app level
|
| 148 |
+
@app.get("/health", tags=["Health"])
|
| 149 |
+
async def app_health():
|
| 150 |
+
"""Simple app-level health check"""
|
| 151 |
+
uptime = time.time() - startup_time
|
| 152 |
+
return {
|
| 153 |
+
"status": "healthy",
|
| 154 |
+
"service": "sema-chat-api",
|
| 155 |
+
"version": settings.app_version,
|
| 156 |
+
"uptime_seconds": uptime,
|
| 157 |
+
"model_type": settings.model_type,
|
| 158 |
+
"model_name": settings.model_name
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Status endpoint
|
| 163 |
+
@app.get("/status", tags=["Health"])
|
| 164 |
+
async def app_status():
|
| 165 |
+
"""Simple status endpoint"""
|
| 166 |
+
return {
|
| 167 |
+
"status": "ok",
|
| 168 |
+
"service": "sema-chat-api",
|
| 169 |
+
"version": settings.app_version,
|
| 170 |
+
"environment": settings.environment
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
import uvicorn
|
| 176 |
+
|
| 177 |
+
print(f"π Starting Sema Chat API on {settings.host}:7860")
|
| 178 |
+
print(f"π§ Debug mode: {settings.debug}")
|
| 179 |
+
print(f"π€ Model: {settings.model_type}/{settings.model_name}")
|
| 180 |
+
|
| 181 |
+
uvicorn.run(
|
| 182 |
+
"app.main:app",
|
| 183 |
+
host=settings.host,
|
| 184 |
+
port=7860,
|
| 185 |
+
reload=settings.debug,
|
| 186 |
+
log_level=settings.log_level.lower()
|
| 187 |
+
)
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Pydantic models and schemas
|
app/models/schemas.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for request/response validation
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional, Dict, Any, Union
|
| 6 |
+
from pydantic import BaseModel, Field, validator
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import uuid
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ChatMessage(BaseModel):
|
| 12 |
+
"""Individual chat message model"""
|
| 13 |
+
|
| 14 |
+
role: str = Field(..., description="Message role: 'user' or 'assistant'")
|
| 15 |
+
content: str = Field(..., description="Message content")
|
| 16 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Message timestamp")
|
| 17 |
+
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional message metadata")
|
| 18 |
+
|
| 19 |
+
@validator('role')
|
| 20 |
+
def validate_role(cls, v):
|
| 21 |
+
if v not in ['user', 'assistant', 'system']:
|
| 22 |
+
raise ValueError('Role must be user, assistant, or system')
|
| 23 |
+
return v
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ChatRequest(BaseModel):
|
| 27 |
+
"""Chat request model"""
|
| 28 |
+
|
| 29 |
+
message: str = Field(
|
| 30 |
+
...,
|
| 31 |
+
description="User message",
|
| 32 |
+
min_length=1,
|
| 33 |
+
max_length=4000,
|
| 34 |
+
example="Hello, how are you today?"
|
| 35 |
+
)
|
| 36 |
+
session_id: str = Field(
|
| 37 |
+
...,
|
| 38 |
+
description="Session identifier for conversation context",
|
| 39 |
+
example="user-123-session"
|
| 40 |
+
)
|
| 41 |
+
system_prompt: Optional[str] = Field(
|
| 42 |
+
default=None,
|
| 43 |
+
description="Custom system prompt for this conversation",
|
| 44 |
+
max_length=1000
|
| 45 |
+
)
|
| 46 |
+
temperature: Optional[float] = Field(
|
| 47 |
+
default=None,
|
| 48 |
+
description="Sampling temperature (0.0 to 1.0)",
|
| 49 |
+
ge=0.0,
|
| 50 |
+
le=1.0
|
| 51 |
+
)
|
| 52 |
+
max_tokens: Optional[int] = Field(
|
| 53 |
+
default=None,
|
| 54 |
+
description="Maximum tokens to generate",
|
| 55 |
+
ge=1,
|
| 56 |
+
le=2048
|
| 57 |
+
)
|
| 58 |
+
stream: Optional[bool] = Field(
|
| 59 |
+
default=False,
|
| 60 |
+
description="Whether to stream the response"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ChatResponse(BaseModel):
|
| 65 |
+
"""Chat response model"""
|
| 66 |
+
|
| 67 |
+
message: str = Field(..., description="Assistant response message")
|
| 68 |
+
session_id: str = Field(..., description="Session identifier")
|
| 69 |
+
message_id: str = Field(..., description="Unique message identifier")
|
| 70 |
+
model_name: str = Field(..., description="Model used for generation")
|
| 71 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Response timestamp")
|
| 72 |
+
generation_time: float = Field(..., description="Time taken to generate response (seconds)")
|
| 73 |
+
token_count: Optional[int] = Field(default=None, description="Number of tokens in response")
|
| 74 |
+
finish_reason: Optional[str] = Field(default=None, description="Reason generation finished")
|
| 75 |
+
|
| 76 |
+
class Config:
|
| 77 |
+
json_schema_extra = {
|
| 78 |
+
"example": {
|
| 79 |
+
"message": "Hello! I'm doing well, thank you for asking. How can I help you today?",
|
| 80 |
+
"session_id": "user-123-session",
|
| 81 |
+
"message_id": "msg-456-789",
|
| 82 |
+
"model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 83 |
+
"timestamp": "2024-01-15T10:30:00Z",
|
| 84 |
+
"generation_time": 1.234,
|
| 85 |
+
"token_count": 25,
|
| 86 |
+
"finish_reason": "stop"
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class StreamChunk(BaseModel):
|
| 92 |
+
"""Streaming response chunk model"""
|
| 93 |
+
|
| 94 |
+
content: str = Field(..., description="Chunk content")
|
| 95 |
+
session_id: str = Field(..., description="Session identifier")
|
| 96 |
+
message_id: str = Field(..., description="Message identifier")
|
| 97 |
+
chunk_id: int = Field(..., description="Chunk sequence number")
|
| 98 |
+
is_final: bool = Field(default=False, description="Whether this is the final chunk")
|
| 99 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Chunk timestamp")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ConversationHistory(BaseModel):
|
| 103 |
+
"""Conversation history model"""
|
| 104 |
+
|
| 105 |
+
session_id: str = Field(..., description="Session identifier")
|
| 106 |
+
messages: List[ChatMessage] = Field(..., description="List of messages in conversation")
|
| 107 |
+
created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time")
|
| 108 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
| 109 |
+
message_count: int = Field(..., description="Total number of messages")
|
| 110 |
+
|
| 111 |
+
class Config:
|
| 112 |
+
json_schema_extra = {
|
| 113 |
+
"example": {
|
| 114 |
+
"session_id": "user-123-session",
|
| 115 |
+
"messages": [
|
| 116 |
+
{
|
| 117 |
+
"role": "user",
|
| 118 |
+
"content": "Hello!",
|
| 119 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"role": "assistant",
|
| 123 |
+
"content": "Hello! How can I help you today?",
|
| 124 |
+
"timestamp": "2024-01-15T10:30:01Z"
|
| 125 |
+
}
|
| 126 |
+
],
|
| 127 |
+
"created_at": "2024-01-15T10:30:00Z",
|
| 128 |
+
"updated_at": "2024-01-15T10:30:01Z",
|
| 129 |
+
"message_count": 2
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class SessionInfo(BaseModel):
|
| 135 |
+
"""Session information model"""
|
| 136 |
+
|
| 137 |
+
session_id: str = Field(..., description="Session identifier")
|
| 138 |
+
created_at: datetime = Field(..., description="Session creation time")
|
| 139 |
+
updated_at: datetime = Field(..., description="Last activity time")
|
| 140 |
+
message_count: int = Field(..., description="Number of messages in session")
|
| 141 |
+
model_name: str = Field(..., description="Model used in this session")
|
| 142 |
+
is_active: bool = Field(..., description="Whether session is active")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class HealthResponse(BaseModel):
|
| 146 |
+
"""Health check response model"""
|
| 147 |
+
|
| 148 |
+
status: str = Field(..., description="API health status")
|
| 149 |
+
version: str = Field(..., description="API version")
|
| 150 |
+
model_type: str = Field(..., description="Current model backend type")
|
| 151 |
+
model_name: str = Field(..., description="Current model name")
|
| 152 |
+
model_loaded: bool = Field(..., description="Whether model is loaded and ready")
|
| 153 |
+
uptime: float = Field(..., description="API uptime in seconds")
|
| 154 |
+
active_sessions: int = Field(..., description="Number of active chat sessions")
|
| 155 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Health check timestamp")
|
| 156 |
+
|
| 157 |
+
class Config:
|
| 158 |
+
json_schema_extra = {
|
| 159 |
+
"example": {
|
| 160 |
+
"status": "healthy",
|
| 161 |
+
"version": "1.0.0",
|
| 162 |
+
"model_type": "local",
|
| 163 |
+
"model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 164 |
+
"model_loaded": True,
|
| 165 |
+
"uptime": 3600.5,
|
| 166 |
+
"active_sessions": 5,
|
| 167 |
+
"timestamp": "2024-01-15T10:30:00Z"
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ErrorResponse(BaseModel):
|
| 173 |
+
"""Error response model"""
|
| 174 |
+
|
| 175 |
+
error: str = Field(..., description="Error type")
|
| 176 |
+
message: str = Field(..., description="Error message")
|
| 177 |
+
details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details")
|
| 178 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Error timestamp")
|
| 179 |
+
request_id: Optional[str] = Field(default=None, description="Request identifier for debugging")
|
| 180 |
+
|
| 181 |
+
class Config:
|
| 182 |
+
json_schema_extra = {
|
| 183 |
+
"example": {
|
| 184 |
+
"error": "validation_error",
|
| 185 |
+
"message": "Message content is required",
|
| 186 |
+
"details": {"field": "message", "constraint": "min_length"},
|
| 187 |
+
"timestamp": "2024-01-15T10:30:00Z",
|
| 188 |
+
"request_id": "req-123-456"
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ModelInfo(BaseModel):
|
| 194 |
+
"""Model information model"""
|
| 195 |
+
|
| 196 |
+
name: str = Field(..., description="Model name")
|
| 197 |
+
type: str = Field(..., description="Model backend type")
|
| 198 |
+
loaded: bool = Field(..., description="Whether model is loaded")
|
| 199 |
+
parameters: Optional[Dict[str, Any]] = Field(default=None, description="Model parameters")
|
| 200 |
+
capabilities: List[str] = Field(default=[], description="Model capabilities")
|
| 201 |
+
|
| 202 |
+
class Config:
|
| 203 |
+
json_schema_extra = {
|
| 204 |
+
"example": {
|
| 205 |
+
"name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 206 |
+
"type": "local",
|
| 207 |
+
"loaded": True,
|
| 208 |
+
"parameters": {
|
| 209 |
+
"temperature": 0.7,
|
| 210 |
+
"max_tokens": 512,
|
| 211 |
+
"top_p": 0.9
|
| 212 |
+
},
|
| 213 |
+
"capabilities": ["chat", "instruction_following", "streaming"]
|
| 214 |
+
}
|
| 215 |
+
}
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Services package
|
app/services/chat_manager.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Manager - Main orchestrator for chat functionality
|
| 3 |
+
Coordinates between model backends and session management
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import AsyncGenerator, List, Optional, Dict, Any
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from ..core.config import settings
|
| 12 |
+
from ..core.logging import LoggerMixin
|
| 13 |
+
from ..models.schemas import (
|
| 14 |
+
ChatMessage, ChatRequest, ChatResponse, StreamChunk,
|
| 15 |
+
ConversationHistory, ErrorResponse
|
| 16 |
+
)
|
| 17 |
+
from .model_manager import get_model_manager
|
| 18 |
+
from .session_manager import get_session_manager
|
| 19 |
+
from .model_backends.base import ModelBackendError, ModelNotLoadedError, GenerationError
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ChatManager(LoggerMixin):
|
| 23 |
+
"""
|
| 24 |
+
Main chat service that orchestrates conversation handling
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.is_initialized = False
|
| 29 |
+
self.active_streams = 0
|
| 30 |
+
self.max_concurrent_streams = settings.max_concurrent_streams
|
| 31 |
+
|
| 32 |
+
async def initialize(self) -> bool:
|
| 33 |
+
"""Initialize the chat manager"""
|
| 34 |
+
try:
|
| 35 |
+
self.log_info("Initializing chat manager")
|
| 36 |
+
|
| 37 |
+
# Initialize model manager
|
| 38 |
+
model_manager = await get_model_manager()
|
| 39 |
+
if not await model_manager.initialize():
|
| 40 |
+
self.log_error("Failed to initialize model manager")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
# Initialize session manager
|
| 44 |
+
session_manager = await get_session_manager()
|
| 45 |
+
if not await session_manager.initialize():
|
| 46 |
+
self.log_error("Failed to initialize session manager")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
self.is_initialized = True
|
| 50 |
+
self.log_info("Chat manager initialized successfully")
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
self.log_error("Chat manager initialization failed", error=str(e))
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
async def shutdown(self):
|
| 58 |
+
"""Shutdown the chat manager"""
|
| 59 |
+
try:
|
| 60 |
+
self.log_info("Shutting down chat manager")
|
| 61 |
+
|
| 62 |
+
# Shutdown managers
|
| 63 |
+
model_manager = await get_model_manager()
|
| 64 |
+
await model_manager.shutdown()
|
| 65 |
+
|
| 66 |
+
session_manager = await get_session_manager()
|
| 67 |
+
await session_manager.shutdown()
|
| 68 |
+
|
| 69 |
+
self.is_initialized = False
|
| 70 |
+
self.log_info("Chat manager shutdown complete")
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
self.log_error("Chat manager shutdown failed", error=str(e))
|
| 74 |
+
|
| 75 |
+
async def process_chat_request(self, request: ChatRequest) -> ChatResponse:
|
| 76 |
+
"""
|
| 77 |
+
Process a non-streaming chat request
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
request: Chat request containing message and parameters
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
ChatResponse: Complete response
|
| 84 |
+
"""
|
| 85 |
+
if not self.is_initialized:
|
| 86 |
+
raise RuntimeError("Chat manager not initialized")
|
| 87 |
+
|
| 88 |
+
start_time = time.time()
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
# Get managers
|
| 92 |
+
model_manager = await get_model_manager()
|
| 93 |
+
session_manager = await get_session_manager()
|
| 94 |
+
|
| 95 |
+
if not model_manager.is_ready():
|
| 96 |
+
raise ModelNotLoadedError("Model not ready for inference")
|
| 97 |
+
|
| 98 |
+
# Ensure session exists
|
| 99 |
+
await session_manager.create_session(request.session_id)
|
| 100 |
+
|
| 101 |
+
# Add user message to session
|
| 102 |
+
user_message = ChatMessage(
|
| 103 |
+
role="user",
|
| 104 |
+
content=request.message,
|
| 105 |
+
timestamp=datetime.utcnow(),
|
| 106 |
+
metadata={"session_id": request.session_id}
|
| 107 |
+
)
|
| 108 |
+
await session_manager.add_message(request.session_id, user_message)
|
| 109 |
+
|
| 110 |
+
# Get conversation history
|
| 111 |
+
messages = await self._prepare_messages_for_model(
|
| 112 |
+
request.session_id,
|
| 113 |
+
request.system_prompt
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Generate response
|
| 117 |
+
backend = model_manager.get_backend()
|
| 118 |
+
response = await backend.generate_response(
|
| 119 |
+
messages=messages,
|
| 120 |
+
temperature=request.temperature or settings.temperature,
|
| 121 |
+
max_tokens=request.max_tokens or settings.max_new_tokens
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Add assistant message to session
|
| 125 |
+
assistant_message = ChatMessage(
|
| 126 |
+
role="assistant",
|
| 127 |
+
content=response.message,
|
| 128 |
+
timestamp=datetime.utcnow(),
|
| 129 |
+
metadata={"session_id": request.session_id, "message_id": response.message_id}
|
| 130 |
+
)
|
| 131 |
+
await session_manager.add_message(request.session_id, assistant_message)
|
| 132 |
+
|
| 133 |
+
# Update response with correct session info
|
| 134 |
+
response.session_id = request.session_id
|
| 135 |
+
|
| 136 |
+
self.log_info("Chat request processed",
|
| 137 |
+
session_id=request.session_id,
|
| 138 |
+
generation_time=response.generation_time,
|
| 139 |
+
total_time=time.time() - start_time)
|
| 140 |
+
|
| 141 |
+
return response
|
| 142 |
+
|
| 143 |
+
except ModelBackendError as e:
|
| 144 |
+
self.log_error("Model backend error", error=str(e), session_id=request.session_id)
|
| 145 |
+
raise
|
| 146 |
+
except Exception as e:
|
| 147 |
+
self.log_error("Chat request processing failed", error=str(e), session_id=request.session_id)
|
| 148 |
+
raise
|
| 149 |
+
|
| 150 |
+
async def process_streaming_chat_request(
|
| 151 |
+
self,
|
| 152 |
+
request: ChatRequest
|
| 153 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 154 |
+
"""
|
| 155 |
+
Process a streaming chat request
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
request: Chat request containing message and parameters
|
| 159 |
+
|
| 160 |
+
Yields:
|
| 161 |
+
StreamChunk: Response chunks
|
| 162 |
+
"""
|
| 163 |
+
if not self.is_initialized:
|
| 164 |
+
raise RuntimeError("Chat manager not initialized")
|
| 165 |
+
|
| 166 |
+
# Check concurrent stream limit
|
| 167 |
+
if self.active_streams >= self.max_concurrent_streams:
|
| 168 |
+
raise RuntimeError(f"Maximum concurrent streams ({self.max_concurrent_streams}) exceeded")
|
| 169 |
+
|
| 170 |
+
self.active_streams += 1
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# Get managers
|
| 174 |
+
model_manager = await get_model_manager()
|
| 175 |
+
session_manager = await get_session_manager()
|
| 176 |
+
|
| 177 |
+
if not model_manager.is_ready():
|
| 178 |
+
raise ModelNotLoadedError("Model not ready for inference")
|
| 179 |
+
|
| 180 |
+
# Ensure session exists
|
| 181 |
+
await session_manager.create_session(request.session_id)
|
| 182 |
+
|
| 183 |
+
# Add user message to session
|
| 184 |
+
user_message = ChatMessage(
|
| 185 |
+
role="user",
|
| 186 |
+
content=request.message,
|
| 187 |
+
timestamp=datetime.utcnow(),
|
| 188 |
+
metadata={"session_id": request.session_id}
|
| 189 |
+
)
|
| 190 |
+
await session_manager.add_message(request.session_id, user_message)
|
| 191 |
+
|
| 192 |
+
# Get conversation history
|
| 193 |
+
messages = await self._prepare_messages_for_model(
|
| 194 |
+
request.session_id,
|
| 195 |
+
request.system_prompt
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Generate streaming response
|
| 199 |
+
backend = model_manager.get_backend()
|
| 200 |
+
full_response = ""
|
| 201 |
+
message_id = None
|
| 202 |
+
|
| 203 |
+
async for chunk in backend.generate_stream(
|
| 204 |
+
messages=messages,
|
| 205 |
+
temperature=request.temperature or settings.temperature,
|
| 206 |
+
max_tokens=request.max_tokens or settings.max_new_tokens
|
| 207 |
+
):
|
| 208 |
+
if message_id is None:
|
| 209 |
+
message_id = chunk.message_id
|
| 210 |
+
|
| 211 |
+
full_response += chunk.content
|
| 212 |
+
yield chunk
|
| 213 |
+
|
| 214 |
+
# Add complete assistant message to session
|
| 215 |
+
if full_response.strip():
|
| 216 |
+
assistant_message = ChatMessage(
|
| 217 |
+
role="assistant",
|
| 218 |
+
content=full_response.strip(),
|
| 219 |
+
timestamp=datetime.utcnow(),
|
| 220 |
+
metadata={"session_id": request.session_id, "message_id": message_id}
|
| 221 |
+
)
|
| 222 |
+
await session_manager.add_message(request.session_id, assistant_message)
|
| 223 |
+
|
| 224 |
+
self.log_info("Streaming chat request processed",
|
| 225 |
+
session_id=request.session_id,
|
| 226 |
+
response_length=len(full_response))
|
| 227 |
+
|
| 228 |
+
except ModelBackendError as e:
|
| 229 |
+
self.log_error("Model backend error in streaming", error=str(e), session_id=request.session_id)
|
| 230 |
+
raise
|
| 231 |
+
except Exception as e:
|
| 232 |
+
self.log_error("Streaming chat request failed", error=str(e), session_id=request.session_id)
|
| 233 |
+
raise
|
| 234 |
+
finally:
|
| 235 |
+
self.active_streams -= 1
|
| 236 |
+
|
| 237 |
+
async def _prepare_messages_for_model(
|
| 238 |
+
self,
|
| 239 |
+
session_id: str,
|
| 240 |
+
custom_system_prompt: Optional[str] = None
|
| 241 |
+
) -> List[ChatMessage]:
|
| 242 |
+
"""
|
| 243 |
+
Prepare messages for model input, including system prompt and history
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
session_id: Session identifier
|
| 247 |
+
custom_system_prompt: Optional custom system prompt
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
List of ChatMessage objects ready for model input
|
| 251 |
+
"""
|
| 252 |
+
session_manager = await get_session_manager()
|
| 253 |
+
|
| 254 |
+
# Get conversation history
|
| 255 |
+
history_messages = await session_manager.get_session_messages(session_id)
|
| 256 |
+
|
| 257 |
+
# Prepare messages list
|
| 258 |
+
messages = []
|
| 259 |
+
|
| 260 |
+
# Add system prompt if provided
|
| 261 |
+
system_prompt = custom_system_prompt or settings.get_system_prompt()
|
| 262 |
+
if system_prompt:
|
| 263 |
+
messages.append(ChatMessage(
|
| 264 |
+
role="system",
|
| 265 |
+
content=system_prompt,
|
| 266 |
+
timestamp=datetime.utcnow()
|
| 267 |
+
))
|
| 268 |
+
|
| 269 |
+
# Add conversation history
|
| 270 |
+
messages.extend(history_messages)
|
| 271 |
+
|
| 272 |
+
return messages
|
| 273 |
+
|
| 274 |
+
async def get_conversation_history(self, session_id: str) -> Optional[ConversationHistory]:
|
| 275 |
+
"""
|
| 276 |
+
Get conversation history for a session
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
session_id: Session identifier
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
ConversationHistory or None if session not found
|
| 283 |
+
"""
|
| 284 |
+
try:
|
| 285 |
+
session_manager = await get_session_manager()
|
| 286 |
+
return await session_manager.get_session(session_id)
|
| 287 |
+
|
| 288 |
+
except Exception as e:
|
| 289 |
+
self.log_error("Failed to get conversation history", error=str(e), session_id=session_id)
|
| 290 |
+
return None
|
| 291 |
+
|
| 292 |
+
async def clear_conversation(self, session_id: str) -> bool:
|
| 293 |
+
"""
|
| 294 |
+
Clear conversation history for a session
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
session_id: Session identifier
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
bool: True if cleared successfully
|
| 301 |
+
"""
|
| 302 |
+
try:
|
| 303 |
+
session_manager = await get_session_manager()
|
| 304 |
+
return await session_manager.delete_session(session_id)
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
self.log_error("Failed to clear conversation", error=str(e), session_id=session_id)
|
| 308 |
+
return False
|
| 309 |
+
|
| 310 |
+
async def get_active_sessions(self) -> List[Dict[str, Any]]:
|
| 311 |
+
"""Get information about active chat sessions"""
|
| 312 |
+
try:
|
| 313 |
+
session_manager = await get_session_manager()
|
| 314 |
+
sessions = await session_manager.get_active_sessions()
|
| 315 |
+
|
| 316 |
+
return [
|
| 317 |
+
{
|
| 318 |
+
"session_id": session.session_id,
|
| 319 |
+
"created_at": session.created_at.isoformat(),
|
| 320 |
+
"updated_at": session.updated_at.isoformat(),
|
| 321 |
+
"message_count": session.message_count,
|
| 322 |
+
"model_name": session.model_name,
|
| 323 |
+
"is_active": session.is_active
|
| 324 |
+
}
|
| 325 |
+
for session in sessions
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
self.log_error("Failed to get active sessions", error=str(e))
|
| 330 |
+
return []
|
| 331 |
+
|
| 332 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 333 |
+
"""Perform a comprehensive health check"""
|
| 334 |
+
try:
|
| 335 |
+
health_status = {
|
| 336 |
+
"chat_manager": {
|
| 337 |
+
"status": "healthy" if self.is_initialized else "unhealthy",
|
| 338 |
+
"initialized": self.is_initialized,
|
| 339 |
+
"active_streams": self.active_streams,
|
| 340 |
+
"max_concurrent_streams": self.max_concurrent_streams
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# Check model manager
|
| 345 |
+
model_manager = await get_model_manager()
|
| 346 |
+
model_health = await model_manager.health_check()
|
| 347 |
+
health_status["model_manager"] = model_health
|
| 348 |
+
|
| 349 |
+
# Check session manager
|
| 350 |
+
session_manager = await get_session_manager()
|
| 351 |
+
active_sessions = await session_manager.get_active_sessions()
|
| 352 |
+
health_status["session_manager"] = {
|
| 353 |
+
"status": "healthy",
|
| 354 |
+
"active_sessions": len(active_sessions),
|
| 355 |
+
"storage_type": "redis" if session_manager.use_redis else "memory"
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
# Overall status
|
| 359 |
+
overall_healthy = (
|
| 360 |
+
self.is_initialized and
|
| 361 |
+
model_health.get("status") == "healthy"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
health_status["overall"] = {
|
| 365 |
+
"status": "healthy" if overall_healthy else "unhealthy",
|
| 366 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
return health_status
|
| 370 |
+
|
| 371 |
+
except Exception as e:
|
| 372 |
+
self.log_error("Health check failed", error=str(e))
|
| 373 |
+
return {
|
| 374 |
+
"overall": {
|
| 375 |
+
"status": "unhealthy",
|
| 376 |
+
"error": str(e),
|
| 377 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# Global chat manager instance
|
| 383 |
+
chat_manager = ChatManager()
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
async def get_chat_manager() -> ChatManager:
|
| 387 |
+
"""Get the global chat manager instance"""
|
| 388 |
+
return chat_manager
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
async def initialize_chat_manager() -> bool:
|
| 392 |
+
"""Initialize the global chat manager"""
|
| 393 |
+
return await chat_manager.initialize()
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
async def shutdown_chat_manager():
|
| 397 |
+
"""Shutdown the global chat manager"""
|
| 398 |
+
await chat_manager.shutdown()
|
app/services/model_backends/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Model backends package
|
app/services/model_backends/anthropic_api.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anthropic API backend
|
| 3 |
+
Uses Anthropic's API for Claude models
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import anthropic
|
| 12 |
+
|
| 13 |
+
from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
|
| 14 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 15 |
+
from ...core.config import settings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AnthropicAPIBackend(ModelBackend):
|
| 19 |
+
"""Anthropic API backend for Claude models"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_name: str, **kwargs):
|
| 22 |
+
super().__init__(model_name, **kwargs)
|
| 23 |
+
self.client = None
|
| 24 |
+
self.api_key = kwargs.get('api_key', settings.anthropic_api_key)
|
| 25 |
+
self.capabilities = ["chat", "streaming", "api_based", "long_context"]
|
| 26 |
+
|
| 27 |
+
# Generation parameters
|
| 28 |
+
self.parameters = {
|
| 29 |
+
'temperature': kwargs.get('temperature', settings.temperature),
|
| 30 |
+
'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
|
| 31 |
+
'top_p': kwargs.get('top_p', settings.top_p),
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
async def load_model(self) -> bool:
|
| 35 |
+
"""Initialize the Anthropic API client"""
|
| 36 |
+
try:
|
| 37 |
+
if not self.api_key:
|
| 38 |
+
raise ModelLoadError("Anthropic API key is required")
|
| 39 |
+
|
| 40 |
+
self.log_info("Initializing Anthropic API client", model=self.model_name)
|
| 41 |
+
|
| 42 |
+
# Initialize the Anthropic client
|
| 43 |
+
self.client = anthropic.AsyncAnthropic(
|
| 44 |
+
api_key=self.api_key
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Test the connection
|
| 48 |
+
await self._test_connection()
|
| 49 |
+
|
| 50 |
+
self.is_loaded = True
|
| 51 |
+
self.log_info("Anthropic API client initialized successfully", model=self.model_name)
|
| 52 |
+
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
self.log_error("Failed to initialize Anthropic API client", error=str(e), model=self.model_name)
|
| 57 |
+
raise ModelLoadError(f"Failed to initialize Anthropic API for {self.model_name}: {str(e)}")
|
| 58 |
+
|
| 59 |
+
async def unload_model(self) -> bool:
|
| 60 |
+
"""Clean up the API client"""
|
| 61 |
+
try:
|
| 62 |
+
if self.client:
|
| 63 |
+
await self.client.close()
|
| 64 |
+
self.client = None
|
| 65 |
+
self.is_loaded = False
|
| 66 |
+
self.log_info("Anthropic API client cleaned up", model=self.model_name)
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
self.log_error("Failed to cleanup Anthropic API client", error=str(e), model=self.model_name)
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
async def _test_connection(self):
|
| 74 |
+
"""Test the Anthropic API connection"""
|
| 75 |
+
try:
|
| 76 |
+
# Simple test request
|
| 77 |
+
response = await self.client.messages.create(
|
| 78 |
+
model=self.model_name,
|
| 79 |
+
max_tokens=5,
|
| 80 |
+
temperature=0.1,
|
| 81 |
+
messages=[{"role": "user", "content": "Hello"}]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.log_info("Anthropic API connection test successful", model=self.model_name)
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
self.log_error("Anthropic API connection test failed", error=str(e), model=self.model_name)
|
| 88 |
+
raise
|
| 89 |
+
|
| 90 |
+
def _format_messages_for_api(self, messages: List[ChatMessage]) -> tuple:
|
| 91 |
+
"""Format messages for Anthropic API (separate system and messages)"""
|
| 92 |
+
system_message = None
|
| 93 |
+
formatted_messages = []
|
| 94 |
+
|
| 95 |
+
for msg in messages:
|
| 96 |
+
if msg.role == "system":
|
| 97 |
+
system_message = msg.content
|
| 98 |
+
else:
|
| 99 |
+
formatted_messages.append({
|
| 100 |
+
"role": msg.role,
|
| 101 |
+
"content": msg.content
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
return system_message, formatted_messages
|
| 105 |
+
|
| 106 |
+
async def generate_response(
|
| 107 |
+
self,
|
| 108 |
+
messages: List[ChatMessage],
|
| 109 |
+
temperature: float = 0.7,
|
| 110 |
+
max_tokens: int = 512,
|
| 111 |
+
**kwargs
|
| 112 |
+
) -> ChatResponse:
|
| 113 |
+
"""Generate a complete response using Anthropic API"""
|
| 114 |
+
if not self.is_loaded:
|
| 115 |
+
raise ModelNotLoadedError("Anthropic API client not initialized")
|
| 116 |
+
|
| 117 |
+
start_time = time.time()
|
| 118 |
+
message_id = str(uuid.uuid4())
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Validate parameters
|
| 122 |
+
params = self.validate_parameters(
|
| 123 |
+
temperature=temperature,
|
| 124 |
+
max_tokens=max_tokens,
|
| 125 |
+
**kwargs
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Format messages
|
| 129 |
+
system_message, api_messages = self._format_messages_for_api(messages)
|
| 130 |
+
|
| 131 |
+
# Prepare request parameters
|
| 132 |
+
request_params = {
|
| 133 |
+
"model": self.model_name,
|
| 134 |
+
"messages": api_messages,
|
| 135 |
+
"max_tokens": params['max_tokens'],
|
| 136 |
+
"temperature": params['temperature'],
|
| 137 |
+
"top_p": params.get('top_p', 0.9),
|
| 138 |
+
"stream": False
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Add system message if present
|
| 142 |
+
if system_message:
|
| 143 |
+
request_params["system"] = system_message
|
| 144 |
+
|
| 145 |
+
# Make API call
|
| 146 |
+
response = await self.client.messages.create(**request_params)
|
| 147 |
+
|
| 148 |
+
# Extract response
|
| 149 |
+
response_text = response.content[0].text if response.content else ""
|
| 150 |
+
finish_reason = getattr(response, 'stop_reason', 'stop')
|
| 151 |
+
token_count = getattr(response.usage, 'output_tokens', None) if hasattr(response, 'usage') else None
|
| 152 |
+
|
| 153 |
+
generation_time = time.time() - start_time
|
| 154 |
+
|
| 155 |
+
return ChatResponse(
|
| 156 |
+
message=response_text.strip(),
|
| 157 |
+
session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
|
| 158 |
+
message_id=message_id,
|
| 159 |
+
model_name=self.model_name,
|
| 160 |
+
generation_time=generation_time,
|
| 161 |
+
token_count=token_count,
|
| 162 |
+
finish_reason=finish_reason
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
self.log_error("Anthropic API generation failed", error=str(e), model=self.model_name)
|
| 167 |
+
raise GenerationError(f"Failed to generate response via Anthropic API: {str(e)}")
|
| 168 |
+
|
| 169 |
+
async def generate_stream(
|
| 170 |
+
self,
|
| 171 |
+
messages: List[ChatMessage],
|
| 172 |
+
temperature: float = 0.7,
|
| 173 |
+
max_tokens: int = 512,
|
| 174 |
+
**kwargs
|
| 175 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 176 |
+
"""Generate a streaming response using Anthropic API"""
|
| 177 |
+
if not self.is_loaded:
|
| 178 |
+
raise ModelNotLoadedError("Anthropic API client not initialized")
|
| 179 |
+
|
| 180 |
+
message_id = str(uuid.uuid4())
|
| 181 |
+
session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
|
| 182 |
+
chunk_id = 0
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
# Validate parameters
|
| 186 |
+
params = self.validate_parameters(
|
| 187 |
+
temperature=temperature,
|
| 188 |
+
max_tokens=max_tokens,
|
| 189 |
+
**kwargs
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Format messages
|
| 193 |
+
system_message, api_messages = self._format_messages_for_api(messages)
|
| 194 |
+
|
| 195 |
+
# Prepare request parameters
|
| 196 |
+
request_params = {
|
| 197 |
+
"model": self.model_name,
|
| 198 |
+
"messages": api_messages,
|
| 199 |
+
"max_tokens": params['max_tokens'],
|
| 200 |
+
"temperature": params['temperature'],
|
| 201 |
+
"top_p": params.get('top_p', 0.9),
|
| 202 |
+
"stream": True
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
# Add system message if present
|
| 206 |
+
if system_message:
|
| 207 |
+
request_params["system"] = system_message
|
| 208 |
+
|
| 209 |
+
# Create streaming request
|
| 210 |
+
stream = await self.client.messages.create(**request_params)
|
| 211 |
+
|
| 212 |
+
# Process streaming chunks
|
| 213 |
+
async for chunk in stream:
|
| 214 |
+
if chunk.type == "content_block_delta":
|
| 215 |
+
if hasattr(chunk.delta, 'text') and chunk.delta.text:
|
| 216 |
+
yield StreamChunk(
|
| 217 |
+
content=chunk.delta.text,
|
| 218 |
+
session_id=session_id,
|
| 219 |
+
message_id=message_id,
|
| 220 |
+
chunk_id=chunk_id,
|
| 221 |
+
is_final=False
|
| 222 |
+
)
|
| 223 |
+
chunk_id += 1
|
| 224 |
+
|
| 225 |
+
# Add small delay to prevent overwhelming the client
|
| 226 |
+
await asyncio.sleep(settings.stream_delay)
|
| 227 |
+
|
| 228 |
+
elif chunk.type == "message_stop":
|
| 229 |
+
break
|
| 230 |
+
|
| 231 |
+
# Send final chunk
|
| 232 |
+
yield StreamChunk(
|
| 233 |
+
content="",
|
| 234 |
+
session_id=session_id,
|
| 235 |
+
message_id=message_id,
|
| 236 |
+
chunk_id=chunk_id,
|
| 237 |
+
is_final=True
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
self.log_error("Anthropic API streaming failed", error=str(e), model=self.model_name)
|
| 242 |
+
raise GenerationError(f"Failed to generate streaming response via Anthropic API: {str(e)}")
|
| 243 |
+
|
| 244 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 245 |
+
"""Get information about the current model"""
|
| 246 |
+
return {
|
| 247 |
+
"name": self.model_name,
|
| 248 |
+
"type": "anthropic_api",
|
| 249 |
+
"loaded": self.is_loaded,
|
| 250 |
+
"provider": "Anthropic",
|
| 251 |
+
"capabilities": self.capabilities,
|
| 252 |
+
"parameters": self.parameters,
|
| 253 |
+
"requires_api_key": True,
|
| 254 |
+
"api_key_configured": bool(self.api_key),
|
| 255 |
+
"context_window": self._get_context_window()
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
def _get_context_window(self) -> int:
|
| 259 |
+
"""Get the context window size for the model"""
|
| 260 |
+
context_windows = {
|
| 261 |
+
"claude-3-haiku-20240307": 200000,
|
| 262 |
+
"claude-3-sonnet-20240229": 200000,
|
| 263 |
+
"claude-3-opus-20240229": 200000,
|
| 264 |
+
"claude-3-5-sonnet-20241022": 200000,
|
| 265 |
+
"claude-3-5-haiku-20241022": 200000,
|
| 266 |
+
}
|
| 267 |
+
return context_windows.get(self.model_name, 100000)
|
| 268 |
+
|
| 269 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 270 |
+
"""Perform a health check on the Anthropic API"""
|
| 271 |
+
try:
|
| 272 |
+
if not self.is_loaded:
|
| 273 |
+
return {
|
| 274 |
+
"status": "unhealthy",
|
| 275 |
+
"reason": "client_not_initialized",
|
| 276 |
+
"model_name": self.model_name
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
# Test API connectivity
|
| 280 |
+
start_time = time.time()
|
| 281 |
+
test_messages = [ChatMessage(role="user", content="Hi")]
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
response = await asyncio.wait_for(
|
| 285 |
+
self.generate_response(
|
| 286 |
+
test_messages,
|
| 287 |
+
temperature=0.1,
|
| 288 |
+
max_tokens=5
|
| 289 |
+
),
|
| 290 |
+
timeout=10.0
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
response_time = time.time() - start_time
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
"status": "healthy",
|
| 297 |
+
"model_name": self.model_name,
|
| 298 |
+
"response_time": response_time,
|
| 299 |
+
"provider": "Anthropic",
|
| 300 |
+
"context_window": self._get_context_window()
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
except asyncio.TimeoutError:
|
| 304 |
+
return {
|
| 305 |
+
"status": "unhealthy",
|
| 306 |
+
"reason": "api_timeout",
|
| 307 |
+
"model_name": self.model_name,
|
| 308 |
+
"provider": "Anthropic"
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
except Exception as e:
|
| 312 |
+
self.log_error("Anthropic API health check failed", error=str(e), model=self.model_name)
|
| 313 |
+
return {
|
| 314 |
+
"status": "unhealthy",
|
| 315 |
+
"reason": "api_error",
|
| 316 |
+
"error": str(e),
|
| 317 |
+
"model_name": self.model_name,
|
| 318 |
+
"provider": "Anthropic"
|
| 319 |
+
}
|
app/services/model_backends/base.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract base class for model backends
|
| 3 |
+
Defines the interface that all model backends must implement
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 8 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 9 |
+
from ...core.logging import LoggerMixin
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelBackend(ABC, LoggerMixin):
|
| 13 |
+
"""Abstract base class for all model backends"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, model_name: str, **kwargs):
|
| 16 |
+
self.model_name = model_name
|
| 17 |
+
self.is_loaded = False
|
| 18 |
+
self.capabilities = []
|
| 19 |
+
self.parameters = {}
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
async def load_model(self) -> bool:
|
| 23 |
+
"""
|
| 24 |
+
Load the model and prepare it for inference
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
bool: True if model loaded successfully, False otherwise
|
| 28 |
+
"""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
async def unload_model(self) -> bool:
|
| 33 |
+
"""
|
| 34 |
+
Unload the model and free resources
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
bool: True if model unloaded successfully, False otherwise
|
| 38 |
+
"""
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
async def generate_response(
|
| 43 |
+
self,
|
| 44 |
+
messages: List[ChatMessage],
|
| 45 |
+
temperature: float = 0.7,
|
| 46 |
+
max_tokens: int = 512,
|
| 47 |
+
**kwargs
|
| 48 |
+
) -> ChatResponse:
|
| 49 |
+
"""
|
| 50 |
+
Generate a complete response (non-streaming)
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
messages: List of conversation messages
|
| 54 |
+
temperature: Sampling temperature
|
| 55 |
+
max_tokens: Maximum tokens to generate
|
| 56 |
+
**kwargs: Additional generation parameters
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
ChatResponse: Complete response
|
| 60 |
+
"""
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
async def generate_stream(
|
| 65 |
+
self,
|
| 66 |
+
messages: List[ChatMessage],
|
| 67 |
+
temperature: float = 0.7,
|
| 68 |
+
max_tokens: int = 512,
|
| 69 |
+
**kwargs
|
| 70 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 71 |
+
"""
|
| 72 |
+
Generate a streaming response
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
messages: List of conversation messages
|
| 76 |
+
temperature: Sampling temperature
|
| 77 |
+
max_tokens: Maximum tokens to generate
|
| 78 |
+
**kwargs: Additional generation parameters
|
| 79 |
+
|
| 80 |
+
Yields:
|
| 81 |
+
StreamChunk: Response chunks
|
| 82 |
+
"""
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 87 |
+
"""
|
| 88 |
+
Get information about the current model
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Dict containing model information
|
| 92 |
+
"""
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
def supports_streaming(self) -> bool:
|
| 96 |
+
"""Check if this backend supports streaming"""
|
| 97 |
+
return "streaming" in self.capabilities
|
| 98 |
+
|
| 99 |
+
def supports_chat(self) -> bool:
|
| 100 |
+
"""Check if this backend supports chat/conversation"""
|
| 101 |
+
return "chat" in self.capabilities
|
| 102 |
+
|
| 103 |
+
def is_model_loaded(self) -> bool:
|
| 104 |
+
"""Check if model is loaded and ready"""
|
| 105 |
+
return self.is_loaded
|
| 106 |
+
|
| 107 |
+
def format_messages_for_model(self, messages: List[ChatMessage]) -> Any:
|
| 108 |
+
"""
|
| 109 |
+
Format messages for the specific model format
|
| 110 |
+
Override in subclasses if needed
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
messages: List of ChatMessage objects
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Formatted messages for the model
|
| 117 |
+
"""
|
| 118 |
+
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 119 |
+
|
| 120 |
+
def validate_parameters(self, **kwargs) -> Dict[str, Any]:
|
| 121 |
+
"""
|
| 122 |
+
Validate and normalize generation parameters
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
**kwargs: Generation parameters
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Dict of validated parameters
|
| 129 |
+
"""
|
| 130 |
+
validated = {}
|
| 131 |
+
|
| 132 |
+
# Temperature validation
|
| 133 |
+
temperature = kwargs.get('temperature', 0.7)
|
| 134 |
+
validated['temperature'] = max(0.0, min(1.0, float(temperature)))
|
| 135 |
+
|
| 136 |
+
# Max tokens validation
|
| 137 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
| 138 |
+
validated['max_tokens'] = max(1, min(2048, int(max_tokens)))
|
| 139 |
+
|
| 140 |
+
# Top-p validation
|
| 141 |
+
top_p = kwargs.get('top_p', 0.9)
|
| 142 |
+
validated['top_p'] = max(0.0, min(1.0, float(top_p)))
|
| 143 |
+
|
| 144 |
+
# Top-k validation
|
| 145 |
+
top_k = kwargs.get('top_k', 50)
|
| 146 |
+
validated['top_k'] = max(1, int(top_k))
|
| 147 |
+
|
| 148 |
+
return validated
|
| 149 |
+
|
| 150 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 151 |
+
"""
|
| 152 |
+
Perform a health check on the model backend
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Dict containing health status
|
| 156 |
+
"""
|
| 157 |
+
try:
|
| 158 |
+
if not self.is_loaded:
|
| 159 |
+
return {
|
| 160 |
+
"status": "unhealthy",
|
| 161 |
+
"reason": "model_not_loaded",
|
| 162 |
+
"model_name": self.model_name
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# Try a simple generation to test the model
|
| 166 |
+
test_messages = [
|
| 167 |
+
ChatMessage(role="user", content="Hello")
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
# Use a timeout for the health check
|
| 171 |
+
import asyncio
|
| 172 |
+
try:
|
| 173 |
+
response = await asyncio.wait_for(
|
| 174 |
+
self.generate_response(
|
| 175 |
+
test_messages,
|
| 176 |
+
temperature=0.1,
|
| 177 |
+
max_tokens=10
|
| 178 |
+
),
|
| 179 |
+
timeout=10.0
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return {
|
| 183 |
+
"status": "healthy",
|
| 184 |
+
"model_name": self.model_name,
|
| 185 |
+
"response_time": getattr(response, 'generation_time', 0.0)
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
except asyncio.TimeoutError:
|
| 189 |
+
return {
|
| 190 |
+
"status": "unhealthy",
|
| 191 |
+
"reason": "timeout",
|
| 192 |
+
"model_name": self.model_name
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
self.log_error("Health check failed", error=str(e), model=self.model_name)
|
| 197 |
+
return {
|
| 198 |
+
"status": "unhealthy",
|
| 199 |
+
"reason": "generation_error",
|
| 200 |
+
"error": str(e),
|
| 201 |
+
"model_name": self.model_name
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class ModelBackendError(Exception):
|
| 206 |
+
"""Base exception for model backend errors"""
|
| 207 |
+
pass
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class ModelLoadError(ModelBackendError):
|
| 211 |
+
"""Exception raised when model loading fails"""
|
| 212 |
+
pass
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class GenerationError(ModelBackendError):
|
| 216 |
+
"""Exception raised when text generation fails"""
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ModelNotLoadedError(ModelBackendError):
|
| 221 |
+
"""Exception raised when trying to use an unloaded model"""
|
| 222 |
+
pass
|
app/services/model_backends/google_api.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Google AI Studio API backend
|
| 3 |
+
Uses Google's AI Studio API for Gemma and other Google models
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
import json
|
| 10 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import httpx
|
| 13 |
+
|
| 14 |
+
from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
|
| 15 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 16 |
+
from ...core.config import settings
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GoogleAIBackend(ModelBackend):
|
| 20 |
+
"""Google AI Studio API backend for Gemma and other Google models"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name: str, **kwargs):
|
| 23 |
+
super().__init__(model_name, **kwargs)
|
| 24 |
+
self.api_key = kwargs.get('api_key', settings.google_api_key)
|
| 25 |
+
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
| 26 |
+
self.capabilities = ["chat", "streaming", "api_based"]
|
| 27 |
+
|
| 28 |
+
# Generation parameters
|
| 29 |
+
self.parameters = {
|
| 30 |
+
'temperature': kwargs.get('temperature', settings.temperature),
|
| 31 |
+
'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
|
| 32 |
+
'top_p': kwargs.get('top_p', settings.top_p),
|
| 33 |
+
'top_k': kwargs.get('top_k', settings.top_k),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
async def load_model(self) -> bool:
|
| 37 |
+
"""Initialize the Google AI API client"""
|
| 38 |
+
try:
|
| 39 |
+
if not self.api_key:
|
| 40 |
+
raise ModelLoadError("Google AI API key is required")
|
| 41 |
+
|
| 42 |
+
self.log_info("Initializing Google AI API client", model=self.model_name)
|
| 43 |
+
|
| 44 |
+
# Test the connection
|
| 45 |
+
await self._test_connection()
|
| 46 |
+
|
| 47 |
+
self.is_loaded = True
|
| 48 |
+
self.log_info("Google AI API client initialized successfully", model=self.model_name)
|
| 49 |
+
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
except Exception as e:
|
| 53 |
+
self.log_error("Failed to initialize Google AI API client", error=str(e), model=self.model_name)
|
| 54 |
+
raise ModelLoadError(f"Failed to initialize Google AI API for {self.model_name}: {str(e)}")
|
| 55 |
+
|
| 56 |
+
async def unload_model(self) -> bool:
|
| 57 |
+
"""Clean up the API client"""
|
| 58 |
+
try:
|
| 59 |
+
self.is_loaded = False
|
| 60 |
+
self.log_info("Google AI API client cleaned up", model=self.model_name)
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
self.log_error("Failed to cleanup Google AI API client", error=str(e), model=self.model_name)
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
async def _test_connection(self):
|
| 68 |
+
"""Test the Google AI API connection"""
|
| 69 |
+
try:
|
| 70 |
+
url = f"{self.base_url}/models/{self.model_name}:generateContent"
|
| 71 |
+
|
| 72 |
+
test_data = {
|
| 73 |
+
"contents": [
|
| 74 |
+
{
|
| 75 |
+
"parts": [{"text": "Hello"}]
|
| 76 |
+
}
|
| 77 |
+
],
|
| 78 |
+
"generationConfig": {
|
| 79 |
+
"maxOutputTokens": 5,
|
| 80 |
+
"temperature": 0.1
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
async with httpx.AsyncClient() as client:
|
| 85 |
+
response = await client.post(
|
| 86 |
+
f"{url}?key={self.api_key}",
|
| 87 |
+
headers={'Content-Type': 'application/json'},
|
| 88 |
+
json=test_data,
|
| 89 |
+
timeout=10.0
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if response.status_code != 200:
|
| 93 |
+
raise Exception(f"API test failed with status {response.status_code}: {response.text}")
|
| 94 |
+
|
| 95 |
+
self.log_info("Google AI API connection test successful", model=self.model_name)
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
self.log_error("Google AI API connection test failed", error=str(e), model=self.model_name)
|
| 99 |
+
raise
|
| 100 |
+
|
| 101 |
+
def _format_messages_for_api(self, messages: List[ChatMessage]) -> Dict[str, Any]:
|
| 102 |
+
"""Format messages for Google AI API"""
|
| 103 |
+
contents = []
|
| 104 |
+
system_instruction = None
|
| 105 |
+
|
| 106 |
+
for msg in messages:
|
| 107 |
+
if msg.role == "system":
|
| 108 |
+
system_instruction = msg.content
|
| 109 |
+
elif msg.role == "user":
|
| 110 |
+
contents.append({
|
| 111 |
+
"role": "user",
|
| 112 |
+
"parts": [{"text": msg.content}]
|
| 113 |
+
})
|
| 114 |
+
elif msg.role == "assistant":
|
| 115 |
+
contents.append({
|
| 116 |
+
"role": "model",
|
| 117 |
+
"parts": [{"text": msg.content}]
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
result = {"contents": contents}
|
| 121 |
+
if system_instruction:
|
| 122 |
+
result["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
| 123 |
+
|
| 124 |
+
return result
|
| 125 |
+
|
| 126 |
+
async def generate_response(
|
| 127 |
+
self,
|
| 128 |
+
messages: List[ChatMessage],
|
| 129 |
+
temperature: float = 0.7,
|
| 130 |
+
max_tokens: int = 512,
|
| 131 |
+
**kwargs
|
| 132 |
+
) -> ChatResponse:
|
| 133 |
+
"""Generate a complete response using Google AI API"""
|
| 134 |
+
if not self.is_loaded:
|
| 135 |
+
raise ModelNotLoadedError("Google AI API client not initialized")
|
| 136 |
+
|
| 137 |
+
start_time = time.time()
|
| 138 |
+
message_id = str(uuid.uuid4())
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
# Validate parameters
|
| 142 |
+
params = self.validate_parameters(
|
| 143 |
+
temperature=temperature,
|
| 144 |
+
max_tokens=max_tokens,
|
| 145 |
+
**kwargs
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Format messages
|
| 149 |
+
api_data = self._format_messages_for_api(messages)
|
| 150 |
+
|
| 151 |
+
# Add generation config
|
| 152 |
+
api_data["generationConfig"] = {
|
| 153 |
+
"maxOutputTokens": params['max_tokens'],
|
| 154 |
+
"temperature": params['temperature'],
|
| 155 |
+
"topP": params.get('top_p', 0.9),
|
| 156 |
+
"topK": params.get('top_k', 40)
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
# Make API call
|
| 160 |
+
url = f"{self.base_url}/models/{self.model_name}:generateContent"
|
| 161 |
+
|
| 162 |
+
async with httpx.AsyncClient() as client:
|
| 163 |
+
response = await client.post(
|
| 164 |
+
f"{url}?key={self.api_key}",
|
| 165 |
+
headers={'Content-Type': 'application/json'},
|
| 166 |
+
json=api_data,
|
| 167 |
+
timeout=30.0
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if response.status_code != 200:
|
| 171 |
+
raise GenerationError(f"API request failed with status {response.status_code}: {response.text}")
|
| 172 |
+
|
| 173 |
+
response_data = response.json()
|
| 174 |
+
|
| 175 |
+
# Extract response text
|
| 176 |
+
if 'candidates' in response_data and response_data['candidates']:
|
| 177 |
+
candidate = response_data['candidates'][0]
|
| 178 |
+
if 'content' in candidate and 'parts' in candidate['content']:
|
| 179 |
+
parts = candidate['content']['parts']
|
| 180 |
+
response_text = ''.join(part.get('text', '') for part in parts)
|
| 181 |
+
else:
|
| 182 |
+
response_text = str(response_data)
|
| 183 |
+
else:
|
| 184 |
+
response_text = str(response_data)
|
| 185 |
+
|
| 186 |
+
generation_time = time.time() - start_time
|
| 187 |
+
|
| 188 |
+
return ChatResponse(
|
| 189 |
+
message=response_text.strip(),
|
| 190 |
+
session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
|
| 191 |
+
message_id=message_id,
|
| 192 |
+
model_name=self.model_name,
|
| 193 |
+
generation_time=generation_time,
|
| 194 |
+
token_count=len(response_text.split()), # Approximate token count
|
| 195 |
+
finish_reason="stop"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
self.log_error("Google AI API generation failed", error=str(e), model=self.model_name)
|
| 200 |
+
raise GenerationError(f"Failed to generate response via Google AI API: {str(e)}")
|
| 201 |
+
|
| 202 |
+
async def generate_stream(
|
| 203 |
+
self,
|
| 204 |
+
messages: List[ChatMessage],
|
| 205 |
+
temperature: float = 0.7,
|
| 206 |
+
max_tokens: int = 512,
|
| 207 |
+
**kwargs
|
| 208 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 209 |
+
"""Generate a streaming response using Google AI API"""
|
| 210 |
+
if not self.is_loaded:
|
| 211 |
+
raise ModelNotLoadedError("Google AI API client not initialized")
|
| 212 |
+
|
| 213 |
+
message_id = str(uuid.uuid4())
|
| 214 |
+
session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
|
| 215 |
+
chunk_id = 0
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
# Validate parameters
|
| 219 |
+
params = self.validate_parameters(
|
| 220 |
+
temperature=temperature,
|
| 221 |
+
max_tokens=max_tokens,
|
| 222 |
+
**kwargs
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Format messages
|
| 226 |
+
api_data = self._format_messages_for_api(messages)
|
| 227 |
+
|
| 228 |
+
# Add generation config
|
| 229 |
+
api_data["generationConfig"] = {
|
| 230 |
+
"maxOutputTokens": params['max_tokens'],
|
| 231 |
+
"temperature": params['temperature'],
|
| 232 |
+
"topP": params.get('top_p', 0.9),
|
| 233 |
+
"topK": params.get('top_k', 40)
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# Make streaming API call
|
| 237 |
+
url = f"{self.base_url}/models/{self.model_name}:streamGenerateContent"
|
| 238 |
+
|
| 239 |
+
async with httpx.AsyncClient() as client:
|
| 240 |
+
async with client.stream(
|
| 241 |
+
'POST',
|
| 242 |
+
f"{url}?key={self.api_key}",
|
| 243 |
+
headers={'Content-Type': 'application/json'},
|
| 244 |
+
json=api_data,
|
| 245 |
+
timeout=60.0
|
| 246 |
+
) as response:
|
| 247 |
+
|
| 248 |
+
if response.status_code != 200:
|
| 249 |
+
raise GenerationError(f"Streaming request failed with status {response.status_code}")
|
| 250 |
+
|
| 251 |
+
async for line in response.aiter_lines():
|
| 252 |
+
if line.strip():
|
| 253 |
+
try:
|
| 254 |
+
# Google AI API returns JSON objects separated by newlines
|
| 255 |
+
data = json.loads(line)
|
| 256 |
+
|
| 257 |
+
if 'candidates' in data and data['candidates']:
|
| 258 |
+
candidate = data['candidates'][0]
|
| 259 |
+
if 'content' in candidate and 'parts' in candidate['content']:
|
| 260 |
+
parts = candidate['content']['parts']
|
| 261 |
+
content = ''.join(part.get('text', '') for part in parts)
|
| 262 |
+
|
| 263 |
+
if content:
|
| 264 |
+
yield StreamChunk(
|
| 265 |
+
content=content,
|
| 266 |
+
session_id=session_id,
|
| 267 |
+
message_id=message_id,
|
| 268 |
+
chunk_id=chunk_id,
|
| 269 |
+
is_final=False
|
| 270 |
+
)
|
| 271 |
+
chunk_id += 1
|
| 272 |
+
|
| 273 |
+
# Add small delay
|
| 274 |
+
await asyncio.sleep(settings.stream_delay)
|
| 275 |
+
|
| 276 |
+
except json.JSONDecodeError:
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
# Send final chunk
|
| 280 |
+
yield StreamChunk(
|
| 281 |
+
content="",
|
| 282 |
+
session_id=session_id,
|
| 283 |
+
message_id=message_id,
|
| 284 |
+
chunk_id=chunk_id,
|
| 285 |
+
is_final=True
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
except Exception as e:
|
| 289 |
+
self.log_error("Google AI API streaming failed", error=str(e), model=self.model_name)
|
| 290 |
+
raise GenerationError(f"Failed to generate streaming response via Google AI API: {str(e)}")
|
| 291 |
+
|
| 292 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 293 |
+
"""Get information about the current model"""
|
| 294 |
+
return {
|
| 295 |
+
"name": self.model_name,
|
| 296 |
+
"type": "google_ai",
|
| 297 |
+
"loaded": self.is_loaded,
|
| 298 |
+
"provider": "Google AI Studio",
|
| 299 |
+
"capabilities": self.capabilities,
|
| 300 |
+
"parameters": self.parameters,
|
| 301 |
+
"requires_api_key": True,
|
| 302 |
+
"api_key_configured": bool(self.api_key),
|
| 303 |
+
"base_url": self.base_url
|
| 304 |
+
}
|
app/services/model_backends/hf_api.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace Inference API backend
|
| 3 |
+
Uses HuggingFace's hosted inference API for model access
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
import json
|
| 10 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import httpx
|
| 13 |
+
from huggingface_hub import InferenceClient
|
| 14 |
+
|
| 15 |
+
from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
|
| 16 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 17 |
+
from ...core.config import settings
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class HuggingFaceAPIBackend(ModelBackend):
|
| 21 |
+
"""HuggingFace Inference API backend"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_name: str, **kwargs):
|
| 24 |
+
super().__init__(model_name, **kwargs)
|
| 25 |
+
self.client = None
|
| 26 |
+
self.api_token = kwargs.get('api_token', settings.hf_api_token)
|
| 27 |
+
self.inference_url = kwargs.get('inference_url', settings.hf_inference_url)
|
| 28 |
+
self.capabilities = ["chat", "streaming", "api_based"]
|
| 29 |
+
|
| 30 |
+
# Generation parameters
|
| 31 |
+
self.parameters = {
|
| 32 |
+
'temperature': kwargs.get('temperature', settings.temperature),
|
| 33 |
+
'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
|
| 34 |
+
'top_p': kwargs.get('top_p', settings.top_p),
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
async def load_model(self) -> bool:
|
| 38 |
+
"""Initialize the HuggingFace API client"""
|
| 39 |
+
try:
|
| 40 |
+
if not self.api_token:
|
| 41 |
+
raise ModelLoadError("HuggingFace API token is required")
|
| 42 |
+
|
| 43 |
+
self.log_info("Initializing HuggingFace API client", model=self.model_name)
|
| 44 |
+
|
| 45 |
+
# Initialize the inference client
|
| 46 |
+
self.client = InferenceClient(
|
| 47 |
+
model=self.model_name,
|
| 48 |
+
token=self.api_token
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Test the connection with a simple request
|
| 52 |
+
await self._test_connection()
|
| 53 |
+
|
| 54 |
+
self.is_loaded = True
|
| 55 |
+
self.log_info("HuggingFace API client initialized successfully", model=self.model_name)
|
| 56 |
+
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
self.log_error("Failed to initialize HuggingFace API client", error=str(e), model=self.model_name)
|
| 61 |
+
raise ModelLoadError(f"Failed to initialize HuggingFace API for {self.model_name}: {str(e)}")
|
| 62 |
+
|
| 63 |
+
async def unload_model(self) -> bool:
|
| 64 |
+
"""Clean up the API client"""
|
| 65 |
+
try:
|
| 66 |
+
self.client = None
|
| 67 |
+
self.is_loaded = False
|
| 68 |
+
self.log_info("HuggingFace API client cleaned up", model=self.model_name)
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
self.log_error("Failed to cleanup API client", error=str(e), model=self.model_name)
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
async def _test_connection(self):
|
| 76 |
+
"""Test the API connection"""
|
| 77 |
+
try:
|
| 78 |
+
# Simple test message
|
| 79 |
+
test_messages = [{"role": "user", "content": "Hello"}]
|
| 80 |
+
|
| 81 |
+
# Use asyncio to run the sync client method
|
| 82 |
+
loop = asyncio.get_event_loop()
|
| 83 |
+
response = await loop.run_in_executor(
|
| 84 |
+
None,
|
| 85 |
+
lambda: self.client.chat_completion(
|
| 86 |
+
messages=test_messages,
|
| 87 |
+
max_tokens=10,
|
| 88 |
+
temperature=0.1
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.log_info("API connection test successful", model=self.model_name)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
self.log_error("API connection test failed", error=str(e), model=self.model_name)
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
def _format_messages_for_api(self, messages: List[ChatMessage]) -> List[Dict[str, str]]:
|
| 99 |
+
"""Format messages for HuggingFace API"""
|
| 100 |
+
formatted = []
|
| 101 |
+
for msg in messages:
|
| 102 |
+
formatted.append({
|
| 103 |
+
"role": msg.role,
|
| 104 |
+
"content": msg.content
|
| 105 |
+
})
|
| 106 |
+
return formatted
|
| 107 |
+
|
| 108 |
+
async def generate_response(
|
| 109 |
+
self,
|
| 110 |
+
messages: List[ChatMessage],
|
| 111 |
+
temperature: float = 0.7,
|
| 112 |
+
max_tokens: int = 512,
|
| 113 |
+
**kwargs
|
| 114 |
+
) -> ChatResponse:
|
| 115 |
+
"""Generate a complete response using HuggingFace API"""
|
| 116 |
+
if not self.is_loaded:
|
| 117 |
+
raise ModelNotLoadedError("HuggingFace API client not initialized")
|
| 118 |
+
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
message_id = str(uuid.uuid4())
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# Validate parameters
|
| 124 |
+
params = self.validate_parameters(
|
| 125 |
+
temperature=temperature,
|
| 126 |
+
max_tokens=max_tokens,
|
| 127 |
+
**kwargs
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Format messages
|
| 131 |
+
api_messages = self._format_messages_for_api(messages)
|
| 132 |
+
|
| 133 |
+
# Make API call
|
| 134 |
+
loop = asyncio.get_event_loop()
|
| 135 |
+
response = await loop.run_in_executor(
|
| 136 |
+
None,
|
| 137 |
+
lambda: self.client.chat_completion(
|
| 138 |
+
messages=api_messages,
|
| 139 |
+
max_tokens=params['max_tokens'],
|
| 140 |
+
temperature=params['temperature'],
|
| 141 |
+
top_p=params.get('top_p', 0.9),
|
| 142 |
+
stream=False
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Extract response text
|
| 147 |
+
if hasattr(response, 'choices') and response.choices:
|
| 148 |
+
response_text = response.choices[0].message.content
|
| 149 |
+
finish_reason = getattr(response.choices[0], 'finish_reason', 'stop')
|
| 150 |
+
else:
|
| 151 |
+
response_text = str(response)
|
| 152 |
+
finish_reason = 'unknown'
|
| 153 |
+
|
| 154 |
+
generation_time = time.time() - start_time
|
| 155 |
+
|
| 156 |
+
return ChatResponse(
|
| 157 |
+
message=response_text.strip(),
|
| 158 |
+
session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
|
| 159 |
+
message_id=message_id,
|
| 160 |
+
model_name=self.model_name,
|
| 161 |
+
generation_time=generation_time,
|
| 162 |
+
token_count=len(response_text.split()), # Approximate token count
|
| 163 |
+
finish_reason=finish_reason
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
self.log_error("HuggingFace API generation failed", error=str(e), model=self.model_name)
|
| 168 |
+
raise GenerationError(f"Failed to generate response via HuggingFace API: {str(e)}")
|
| 169 |
+
|
| 170 |
+
async def generate_stream(
|
| 171 |
+
self,
|
| 172 |
+
messages: List[ChatMessage],
|
| 173 |
+
temperature: float = 0.7,
|
| 174 |
+
max_tokens: int = 512,
|
| 175 |
+
**kwargs
|
| 176 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 177 |
+
"""Generate a streaming response using HuggingFace API"""
|
| 178 |
+
if not self.is_loaded:
|
| 179 |
+
raise ModelNotLoadedError("HuggingFace API client not initialized")
|
| 180 |
+
|
| 181 |
+
message_id = str(uuid.uuid4())
|
| 182 |
+
session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
|
| 183 |
+
chunk_id = 0
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# Validate parameters
|
| 187 |
+
params = self.validate_parameters(
|
| 188 |
+
temperature=temperature,
|
| 189 |
+
max_tokens=max_tokens,
|
| 190 |
+
**kwargs
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Format messages
|
| 194 |
+
api_messages = self._format_messages_for_api(messages)
|
| 195 |
+
|
| 196 |
+
# Create streaming generator
|
| 197 |
+
loop = asyncio.get_event_loop()
|
| 198 |
+
|
| 199 |
+
def stream_generator():
|
| 200 |
+
return self.client.chat_completion(
|
| 201 |
+
messages=api_messages,
|
| 202 |
+
max_tokens=params['max_tokens'],
|
| 203 |
+
temperature=params['temperature'],
|
| 204 |
+
top_p=params.get('top_p', 0.9),
|
| 205 |
+
stream=True
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Get the streaming response
|
| 209 |
+
stream = await loop.run_in_executor(None, stream_generator)
|
| 210 |
+
|
| 211 |
+
# Process streaming chunks
|
| 212 |
+
for chunk in stream:
|
| 213 |
+
if hasattr(chunk, 'choices') and chunk.choices:
|
| 214 |
+
delta = chunk.choices[0].delta
|
| 215 |
+
if hasattr(delta, 'content') and delta.content:
|
| 216 |
+
yield StreamChunk(
|
| 217 |
+
content=delta.content,
|
| 218 |
+
session_id=session_id,
|
| 219 |
+
message_id=message_id,
|
| 220 |
+
chunk_id=chunk_id,
|
| 221 |
+
is_final=False
|
| 222 |
+
)
|
| 223 |
+
chunk_id += 1
|
| 224 |
+
|
| 225 |
+
# Add small delay to prevent overwhelming the client
|
| 226 |
+
await asyncio.sleep(settings.stream_delay)
|
| 227 |
+
|
| 228 |
+
# Send final chunk
|
| 229 |
+
yield StreamChunk(
|
| 230 |
+
content="",
|
| 231 |
+
session_id=session_id,
|
| 232 |
+
message_id=message_id,
|
| 233 |
+
chunk_id=chunk_id,
|
| 234 |
+
is_final=True
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
self.log_error("HuggingFace API streaming failed", error=str(e), model=self.model_name)
|
| 239 |
+
raise GenerationError(f"Failed to generate streaming response via HuggingFace API: {str(e)}")
|
| 240 |
+
|
| 241 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 242 |
+
"""Get information about the current model"""
|
| 243 |
+
return {
|
| 244 |
+
"name": self.model_name,
|
| 245 |
+
"type": "huggingface_api",
|
| 246 |
+
"loaded": self.is_loaded,
|
| 247 |
+
"api_endpoint": self.inference_url,
|
| 248 |
+
"capabilities": self.capabilities,
|
| 249 |
+
"parameters": self.parameters,
|
| 250 |
+
"requires_token": True,
|
| 251 |
+
"token_configured": bool(self.api_token)
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 255 |
+
"""Perform a health check on the HuggingFace API"""
|
| 256 |
+
try:
|
| 257 |
+
if not self.is_loaded:
|
| 258 |
+
return {
|
| 259 |
+
"status": "unhealthy",
|
| 260 |
+
"reason": "client_not_initialized",
|
| 261 |
+
"model_name": self.model_name
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# Test API connectivity
|
| 265 |
+
start_time = time.time()
|
| 266 |
+
test_messages = [ChatMessage(role="user", content="Hi")]
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
response = await asyncio.wait_for(
|
| 270 |
+
self.generate_response(
|
| 271 |
+
test_messages,
|
| 272 |
+
temperature=0.1,
|
| 273 |
+
max_tokens=5
|
| 274 |
+
),
|
| 275 |
+
timeout=15.0
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
response_time = time.time() - start_time
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"status": "healthy",
|
| 282 |
+
"model_name": self.model_name,
|
| 283 |
+
"response_time": response_time,
|
| 284 |
+
"api_endpoint": self.inference_url
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
except asyncio.TimeoutError:
|
| 288 |
+
return {
|
| 289 |
+
"status": "unhealthy",
|
| 290 |
+
"reason": "api_timeout",
|
| 291 |
+
"model_name": self.model_name,
|
| 292 |
+
"api_endpoint": self.inference_url
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
self.log_error("HuggingFace API health check failed", error=str(e), model=self.model_name)
|
| 297 |
+
return {
|
| 298 |
+
"status": "unhealthy",
|
| 299 |
+
"reason": "api_error",
|
| 300 |
+
"error": str(e),
|
| 301 |
+
"model_name": self.model_name,
|
| 302 |
+
"api_endpoint": self.inference_url
|
| 303 |
+
}
|
app/services/model_backends/local_hf.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Local HuggingFace model backend
|
| 3 |
+
Loads and runs models locally using transformers library
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
AutoModelForCausalLM,
|
| 15 |
+
TextIteratorStreamer,
|
| 16 |
+
GenerationConfig
|
| 17 |
+
)
|
| 18 |
+
from threading import Thread
|
| 19 |
+
from queue import Queue
|
| 20 |
+
|
| 21 |
+
from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
|
| 22 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 23 |
+
from ...core.config import settings
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LocalHuggingFaceBackend(ModelBackend):
|
| 27 |
+
"""Local HuggingFace model backend using transformers"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_name: str, **kwargs):
|
| 30 |
+
super().__init__(model_name, **kwargs)
|
| 31 |
+
self.tokenizer = None
|
| 32 |
+
self.model = None
|
| 33 |
+
self.device = kwargs.get('device', settings.device)
|
| 34 |
+
self.capabilities = ["chat", "streaming", "instruction_following"]
|
| 35 |
+
|
| 36 |
+
# Generation parameters
|
| 37 |
+
self.parameters = {
|
| 38 |
+
'temperature': kwargs.get('temperature', settings.temperature),
|
| 39 |
+
'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
|
| 40 |
+
'top_p': kwargs.get('top_p', settings.top_p),
|
| 41 |
+
'top_k': kwargs.get('top_k', settings.top_k),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
async def load_model(self) -> bool:
|
| 45 |
+
"""Load the HuggingFace model and tokenizer"""
|
| 46 |
+
try:
|
| 47 |
+
self.log_info("Loading local HuggingFace model", model=self.model_name)
|
| 48 |
+
|
| 49 |
+
# Determine device
|
| 50 |
+
if self.device == "auto":
|
| 51 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
|
| 53 |
+
self.log_info("Using device", device=self.device)
|
| 54 |
+
|
| 55 |
+
# Load tokenizer
|
| 56 |
+
self.log_info("Loading tokenizer")
|
| 57 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 58 |
+
self.model_name,
|
| 59 |
+
trust_remote_code=True,
|
| 60 |
+
padding_side="left"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Add pad token if not present
|
| 64 |
+
if self.tokenizer.pad_token is None:
|
| 65 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 66 |
+
|
| 67 |
+
# Load model
|
| 68 |
+
self.log_info("Loading model")
|
| 69 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 70 |
+
self.model_name,
|
| 71 |
+
trust_remote_code=True,
|
| 72 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 73 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 74 |
+
low_cpu_mem_usage=True
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if self.device == "cpu":
|
| 78 |
+
self.model = self.model.to(self.device)
|
| 79 |
+
|
| 80 |
+
# Set model to evaluation mode
|
| 81 |
+
self.model.eval()
|
| 82 |
+
|
| 83 |
+
self.is_loaded = True
|
| 84 |
+
self.log_info("Model loaded successfully",
|
| 85 |
+
model=self.model_name,
|
| 86 |
+
device=self.device,
|
| 87 |
+
parameters=self.model.num_parameters() if hasattr(self.model, 'num_parameters') else 'unknown')
|
| 88 |
+
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
self.log_error("Failed to load model", error=str(e), model=self.model_name)
|
| 93 |
+
raise ModelLoadError(f"Failed to load model {self.model_name}: {str(e)}")
|
| 94 |
+
|
| 95 |
+
async def unload_model(self) -> bool:
|
| 96 |
+
"""Unload the model and free memory"""
|
| 97 |
+
try:
|
| 98 |
+
if self.model is not None:
|
| 99 |
+
del self.model
|
| 100 |
+
self.model = None
|
| 101 |
+
|
| 102 |
+
if self.tokenizer is not None:
|
| 103 |
+
del self.tokenizer
|
| 104 |
+
self.tokenizer = None
|
| 105 |
+
|
| 106 |
+
# Clear CUDA cache if using GPU
|
| 107 |
+
if torch.cuda.is_available():
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
|
| 110 |
+
self.is_loaded = False
|
| 111 |
+
self.log_info("Model unloaded successfully", model=self.model_name)
|
| 112 |
+
return True
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
self.log_error("Failed to unload model", error=str(e), model=self.model_name)
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
def _prepare_chat_input(self, messages: List[ChatMessage]) -> str:
|
| 119 |
+
"""Prepare chat messages for the model"""
|
| 120 |
+
if not self.tokenizer:
|
| 121 |
+
raise ModelNotLoadedError("Tokenizer not loaded")
|
| 122 |
+
|
| 123 |
+
# Check if tokenizer has chat template
|
| 124 |
+
if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template:
|
| 125 |
+
# Use the model's chat template
|
| 126 |
+
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
| 127 |
+
return self.tokenizer.apply_chat_template(
|
| 128 |
+
formatted_messages,
|
| 129 |
+
tokenize=False,
|
| 130 |
+
add_generation_prompt=True
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
# Fallback to simple concatenation
|
| 134 |
+
chat_text = ""
|
| 135 |
+
for msg in messages:
|
| 136 |
+
if msg.role == "system":
|
| 137 |
+
chat_text += f"System: {msg.content}\n"
|
| 138 |
+
elif msg.role == "user":
|
| 139 |
+
chat_text += f"User: {msg.content}\n"
|
| 140 |
+
elif msg.role == "assistant":
|
| 141 |
+
chat_text += f"Assistant: {msg.content}\n"
|
| 142 |
+
|
| 143 |
+
chat_text += "Assistant: "
|
| 144 |
+
return chat_text
|
| 145 |
+
|
| 146 |
+
async def generate_response(
|
| 147 |
+
self,
|
| 148 |
+
messages: List[ChatMessage],
|
| 149 |
+
temperature: float = 0.7,
|
| 150 |
+
max_tokens: int = 512,
|
| 151 |
+
**kwargs
|
| 152 |
+
) -> ChatResponse:
|
| 153 |
+
"""Generate a complete response"""
|
| 154 |
+
if not self.is_loaded:
|
| 155 |
+
raise ModelNotLoadedError("Model not loaded")
|
| 156 |
+
|
| 157 |
+
start_time = time.time()
|
| 158 |
+
message_id = str(uuid.uuid4())
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
# Validate parameters
|
| 162 |
+
params = self.validate_parameters(
|
| 163 |
+
temperature=temperature,
|
| 164 |
+
max_tokens=max_tokens,
|
| 165 |
+
**kwargs
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Prepare input
|
| 169 |
+
chat_input = self._prepare_chat_input(messages)
|
| 170 |
+
|
| 171 |
+
# Tokenize input
|
| 172 |
+
inputs = self.tokenizer(
|
| 173 |
+
chat_input,
|
| 174 |
+
return_tensors="pt",
|
| 175 |
+
padding=True,
|
| 176 |
+
truncation=True,
|
| 177 |
+
max_length=settings.max_length - params['max_tokens']
|
| 178 |
+
).to(self.device)
|
| 179 |
+
|
| 180 |
+
# Generate response
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
outputs = self.model.generate(
|
| 183 |
+
**inputs,
|
| 184 |
+
max_new_tokens=params['max_tokens'],
|
| 185 |
+
temperature=params['temperature'],
|
| 186 |
+
top_p=params['top_p'],
|
| 187 |
+
top_k=params['top_k'],
|
| 188 |
+
do_sample=True,
|
| 189 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 190 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 191 |
+
repetition_penalty=1.1,
|
| 192 |
+
no_repeat_ngram_size=3
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Decode response
|
| 196 |
+
input_length = inputs['input_ids'].shape[1]
|
| 197 |
+
generated_tokens = outputs[0][input_length:]
|
| 198 |
+
response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 199 |
+
|
| 200 |
+
generation_time = time.time() - start_time
|
| 201 |
+
|
| 202 |
+
return ChatResponse(
|
| 203 |
+
message=response_text.strip(),
|
| 204 |
+
session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
|
| 205 |
+
message_id=message_id,
|
| 206 |
+
model_name=self.model_name,
|
| 207 |
+
generation_time=generation_time,
|
| 208 |
+
token_count=len(generated_tokens),
|
| 209 |
+
finish_reason="stop"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
self.log_error("Generation failed", error=str(e), model=self.model_name)
|
| 214 |
+
raise GenerationError(f"Failed to generate response: {str(e)}")
|
| 215 |
+
|
| 216 |
+
async def generate_stream(
|
| 217 |
+
self,
|
| 218 |
+
messages: List[ChatMessage],
|
| 219 |
+
temperature: float = 0.7,
|
| 220 |
+
max_tokens: int = 512,
|
| 221 |
+
**kwargs
|
| 222 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 223 |
+
"""Generate a streaming response"""
|
| 224 |
+
if not self.is_loaded:
|
| 225 |
+
raise ModelNotLoadedError("Model not loaded")
|
| 226 |
+
|
| 227 |
+
message_id = str(uuid.uuid4())
|
| 228 |
+
session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
|
| 229 |
+
chunk_id = 0
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
# Validate parameters
|
| 233 |
+
params = self.validate_parameters(
|
| 234 |
+
temperature=temperature,
|
| 235 |
+
max_tokens=max_tokens,
|
| 236 |
+
**kwargs
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Prepare input
|
| 240 |
+
chat_input = self._prepare_chat_input(messages)
|
| 241 |
+
|
| 242 |
+
# Tokenize input
|
| 243 |
+
inputs = self.tokenizer(
|
| 244 |
+
chat_input,
|
| 245 |
+
return_tensors="pt",
|
| 246 |
+
padding=True,
|
| 247 |
+
truncation=True,
|
| 248 |
+
max_length=settings.max_length - params['max_tokens']
|
| 249 |
+
).to(self.device)
|
| 250 |
+
|
| 251 |
+
# Create streamer
|
| 252 |
+
streamer = TextIteratorStreamer(
|
| 253 |
+
self.tokenizer,
|
| 254 |
+
skip_prompt=True,
|
| 255 |
+
skip_special_tokens=True
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Generation parameters
|
| 259 |
+
generation_kwargs = {
|
| 260 |
+
**inputs,
|
| 261 |
+
'max_new_tokens': params['max_tokens'],
|
| 262 |
+
'temperature': params['temperature'],
|
| 263 |
+
'top_p': params['top_p'],
|
| 264 |
+
'top_k': params['top_k'],
|
| 265 |
+
'do_sample': True,
|
| 266 |
+
'pad_token_id': self.tokenizer.pad_token_id,
|
| 267 |
+
'eos_token_id': self.tokenizer.eos_token_id,
|
| 268 |
+
'repetition_penalty': 1.1,
|
| 269 |
+
'no_repeat_ngram_size': 3,
|
| 270 |
+
'streamer': streamer
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
# Start generation in a separate thread
|
| 274 |
+
generation_thread = Thread(
|
| 275 |
+
target=self.model.generate,
|
| 276 |
+
kwargs=generation_kwargs
|
| 277 |
+
)
|
| 278 |
+
generation_thread.start()
|
| 279 |
+
|
| 280 |
+
# Stream the response
|
| 281 |
+
for chunk_text in streamer:
|
| 282 |
+
if chunk_text: # Skip empty chunks
|
| 283 |
+
yield StreamChunk(
|
| 284 |
+
content=chunk_text,
|
| 285 |
+
session_id=session_id,
|
| 286 |
+
message_id=message_id,
|
| 287 |
+
chunk_id=chunk_id,
|
| 288 |
+
is_final=False
|
| 289 |
+
)
|
| 290 |
+
chunk_id += 1
|
| 291 |
+
|
| 292 |
+
# Add small delay to prevent overwhelming the client
|
| 293 |
+
await asyncio.sleep(settings.stream_delay)
|
| 294 |
+
|
| 295 |
+
# Send final chunk
|
| 296 |
+
yield StreamChunk(
|
| 297 |
+
content="",
|
| 298 |
+
session_id=session_id,
|
| 299 |
+
message_id=message_id,
|
| 300 |
+
chunk_id=chunk_id,
|
| 301 |
+
is_final=True
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Wait for generation thread to complete
|
| 305 |
+
generation_thread.join()
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
self.log_error("Streaming generation failed", error=str(e), model=self.model_name)
|
| 309 |
+
raise GenerationError(f"Failed to generate streaming response: {str(e)}")
|
| 310 |
+
|
| 311 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 312 |
+
"""Get information about the current model"""
|
| 313 |
+
info = {
|
| 314 |
+
"name": self.model_name,
|
| 315 |
+
"type": "local_huggingface",
|
| 316 |
+
"loaded": self.is_loaded,
|
| 317 |
+
"device": self.device,
|
| 318 |
+
"capabilities": self.capabilities,
|
| 319 |
+
"parameters": self.parameters
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
if self.model and hasattr(self.model, 'config'):
|
| 323 |
+
info["model_config"] = {
|
| 324 |
+
"vocab_size": getattr(self.model.config, 'vocab_size', None),
|
| 325 |
+
"hidden_size": getattr(self.model.config, 'hidden_size', None),
|
| 326 |
+
"num_layers": getattr(self.model.config, 'num_hidden_layers', None),
|
| 327 |
+
"num_attention_heads": getattr(self.model.config, 'num_attention_heads', None),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
return info
|
app/services/model_backends/minimax_api.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MiniMax API backend
|
| 3 |
+
Uses MiniMax's API for their M1 model with reasoning capabilities
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
import json
|
| 10 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import httpx
|
| 13 |
+
|
| 14 |
+
from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
|
| 15 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 16 |
+
from ...core.config import settings
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MiniMaxAPIBackend(ModelBackend):
|
| 20 |
+
"""MiniMax API backend for M1 model"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name: str, **kwargs):
|
| 23 |
+
super().__init__(model_name, **kwargs)
|
| 24 |
+
self.api_url = kwargs.get('api_url', settings.minimax_api_url)
|
| 25 |
+
self.api_key = kwargs.get('api_key', settings.minimax_api_key)
|
| 26 |
+
self.model_version = kwargs.get('model_version', settings.minimax_model_version)
|
| 27 |
+
self.capabilities = ["chat", "streaming", "reasoning", "api_based"]
|
| 28 |
+
|
| 29 |
+
# Generation parameters
|
| 30 |
+
self.parameters = {
|
| 31 |
+
'temperature': kwargs.get('temperature', settings.temperature),
|
| 32 |
+
'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
|
| 33 |
+
'top_p': kwargs.get('top_p', settings.top_p),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
async def load_model(self) -> bool:
|
| 37 |
+
"""Initialize the MiniMax API client"""
|
| 38 |
+
try:
|
| 39 |
+
if not self.api_key or not self.api_url:
|
| 40 |
+
raise ModelLoadError("MiniMax API key and URL are required")
|
| 41 |
+
|
| 42 |
+
self.log_info("Initializing MiniMax API client", model=self.model_name)
|
| 43 |
+
|
| 44 |
+
# Test the connection
|
| 45 |
+
await self._test_connection()
|
| 46 |
+
|
| 47 |
+
self.is_loaded = True
|
| 48 |
+
self.log_info("MiniMax API client initialized successfully", model=self.model_name)
|
| 49 |
+
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
except Exception as e:
|
| 53 |
+
self.log_error("Failed to initialize MiniMax API client", error=str(e), model=self.model_name)
|
| 54 |
+
raise ModelLoadError(f"Failed to initialize MiniMax API for {self.model_name}: {str(e)}")
|
| 55 |
+
|
| 56 |
+
async def unload_model(self) -> bool:
|
| 57 |
+
"""Clean up the API client"""
|
| 58 |
+
try:
|
| 59 |
+
self.is_loaded = False
|
| 60 |
+
self.log_info("MiniMax API client cleaned up", model=self.model_name)
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
self.log_error("Failed to cleanup MiniMax API client", error=str(e), model=self.model_name)
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
async def _test_connection(self):
|
| 68 |
+
"""Test the MiniMax API connection"""
|
| 69 |
+
try:
|
| 70 |
+
test_data = {
|
| 71 |
+
'model': self.model_version,
|
| 72 |
+
'messages': [{"role": "user", "content": "Hello"}],
|
| 73 |
+
'stream': False,
|
| 74 |
+
'max_tokens': 5,
|
| 75 |
+
'temperature': 0.1
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
async with httpx.AsyncClient() as client:
|
| 79 |
+
response = await client.post(
|
| 80 |
+
self.api_url,
|
| 81 |
+
headers={
|
| 82 |
+
'Content-Type': 'application/json',
|
| 83 |
+
'Authorization': f'Bearer {self.api_key}'
|
| 84 |
+
},
|
| 85 |
+
json=test_data,
|
| 86 |
+
timeout=10.0
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if response.status_code != 200:
|
| 90 |
+
raise Exception(f"API test failed with status {response.status_code}")
|
| 91 |
+
|
| 92 |
+
self.log_info("MiniMax API connection test successful", model=self.model_name)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
self.log_error("MiniMax API connection test failed", error=str(e), model=self.model_name)
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
def _format_messages_for_api(self, messages: List[ChatMessage]) -> List[Dict[str, str]]:
|
| 99 |
+
"""Format messages for MiniMax API"""
|
| 100 |
+
formatted = []
|
| 101 |
+
for msg in messages:
|
| 102 |
+
formatted.append({
|
| 103 |
+
"role": msg.role,
|
| 104 |
+
"content": msg.content
|
| 105 |
+
})
|
| 106 |
+
return formatted
|
| 107 |
+
|
| 108 |
+
async def generate_response(
|
| 109 |
+
self,
|
| 110 |
+
messages: List[ChatMessage],
|
| 111 |
+
temperature: float = 0.7,
|
| 112 |
+
max_tokens: int = 512,
|
| 113 |
+
**kwargs
|
| 114 |
+
) -> ChatResponse:
|
| 115 |
+
"""Generate a complete response using MiniMax API"""
|
| 116 |
+
if not self.is_loaded:
|
| 117 |
+
raise ModelNotLoadedError("MiniMax API client not initialized")
|
| 118 |
+
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
message_id = str(uuid.uuid4())
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# Validate parameters
|
| 124 |
+
params = self.validate_parameters(
|
| 125 |
+
temperature=temperature,
|
| 126 |
+
max_tokens=max_tokens,
|
| 127 |
+
**kwargs
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Format messages
|
| 131 |
+
api_messages = self._format_messages_for_api(messages)
|
| 132 |
+
|
| 133 |
+
# Prepare request data
|
| 134 |
+
request_data = {
|
| 135 |
+
'model': self.model_version,
|
| 136 |
+
'messages': api_messages,
|
| 137 |
+
'stream': False,
|
| 138 |
+
'max_tokens': params['max_tokens'],
|
| 139 |
+
'temperature': params['temperature'],
|
| 140 |
+
'top_p': params.get('top_p', 0.9)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Make API call
|
| 144 |
+
async with httpx.AsyncClient() as client:
|
| 145 |
+
response = await client.post(
|
| 146 |
+
self.api_url,
|
| 147 |
+
headers={
|
| 148 |
+
'Content-Type': 'application/json',
|
| 149 |
+
'Authorization': f'Bearer {self.api_key}'
|
| 150 |
+
},
|
| 151 |
+
json=request_data,
|
| 152 |
+
timeout=30.0
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if response.status_code != 200:
|
| 156 |
+
raise GenerationError(f"API request failed with status {response.status_code}")
|
| 157 |
+
|
| 158 |
+
response_data = response.json()
|
| 159 |
+
|
| 160 |
+
# Extract response text
|
| 161 |
+
if 'choices' in response_data and response_data['choices']:
|
| 162 |
+
choice = response_data['choices'][0]
|
| 163 |
+
if 'message' in choice:
|
| 164 |
+
response_text = choice['message'].get('content', '')
|
| 165 |
+
reasoning_content = choice['message'].get('reasoning_content', '')
|
| 166 |
+
|
| 167 |
+
# Combine reasoning and main content if both exist
|
| 168 |
+
if reasoning_content and response_text:
|
| 169 |
+
full_response = f"[Reasoning: {reasoning_content}]\n\n{response_text}"
|
| 170 |
+
else:
|
| 171 |
+
full_response = response_text or reasoning_content
|
| 172 |
+
else:
|
| 173 |
+
full_response = str(response_data)
|
| 174 |
+
else:
|
| 175 |
+
full_response = str(response_data)
|
| 176 |
+
|
| 177 |
+
generation_time = time.time() - start_time
|
| 178 |
+
|
| 179 |
+
return ChatResponse(
|
| 180 |
+
message=full_response.strip(),
|
| 181 |
+
session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
|
| 182 |
+
message_id=message_id,
|
| 183 |
+
model_name=self.model_name,
|
| 184 |
+
generation_time=generation_time,
|
| 185 |
+
token_count=len(full_response.split()), # Approximate token count
|
| 186 |
+
finish_reason="stop"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
self.log_error("MiniMax API generation failed", error=str(e), model=self.model_name)
|
| 191 |
+
raise GenerationError(f"Failed to generate response via MiniMax API: {str(e)}")
|
| 192 |
+
|
| 193 |
+
async def generate_stream(
|
| 194 |
+
self,
|
| 195 |
+
messages: List[ChatMessage],
|
| 196 |
+
temperature: float = 0.7,
|
| 197 |
+
max_tokens: int = 512,
|
| 198 |
+
**kwargs
|
| 199 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 200 |
+
"""Generate a streaming response using MiniMax API"""
|
| 201 |
+
if not self.is_loaded:
|
| 202 |
+
raise ModelNotLoadedError("MiniMax API client not initialized")
|
| 203 |
+
|
| 204 |
+
message_id = str(uuid.uuid4())
|
| 205 |
+
session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
|
| 206 |
+
chunk_id = 0
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
# Validate parameters
|
| 210 |
+
params = self.validate_parameters(
|
| 211 |
+
temperature=temperature,
|
| 212 |
+
max_tokens=max_tokens,
|
| 213 |
+
**kwargs
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Format messages
|
| 217 |
+
api_messages = self._format_messages_for_api(messages)
|
| 218 |
+
|
| 219 |
+
# Prepare request data
|
| 220 |
+
request_data = {
|
| 221 |
+
'model': self.model_version,
|
| 222 |
+
'messages': api_messages,
|
| 223 |
+
'stream': True,
|
| 224 |
+
'max_tokens': params['max_tokens'],
|
| 225 |
+
'temperature': params['temperature'],
|
| 226 |
+
'top_p': params.get('top_p', 0.9)
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
# Make streaming API call
|
| 230 |
+
async with httpx.AsyncClient() as client:
|
| 231 |
+
async with client.stream(
|
| 232 |
+
'POST',
|
| 233 |
+
self.api_url,
|
| 234 |
+
headers={
|
| 235 |
+
'Content-Type': 'application/json',
|
| 236 |
+
'Authorization': f'Bearer {self.api_key}'
|
| 237 |
+
},
|
| 238 |
+
json=request_data,
|
| 239 |
+
timeout=60.0
|
| 240 |
+
) as response:
|
| 241 |
+
|
| 242 |
+
if response.status_code != 200:
|
| 243 |
+
raise GenerationError(f"Streaming request failed with status {response.status_code}")
|
| 244 |
+
|
| 245 |
+
async for line in response.aiter_lines():
|
| 246 |
+
if line.startswith('data:'):
|
| 247 |
+
try:
|
| 248 |
+
data = json.loads(line[5:]) # Remove 'data:' prefix
|
| 249 |
+
|
| 250 |
+
if 'choices' not in data:
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
choice = data['choices'][0]
|
| 254 |
+
|
| 255 |
+
# Handle delta content
|
| 256 |
+
if 'delta' in choice:
|
| 257 |
+
delta = choice['delta']
|
| 258 |
+
reasoning_content = delta.get('reasoning_content', '')
|
| 259 |
+
content = delta.get('content', '')
|
| 260 |
+
|
| 261 |
+
# Send reasoning content if available
|
| 262 |
+
if reasoning_content:
|
| 263 |
+
yield StreamChunk(
|
| 264 |
+
content=f"[Thinking: {reasoning_content}]",
|
| 265 |
+
session_id=session_id,
|
| 266 |
+
message_id=message_id,
|
| 267 |
+
chunk_id=chunk_id,
|
| 268 |
+
is_final=False
|
| 269 |
+
)
|
| 270 |
+
chunk_id += 1
|
| 271 |
+
|
| 272 |
+
# Send main content
|
| 273 |
+
if content:
|
| 274 |
+
yield StreamChunk(
|
| 275 |
+
content=content,
|
| 276 |
+
session_id=session_id,
|
| 277 |
+
message_id=message_id,
|
| 278 |
+
chunk_id=chunk_id,
|
| 279 |
+
is_final=False
|
| 280 |
+
)
|
| 281 |
+
chunk_id += 1
|
| 282 |
+
|
| 283 |
+
# Handle complete message
|
| 284 |
+
elif 'message' in choice:
|
| 285 |
+
message_data = choice['message']
|
| 286 |
+
reasoning_content = message_data.get('reasoning_content', '')
|
| 287 |
+
main_content = message_data.get('content', '')
|
| 288 |
+
|
| 289 |
+
if reasoning_content:
|
| 290 |
+
yield StreamChunk(
|
| 291 |
+
content=f"\n[Final reasoning: {reasoning_content}]\n",
|
| 292 |
+
session_id=session_id,
|
| 293 |
+
message_id=message_id,
|
| 294 |
+
chunk_id=chunk_id,
|
| 295 |
+
is_final=False
|
| 296 |
+
)
|
| 297 |
+
chunk_id += 1
|
| 298 |
+
|
| 299 |
+
if main_content:
|
| 300 |
+
yield StreamChunk(
|
| 301 |
+
content=main_content,
|
| 302 |
+
session_id=session_id,
|
| 303 |
+
message_id=message_id,
|
| 304 |
+
chunk_id=chunk_id,
|
| 305 |
+
is_final=False
|
| 306 |
+
)
|
| 307 |
+
chunk_id += 1
|
| 308 |
+
|
| 309 |
+
except json.JSONDecodeError:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
# Add small delay
|
| 313 |
+
await asyncio.sleep(settings.stream_delay)
|
| 314 |
+
|
| 315 |
+
# Send final chunk
|
| 316 |
+
yield StreamChunk(
|
| 317 |
+
content="",
|
| 318 |
+
session_id=session_id,
|
| 319 |
+
message_id=message_id,
|
| 320 |
+
chunk_id=chunk_id,
|
| 321 |
+
is_final=True
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
self.log_error("MiniMax API streaming failed", error=str(e), model=self.model_name)
|
| 326 |
+
raise GenerationError(f"Failed to generate streaming response via MiniMax API: {str(e)}")
|
| 327 |
+
|
| 328 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 329 |
+
"""Get information about the current model"""
|
| 330 |
+
return {
|
| 331 |
+
"name": self.model_name,
|
| 332 |
+
"type": "minimax_api",
|
| 333 |
+
"loaded": self.is_loaded,
|
| 334 |
+
"provider": "MiniMax",
|
| 335 |
+
"model_version": self.model_version,
|
| 336 |
+
"capabilities": self.capabilities,
|
| 337 |
+
"parameters": self.parameters,
|
| 338 |
+
"requires_api_key": True,
|
| 339 |
+
"api_key_configured": bool(self.api_key),
|
| 340 |
+
"api_url": self.api_url
|
| 341 |
+
}
|
app/services/model_backends/openai_api.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI API backend
|
| 3 |
+
Uses OpenAI's API for model access (GPT-3.5, GPT-4, etc.)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from typing import AsyncGenerator, List, Dict, Any, Optional
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import openai
|
| 12 |
+
|
| 13 |
+
from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
|
| 14 |
+
from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
|
| 15 |
+
from ...core.config import settings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OpenAIAPIBackend(ModelBackend):
|
| 19 |
+
"""OpenAI API backend"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_name: str, **kwargs):
|
| 22 |
+
super().__init__(model_name, **kwargs)
|
| 23 |
+
self.client = None
|
| 24 |
+
self.api_key = kwargs.get('api_key', settings.openai_api_key)
|
| 25 |
+
self.org_id = kwargs.get('org_id', settings.openai_org_id)
|
| 26 |
+
self.capabilities = ["chat", "streaming", "api_based", "function_calling"]
|
| 27 |
+
|
| 28 |
+
# Generation parameters
|
| 29 |
+
self.parameters = {
|
| 30 |
+
'temperature': kwargs.get('temperature', settings.temperature),
|
| 31 |
+
'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
|
| 32 |
+
'top_p': kwargs.get('top_p', settings.top_p),
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
async def load_model(self) -> bool:
|
| 36 |
+
"""Initialize the OpenAI API client"""
|
| 37 |
+
try:
|
| 38 |
+
if not self.api_key:
|
| 39 |
+
raise ModelLoadError("OpenAI API key is required")
|
| 40 |
+
|
| 41 |
+
self.log_info("Initializing OpenAI API client", model=self.model_name)
|
| 42 |
+
|
| 43 |
+
# Initialize the OpenAI client
|
| 44 |
+
self.client = openai.AsyncOpenAI(
|
| 45 |
+
api_key=self.api_key,
|
| 46 |
+
organization=self.org_id
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Test the connection
|
| 50 |
+
await self._test_connection()
|
| 51 |
+
|
| 52 |
+
self.is_loaded = True
|
| 53 |
+
self.log_info("OpenAI API client initialized successfully", model=self.model_name)
|
| 54 |
+
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
self.log_error("Failed to initialize OpenAI API client", error=str(e), model=self.model_name)
|
| 59 |
+
raise ModelLoadError(f"Failed to initialize OpenAI API for {self.model_name}: {str(e)}")
|
| 60 |
+
|
| 61 |
+
async def unload_model(self) -> bool:
|
| 62 |
+
"""Clean up the API client"""
|
| 63 |
+
try:
|
| 64 |
+
if self.client:
|
| 65 |
+
await self.client.close()
|
| 66 |
+
self.client = None
|
| 67 |
+
self.is_loaded = False
|
| 68 |
+
self.log_info("OpenAI API client cleaned up", model=self.model_name)
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
self.log_error("Failed to cleanup OpenAI API client", error=str(e), model=self.model_name)
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
async def _test_connection(self):
|
| 76 |
+
"""Test the OpenAI API connection"""
|
| 77 |
+
try:
|
| 78 |
+
# Simple test request
|
| 79 |
+
response = await self.client.chat.completions.create(
|
| 80 |
+
model=self.model_name,
|
| 81 |
+
messages=[{"role": "user", "content": "Hello"}],
|
| 82 |
+
max_tokens=5,
|
| 83 |
+
temperature=0.1
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.log_info("OpenAI API connection test successful", model=self.model_name)
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
self.log_error("OpenAI API connection test failed", error=str(e), model=self.model_name)
|
| 90 |
+
raise
|
| 91 |
+
|
| 92 |
+
def _format_messages_for_api(self, messages: List[ChatMessage]) -> List[Dict[str, str]]:
|
| 93 |
+
"""Format messages for OpenAI API"""
|
| 94 |
+
formatted = []
|
| 95 |
+
for msg in messages:
|
| 96 |
+
formatted.append({
|
| 97 |
+
"role": msg.role,
|
| 98 |
+
"content": msg.content
|
| 99 |
+
})
|
| 100 |
+
return formatted
|
| 101 |
+
|
| 102 |
+
async def generate_response(
|
| 103 |
+
self,
|
| 104 |
+
messages: List[ChatMessage],
|
| 105 |
+
temperature: float = 0.7,
|
| 106 |
+
max_tokens: int = 512,
|
| 107 |
+
**kwargs
|
| 108 |
+
) -> ChatResponse:
|
| 109 |
+
"""Generate a complete response using OpenAI API"""
|
| 110 |
+
if not self.is_loaded:
|
| 111 |
+
raise ModelNotLoadedError("OpenAI API client not initialized")
|
| 112 |
+
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
message_id = str(uuid.uuid4())
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Validate parameters
|
| 118 |
+
params = self.validate_parameters(
|
| 119 |
+
temperature=temperature,
|
| 120 |
+
max_tokens=max_tokens,
|
| 121 |
+
**kwargs
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Format messages
|
| 125 |
+
api_messages = self._format_messages_for_api(messages)
|
| 126 |
+
|
| 127 |
+
# Make API call
|
| 128 |
+
response = await self.client.chat.completions.create(
|
| 129 |
+
model=self.model_name,
|
| 130 |
+
messages=api_messages,
|
| 131 |
+
max_tokens=params['max_tokens'],
|
| 132 |
+
temperature=params['temperature'],
|
| 133 |
+
top_p=params.get('top_p', 0.9),
|
| 134 |
+
stream=False
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Extract response
|
| 138 |
+
response_text = response.choices[0].message.content
|
| 139 |
+
finish_reason = response.choices[0].finish_reason
|
| 140 |
+
token_count = response.usage.completion_tokens if response.usage else None
|
| 141 |
+
|
| 142 |
+
generation_time = time.time() - start_time
|
| 143 |
+
|
| 144 |
+
return ChatResponse(
|
| 145 |
+
message=response_text.strip(),
|
| 146 |
+
session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
|
| 147 |
+
message_id=message_id,
|
| 148 |
+
model_name=self.model_name,
|
| 149 |
+
generation_time=generation_time,
|
| 150 |
+
token_count=token_count,
|
| 151 |
+
finish_reason=finish_reason
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
self.log_error("OpenAI API generation failed", error=str(e), model=self.model_name)
|
| 156 |
+
raise GenerationError(f"Failed to generate response via OpenAI API: {str(e)}")
|
| 157 |
+
|
| 158 |
+
async def generate_stream(
|
| 159 |
+
self,
|
| 160 |
+
messages: List[ChatMessage],
|
| 161 |
+
temperature: float = 0.7,
|
| 162 |
+
max_tokens: int = 512,
|
| 163 |
+
**kwargs
|
| 164 |
+
) -> AsyncGenerator[StreamChunk, None]:
|
| 165 |
+
"""Generate a streaming response using OpenAI API"""
|
| 166 |
+
if not self.is_loaded:
|
| 167 |
+
raise ModelNotLoadedError("OpenAI API client not initialized")
|
| 168 |
+
|
| 169 |
+
message_id = str(uuid.uuid4())
|
| 170 |
+
session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
|
| 171 |
+
chunk_id = 0
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
# Validate parameters
|
| 175 |
+
params = self.validate_parameters(
|
| 176 |
+
temperature=temperature,
|
| 177 |
+
max_tokens=max_tokens,
|
| 178 |
+
**kwargs
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Format messages
|
| 182 |
+
api_messages = self._format_messages_for_api(messages)
|
| 183 |
+
|
| 184 |
+
# Create streaming request
|
| 185 |
+
stream = await self.client.chat.completions.create(
|
| 186 |
+
model=self.model_name,
|
| 187 |
+
messages=api_messages,
|
| 188 |
+
max_tokens=params['max_tokens'],
|
| 189 |
+
temperature=params['temperature'],
|
| 190 |
+
top_p=params.get('top_p', 0.9),
|
| 191 |
+
stream=True
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Process streaming chunks
|
| 195 |
+
async for chunk in stream:
|
| 196 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
| 197 |
+
content = chunk.choices[0].delta.content
|
| 198 |
+
|
| 199 |
+
yield StreamChunk(
|
| 200 |
+
content=content,
|
| 201 |
+
session_id=session_id,
|
| 202 |
+
message_id=message_id,
|
| 203 |
+
chunk_id=chunk_id,
|
| 204 |
+
is_final=False
|
| 205 |
+
)
|
| 206 |
+
chunk_id += 1
|
| 207 |
+
|
| 208 |
+
# Add small delay to prevent overwhelming the client
|
| 209 |
+
await asyncio.sleep(settings.stream_delay)
|
| 210 |
+
|
| 211 |
+
# Check if this is the final chunk
|
| 212 |
+
if chunk.choices and chunk.choices[0].finish_reason:
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
# Send final chunk
|
| 216 |
+
yield StreamChunk(
|
| 217 |
+
content="",
|
| 218 |
+
session_id=session_id,
|
| 219 |
+
message_id=message_id,
|
| 220 |
+
chunk_id=chunk_id,
|
| 221 |
+
is_final=True
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
self.log_error("OpenAI API streaming failed", error=str(e), model=self.model_name)
|
| 226 |
+
raise GenerationError(f"Failed to generate streaming response via OpenAI API: {str(e)}")
|
| 227 |
+
|
| 228 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 229 |
+
"""Get information about the current model"""
|
| 230 |
+
return {
|
| 231 |
+
"name": self.model_name,
|
| 232 |
+
"type": "openai_api",
|
| 233 |
+
"loaded": self.is_loaded,
|
| 234 |
+
"provider": "OpenAI",
|
| 235 |
+
"capabilities": self.capabilities,
|
| 236 |
+
"parameters": self.parameters,
|
| 237 |
+
"requires_api_key": True,
|
| 238 |
+
"api_key_configured": bool(self.api_key),
|
| 239 |
+
"organization": self.org_id
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 243 |
+
"""Perform a health check on the OpenAI API"""
|
| 244 |
+
try:
|
| 245 |
+
if not self.is_loaded:
|
| 246 |
+
return {
|
| 247 |
+
"status": "unhealthy",
|
| 248 |
+
"reason": "client_not_initialized",
|
| 249 |
+
"model_name": self.model_name
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
# Test API connectivity
|
| 253 |
+
start_time = time.time()
|
| 254 |
+
test_messages = [ChatMessage(role="user", content="Hi")]
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
response = await asyncio.wait_for(
|
| 258 |
+
self.generate_response(
|
| 259 |
+
test_messages,
|
| 260 |
+
temperature=0.1,
|
| 261 |
+
max_tokens=5
|
| 262 |
+
),
|
| 263 |
+
timeout=10.0
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
response_time = time.time() - start_time
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"status": "healthy",
|
| 270 |
+
"model_name": self.model_name,
|
| 271 |
+
"response_time": response_time,
|
| 272 |
+
"provider": "OpenAI"
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
except asyncio.TimeoutError:
|
| 276 |
+
return {
|
| 277 |
+
"status": "unhealthy",
|
| 278 |
+
"reason": "api_timeout",
|
| 279 |
+
"model_name": self.model_name,
|
| 280 |
+
"provider": "OpenAI"
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
self.log_error("OpenAI API health check failed", error=str(e), model=self.model_name)
|
| 285 |
+
return {
|
| 286 |
+
"status": "unhealthy",
|
| 287 |
+
"reason": "api_error",
|
| 288 |
+
"error": str(e),
|
| 289 |
+
"model_name": self.model_name,
|
| 290 |
+
"provider": "OpenAI"
|
| 291 |
+
}
|
app/services/model_manager.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Manager - Central hub for managing different model backends
|
| 3 |
+
Handles backend selection based on environment configuration
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Dict, Any
|
| 7 |
+
from ..core.config import settings
|
| 8 |
+
from ..core.logging import LoggerMixin
|
| 9 |
+
from .model_backends.base import ModelBackend, ModelBackendError
|
| 10 |
+
from .model_backends.local_hf import LocalHuggingFaceBackend
|
| 11 |
+
from .model_backends.hf_api import HuggingFaceAPIBackend
|
| 12 |
+
from .model_backends.openai_api import OpenAIAPIBackend
|
| 13 |
+
from .model_backends.anthropic_api import AnthropicAPIBackend
|
| 14 |
+
from .model_backends.minimax_api import MiniMaxAPIBackend
|
| 15 |
+
from .model_backends.google_api import GoogleAIBackend
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ModelManager(LoggerMixin):
|
| 19 |
+
"""
|
| 20 |
+
Central manager for model backends
|
| 21 |
+
Handles initialization, switching, and management of different model types
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.current_backend: Optional[ModelBackend] = None
|
| 26 |
+
self.backend_type = settings.model_type.lower()
|
| 27 |
+
self.model_name = settings.model_name
|
| 28 |
+
self.is_initialized = False
|
| 29 |
+
|
| 30 |
+
async def initialize(self) -> bool:
|
| 31 |
+
"""
|
| 32 |
+
Initialize the model backend based on configuration
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
bool: True if initialization successful, False otherwise
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
self.log_info("Initializing model manager",
|
| 39 |
+
backend_type=self.backend_type,
|
| 40 |
+
model_name=self.model_name)
|
| 41 |
+
|
| 42 |
+
# Validate configuration
|
| 43 |
+
if not settings.validate_model_config():
|
| 44 |
+
self.log_error("Invalid model configuration", backend_type=self.backend_type)
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
# Create the appropriate backend
|
| 48 |
+
backend = self._create_backend()
|
| 49 |
+
if not backend:
|
| 50 |
+
self.log_error("Failed to create backend", backend_type=self.backend_type)
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
# Load the model
|
| 54 |
+
success = await backend.load_model()
|
| 55 |
+
if success:
|
| 56 |
+
self.current_backend = backend
|
| 57 |
+
self.is_initialized = True
|
| 58 |
+
self.log_info("Model manager initialized successfully",
|
| 59 |
+
backend_type=self.backend_type,
|
| 60 |
+
model_name=self.model_name)
|
| 61 |
+
return True
|
| 62 |
+
else:
|
| 63 |
+
self.log_error("Failed to load model", backend_type=self.backend_type)
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
self.log_error("Model manager initialization failed",
|
| 68 |
+
error=str(e),
|
| 69 |
+
backend_type=self.backend_type)
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
def _create_backend(self) -> Optional[ModelBackend]:
|
| 73 |
+
"""Create the appropriate model backend based on configuration"""
|
| 74 |
+
try:
|
| 75 |
+
if self.backend_type == "local":
|
| 76 |
+
return LocalHuggingFaceBackend(
|
| 77 |
+
model_name=self.model_name,
|
| 78 |
+
device=settings.device,
|
| 79 |
+
temperature=settings.temperature,
|
| 80 |
+
max_tokens=settings.max_new_tokens,
|
| 81 |
+
top_p=settings.top_p,
|
| 82 |
+
top_k=settings.top_k
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
elif self.backend_type == "hf_api":
|
| 86 |
+
return HuggingFaceAPIBackend(
|
| 87 |
+
model_name=self.model_name,
|
| 88 |
+
api_token=settings.hf_api_token,
|
| 89 |
+
inference_url=settings.hf_inference_url,
|
| 90 |
+
temperature=settings.temperature,
|
| 91 |
+
max_tokens=settings.max_new_tokens,
|
| 92 |
+
top_p=settings.top_p
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
elif self.backend_type == "openai":
|
| 96 |
+
return OpenAIAPIBackend(
|
| 97 |
+
model_name=self.model_name,
|
| 98 |
+
api_key=settings.openai_api_key,
|
| 99 |
+
org_id=settings.openai_org_id,
|
| 100 |
+
temperature=settings.temperature,
|
| 101 |
+
max_tokens=settings.max_new_tokens,
|
| 102 |
+
top_p=settings.top_p
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
elif self.backend_type == "anthropic":
|
| 106 |
+
return AnthropicAPIBackend(
|
| 107 |
+
model_name=self.model_name,
|
| 108 |
+
api_key=settings.anthropic_api_key,
|
| 109 |
+
temperature=settings.temperature,
|
| 110 |
+
max_tokens=settings.max_new_tokens,
|
| 111 |
+
top_p=settings.top_p
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
elif self.backend_type == "minimax":
|
| 115 |
+
return MiniMaxAPIBackend(
|
| 116 |
+
model_name=self.model_name,
|
| 117 |
+
api_key=settings.minimax_api_key,
|
| 118 |
+
api_url=settings.minimax_api_url,
|
| 119 |
+
model_version=settings.minimax_model_version,
|
| 120 |
+
temperature=settings.temperature,
|
| 121 |
+
max_tokens=settings.max_new_tokens,
|
| 122 |
+
top_p=settings.top_p
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
elif self.backend_type == "google":
|
| 126 |
+
return GoogleAIBackend(
|
| 127 |
+
model_name=self.model_name,
|
| 128 |
+
api_key=settings.google_api_key,
|
| 129 |
+
temperature=settings.temperature,
|
| 130 |
+
max_tokens=settings.max_new_tokens,
|
| 131 |
+
top_p=settings.top_p,
|
| 132 |
+
top_k=settings.top_k
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
self.log_error("Unsupported backend type", backend_type=self.backend_type)
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
self.log_error("Failed to create backend", error=str(e), backend_type=self.backend_type)
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
async def shutdown(self) -> bool:
|
| 144 |
+
"""
|
| 145 |
+
Shutdown the current backend and cleanup resources
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
bool: True if shutdown successful, False otherwise
|
| 149 |
+
"""
|
| 150 |
+
try:
|
| 151 |
+
if self.current_backend:
|
| 152 |
+
success = await self.current_backend.unload_model()
|
| 153 |
+
self.current_backend = None
|
| 154 |
+
self.is_initialized = False
|
| 155 |
+
self.log_info("Model manager shutdown successfully")
|
| 156 |
+
return success
|
| 157 |
+
return True
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
self.log_error("Model manager shutdown failed", error=str(e))
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
def get_backend(self) -> Optional[ModelBackend]:
|
| 164 |
+
"""
|
| 165 |
+
Get the current model backend
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
ModelBackend: Current backend instance or None if not initialized
|
| 169 |
+
"""
|
| 170 |
+
return self.current_backend
|
| 171 |
+
|
| 172 |
+
def is_ready(self) -> bool:
|
| 173 |
+
"""
|
| 174 |
+
Check if the model manager is ready for inference
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
bool: True if ready, False otherwise
|
| 178 |
+
"""
|
| 179 |
+
return (self.is_initialized and
|
| 180 |
+
self.current_backend is not None and
|
| 181 |
+
self.current_backend.is_model_loaded())
|
| 182 |
+
|
| 183 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 184 |
+
"""
|
| 185 |
+
Get information about the current model
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Dict containing model information
|
| 189 |
+
"""
|
| 190 |
+
if not self.current_backend:
|
| 191 |
+
return {
|
| 192 |
+
"status": "not_initialized",
|
| 193 |
+
"backend_type": self.backend_type,
|
| 194 |
+
"model_name": self.model_name,
|
| 195 |
+
"is_ready": False
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
info = self.current_backend.get_model_info()
|
| 199 |
+
info.update({
|
| 200 |
+
"is_ready": self.is_ready(),
|
| 201 |
+
"manager_initialized": self.is_initialized
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
return info
|
| 205 |
+
|
| 206 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 207 |
+
"""
|
| 208 |
+
Perform a comprehensive health check
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Dict containing health status
|
| 212 |
+
"""
|
| 213 |
+
if not self.is_ready():
|
| 214 |
+
return {
|
| 215 |
+
"status": "unhealthy",
|
| 216 |
+
"reason": "manager_not_ready",
|
| 217 |
+
"backend_type": self.backend_type,
|
| 218 |
+
"model_name": self.model_name,
|
| 219 |
+
"is_initialized": self.is_initialized,
|
| 220 |
+
"backend_loaded": self.current_backend is not None
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# Delegate to backend health check
|
| 224 |
+
backend_health = await self.current_backend.health_check()
|
| 225 |
+
|
| 226 |
+
# Add manager-level information
|
| 227 |
+
backend_health.update({
|
| 228 |
+
"manager_status": "healthy",
|
| 229 |
+
"backend_type": self.backend_type,
|
| 230 |
+
"is_ready": self.is_ready()
|
| 231 |
+
})
|
| 232 |
+
|
| 233 |
+
return backend_health
|
| 234 |
+
|
| 235 |
+
async def switch_model(self, new_model_name: str, new_backend_type: Optional[str] = None) -> bool:
|
| 236 |
+
"""
|
| 237 |
+
Switch to a different model (and optionally backend type)
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
new_model_name: Name of the new model
|
| 241 |
+
new_backend_type: Optional new backend type
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
bool: True if switch successful, False otherwise
|
| 245 |
+
"""
|
| 246 |
+
try:
|
| 247 |
+
self.log_info("Switching model",
|
| 248 |
+
current_model=self.model_name,
|
| 249 |
+
new_model=new_model_name,
|
| 250 |
+
current_backend=self.backend_type,
|
| 251 |
+
new_backend=new_backend_type)
|
| 252 |
+
|
| 253 |
+
# Shutdown current backend
|
| 254 |
+
if self.current_backend:
|
| 255 |
+
await self.current_backend.unload_model()
|
| 256 |
+
self.current_backend = None
|
| 257 |
+
|
| 258 |
+
# Update configuration
|
| 259 |
+
old_model_name = self.model_name
|
| 260 |
+
old_backend_type = self.backend_type
|
| 261 |
+
|
| 262 |
+
self.model_name = new_model_name
|
| 263 |
+
if new_backend_type:
|
| 264 |
+
self.backend_type = new_backend_type.lower()
|
| 265 |
+
|
| 266 |
+
# Try to initialize new backend
|
| 267 |
+
success = await self.initialize()
|
| 268 |
+
|
| 269 |
+
if not success:
|
| 270 |
+
# Rollback on failure
|
| 271 |
+
self.log_warning("Model switch failed, rolling back",
|
| 272 |
+
failed_model=new_model_name,
|
| 273 |
+
rollback_model=old_model_name)
|
| 274 |
+
|
| 275 |
+
self.model_name = old_model_name
|
| 276 |
+
self.backend_type = old_backend_type
|
| 277 |
+
await self.initialize() # Try to restore previous state
|
| 278 |
+
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
self.log_info("Model switch successful",
|
| 282 |
+
new_model=new_model_name,
|
| 283 |
+
new_backend=self.backend_type)
|
| 284 |
+
return True
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
self.log_error("Model switch failed", error=str(e))
|
| 288 |
+
return False
|
| 289 |
+
|
| 290 |
+
def get_supported_backends(self) -> Dict[str, Dict[str, Any]]:
|
| 291 |
+
"""
|
| 292 |
+
Get information about supported backends
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Dict containing backend information
|
| 296 |
+
"""
|
| 297 |
+
return {
|
| 298 |
+
"local": {
|
| 299 |
+
"name": "Local HuggingFace",
|
| 300 |
+
"description": "Run models locally using transformers",
|
| 301 |
+
"requires": ["model_name", "device"],
|
| 302 |
+
"capabilities": ["chat", "streaming", "offline"],
|
| 303 |
+
"example_models": [
|
| 304 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 305 |
+
"microsoft/DialoGPT-medium",
|
| 306 |
+
"Qwen/Qwen2.5-0.5B-Instruct"
|
| 307 |
+
]
|
| 308 |
+
},
|
| 309 |
+
"hf_api": {
|
| 310 |
+
"name": "HuggingFace Inference API",
|
| 311 |
+
"description": "Use HuggingFace's hosted inference API",
|
| 312 |
+
"requires": ["model_name", "hf_api_token"],
|
| 313 |
+
"capabilities": ["chat", "streaming", "serverless"],
|
| 314 |
+
"example_models": [
|
| 315 |
+
"microsoft/DialoGPT-large",
|
| 316 |
+
"microsoft/phi-2",
|
| 317 |
+
"google/gemma-2b-it"
|
| 318 |
+
]
|
| 319 |
+
},
|
| 320 |
+
"openai": {
|
| 321 |
+
"name": "OpenAI API",
|
| 322 |
+
"description": "Use OpenAI's GPT models",
|
| 323 |
+
"requires": ["model_name", "openai_api_key"],
|
| 324 |
+
"capabilities": ["chat", "streaming", "function_calling"],
|
| 325 |
+
"example_models": [
|
| 326 |
+
"gpt-3.5-turbo",
|
| 327 |
+
"gpt-4",
|
| 328 |
+
"gpt-4-turbo"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
"anthropic": {
|
| 332 |
+
"name": "Anthropic API",
|
| 333 |
+
"description": "Use Anthropic's Claude models",
|
| 334 |
+
"requires": ["model_name", "anthropic_api_key"],
|
| 335 |
+
"capabilities": ["chat", "streaming", "long_context"],
|
| 336 |
+
"example_models": [
|
| 337 |
+
"claude-3-haiku-20240307",
|
| 338 |
+
"claude-3-sonnet-20240229",
|
| 339 |
+
"claude-3-opus-20240229"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
"minimax": {
|
| 343 |
+
"name": "MiniMax API",
|
| 344 |
+
"description": "Use MiniMax's M1 model with reasoning capabilities",
|
| 345 |
+
"requires": ["model_name", "minimax_api_key", "minimax_api_url"],
|
| 346 |
+
"capabilities": ["chat", "streaming", "reasoning"],
|
| 347 |
+
"example_models": [
|
| 348 |
+
"MiniMax-M1"
|
| 349 |
+
]
|
| 350 |
+
},
|
| 351 |
+
"google": {
|
| 352 |
+
"name": "Google AI Studio",
|
| 353 |
+
"description": "Use Google's Gemma and other models via AI Studio",
|
| 354 |
+
"requires": ["model_name", "google_api_key"],
|
| 355 |
+
"capabilities": ["chat", "streaming", "multimodal"],
|
| 356 |
+
"example_models": [
|
| 357 |
+
"gemini-1.5-flash",
|
| 358 |
+
"gemini-1.5-pro",
|
| 359 |
+
"gemma-2-9b-it",
|
| 360 |
+
"gemma-2-27b-it"
|
| 361 |
+
]
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# Global model manager instance
|
| 367 |
+
model_manager = ModelManager()
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
async def get_model_manager() -> ModelManager:
|
| 371 |
+
"""Get the global model manager instance"""
|
| 372 |
+
return model_manager
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
async def initialize_model_manager() -> bool:
|
| 376 |
+
"""Initialize the global model manager"""
|
| 377 |
+
return await model_manager.initialize()
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
async def shutdown_model_manager() -> bool:
|
| 381 |
+
"""Shutdown the global model manager"""
|
| 382 |
+
return await model_manager.shutdown()
|
app/services/session_manager.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session Manager - Handles conversation sessions and message history
|
| 3 |
+
Supports both in-memory and Redis-based storage
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
from typing import Dict, List, Optional, Any
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
import json
|
| 11 |
+
import uuid
|
| 12 |
+
|
| 13 |
+
from ..core.config import settings
|
| 14 |
+
from ..core.logging import LoggerMixin
|
| 15 |
+
from ..models.schemas import ChatMessage, ConversationHistory, SessionInfo
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SessionManager(LoggerMixin):
|
| 19 |
+
"""
|
| 20 |
+
Manages chat sessions and conversation history
|
| 21 |
+
Supports both in-memory and Redis storage backends
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.sessions: Dict[str, ConversationHistory] = {}
|
| 26 |
+
self.redis_client = None
|
| 27 |
+
self.use_redis = bool(settings.redis_url)
|
| 28 |
+
self.session_timeout = settings.session_timeout * 60 # Convert to seconds
|
| 29 |
+
self.max_sessions_per_user = settings.max_sessions_per_user
|
| 30 |
+
self.max_messages_per_session = settings.max_messages_per_session
|
| 31 |
+
|
| 32 |
+
# Cleanup task
|
| 33 |
+
self._cleanup_task = None
|
| 34 |
+
|
| 35 |
+
async def initialize(self) -> bool:
|
| 36 |
+
"""Initialize the session manager"""
|
| 37 |
+
try:
|
| 38 |
+
if self.use_redis:
|
| 39 |
+
await self._initialize_redis()
|
| 40 |
+
|
| 41 |
+
# Start cleanup task
|
| 42 |
+
self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
|
| 43 |
+
|
| 44 |
+
self.log_info("Session manager initialized",
|
| 45 |
+
storage_type="redis" if self.use_redis else "memory",
|
| 46 |
+
session_timeout=self.session_timeout)
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
self.log_error("Failed to initialize session manager", error=str(e))
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
async def _initialize_redis(self):
|
| 54 |
+
"""Initialize Redis connection"""
|
| 55 |
+
try:
|
| 56 |
+
import redis.asyncio as redis
|
| 57 |
+
self.redis_client = redis.from_url(settings.redis_url)
|
| 58 |
+
|
| 59 |
+
# Test connection
|
| 60 |
+
await self.redis_client.ping()
|
| 61 |
+
self.log_info("Redis connection established", url=settings.redis_url)
|
| 62 |
+
|
| 63 |
+
except ImportError:
|
| 64 |
+
self.log_warning("Redis not available, falling back to memory storage")
|
| 65 |
+
self.use_redis = False
|
| 66 |
+
except Exception as e:
|
| 67 |
+
self.log_error("Redis connection failed", error=str(e))
|
| 68 |
+
self.use_redis = False
|
| 69 |
+
|
| 70 |
+
async def shutdown(self):
|
| 71 |
+
"""Shutdown the session manager"""
|
| 72 |
+
try:
|
| 73 |
+
if self._cleanup_task:
|
| 74 |
+
self._cleanup_task.cancel()
|
| 75 |
+
try:
|
| 76 |
+
await self._cleanup_task
|
| 77 |
+
except asyncio.CancelledError:
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
if self.redis_client:
|
| 81 |
+
await self.redis_client.close()
|
| 82 |
+
|
| 83 |
+
self.log_info("Session manager shutdown complete")
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
self.log_error("Session manager shutdown failed", error=str(e))
|
| 87 |
+
|
| 88 |
+
async def create_session(self, session_id: str, user_id: Optional[str] = None) -> bool:
|
| 89 |
+
"""
|
| 90 |
+
Create a new chat session
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
session_id: Unique session identifier
|
| 94 |
+
user_id: Optional user identifier for session limits
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
bool: True if session created successfully
|
| 98 |
+
"""
|
| 99 |
+
try:
|
| 100 |
+
# Check if session already exists
|
| 101 |
+
if await self.session_exists(session_id):
|
| 102 |
+
self.log_info("Session already exists", session_id=session_id)
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
# Check user session limits if user_id provided
|
| 106 |
+
if user_id and self.max_sessions_per_user > 0:
|
| 107 |
+
user_sessions = await self.get_user_sessions(user_id)
|
| 108 |
+
if len(user_sessions) >= self.max_sessions_per_user:
|
| 109 |
+
self.log_warning("User session limit exceeded",
|
| 110 |
+
user_id=user_id,
|
| 111 |
+
limit=self.max_sessions_per_user)
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
# Create new session
|
| 115 |
+
session = ConversationHistory(
|
| 116 |
+
session_id=session_id,
|
| 117 |
+
messages=[],
|
| 118 |
+
created_at=datetime.utcnow(),
|
| 119 |
+
updated_at=datetime.utcnow(),
|
| 120 |
+
message_count=0
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
await self._store_session(session)
|
| 124 |
+
|
| 125 |
+
self.log_info("Session created", session_id=session_id, user_id=user_id)
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
self.log_error("Failed to create session", error=str(e), session_id=session_id)
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
async def add_message(self, session_id: str, message: ChatMessage) -> bool:
|
| 133 |
+
"""
|
| 134 |
+
Add a message to a session
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
session_id: Session identifier
|
| 138 |
+
message: Message to add
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
bool: True if message added successfully
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
# Get or create session
|
| 145 |
+
session = await self.get_session(session_id)
|
| 146 |
+
if not session:
|
| 147 |
+
await self.create_session(session_id)
|
| 148 |
+
session = await self.get_session(session_id)
|
| 149 |
+
|
| 150 |
+
if not session:
|
| 151 |
+
self.log_error("Failed to create session", session_id=session_id)
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
# Check message limit
|
| 155 |
+
if (self.max_messages_per_session > 0 and
|
| 156 |
+
len(session.messages) >= self.max_messages_per_session):
|
| 157 |
+
# Remove oldest messages to make room
|
| 158 |
+
messages_to_remove = len(session.messages) - self.max_messages_per_session + 1
|
| 159 |
+
session.messages = session.messages[messages_to_remove:]
|
| 160 |
+
self.log_info("Trimmed old messages",
|
| 161 |
+
session_id=session_id,
|
| 162 |
+
removed_count=messages_to_remove)
|
| 163 |
+
|
| 164 |
+
# Add message
|
| 165 |
+
session.messages.append(message)
|
| 166 |
+
session.message_count = len(session.messages)
|
| 167 |
+
session.updated_at = datetime.utcnow()
|
| 168 |
+
|
| 169 |
+
# Store updated session
|
| 170 |
+
await self._store_session(session)
|
| 171 |
+
|
| 172 |
+
self.log_debug("Message added to session",
|
| 173 |
+
session_id=session_id,
|
| 174 |
+
message_role=message.role,
|
| 175 |
+
total_messages=session.message_count)
|
| 176 |
+
return True
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
self.log_error("Failed to add message", error=str(e), session_id=session_id)
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
async def get_session(self, session_id: str) -> Optional[ConversationHistory]:
|
| 183 |
+
"""
|
| 184 |
+
Get a session by ID
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
session_id: Session identifier
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
ConversationHistory or None if not found
|
| 191 |
+
"""
|
| 192 |
+
try:
|
| 193 |
+
if self.use_redis:
|
| 194 |
+
return await self._get_session_from_redis(session_id)
|
| 195 |
+
else:
|
| 196 |
+
return self.sessions.get(session_id)
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
self.log_error("Failed to get session", error=str(e), session_id=session_id)
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
async def session_exists(self, session_id: str) -> bool:
|
| 203 |
+
"""Check if a session exists"""
|
| 204 |
+
session = await self.get_session(session_id)
|
| 205 |
+
return session is not None
|
| 206 |
+
|
| 207 |
+
async def delete_session(self, session_id: str) -> bool:
|
| 208 |
+
"""
|
| 209 |
+
Delete a session
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
session_id: Session identifier
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
bool: True if session deleted successfully
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
if self.use_redis:
|
| 219 |
+
await self.redis_client.delete(f"session:{session_id}")
|
| 220 |
+
else:
|
| 221 |
+
self.sessions.pop(session_id, None)
|
| 222 |
+
|
| 223 |
+
self.log_info("Session deleted", session_id=session_id)
|
| 224 |
+
return True
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
self.log_error("Failed to delete session", error=str(e), session_id=session_id)
|
| 228 |
+
return False
|
| 229 |
+
|
| 230 |
+
async def get_session_messages(self, session_id: str, limit: Optional[int] = None) -> List[ChatMessage]:
|
| 231 |
+
"""
|
| 232 |
+
Get messages from a session
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
session_id: Session identifier
|
| 236 |
+
limit: Optional limit on number of messages to return
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
List of ChatMessage objects
|
| 240 |
+
"""
|
| 241 |
+
session = await self.get_session(session_id)
|
| 242 |
+
if not session:
|
| 243 |
+
return []
|
| 244 |
+
|
| 245 |
+
messages = session.messages
|
| 246 |
+
if limit and limit > 0:
|
| 247 |
+
messages = messages[-limit:] # Get last N messages
|
| 248 |
+
|
| 249 |
+
return messages
|
| 250 |
+
|
| 251 |
+
async def get_active_sessions(self) -> List[SessionInfo]:
|
| 252 |
+
"""Get information about all active sessions"""
|
| 253 |
+
try:
|
| 254 |
+
sessions = []
|
| 255 |
+
|
| 256 |
+
if self.use_redis:
|
| 257 |
+
# Get all session keys from Redis
|
| 258 |
+
keys = await self.redis_client.keys("session:*")
|
| 259 |
+
for key in keys:
|
| 260 |
+
session_id = key.decode().replace("session:", "")
|
| 261 |
+
session = await self.get_session(session_id)
|
| 262 |
+
if session:
|
| 263 |
+
sessions.append(self._session_to_info(session))
|
| 264 |
+
else:
|
| 265 |
+
# Get from memory
|
| 266 |
+
for session in self.sessions.values():
|
| 267 |
+
sessions.append(self._session_to_info(session))
|
| 268 |
+
|
| 269 |
+
return sessions
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
self.log_error("Failed to get active sessions", error=str(e))
|
| 273 |
+
return []
|
| 274 |
+
|
| 275 |
+
async def get_user_sessions(self, user_id: str) -> List[SessionInfo]:
|
| 276 |
+
"""Get sessions for a specific user (requires user_id in session metadata)"""
|
| 277 |
+
# This is a simplified implementation
|
| 278 |
+
# In a real system, you'd store user_id -> session_id mappings
|
| 279 |
+
all_sessions = await self.get_active_sessions()
|
| 280 |
+
return [s for s in all_sessions if s.session_id.startswith(f"{user_id}-")]
|
| 281 |
+
|
| 282 |
+
def _session_to_info(self, session: ConversationHistory) -> SessionInfo:
|
| 283 |
+
"""Convert ConversationHistory to SessionInfo"""
|
| 284 |
+
return SessionInfo(
|
| 285 |
+
session_id=session.session_id,
|
| 286 |
+
created_at=session.created_at,
|
| 287 |
+
updated_at=session.updated_at,
|
| 288 |
+
message_count=session.message_count,
|
| 289 |
+
model_name=settings.model_name, # Current model
|
| 290 |
+
is_active=True
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
async def _store_session(self, session: ConversationHistory):
|
| 294 |
+
"""Store session in the appropriate backend"""
|
| 295 |
+
if self.use_redis:
|
| 296 |
+
await self._store_session_in_redis(session)
|
| 297 |
+
else:
|
| 298 |
+
self.sessions[session.session_id] = session
|
| 299 |
+
|
| 300 |
+
async def _store_session_in_redis(self, session: ConversationHistory):
|
| 301 |
+
"""Store session in Redis"""
|
| 302 |
+
key = f"session:{session.session_id}"
|
| 303 |
+
data = {
|
| 304 |
+
"session_id": session.session_id,
|
| 305 |
+
"messages": [
|
| 306 |
+
{
|
| 307 |
+
"role": msg.role,
|
| 308 |
+
"content": msg.content,
|
| 309 |
+
"timestamp": msg.timestamp.isoformat(),
|
| 310 |
+
"metadata": msg.metadata or {}
|
| 311 |
+
}
|
| 312 |
+
for msg in session.messages
|
| 313 |
+
],
|
| 314 |
+
"created_at": session.created_at.isoformat(),
|
| 315 |
+
"updated_at": session.updated_at.isoformat(),
|
| 316 |
+
"message_count": session.message_count
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
await self.redis_client.setex(
|
| 320 |
+
key,
|
| 321 |
+
self.session_timeout,
|
| 322 |
+
json.dumps(data, default=str)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
async def _get_session_from_redis(self, session_id: str) -> Optional[ConversationHistory]:
|
| 326 |
+
"""Get session from Redis"""
|
| 327 |
+
key = f"session:{session_id}"
|
| 328 |
+
data = await self.redis_client.get(key)
|
| 329 |
+
|
| 330 |
+
if not data:
|
| 331 |
+
return None
|
| 332 |
+
|
| 333 |
+
try:
|
| 334 |
+
session_data = json.loads(data)
|
| 335 |
+
messages = [
|
| 336 |
+
ChatMessage(
|
| 337 |
+
role=msg["role"],
|
| 338 |
+
content=msg["content"],
|
| 339 |
+
timestamp=datetime.fromisoformat(msg["timestamp"]),
|
| 340 |
+
metadata=msg.get("metadata")
|
| 341 |
+
)
|
| 342 |
+
for msg in session_data["messages"]
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
return ConversationHistory(
|
| 346 |
+
session_id=session_data["session_id"],
|
| 347 |
+
messages=messages,
|
| 348 |
+
created_at=datetime.fromisoformat(session_data["created_at"]),
|
| 349 |
+
updated_at=datetime.fromisoformat(session_data["updated_at"]),
|
| 350 |
+
message_count=session_data["message_count"]
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
self.log_error("Failed to parse session from Redis", error=str(e), session_id=session_id)
|
| 355 |
+
return None
|
| 356 |
+
|
| 357 |
+
async def _cleanup_expired_sessions(self):
|
| 358 |
+
"""Background task to cleanup expired sessions"""
|
| 359 |
+
while True:
|
| 360 |
+
try:
|
| 361 |
+
await asyncio.sleep(300) # Run every 5 minutes
|
| 362 |
+
|
| 363 |
+
if not self.use_redis: # Redis handles expiration automatically
|
| 364 |
+
current_time = datetime.utcnow()
|
| 365 |
+
expired_sessions = []
|
| 366 |
+
|
| 367 |
+
for session_id, session in self.sessions.items():
|
| 368 |
+
if (current_time - session.updated_at).total_seconds() > self.session_timeout:
|
| 369 |
+
expired_sessions.append(session_id)
|
| 370 |
+
|
| 371 |
+
for session_id in expired_sessions:
|
| 372 |
+
del self.sessions[session_id]
|
| 373 |
+
self.log_debug("Expired session cleaned up", session_id=session_id)
|
| 374 |
+
|
| 375 |
+
if expired_sessions:
|
| 376 |
+
self.log_info("Cleaned up expired sessions", count=len(expired_sessions))
|
| 377 |
+
|
| 378 |
+
except asyncio.CancelledError:
|
| 379 |
+
break
|
| 380 |
+
except Exception as e:
|
| 381 |
+
self.log_error("Session cleanup failed", error=str(e))
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# Global session manager instance
|
| 385 |
+
session_manager = SessionManager()
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
async def get_session_manager() -> SessionManager:
|
| 389 |
+
"""Get the global session manager instance"""
|
| 390 |
+
return session_manager
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
async def initialize_session_manager() -> bool:
|
| 394 |
+
"""Initialize the global session manager"""
|
| 395 |
+
return await session_manager.initialize()
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
async def shutdown_session_manager():
|
| 399 |
+
"""Shutdown the global session manager"""
|
| 400 |
+
await session_manager.shutdown()
|
app/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Utilities package
|
app/utils/helpers.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions and helpers
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
import uuid
|
| 7 |
+
import hashlib
|
| 8 |
+
from typing import Optional, Dict, Any, List
|
| 9 |
+
from datetime import datetime, timezone
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def generate_session_id(user_id: Optional[str] = None) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Generate a unique session ID
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
user_id: Optional user identifier to include in session ID
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Unique session identifier
|
| 21 |
+
"""
|
| 22 |
+
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
|
| 23 |
+
random_part = str(uuid.uuid4())[:8]
|
| 24 |
+
|
| 25 |
+
if user_id:
|
| 26 |
+
# Create a hash of user_id for privacy
|
| 27 |
+
user_hash = hashlib.md5(user_id.encode()).hexdigest()[:8]
|
| 28 |
+
return f"{user_hash}-{timestamp}-{random_part}"
|
| 29 |
+
else:
|
| 30 |
+
return f"anon-{timestamp}-{random_part}"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def generate_message_id() -> str:
|
| 34 |
+
"""Generate a unique message ID"""
|
| 35 |
+
return f"msg-{uuid.uuid4()}"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def sanitize_text(text: str, max_length: int = 4000) -> str:
|
| 39 |
+
"""
|
| 40 |
+
Sanitize and clean text input
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
text: Input text to sanitize
|
| 44 |
+
max_length: Maximum allowed length
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Sanitized text
|
| 48 |
+
"""
|
| 49 |
+
if not text:
|
| 50 |
+
return ""
|
| 51 |
+
|
| 52 |
+
# Remove excessive whitespace
|
| 53 |
+
text = re.sub(r'\s+', ' ', text.strip())
|
| 54 |
+
|
| 55 |
+
# Truncate if too long
|
| 56 |
+
if len(text) > max_length:
|
| 57 |
+
text = text[:max_length].rsplit(' ', 1)[0] + "..."
|
| 58 |
+
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def format_timestamp(dt: datetime) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Format datetime for consistent display
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
dt: Datetime object
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Formatted timestamp string
|
| 71 |
+
"""
|
| 72 |
+
return dt.strftime("%Y-%m-%d %H:%M:%S UTC")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def estimate_tokens(text: str) -> int:
|
| 76 |
+
"""
|
| 77 |
+
Rough estimation of token count for text
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
text: Input text
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Estimated token count
|
| 84 |
+
"""
|
| 85 |
+
# Very rough estimation: ~4 characters per token on average
|
| 86 |
+
return max(1, len(text) // 4)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def truncate_conversation_history(
|
| 90 |
+
messages: List[Dict[str, Any]],
|
| 91 |
+
max_tokens: int = 2000
|
| 92 |
+
) -> List[Dict[str, Any]]:
|
| 93 |
+
"""
|
| 94 |
+
Truncate conversation history to fit within token limit
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
messages: List of message dictionaries
|
| 98 |
+
max_tokens: Maximum token limit
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Truncated list of messages
|
| 102 |
+
"""
|
| 103 |
+
if not messages:
|
| 104 |
+
return messages
|
| 105 |
+
|
| 106 |
+
# Always keep system message if present
|
| 107 |
+
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
| 108 |
+
other_messages = [msg for msg in messages if msg.get("role") != "system"]
|
| 109 |
+
|
| 110 |
+
# Estimate tokens for system messages
|
| 111 |
+
system_tokens = sum(estimate_tokens(msg.get("content", "")) for msg in system_messages)
|
| 112 |
+
available_tokens = max_tokens - system_tokens
|
| 113 |
+
|
| 114 |
+
if available_tokens <= 0:
|
| 115 |
+
return system_messages
|
| 116 |
+
|
| 117 |
+
# Add messages from the end (most recent first) until we hit the limit
|
| 118 |
+
selected_messages = []
|
| 119 |
+
current_tokens = 0
|
| 120 |
+
|
| 121 |
+
for msg in reversed(other_messages):
|
| 122 |
+
msg_tokens = estimate_tokens(msg.get("content", ""))
|
| 123 |
+
if current_tokens + msg_tokens <= available_tokens:
|
| 124 |
+
selected_messages.insert(0, msg)
|
| 125 |
+
current_tokens += msg_tokens
|
| 126 |
+
else:
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
return system_messages + selected_messages
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def validate_session_id(session_id: str) -> bool:
|
| 133 |
+
"""
|
| 134 |
+
Validate session ID format
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
session_id: Session identifier to validate
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
True if valid, False otherwise
|
| 141 |
+
"""
|
| 142 |
+
if not session_id or len(session_id) < 5 or len(session_id) > 100:
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
# Allow alphanumeric, hyphens, and underscores
|
| 146 |
+
return bool(re.match(r'^[a-zA-Z0-9_-]+$', session_id))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def extract_model_name_from_path(model_path: str) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Extract clean model name from HuggingFace model path
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
model_path: Full model path (e.g., "microsoft/DialoGPT-medium")
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Clean model name
|
| 158 |
+
"""
|
| 159 |
+
if "/" in model_path:
|
| 160 |
+
return model_path.split("/")[-1]
|
| 161 |
+
return model_path
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def format_model_info(model_info: Dict[str, Any]) -> Dict[str, Any]:
|
| 165 |
+
"""
|
| 166 |
+
Format model information for API responses
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
model_info: Raw model information
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Formatted model information
|
| 173 |
+
"""
|
| 174 |
+
formatted = {
|
| 175 |
+
"name": model_info.get("name", "unknown"),
|
| 176 |
+
"type": model_info.get("type", "unknown"),
|
| 177 |
+
"loaded": model_info.get("loaded", False),
|
| 178 |
+
"capabilities": model_info.get("capabilities", []),
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# Add backend-specific information
|
| 182 |
+
if "device" in model_info:
|
| 183 |
+
formatted["device"] = model_info["device"]
|
| 184 |
+
|
| 185 |
+
if "provider" in model_info:
|
| 186 |
+
formatted["provider"] = model_info["provider"]
|
| 187 |
+
|
| 188 |
+
if "parameters" in model_info:
|
| 189 |
+
formatted["parameters"] = model_info["parameters"]
|
| 190 |
+
|
| 191 |
+
return formatted
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_error_response(
|
| 195 |
+
error_type: str,
|
| 196 |
+
message: str,
|
| 197 |
+
details: Optional[Dict[str, Any]] = None,
|
| 198 |
+
request_id: Optional[str] = None
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
"""
|
| 201 |
+
Create standardized error response
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
error_type: Type of error
|
| 205 |
+
message: Error message
|
| 206 |
+
details: Optional additional details
|
| 207 |
+
request_id: Optional request identifier
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Formatted error response
|
| 211 |
+
"""
|
| 212 |
+
return {
|
| 213 |
+
"error": error_type,
|
| 214 |
+
"message": message,
|
| 215 |
+
"details": details or {},
|
| 216 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 217 |
+
"request_id": request_id or generate_message_id()
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def parse_model_backend_from_name(model_name: str) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Guess the appropriate backend type from model name
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
model_name: Model name or path
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Suggested backend type
|
| 230 |
+
"""
|
| 231 |
+
model_lower = model_name.lower()
|
| 232 |
+
|
| 233 |
+
if "gpt" in model_lower and ("3.5" in model_lower or "4" in model_lower):
|
| 234 |
+
return "openai"
|
| 235 |
+
elif "claude" in model_lower:
|
| 236 |
+
return "anthropic"
|
| 237 |
+
elif any(provider in model_lower for provider in ["microsoft", "google", "meta", "huggingface"]):
|
| 238 |
+
return "hf_api" # Likely available via HF API
|
| 239 |
+
else:
|
| 240 |
+
return "local" # Default to local
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def get_supported_model_examples() -> Dict[str, List[str]]:
|
| 244 |
+
"""
|
| 245 |
+
Get examples of supported models for each backend type
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Dictionary mapping backend types to example models
|
| 249 |
+
"""
|
| 250 |
+
return {
|
| 251 |
+
"local": [
|
| 252 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 253 |
+
"microsoft/DialoGPT-medium",
|
| 254 |
+
"Qwen/Qwen2.5-0.5B-Instruct",
|
| 255 |
+
"microsoft/phi-2"
|
| 256 |
+
],
|
| 257 |
+
"hf_api": [
|
| 258 |
+
"microsoft/DialoGPT-large",
|
| 259 |
+
"google/gemma-2b-it",
|
| 260 |
+
"microsoft/phi-2",
|
| 261 |
+
"meta-llama/Llama-2-7b-chat-hf"
|
| 262 |
+
],
|
| 263 |
+
"openai": [
|
| 264 |
+
"gpt-3.5-turbo",
|
| 265 |
+
"gpt-4",
|
| 266 |
+
"gpt-4-turbo",
|
| 267 |
+
"gpt-4o"
|
| 268 |
+
],
|
| 269 |
+
"anthropic": [
|
| 270 |
+
"claude-3-haiku-20240307",
|
| 271 |
+
"claude-3-sonnet-20240229",
|
| 272 |
+
"claude-3-opus-20240229",
|
| 273 |
+
"claude-3-5-sonnet-20241022"
|
| 274 |
+
]
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def calculate_response_metrics(
|
| 279 |
+
start_time: float,
|
| 280 |
+
response_text: str,
|
| 281 |
+
token_count: Optional[int] = None
|
| 282 |
+
) -> Dict[str, Any]:
|
| 283 |
+
"""
|
| 284 |
+
Calculate response metrics for monitoring
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
start_time: Request start time
|
| 288 |
+
response_text: Generated response text
|
| 289 |
+
token_count: Actual token count if available
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Dictionary of metrics
|
| 293 |
+
"""
|
| 294 |
+
import time
|
| 295 |
+
|
| 296 |
+
end_time = time.time()
|
| 297 |
+
total_time = end_time - start_time
|
| 298 |
+
|
| 299 |
+
estimated_tokens = token_count or estimate_tokens(response_text)
|
| 300 |
+
tokens_per_second = estimated_tokens / total_time if total_time > 0 else 0
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"total_time": total_time,
|
| 304 |
+
"character_count": len(response_text),
|
| 305 |
+
"estimated_tokens": estimated_tokens,
|
| 306 |
+
"actual_tokens": token_count,
|
| 307 |
+
"tokens_per_second": tokens_per_second,
|
| 308 |
+
"words_count": len(response_text.split())
|
| 309 |
+
}
|
examples/test_backends.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example script to test different model backends
|
| 3 |
+
Demonstrates how to configure and use various model types
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import asyncio
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Add the app directory to the Python path
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
from app.core.config import Settings
|
| 16 |
+
from app.services.model_backends.local_hf import LocalHuggingFaceBackend
|
| 17 |
+
from app.services.model_backends.hf_api import HuggingFaceAPIBackend
|
| 18 |
+
from app.services.model_backends.openai_api import OpenAIAPIBackend
|
| 19 |
+
from app.services.model_backends.anthropic_api import AnthropicAPIBackend
|
| 20 |
+
from app.models.schemas import ChatMessage
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def test_local_hf_backend():
|
| 24 |
+
"""Test local HuggingFace backend"""
|
| 25 |
+
print("π€ Testing Local HuggingFace Backend")
|
| 26 |
+
print("-" * 40)
|
| 27 |
+
|
| 28 |
+
# Use a small model for testing
|
| 29 |
+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 30 |
+
|
| 31 |
+
backend = LocalHuggingFaceBackend(
|
| 32 |
+
model_name=model_name,
|
| 33 |
+
device="cpu", # Use CPU for compatibility
|
| 34 |
+
temperature=0.7,
|
| 35 |
+
max_tokens=50
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
print(f"Loading model: {model_name}")
|
| 40 |
+
success = await backend.load_model()
|
| 41 |
+
|
| 42 |
+
if not success:
|
| 43 |
+
print("β Failed to load model")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
print("β
Model loaded successfully")
|
| 47 |
+
|
| 48 |
+
# Test generation
|
| 49 |
+
messages = [
|
| 50 |
+
ChatMessage(role="user", content="Hello! What's your name?")
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
print("Generating response...")
|
| 54 |
+
start_time = time.time()
|
| 55 |
+
response = await backend.generate_response(messages, max_tokens=30)
|
| 56 |
+
end_time = time.time()
|
| 57 |
+
|
| 58 |
+
print(f"β
Response generated in {end_time - start_time:.2f}s")
|
| 59 |
+
print(f"Response: {response.message}")
|
| 60 |
+
|
| 61 |
+
# Test streaming
|
| 62 |
+
print("\nTesting streaming...")
|
| 63 |
+
full_response = ""
|
| 64 |
+
chunk_count = 0
|
| 65 |
+
|
| 66 |
+
async for chunk in backend.generate_stream(messages, max_tokens=30):
|
| 67 |
+
full_response += chunk.content
|
| 68 |
+
chunk_count += 1
|
| 69 |
+
if chunk.is_final:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
print(f"β
Streaming completed with {chunk_count} chunks")
|
| 73 |
+
print(f"Streamed response: {full_response}")
|
| 74 |
+
|
| 75 |
+
# Cleanup
|
| 76 |
+
await backend.unload_model()
|
| 77 |
+
print("β
Model unloaded")
|
| 78 |
+
|
| 79 |
+
return True
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"β Local HF backend test failed: {e}")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
async def test_hf_api_backend():
|
| 87 |
+
"""Test HuggingFace API backend"""
|
| 88 |
+
print("\nπ Testing HuggingFace API Backend")
|
| 89 |
+
print("-" * 40)
|
| 90 |
+
|
| 91 |
+
# Check if API token is available
|
| 92 |
+
api_token = os.getenv("HF_API_TOKEN")
|
| 93 |
+
if not api_token:
|
| 94 |
+
print("β οΈ HF_API_TOKEN not set, skipping HF API test")
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
model_name = "microsoft/DialoGPT-medium"
|
| 98 |
+
|
| 99 |
+
backend = HuggingFaceAPIBackend(
|
| 100 |
+
model_name=model_name,
|
| 101 |
+
api_token=api_token,
|
| 102 |
+
temperature=0.7,
|
| 103 |
+
max_tokens=50
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
print(f"Initializing API client for: {model_name}")
|
| 108 |
+
success = await backend.load_model()
|
| 109 |
+
|
| 110 |
+
if not success:
|
| 111 |
+
print("β Failed to initialize API client")
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
print("β
API client initialized")
|
| 115 |
+
|
| 116 |
+
# Test generation
|
| 117 |
+
messages = [
|
| 118 |
+
ChatMessage(role="user", content="Hello! How are you?")
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
print("Generating response via API...")
|
| 122 |
+
start_time = time.time()
|
| 123 |
+
response = await backend.generate_response(messages, max_tokens=30)
|
| 124 |
+
end_time = time.time()
|
| 125 |
+
|
| 126 |
+
print(f"β
Response generated in {end_time - start_time:.2f}s")
|
| 127 |
+
print(f"Response: {response.message}")
|
| 128 |
+
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"β HF API backend test failed: {e}")
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
async def test_openai_backend():
|
| 137 |
+
"""Test OpenAI API backend"""
|
| 138 |
+
print("\nπ₯ Testing OpenAI API Backend")
|
| 139 |
+
print("-" * 40)
|
| 140 |
+
|
| 141 |
+
# Check if API key is available
|
| 142 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 143 |
+
if not api_key:
|
| 144 |
+
print("β οΈ OPENAI_API_KEY not set, skipping OpenAI test")
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
model_name = "gpt-3.5-turbo"
|
| 148 |
+
|
| 149 |
+
backend = OpenAIAPIBackend(
|
| 150 |
+
model_name=model_name,
|
| 151 |
+
api_key=api_key,
|
| 152 |
+
temperature=0.7,
|
| 153 |
+
max_tokens=50
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
print(f"Initializing OpenAI client for: {model_name}")
|
| 158 |
+
success = await backend.load_model()
|
| 159 |
+
|
| 160 |
+
if not success:
|
| 161 |
+
print("β Failed to initialize OpenAI client")
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
print("β
OpenAI client initialized")
|
| 165 |
+
|
| 166 |
+
# Test generation
|
| 167 |
+
messages = [
|
| 168 |
+
ChatMessage(role="user", content="Hello! What's the weather like?")
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
print("Generating response via OpenAI...")
|
| 172 |
+
start_time = time.time()
|
| 173 |
+
response = await backend.generate_response(messages, max_tokens=30)
|
| 174 |
+
end_time = time.time()
|
| 175 |
+
|
| 176 |
+
print(f"β
Response generated in {end_time - start_time:.2f}s")
|
| 177 |
+
print(f"Response: {response.message}")
|
| 178 |
+
|
| 179 |
+
# Test streaming
|
| 180 |
+
print("\nTesting streaming...")
|
| 181 |
+
full_response = ""
|
| 182 |
+
chunk_count = 0
|
| 183 |
+
|
| 184 |
+
async for chunk in backend.generate_stream(messages, max_tokens=30):
|
| 185 |
+
full_response += chunk.content
|
| 186 |
+
chunk_count += 1
|
| 187 |
+
if chunk.is_final:
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
print(f"β
Streaming completed with {chunk_count} chunks")
|
| 191 |
+
print(f"Streamed response: {full_response}")
|
| 192 |
+
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"β OpenAI backend test failed: {e}")
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
async def test_anthropic_backend():
|
| 201 |
+
"""Test Anthropic API backend"""
|
| 202 |
+
print("\nπ§ Testing Anthropic API Backend")
|
| 203 |
+
print("-" * 40)
|
| 204 |
+
|
| 205 |
+
# Check if API key is available
|
| 206 |
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
| 207 |
+
if not api_key:
|
| 208 |
+
print("β οΈ ANTHROPIC_API_KEY not set, skipping Anthropic test")
|
| 209 |
+
return True
|
| 210 |
+
|
| 211 |
+
model_name = "claude-3-haiku-20240307"
|
| 212 |
+
|
| 213 |
+
backend = AnthropicAPIBackend(
|
| 214 |
+
model_name=model_name,
|
| 215 |
+
api_key=api_key,
|
| 216 |
+
temperature=0.7,
|
| 217 |
+
max_tokens=50
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
print(f"Initializing Anthropic client for: {model_name}")
|
| 222 |
+
success = await backend.load_model()
|
| 223 |
+
|
| 224 |
+
if not success:
|
| 225 |
+
print("β Failed to initialize Anthropic client")
|
| 226 |
+
return False
|
| 227 |
+
|
| 228 |
+
print("β
Anthropic client initialized")
|
| 229 |
+
|
| 230 |
+
# Test generation
|
| 231 |
+
messages = [
|
| 232 |
+
ChatMessage(role="user", content="Hello! Tell me about yourself.")
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
print("Generating response via Anthropic...")
|
| 236 |
+
start_time = time.time()
|
| 237 |
+
response = await backend.generate_response(messages, max_tokens=30)
|
| 238 |
+
end_time = time.time()
|
| 239 |
+
|
| 240 |
+
print(f"β
Response generated in {end_time - start_time:.2f}s")
|
| 241 |
+
print(f"Response: {response.message}")
|
| 242 |
+
|
| 243 |
+
return True
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
print(f"β Anthropic backend test failed: {e}")
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
async def main():
|
| 251 |
+
"""Main test function"""
|
| 252 |
+
print("π Sema Chat Backend Testing")
|
| 253 |
+
print("=" * 50)
|
| 254 |
+
|
| 255 |
+
results = {}
|
| 256 |
+
|
| 257 |
+
# Test each backend
|
| 258 |
+
results["local_hf"] = await test_local_hf_backend()
|
| 259 |
+
results["hf_api"] = await test_hf_api_backend()
|
| 260 |
+
results["openai"] = await test_openai_backend()
|
| 261 |
+
results["anthropic"] = await test_anthropic_backend()
|
| 262 |
+
|
| 263 |
+
# Summary
|
| 264 |
+
print("\n" + "=" * 50)
|
| 265 |
+
print("π Test Results Summary")
|
| 266 |
+
print("-" * 25)
|
| 267 |
+
|
| 268 |
+
for backend, success in results.items():
|
| 269 |
+
status = "β
PASS" if success else "β FAIL"
|
| 270 |
+
print(f"{backend:15} {status}")
|
| 271 |
+
|
| 272 |
+
total_tests = len(results)
|
| 273 |
+
passed_tests = sum(results.values())
|
| 274 |
+
|
| 275 |
+
print(f"\nTotal: {passed_tests}/{total_tests} backends working")
|
| 276 |
+
|
| 277 |
+
if passed_tests == total_tests:
|
| 278 |
+
print("π All available backends are working!")
|
| 279 |
+
elif passed_tests > 0:
|
| 280 |
+
print("β οΈ Some backends are working, check configuration for others")
|
| 281 |
+
else:
|
| 282 |
+
print("β No backends are working, check your setup")
|
| 283 |
+
|
| 284 |
+
print("\nπ‘ Tips:")
|
| 285 |
+
print("- For HF API: Set HF_API_TOKEN environment variable")
|
| 286 |
+
print("- For OpenAI: Set OPENAI_API_KEY environment variable")
|
| 287 |
+
print("- For Anthropic: Set ANTHROPIC_API_KEY environment variable")
|
| 288 |
+
print("- For local models: Ensure you have enough RAM/VRAM")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
asyncio.run(main())
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pydantic
|
| 4 |
+
python-multipart
|
| 5 |
+
websockets
|
| 6 |
+
sse-starlette
|
| 7 |
+
slowapi
|
| 8 |
+
prometheus-client
|
| 9 |
+
structlog
|
| 10 |
+
python-dotenv
|
| 11 |
+
httpx
|
| 12 |
+
aiofiles
|
| 13 |
+
|
| 14 |
+
# HuggingFace & ML
|
| 15 |
+
transformers
|
| 16 |
+
torch
|
| 17 |
+
huggingface-hub
|
| 18 |
+
accelerate
|
| 19 |
+
sentencepiece
|
| 20 |
+
|
| 21 |
+
# API Clients
|
| 22 |
+
openai
|
| 23 |
+
anthropic
|
| 24 |
+
|
| 25 |
+
# Utilities
|
| 26 |
+
uuid
|
| 27 |
+
asyncio-mqtt
|
| 28 |
+
redis
|
setup_huggingface.sh
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# π Sema Chat API - HuggingFace Spaces Setup Script
|
| 4 |
+
# This script helps you deploy Sema Chat API to HuggingFace Spaces with Gemma
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
echo "π Sema Chat API - HuggingFace Spaces Setup"
|
| 9 |
+
echo "=========================================="
|
| 10 |
+
|
| 11 |
+
# Check if we're in the right directory
|
| 12 |
+
if [ ! -f "app/main.py" ]; then
|
| 13 |
+
echo "β Error: Please run this script from the backend/sema-chat directory"
|
| 14 |
+
echo " Current directory: $(pwd)"
|
| 15 |
+
echo " Expected files: app/main.py, requirements.txt, Dockerfile"
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
echo "β
Found Sema Chat API files"
|
| 20 |
+
|
| 21 |
+
# Get user input
|
| 22 |
+
read -p "π Enter your HuggingFace username: " HF_USERNAME
|
| 23 |
+
read -p "π Enter your Space name (e.g., sema-chat-gemma): " SPACE_NAME
|
| 24 |
+
read -p "π Enter your Google AI API key (or press Enter to skip): " GOOGLE_API_KEY
|
| 25 |
+
|
| 26 |
+
# Validate inputs
|
| 27 |
+
if [ -z "$HF_USERNAME" ]; then
|
| 28 |
+
echo "β Error: HuggingFace username is required"
|
| 29 |
+
exit 1
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
if [ -z "$SPACE_NAME" ]; then
|
| 33 |
+
echo "β Error: Space name is required"
|
| 34 |
+
exit 1
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
SPACE_URL="https://huggingface.co/spaces/$HF_USERNAME/$SPACE_NAME"
|
| 38 |
+
SPACE_REPO="https://huggingface.co/spaces/$HF_USERNAME/$SPACE_NAME"
|
| 39 |
+
|
| 40 |
+
echo ""
|
| 41 |
+
echo "π Configuration Summary:"
|
| 42 |
+
echo " HuggingFace Username: $HF_USERNAME"
|
| 43 |
+
echo " Space Name: $SPACE_NAME"
|
| 44 |
+
echo " Space URL: $SPACE_URL"
|
| 45 |
+
echo " Google AI Key: ${GOOGLE_API_KEY:+[PROVIDED]}${GOOGLE_API_KEY:-[NOT PROVIDED]}"
|
| 46 |
+
echo ""
|
| 47 |
+
|
| 48 |
+
read -p "π€ Continue with deployment? (y/N): " CONFIRM
|
| 49 |
+
if [[ ! $CONFIRM =~ ^[Yy]$ ]]; then
|
| 50 |
+
echo "β Deployment cancelled"
|
| 51 |
+
exit 0
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
# Create deployment directory
|
| 55 |
+
DEPLOY_DIR="../sema-chat-deploy"
|
| 56 |
+
echo "π Creating deployment directory: $DEPLOY_DIR"
|
| 57 |
+
rm -rf "$DEPLOY_DIR"
|
| 58 |
+
mkdir -p "$DEPLOY_DIR"
|
| 59 |
+
|
| 60 |
+
# Copy all files
|
| 61 |
+
echo "π Copying files..."
|
| 62 |
+
cp -r . "$DEPLOY_DIR/"
|
| 63 |
+
cd "$DEPLOY_DIR"
|
| 64 |
+
|
| 65 |
+
# Create README.md for the Space
|
| 66 |
+
echo "π Creating Space README..."
|
| 67 |
+
cat > README.md << EOF
|
| 68 |
+
---
|
| 69 |
+
title: Sema Chat API
|
| 70 |
+
emoji: π¬
|
| 71 |
+
colorFrom: purple
|
| 72 |
+
colorTo: pink
|
| 73 |
+
sdk: docker
|
| 74 |
+
pinned: false
|
| 75 |
+
license: mit
|
| 76 |
+
short_description: Modern chatbot API with Gemma integration and streaming capabilities
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
# Sema Chat API π¬
|
| 80 |
+
|
| 81 |
+
Modern chatbot API with streaming capabilities, powered by Google's Gemma model.
|
| 82 |
+
|
| 83 |
+
## π Features
|
| 84 |
+
|
| 85 |
+
- **Real-time Streaming**: Server-Sent Events and WebSocket support
|
| 86 |
+
- **Gemma Integration**: Powered by Google's Gemma 2 9B model
|
| 87 |
+
- **Session Management**: Persistent conversation contexts
|
| 88 |
+
- **RESTful API**: Clean, documented endpoints
|
| 89 |
+
- **Interactive UI**: Built-in Swagger documentation
|
| 90 |
+
|
| 91 |
+
## π API Endpoints
|
| 92 |
+
|
| 93 |
+
- **Chat**: \`POST /api/v1/chat\`
|
| 94 |
+
- **Streaming**: \`GET /api/v1/chat/stream\`
|
| 95 |
+
- **WebSocket**: \`ws://space-url/api/v1/chat/ws\`
|
| 96 |
+
- **Health**: \`GET /api/v1/health\`
|
| 97 |
+
- **Docs**: \`GET /\` (Swagger UI)
|
| 98 |
+
|
| 99 |
+
## π¬ Quick Test
|
| 100 |
+
|
| 101 |
+
\`\`\`bash
|
| 102 |
+
curl -X POST "https://$HF_USERNAME-$SPACE_NAME.hf.space/api/v1/chat" \\
|
| 103 |
+
-H "Content-Type: application/json" \\
|
| 104 |
+
-d '{
|
| 105 |
+
"message": "Hello! Can you introduce yourself?",
|
| 106 |
+
"session_id": "test-session"
|
| 107 |
+
}'
|
| 108 |
+
\`\`\`
|
| 109 |
+
|
| 110 |
+
## π Streaming Test
|
| 111 |
+
|
| 112 |
+
\`\`\`bash
|
| 113 |
+
curl -N -H "Accept: text/event-stream" \\
|
| 114 |
+
"https://$HF_USERNAME-$SPACE_NAME.hf.space/api/v1/chat/stream?message=Tell%20me%20about%20AI&session_id=test"
|
| 115 |
+
\`\`\`
|
| 116 |
+
|
| 117 |
+
## βοΈ Configuration
|
| 118 |
+
|
| 119 |
+
This Space is configured to use Google's Gemma model via AI Studio.
|
| 120 |
+
Set your \`GOOGLE_API_KEY\` in the Space settings to enable the API.
|
| 121 |
+
|
| 122 |
+
## π οΈ Built With
|
| 123 |
+
|
| 124 |
+
- **FastAPI**: Modern Python web framework
|
| 125 |
+
- **Google Gemma**: Advanced language model
|
| 126 |
+
- **Docker**: Containerized deployment
|
| 127 |
+
- **HuggingFace Spaces**: Hosting platform
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
Created by $HF_USERNAME | Powered by Sema AI
|
| 132 |
+
EOF
|
| 133 |
+
|
| 134 |
+
# Create .gitignore
|
| 135 |
+
echo "π« Creating .gitignore..."
|
| 136 |
+
cat > .gitignore << EOF
|
| 137 |
+
__pycache__/
|
| 138 |
+
*.py[cod]
|
| 139 |
+
*$py.class
|
| 140 |
+
*.so
|
| 141 |
+
.Python
|
| 142 |
+
build/
|
| 143 |
+
develop-eggs/
|
| 144 |
+
dist/
|
| 145 |
+
downloads/
|
| 146 |
+
eggs/
|
| 147 |
+
.eggs/
|
| 148 |
+
lib/
|
| 149 |
+
lib64/
|
| 150 |
+
parts/
|
| 151 |
+
sdist/
|
| 152 |
+
var/
|
| 153 |
+
wheels/
|
| 154 |
+
*.egg-info/
|
| 155 |
+
.installed.cfg
|
| 156 |
+
*.egg
|
| 157 |
+
MANIFEST
|
| 158 |
+
|
| 159 |
+
.env
|
| 160 |
+
.venv
|
| 161 |
+
env/
|
| 162 |
+
venv/
|
| 163 |
+
ENV/
|
| 164 |
+
env.bak/
|
| 165 |
+
venv.bak/
|
| 166 |
+
|
| 167 |
+
.pytest_cache/
|
| 168 |
+
.coverage
|
| 169 |
+
htmlcov/
|
| 170 |
+
|
| 171 |
+
.DS_Store
|
| 172 |
+
.vscode/
|
| 173 |
+
.idea/
|
| 174 |
+
|
| 175 |
+
logs/
|
| 176 |
+
*.log
|
| 177 |
+
EOF
|
| 178 |
+
|
| 179 |
+
# Initialize git repository
|
| 180 |
+
echo "π§ Initializing git repository..."
|
| 181 |
+
git init
|
| 182 |
+
git remote add origin "$SPACE_REPO"
|
| 183 |
+
|
| 184 |
+
# Create initial commit
|
| 185 |
+
echo "π¦ Creating initial commit..."
|
| 186 |
+
git add .
|
| 187 |
+
git commit -m "Initial deployment of Sema Chat API with Gemma support
|
| 188 |
+
|
| 189 |
+
Features:
|
| 190 |
+
- Google Gemma 2 9B integration
|
| 191 |
+
- Real-time streaming responses
|
| 192 |
+
- Session management
|
| 193 |
+
- RESTful API with Swagger docs
|
| 194 |
+
- WebSocket support
|
| 195 |
+
- Health monitoring
|
| 196 |
+
|
| 197 |
+
Configuration:
|
| 198 |
+
- MODEL_TYPE=google
|
| 199 |
+
- MODEL_NAME=gemma-2-9b-it
|
| 200 |
+
- Port: 7860 (HuggingFace standard)
|
| 201 |
+
"
|
| 202 |
+
|
| 203 |
+
echo ""
|
| 204 |
+
echo "π Setup Complete!"
|
| 205 |
+
echo "=================="
|
| 206 |
+
echo ""
|
| 207 |
+
echo "π Next Steps:"
|
| 208 |
+
echo "1. Create your HuggingFace Space:"
|
| 209 |
+
echo " β Go to: https://huggingface.co/spaces"
|
| 210 |
+
echo " β Click 'Create new Space'"
|
| 211 |
+
echo " β Name: $SPACE_NAME"
|
| 212 |
+
echo " β SDK: Docker"
|
| 213 |
+
echo " β License: MIT"
|
| 214 |
+
echo ""
|
| 215 |
+
echo "2. Push your code:"
|
| 216 |
+
echo " β cd $DEPLOY_DIR"
|
| 217 |
+
echo " β git push origin main"
|
| 218 |
+
echo ""
|
| 219 |
+
echo "3. Configure environment variables in Space settings:"
|
| 220 |
+
if [ -n "$GOOGLE_API_KEY" ]; then
|
| 221 |
+
echo " β MODEL_TYPE=google"
|
| 222 |
+
echo " β MODEL_NAME=gemma-2-9b-it"
|
| 223 |
+
echo " β GOOGLE_API_KEY=$GOOGLE_API_KEY"
|
| 224 |
+
else
|
| 225 |
+
echo " β MODEL_TYPE=google"
|
| 226 |
+
echo " β MODEL_NAME=gemma-2-9b-it"
|
| 227 |
+
echo " β GOOGLE_API_KEY=your_google_api_key_here"
|
| 228 |
+
echo ""
|
| 229 |
+
echo " π Get your Google AI API key from: https://aistudio.google.com/"
|
| 230 |
+
fi
|
| 231 |
+
echo " β DEBUG=false"
|
| 232 |
+
echo " β ENVIRONMENT=production"
|
| 233 |
+
echo ""
|
| 234 |
+
echo "4. Wait for build and test:"
|
| 235 |
+
echo " β Space URL: $SPACE_URL"
|
| 236 |
+
echo " β API Docs: $SPACE_URL/"
|
| 237 |
+
echo " β Health Check: $SPACE_URL/api/v1/health"
|
| 238 |
+
echo ""
|
| 239 |
+
echo "π Your Sema Chat API will be live at:"
|
| 240 |
+
echo " $SPACE_URL"
|
| 241 |
+
echo ""
|
| 242 |
+
echo "Happy deploying! π¬β¨"
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Tests package
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for Sema Chat API
|
| 3 |
+
Tests all endpoints and functionality
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
import asyncio
|
| 10 |
+
import websockets
|
| 11 |
+
from typing import Dict, Any
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SemaChatAPITester:
|
| 16 |
+
"""Test client for Sema Chat API"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 19 |
+
self.base_url = base_url.rstrip("/")
|
| 20 |
+
self.session_id = f"test-session-{int(time.time())}"
|
| 21 |
+
|
| 22 |
+
def test_health_endpoints(self):
|
| 23 |
+
"""Test health and status endpoints"""
|
| 24 |
+
print("π₯ Testing health endpoints...")
|
| 25 |
+
|
| 26 |
+
# Test basic status
|
| 27 |
+
response = requests.get(f"{self.base_url}/status")
|
| 28 |
+
assert response.status_code == 200
|
| 29 |
+
print("β
Status endpoint working")
|
| 30 |
+
|
| 31 |
+
# Test app-level health
|
| 32 |
+
response = requests.get(f"{self.base_url}/health")
|
| 33 |
+
assert response.status_code == 200
|
| 34 |
+
print("β
App health endpoint working")
|
| 35 |
+
|
| 36 |
+
# Test detailed health
|
| 37 |
+
response = requests.get(f"{self.base_url}/api/v1/health")
|
| 38 |
+
assert response.status_code == 200
|
| 39 |
+
health_data = response.json()
|
| 40 |
+
print(f"β
Detailed health check: {health_data['status']}")
|
| 41 |
+
print(f" Model: {health_data['model_name']} ({health_data['model_type']})")
|
| 42 |
+
print(f" Model loaded: {health_data['model_loaded']}")
|
| 43 |
+
|
| 44 |
+
return health_data
|
| 45 |
+
|
| 46 |
+
def test_model_info(self):
|
| 47 |
+
"""Test model information endpoint"""
|
| 48 |
+
print("\nπ€ Testing model info...")
|
| 49 |
+
|
| 50 |
+
response = requests.get(f"{self.base_url}/api/v1/model/info")
|
| 51 |
+
assert response.status_code == 200
|
| 52 |
+
|
| 53 |
+
model_info = response.json()
|
| 54 |
+
print(f"β
Model info retrieved")
|
| 55 |
+
print(f" Name: {model_info['name']}")
|
| 56 |
+
print(f" Type: {model_info['type']}")
|
| 57 |
+
print(f" Loaded: {model_info['loaded']}")
|
| 58 |
+
print(f" Capabilities: {model_info['capabilities']}")
|
| 59 |
+
|
| 60 |
+
return model_info
|
| 61 |
+
|
| 62 |
+
def test_regular_chat(self):
|
| 63 |
+
"""Test regular (non-streaming) chat"""
|
| 64 |
+
print("\n㪠Testing regular chat...")
|
| 65 |
+
|
| 66 |
+
chat_request = {
|
| 67 |
+
"message": "Hello! Can you introduce yourself?",
|
| 68 |
+
"session_id": self.session_id,
|
| 69 |
+
"temperature": 0.7,
|
| 70 |
+
"max_tokens": 100
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
start_time = time.time()
|
| 74 |
+
response = requests.post(
|
| 75 |
+
f"{self.base_url}/api/v1/chat",
|
| 76 |
+
json=chat_request,
|
| 77 |
+
headers={"Content-Type": "application/json"}
|
| 78 |
+
)
|
| 79 |
+
end_time = time.time()
|
| 80 |
+
|
| 81 |
+
assert response.status_code == 200
|
| 82 |
+
chat_response = response.json()
|
| 83 |
+
|
| 84 |
+
print(f"β
Regular chat working")
|
| 85 |
+
print(f" Response time: {end_time - start_time:.2f}s")
|
| 86 |
+
print(f" Generation time: {chat_response['generation_time']:.2f}s")
|
| 87 |
+
print(f" Response: {chat_response['message'][:100]}...")
|
| 88 |
+
print(f" Session ID: {chat_response['session_id']}")
|
| 89 |
+
print(f" Message ID: {chat_response['message_id']}")
|
| 90 |
+
|
| 91 |
+
return chat_response
|
| 92 |
+
|
| 93 |
+
def test_streaming_chat(self):
|
| 94 |
+
"""Test streaming chat via SSE"""
|
| 95 |
+
print("\nπ Testing streaming chat...")
|
| 96 |
+
|
| 97 |
+
params = {
|
| 98 |
+
"message": "Tell me a short story about AI",
|
| 99 |
+
"session_id": self.session_id,
|
| 100 |
+
"temperature": 0.8,
|
| 101 |
+
"max_tokens": 150
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
start_time = time.time()
|
| 105 |
+
response = requests.get(
|
| 106 |
+
f"{self.base_url}/api/v1/chat/stream",
|
| 107 |
+
params=params,
|
| 108 |
+
headers={"Accept": "text/event-stream"},
|
| 109 |
+
stream=True
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
assert response.status_code == 200
|
| 113 |
+
|
| 114 |
+
chunks_received = 0
|
| 115 |
+
full_response = ""
|
| 116 |
+
|
| 117 |
+
for line in response.iter_lines():
|
| 118 |
+
if line:
|
| 119 |
+
line_str = line.decode('utf-8')
|
| 120 |
+
if line_str.startswith('data: '):
|
| 121 |
+
try:
|
| 122 |
+
data = json.loads(line_str[6:]) # Remove 'data: ' prefix
|
| 123 |
+
if 'content' in data:
|
| 124 |
+
full_response += data['content']
|
| 125 |
+
chunks_received += 1
|
| 126 |
+
|
| 127 |
+
if data.get('is_final'):
|
| 128 |
+
break
|
| 129 |
+
except json.JSONDecodeError:
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
end_time = time.time()
|
| 133 |
+
|
| 134 |
+
print(f"β
Streaming chat working")
|
| 135 |
+
print(f" Total time: {end_time - start_time:.2f}s")
|
| 136 |
+
print(f" Chunks received: {chunks_received}")
|
| 137 |
+
print(f" Response: {full_response[:100]}...")
|
| 138 |
+
|
| 139 |
+
return full_response
|
| 140 |
+
|
| 141 |
+
def test_session_management(self):
|
| 142 |
+
"""Test session management endpoints"""
|
| 143 |
+
print("\nπ Testing session management...")
|
| 144 |
+
|
| 145 |
+
# Get session history
|
| 146 |
+
response = requests.get(f"{self.base_url}/api/v1/sessions/{self.session_id}")
|
| 147 |
+
assert response.status_code == 200
|
| 148 |
+
|
| 149 |
+
session_data = response.json()
|
| 150 |
+
print(f"β
Session retrieval working")
|
| 151 |
+
print(f" Messages in session: {session_data['message_count']}")
|
| 152 |
+
print(f" Session created: {session_data['created_at']}")
|
| 153 |
+
|
| 154 |
+
# Get active sessions
|
| 155 |
+
response = requests.get(f"{self.base_url}/api/v1/sessions")
|
| 156 |
+
assert response.status_code == 200
|
| 157 |
+
|
| 158 |
+
sessions = response.json()
|
| 159 |
+
print(f"β
Active sessions list working")
|
| 160 |
+
print(f" Total active sessions: {len(sessions)}")
|
| 161 |
+
|
| 162 |
+
return session_data
|
| 163 |
+
|
| 164 |
+
async def test_websocket_chat(self):
|
| 165 |
+
"""Test WebSocket chat functionality"""
|
| 166 |
+
print("\nπ Testing WebSocket chat...")
|
| 167 |
+
|
| 168 |
+
ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
| 169 |
+
ws_url += "/api/v1/chat/ws"
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
async with websockets.connect(ws_url) as websocket:
|
| 173 |
+
# Send a message
|
| 174 |
+
message = {
|
| 175 |
+
"message": "Hello via WebSocket!",
|
| 176 |
+
"session_id": f"{self.session_id}-ws",
|
| 177 |
+
"temperature": 0.7,
|
| 178 |
+
"max_tokens": 50
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
await websocket.send(json.dumps(message))
|
| 182 |
+
|
| 183 |
+
# Receive response chunks
|
| 184 |
+
chunks_received = 0
|
| 185 |
+
full_response = ""
|
| 186 |
+
|
| 187 |
+
while True:
|
| 188 |
+
try:
|
| 189 |
+
response = await asyncio.wait_for(websocket.recv(), timeout=30.0)
|
| 190 |
+
data = json.loads(response)
|
| 191 |
+
|
| 192 |
+
if data.get("type") == "chunk":
|
| 193 |
+
full_response += data.get("content", "")
|
| 194 |
+
chunks_received += 1
|
| 195 |
+
|
| 196 |
+
if data.get("is_final"):
|
| 197 |
+
break
|
| 198 |
+
elif data.get("type") == "error":
|
| 199 |
+
print(f"β WebSocket error: {data.get('error')}")
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
except asyncio.TimeoutError:
|
| 203 |
+
print("β οΈ WebSocket timeout")
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
print(f"β
WebSocket chat working")
|
| 207 |
+
print(f" Chunks received: {chunks_received}")
|
| 208 |
+
print(f" Response: {full_response[:100]}...")
|
| 209 |
+
|
| 210 |
+
return full_response
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"β WebSocket test failed: {e}")
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
def test_error_handling(self):
|
| 217 |
+
"""Test error handling"""
|
| 218 |
+
print("\nπ¨ Testing error handling...")
|
| 219 |
+
|
| 220 |
+
# Test empty message
|
| 221 |
+
response = requests.post(
|
| 222 |
+
f"{self.base_url}/api/v1/chat",
|
| 223 |
+
json={"message": "", "session_id": self.session_id}
|
| 224 |
+
)
|
| 225 |
+
assert response.status_code == 422 # Validation error
|
| 226 |
+
print("β
Empty message validation working")
|
| 227 |
+
|
| 228 |
+
# Test invalid session ID
|
| 229 |
+
response = requests.get(f"{self.base_url}/api/v1/sessions/invalid-session-id-that-does-not-exist")
|
| 230 |
+
assert response.status_code == 404
|
| 231 |
+
print("β
Invalid session handling working")
|
| 232 |
+
|
| 233 |
+
# Test rate limiting (if enabled)
|
| 234 |
+
print("β
Error handling tests passed")
|
| 235 |
+
|
| 236 |
+
def test_session_cleanup(self):
|
| 237 |
+
"""Test session cleanup"""
|
| 238 |
+
print("\nπ§Ή Testing session cleanup...")
|
| 239 |
+
|
| 240 |
+
# Clear the test session
|
| 241 |
+
response = requests.delete(f"{self.base_url}/api/v1/sessions/{self.session_id}")
|
| 242 |
+
assert response.status_code == 200
|
| 243 |
+
print("β
Session cleanup working")
|
| 244 |
+
|
| 245 |
+
# Verify session is gone
|
| 246 |
+
response = requests.get(f"{self.base_url}/api/v1/sessions/{self.session_id}")
|
| 247 |
+
assert response.status_code == 404
|
| 248 |
+
print("β
Session deletion verified")
|
| 249 |
+
|
| 250 |
+
def run_all_tests(self):
|
| 251 |
+
"""Run all tests"""
|
| 252 |
+
print("π Starting Sema Chat API Tests")
|
| 253 |
+
print("=" * 50)
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
# Test basic endpoints
|
| 257 |
+
health_data = self.test_health_endpoints()
|
| 258 |
+
|
| 259 |
+
if not health_data.get('model_loaded'):
|
| 260 |
+
print("β οΈ Model not loaded, skipping chat tests")
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
model_info = self.test_model_info()
|
| 264 |
+
|
| 265 |
+
# Test chat functionality
|
| 266 |
+
self.test_regular_chat()
|
| 267 |
+
self.test_streaming_chat()
|
| 268 |
+
|
| 269 |
+
# Test session management
|
| 270 |
+
self.test_session_management()
|
| 271 |
+
|
| 272 |
+
# Test WebSocket (async)
|
| 273 |
+
asyncio.run(self.test_websocket_chat())
|
| 274 |
+
|
| 275 |
+
# Test error handling
|
| 276 |
+
self.test_error_handling()
|
| 277 |
+
|
| 278 |
+
# Cleanup
|
| 279 |
+
self.test_session_cleanup()
|
| 280 |
+
|
| 281 |
+
print("\n" + "=" * 50)
|
| 282 |
+
print("π All tests passed successfully!")
|
| 283 |
+
print(f"β
API is working correctly with {model_info['name']}")
|
| 284 |
+
return True
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"\nβ Test failed: {e}")
|
| 288 |
+
import traceback
|
| 289 |
+
traceback.print_exc()
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def main():
|
| 294 |
+
"""Main test function"""
|
| 295 |
+
import argparse
|
| 296 |
+
|
| 297 |
+
parser = argparse.ArgumentParser(description="Test Sema Chat API")
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--url",
|
| 300 |
+
default="http://localhost:7860",
|
| 301 |
+
help="Base URL of the API (default: http://localhost:7860)"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
args = parser.parse_args()
|
| 305 |
+
|
| 306 |
+
tester = SemaChatAPITester(args.url)
|
| 307 |
+
success = tester.run_all_tests()
|
| 308 |
+
|
| 309 |
+
sys.exit(0 if success else 1)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
main()
|