kamau1 commited on
Commit
639f3bb
Β·
1 Parent(s): 0943b9d

Initial Commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
CONFIGURATION_GUIDE.md ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸ”§ Sema Chat API Configuration Guide
2
+
3
+ ## 🎯 **MiniMax Integration**
4
+
5
+ ### Configuration
6
+ ```bash
7
+ MODEL_TYPE=minimax
8
+ MODEL_NAME=MiniMax-M1
9
+ MINIMAX_API_KEY=your_minimax_api_key
10
+ MINIMAX_API_URL=https://api.minimax.chat/v1/text/chatcompletion_v2
11
+ MINIMAX_MODEL_VERSION=abab6.5s-chat
12
+ ```
13
+
14
+ ### Features
15
+ - βœ… **Reasoning Capabilities**: Shows model's thinking process
16
+ - βœ… **Streaming Support**: Real-time response generation
17
+ - βœ… **Custom API Integration**: Direct integration with MiniMax API
18
+ - βœ… **Reasoning Content**: Displays both reasoning and final response
19
+
20
+ ### Example Usage
21
+ ```bash
22
+ curl -X POST "http://localhost:7860/api/v1/chat" \
23
+ -H "Content-Type: application/json" \
24
+ -d '{
25
+ "message": "Solve this math problem: 2x + 5 = 15",
26
+ "session_id": "minimax-test"
27
+ }'
28
+ ```
29
+
30
+ **Response includes reasoning:**
31
+ ```json
32
+ {
33
+ "message": "[Reasoning: I need to solve for x. First, subtract 5 from both sides: 2x = 10. Then divide by 2: x = 5]\n\nTo solve 2x + 5 = 15:\n1. Subtract 5 from both sides: 2x = 10\n2. Divide by 2: x = 5\n\nTherefore, x = 5.",
34
+ "session_id": "minimax-test",
35
+ "model_name": "MiniMax-M1"
36
+ }
37
+ ```
38
+
39
+ ---
40
+
41
+ ## πŸ”₯ **Gemma Integration**
42
+
43
+ ### Option 1: Local Gemma (Free Tier)
44
+ ```bash
45
+ MODEL_TYPE=local
46
+ MODEL_NAME=google/gemma-2b-it
47
+ DEVICE=auto
48
+ ```
49
+
50
+ ### Option 2: Gemma via HuggingFace API
51
+ ```bash
52
+ MODEL_TYPE=hf_api
53
+ MODEL_NAME=google/gemma-2b-it
54
+ HF_API_TOKEN=your_hf_token
55
+ ```
56
+
57
+ ### Option 3: Gemma via Google AI Studio
58
+ ```bash
59
+ MODEL_TYPE=google
60
+ MODEL_NAME=gemma-2-9b-it
61
+ GOOGLE_API_KEY=your_google_api_key
62
+ ```
63
+
64
+ ### Available Gemma Models
65
+ - **gemma-2-2b-it** (2B parameters, instruction-tuned)
66
+ - **gemma-2-9b-it** (9B parameters, instruction-tuned)
67
+ - **gemma-2-27b-it** (27B parameters, instruction-tuned)
68
+ - **gemini-1.5-flash** (Fast, efficient)
69
+ - **gemini-1.5-pro** (Most capable)
70
+
71
+ ### Example Usage
72
+ ```bash
73
+ curl -X POST "http://localhost:7860/api/v1/chat" \
74
+ -H "Content-Type: application/json" \
75
+ -d '{
76
+ "message": "Explain quantum computing in simple terms",
77
+ "session_id": "gemma-test",
78
+ "temperature": 0.7
79
+ }'
80
+ ```
81
+
82
+ ---
83
+
84
+ ## πŸš€ **Complete Backend Comparison**
85
+
86
+ | Backend | Cost | Setup | Streaming | Special Features |
87
+ |---------|------|-------|-----------|------------------|
88
+ | **Local** | Free | Medium | βœ… | Offline, Private |
89
+ | **HF API** | Free/Paid | Easy | βœ… | Many models |
90
+ | **OpenAI** | Paid | Easy | βœ… | High quality |
91
+ | **Anthropic** | Paid | Easy | βœ… | Long context |
92
+ | **MiniMax** | Paid | Easy | βœ… | Reasoning |
93
+ | **Google** | Free/Paid | Easy | βœ… | Multimodal |
94
+
95
+ ---
96
+
97
+ ## πŸ”§ **Configuration Examples**
98
+
99
+ ### Free Tier Setup (HuggingFace Spaces)
100
+ ```bash
101
+ # Best for free deployment
102
+ MODEL_TYPE=local
103
+ MODEL_NAME=TinyLlama/TinyLlama-1.1B-Chat-v1.0
104
+ DEVICE=cpu
105
+ MAX_NEW_TOKENS=256
106
+ TEMPERATURE=0.7
107
+ ```
108
+
109
+ ### Production Setup (API-based)
110
+ ```bash
111
+ # Best for production with fallbacks
112
+ MODEL_TYPE=openai
113
+ MODEL_NAME=gpt-3.5-turbo
114
+ OPENAI_API_KEY=your_key
115
+
116
+ # Fallback configuration
117
+ FALLBACK_MODEL_TYPE=hf_api
118
+ FALLBACK_MODEL_NAME=microsoft/DialoGPT-medium
119
+ HF_API_TOKEN=your_token
120
+ ```
121
+
122
+ ### Research Setup (Multiple Models)
123
+ ```bash
124
+ # Primary: Latest Gemini
125
+ MODEL_TYPE=google
126
+ MODEL_NAME=gemini-1.5-pro
127
+ GOOGLE_API_KEY=your_key
128
+
129
+ # For reasoning tasks
130
+ REASONING_MODEL_TYPE=minimax
131
+ REASONING_MODEL_NAME=MiniMax-M1
132
+ MINIMAX_API_KEY=your_key
133
+ ```
134
+
135
+ ---
136
+
137
+ ## 🎯 **Model Selection Guide**
138
+
139
+ ### For **Free Deployment** (HuggingFace Spaces):
140
+ 1. **TinyLlama/TinyLlama-1.1B-Chat-v1.0** - Smallest, fastest
141
+ 2. **microsoft/DialoGPT-medium** - Better conversations
142
+ 3. **Qwen/Qwen2.5-0.5B-Instruct** - Good instruction following
143
+
144
+ ### For **Reasoning Tasks**:
145
+ 1. **MiniMax M1** - Shows thinking process
146
+ 2. **Claude-3 Opus** - Deep reasoning
147
+ 3. **GPT-4** - Complex problem solving
148
+
149
+ ### For **Conversations**:
150
+ 1. **Claude-3 Haiku** - Natural, fast
151
+ 2. **GPT-3.5-turbo** - Balanced cost/quality
152
+ 3. **Gemini-1.5-flash** - Fast, capable
153
+
154
+ ### For **Multilingual**:
155
+ 1. **Gemma-2-9b-it** - Good multilingual
156
+ 2. **GPT-4** - Excellent multilingual
157
+ 3. **Local models** - Depends on training
158
+
159
+ ---
160
+
161
+ ## πŸ”„ **Dynamic Model Switching**
162
+
163
+ The API supports runtime model switching:
164
+
165
+ ```python
166
+ # Switch to MiniMax for reasoning
167
+ POST /api/v1/model/switch
168
+ {
169
+ "model_type": "minimax",
170
+ "model_name": "MiniMax-M1"
171
+ }
172
+
173
+ # Switch back to fast model
174
+ POST /api/v1/model/switch
175
+ {
176
+ "model_type": "google",
177
+ "model_name": "gemini-1.5-flash"
178
+ }
179
+ ```
180
+
181
+ ---
182
+
183
+ ## πŸ§ͺ **Testing Your Setup**
184
+
185
+ ### Test All Backends
186
+ ```bash
187
+ python examples/test_backends.py
188
+ ```
189
+
190
+ ### Test Specific Backend
191
+ ```bash
192
+ # Test MiniMax
193
+ MINIMAX_API_KEY=your_key python -c "
194
+ import asyncio
195
+ from app.services.model_backends.minimax_api import MiniMaxAPIBackend
196
+ from app.models.schemas import ChatMessage
197
+
198
+ async def test():
199
+ backend = MiniMaxAPIBackend('MiniMax-M1', api_key='your_key', api_url='your_url')
200
+ await backend.load_model()
201
+ messages = [ChatMessage(role='user', content='Hello')]
202
+ response = await backend.generate_response(messages)
203
+ print(response.message)
204
+
205
+ asyncio.run(test())
206
+ "
207
+ ```
208
+
209
+ ### Test Gemma
210
+ ```bash
211
+ # Test local Gemma
212
+ MODEL_TYPE=local MODEL_NAME=google/gemma-2b-it python tests/test_api.py
213
+
214
+ # Test Gemma via Google AI
215
+ MODEL_TYPE=google MODEL_NAME=gemma-2-9b-it GOOGLE_API_KEY=your_key python tests/test_api.py
216
+ ```
217
+
218
+ ---
219
+
220
+ ## πŸš€ **Deployment Examples**
221
+
222
+ ### HuggingFace Spaces (Free)
223
+ ```yaml
224
+ # In your Space settings
225
+ MODEL_TYPE: local
226
+ MODEL_NAME: TinyLlama/TinyLlama-1.1B-Chat-v1.0
227
+ DEVICE: cpu
228
+ ```
229
+
230
+ ### HuggingFace Spaces (With API)
231
+ ```yaml
232
+ # In your Space settings
233
+ MODEL_TYPE: google
234
+ MODEL_NAME: gemma-2-9b-it
235
+ GOOGLE_API_KEY: your_secret_key
236
+ ```
237
+
238
+ ### Docker Deployment
239
+ ```bash
240
+ docker run -e MODEL_TYPE=minimax \
241
+ -e MINIMAX_API_KEY=your_key \
242
+ -e MINIMAX_API_URL=your_url \
243
+ -p 8000:7860 \
244
+ sema-chat-api
245
+ ```
246
+
247
+ ---
248
+
249
+ ## πŸ’‘ **Pro Tips**
250
+
251
+ 1. **Start Small**: Begin with TinyLlama for testing
252
+ 2. **Use APIs for Production**: More reliable than local models
253
+ 3. **Enable Streaming**: Better user experience
254
+ 4. **Monitor Usage**: Track API costs and limits
255
+ 5. **Have Fallbacks**: Configure multiple backends
256
+ 6. **Test Thoroughly**: Use the provided test scripts
257
+
258
+ ---
259
+
260
+ ## πŸ”— **Getting API Keys**
261
+
262
+ - **HuggingFace**: https://huggingface.co/settings/tokens
263
+ - **OpenAI**: https://platform.openai.com/api-keys
264
+ - **Anthropic**: https://console.anthropic.com/
265
+ - **Google AI**: https://aistudio.google.com/
266
+ - **MiniMax**: Contact MiniMax for API access
267
+
268
+ ---
269
+
270
+ **Your architecture is now ready for both MiniMax and Gemma! πŸŽ‰**
Dockerfile ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sema Chat API Dockerfile
2
+ # Multi-stage build for optimized production image
3
+
4
+ FROM python:3.11-slim as builder
5
+
6
+ # Set environment variables
7
+ ENV PYTHONDONTWRITEBYTECODE=1 \
8
+ PYTHONUNBUFFERED=1 \
9
+ PIP_NO_CACHE_DIR=1 \
10
+ PIP_DISABLE_PIP_VERSION_CHECK=1
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && apt-get install -y \
14
+ build-essential \
15
+ curl \
16
+ git \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Create and activate virtual environment
20
+ RUN python -m venv /opt/venv
21
+ ENV PATH="/opt/venv/bin:$PATH"
22
+
23
+ # Copy requirements and install Python dependencies
24
+ COPY requirements.txt .
25
+ RUN pip install --upgrade pip && \
26
+ pip install -r requirements.txt
27
+
28
+ # Production stage
29
+ FROM python:3.11-slim
30
+
31
+ # Set environment variables
32
+ ENV PYTHONDONTWRITEBYTECODE=1 \
33
+ PYTHONUNBUFFERED=1 \
34
+ PATH="/opt/venv/bin:$PATH" \
35
+ PYTHONPATH="/app"
36
+
37
+ # Install runtime dependencies
38
+ RUN apt-get update && apt-get install -y \
39
+ curl \
40
+ && rm -rf /var/lib/apt/lists/*
41
+
42
+ # Copy virtual environment from builder stage
43
+ COPY --from=builder /opt/venv /opt/venv
44
+
45
+ # Create app directory and user
46
+ RUN groupadd -r appuser && useradd -r -g appuser appuser
47
+ WORKDIR /app
48
+
49
+ # Copy application code
50
+ COPY . .
51
+
52
+ # Create necessary directories
53
+ RUN mkdir -p logs && \
54
+ chown -R appuser:appuser /app
55
+
56
+ # Switch to non-root user
57
+ USER appuser
58
+
59
+ # Expose port
60
+ EXPOSE 7860
61
+
62
+ # Health check
63
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
64
+ CMD curl -f http://localhost:7860/health || exit 1
65
+
66
+ # Default command
67
+ CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
HUGGINGFACE_DEPLOYMENT.md ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸš€ HuggingFace Spaces Deployment Guide
2
+
3
+ ## πŸ“‹ **Quick Setup for Gemma**
4
+
5
+ ### Step 1: Create Your HuggingFace Space
6
+ 1. Go to [HuggingFace Spaces](https://huggingface.co/spaces)
7
+ 2. Click **"Create new Space"**
8
+ 3. Choose:
9
+ - **Space name**: `your-username/sema-chat-gemma`
10
+ - **License**: MIT
11
+ - **Space SDK**: Docker
12
+ - **Space hardware**: CPU basic (free) or T4 small (paid)
13
+
14
+ ### Step 2: Clone and Upload Files
15
+ ```bash
16
+ # Clone your new space
17
+ git clone https://huggingface.co/spaces/your-username/sema-chat-gemma
18
+ cd sema-chat-gemma
19
+
20
+ # Copy all files from backend/sema-chat/
21
+ cp -r /path/to/sema/backend/sema-chat/* .
22
+
23
+ # Add and commit
24
+ git add .
25
+ git commit -m "Initial Sema Chat API with Gemma support"
26
+ git push
27
+ ```
28
+
29
+ ### Step 3: Configure Environment Variables
30
+ In your Space settings, add these environment variables:
31
+
32
+ #### **Option A: Local Gemma (Free Tier)**
33
+ ```
34
+ MODEL_TYPE=local
35
+ MODEL_NAME=google/gemma-2b-it
36
+ DEVICE=cpu
37
+ TEMPERATURE=0.7
38
+ MAX_NEW_TOKENS=256
39
+ DEBUG=false
40
+ ENVIRONMENT=production
41
+ ```
42
+
43
+ #### **Option B: Gemma via Google AI Studio (Recommended)**
44
+ ```
45
+ MODEL_TYPE=google
46
+ MODEL_NAME=gemma-2-9b-it
47
+ GOOGLE_API_KEY=your_google_api_key_here
48
+ TEMPERATURE=0.7
49
+ MAX_NEW_TOKENS=512
50
+ DEBUG=false
51
+ ENVIRONMENT=production
52
+ ```
53
+
54
+ #### **Option C: Gemma via HuggingFace API**
55
+ ```
56
+ MODEL_TYPE=hf_api
57
+ MODEL_NAME=google/gemma-2b-it
58
+ HF_API_TOKEN=your_hf_token_here
59
+ TEMPERATURE=0.7
60
+ MAX_NEW_TOKENS=512
61
+ DEBUG=false
62
+ ENVIRONMENT=production
63
+ ```
64
+
65
+ ---
66
+
67
+ ## πŸ”‘ **Getting API Keys**
68
+
69
+ ### Google AI Studio API Key (Recommended)
70
+ 1. Go to [Google AI Studio](https://aistudio.google.com/)
71
+ 2. Sign in with your Google account
72
+ 3. Click **"Get API Key"**
73
+ 4. Create a new API key
74
+ 5. Copy the key and add it to your Space settings
75
+
76
+ ### HuggingFace API Token (Alternative)
77
+ 1. Go to [HuggingFace Settings](https://huggingface.co/settings/tokens)
78
+ 2. Click **"New token"**
79
+ 3. Choose **"Read"** access
80
+ 4. Copy the token and add it to your Space settings
81
+
82
+ ---
83
+
84
+ ## πŸ“ **Required Files Structure**
85
+
86
+ Make sure your Space has these files:
87
+ ```
88
+ your-space/
89
+ β”œβ”€β”€ app/ # Main application code
90
+ β”œβ”€β”€ requirements.txt # Python dependencies
91
+ β”œβ”€β”€ Dockerfile # Container configuration
92
+ β”œβ”€β”€ README.md # Space documentation
93
+ └── .gitignore # Git ignore file
94
+ ```
95
+
96
+ ---
97
+
98
+ ## 🐳 **Dockerfile Configuration**
99
+
100
+ Your Dockerfile should be:
101
+ ```dockerfile
102
+ FROM python:3.11-slim
103
+
104
+ # Set environment variables
105
+ ENV PYTHONDONTWRITEBYTECODE=1 \
106
+ PYTHONUNBUFFERED=1 \
107
+ PYTHONPATH="/app"
108
+
109
+ # Install system dependencies
110
+ RUN apt-get update && apt-get install -y \
111
+ curl \
112
+ && rm -rf /var/lib/apt/lists/*
113
+
114
+ # Set working directory
115
+ WORKDIR /app
116
+
117
+ # Copy requirements and install dependencies
118
+ COPY requirements.txt .
119
+ RUN pip install --no-cache-dir -r requirements.txt
120
+
121
+ # Copy application code
122
+ COPY . .
123
+
124
+ # Create non-root user
125
+ RUN useradd -m -u 1000 user
126
+ USER user
127
+
128
+ # Expose port 7860 (HuggingFace Spaces standard)
129
+ EXPOSE 7860
130
+
131
+ # Health check
132
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
133
+ CMD curl -f http://localhost:7860/health || exit 1
134
+
135
+ # Start the application
136
+ CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
137
+ ```
138
+
139
+ ---
140
+
141
+ ## 🎯 **Recommended Configuration for First Version**
142
+
143
+ For your first deployment, I recommend **Google AI Studio** with Gemma:
144
+
145
+ ### Environment Variables:
146
+ ```
147
+ MODEL_TYPE=google
148
+ MODEL_NAME=gemma-2-9b-it
149
+ GOOGLE_API_KEY=your_api_key_here
150
+ TEMPERATURE=0.7
151
+ MAX_NEW_TOKENS=512
152
+ DEBUG=false
153
+ ENVIRONMENT=production
154
+ ENABLE_STREAMING=true
155
+ RATE_LIMIT=30
156
+ SESSION_TIMEOUT=30
157
+ ```
158
+
159
+ ### Why This Setup?
160
+ - βœ… **Fast deployment** - No model download needed
161
+ - βœ… **Reliable** - Google's infrastructure
162
+ - βœ… **Cost-effective** - Free tier available
163
+ - βœ… **Good performance** - Gemma 2 9B is capable
164
+ - βœ… **Streaming support** - Real-time responses
165
+
166
+ ---
167
+
168
+ ## πŸ§ͺ **Testing Your Deployment**
169
+
170
+ ### 1. Check Health
171
+ ```bash
172
+ curl https://your-username-sema-chat-gemma.hf.space/health
173
+ ```
174
+
175
+ ### 2. Test Chat
176
+ ```bash
177
+ curl -X POST "https://your-username-sema-chat-gemma.hf.space/api/v1/chat" \
178
+ -H "Content-Type: application/json" \
179
+ -d '{
180
+ "message": "Hello! Can you introduce yourself?",
181
+ "session_id": "test-session"
182
+ }'
183
+ ```
184
+
185
+ ### 3. Test Streaming
186
+ ```bash
187
+ curl -N -H "Accept: text/event-stream" \
188
+ "https://your-username-sema-chat-gemma.hf.space/api/v1/chat/stream?message=Tell%20me%20about%20AI&session_id=test"
189
+ ```
190
+
191
+ ### 4. Access Swagger UI
192
+ Visit: `https://your-username-sema-chat-gemma.hf.space/`
193
+
194
+ ---
195
+
196
+ ## πŸ”§ **Troubleshooting**
197
+
198
+ ### Common Issues:
199
+
200
+ #### 1. **Space Won't Start**
201
+ - Check logs in Space settings
202
+ - Verify all required files are present
203
+ - Check Dockerfile syntax
204
+
205
+ #### 2. **Model Loading Fails**
206
+ - Verify API key is correct
207
+ - Check model name spelling
208
+ - Try a smaller model first
209
+
210
+ #### 3. **API Errors**
211
+ - Check environment variables
212
+ - Verify network connectivity
213
+ - Review application logs
214
+
215
+ #### 4. **Slow Responses**
216
+ - Use smaller model (gemma-2-2b-it)
217
+ - Reduce MAX_NEW_TOKENS
218
+ - Enable streaming for better UX
219
+
220
+ ### Debug Commands:
221
+ ```bash
222
+ # Check environment variables
223
+ curl https://your-space.hf.space/api/v1/model/info
224
+
225
+ # Check detailed health
226
+ curl https://your-space.hf.space/api/v1/health
227
+
228
+ # View logs in Space settings
229
+ ```
230
+
231
+ ---
232
+
233
+ ## πŸš€ **Step-by-Step Deployment**
234
+
235
+ ### 1. Prepare Your Space
236
+ ```bash
237
+ # Create and clone your space
238
+ git clone https://huggingface.co/spaces/your-username/sema-chat-gemma
239
+ cd sema-chat-gemma
240
+
241
+ # Copy files
242
+ cp -r ../sema/backend/sema-chat/* .
243
+ ```
244
+
245
+ ### 2. Set Environment Variables
246
+ Go to your Space settings and add:
247
+ ```
248
+ MODEL_TYPE=google
249
+ MODEL_NAME=gemma-2-9b-it
250
+ GOOGLE_API_KEY=your_key_here
251
+ ```
252
+
253
+ ### 3. Deploy
254
+ ```bash
255
+ git add .
256
+ git commit -m "Deploy Sema Chat with Gemma"
257
+ git push
258
+ ```
259
+
260
+ ### 4. Wait for Build
261
+ - Space will automatically build (5-10 minutes)
262
+ - Check build logs for any errors
263
+ - Once running, test the endpoints
264
+
265
+ ### 5. Share Your Space
266
+ Your API will be available at:
267
+ `https://your-username-sema-chat-gemma.hf.space/`
268
+
269
+ ---
270
+
271
+ ## πŸ’‘ **Pro Tips**
272
+
273
+ 1. **Start with Google AI Studio** - Easiest setup
274
+ 2. **Use environment variables** - Never hardcode API keys
275
+ 3. **Enable streaming** - Better user experience
276
+ 4. **Monitor usage** - Check API quotas
277
+ 5. **Test thoroughly** - Use the provided test scripts
278
+ 6. **Document your API** - Swagger UI is auto-generated
279
+
280
+ ---
281
+
282
+ ## πŸŽ‰ **You're Ready!**
283
+
284
+ With this setup, you'll have a production-ready chatbot API with:
285
+ - βœ… Gemma 2 9B model via Google AI Studio
286
+ - βœ… Streaming responses
287
+ - βœ… Session management
288
+ - βœ… Rate limiting
289
+ - βœ… Health monitoring
290
+ - βœ… Interactive Swagger UI
291
+
292
+ **Your Space URL will be:**
293
+ `https://your-username-sema-chat-gemma.hf.space/`
294
+
295
+ Happy deploying! πŸš€
README.md CHANGED
@@ -1,12 +1,217 @@
1
- ---
2
- title: Sema Chat
3
- emoji: πŸ‘€
4
- colorFrom: green
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: Chat Service for sema ai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # Sema Chat API πŸ’¬
2
+
3
+ Modern chatbot API with streaming capabilities, flexible model backends, and production-ready features. Built with FastAPI and designed for rapid GenAI advancements.
4
+
5
+ ## πŸš€ Quick Start with Gemma
6
+
7
+ ### Option 1: Automated HuggingFace Spaces Deployment
8
+ ```bash
9
+ cd backend/sema-chat
10
+ ./setup_huggingface.sh
11
+ ```
12
+
13
+ ### Option 2: Manual Local Setup
14
+ ```bash
15
+ cd backend/sema-chat
16
+ pip install -r requirements.txt
17
+
18
+ # Copy and configure environment
19
+ cp .env.example .env
20
+
21
+ # For Gemma via Google AI Studio (Recommended)
22
+ # Edit .env:
23
+ MODEL_TYPE=google
24
+ MODEL_NAME=gemma-2-9b-it
25
+ GOOGLE_API_KEY=your_google_api_key
26
+
27
+ # Run the API
28
+ uvicorn app.main:app --reload --host 0.0.0.0 --port 7860
29
+ ```
30
+
31
+ ### Option 3: Local Gemma (Free, No API Key)
32
+ ```bash
33
+ # Edit .env:
34
+ MODEL_TYPE=local
35
+ MODEL_NAME=google/gemma-2b-it
36
+ DEVICE=cpu
37
+
38
+ # Run (will download model on first run)
39
+ uvicorn app.main:app --reload --host 0.0.0.0 --port 7860
40
+ ```
41
+
42
+ ## 🌐 Access Your API
43
+
44
+ Once running, access:
45
+ - **Swagger UI**: http://localhost:7860/
46
+ - **Health Check**: http://localhost:7860/api/v1/health
47
+ - **Chat Endpoint**: http://localhost:7860/api/v1/chat
48
+
49
+ ## πŸ§ͺ Quick Test
50
+
51
+ ```bash
52
+ # Test chat
53
+ curl -X POST "http://localhost:7860/api/v1/chat" \
54
+ -H "Content-Type: application/json" \
55
+ -d '{
56
+ "message": "Hello! Can you introduce yourself?",
57
+ "session_id": "test-session"
58
+ }'
59
+
60
+ # Test streaming
61
+ curl -N -H "Accept: text/event-stream" \
62
+ "http://localhost:7860/api/v1/chat/stream?message=Tell%20me%20about%20AI&session_id=test"
63
+ ```
64
+
65
+ ## 🎯 Features
66
+
67
+ ### Core Capabilities
68
+ - βœ… **Real-time Streaming**: Server-Sent Events and WebSocket support
69
+ - βœ… **Multiple Model Backends**: Local, HuggingFace API, OpenAI, Anthropic, Google AI, MiniMax
70
+ - βœ… **Session Management**: Persistent conversation contexts
71
+ - βœ… **Rate Limiting**: Built-in protection with configurable limits
72
+ - βœ… **Health Monitoring**: Comprehensive health checks and metrics
73
+
74
+ ### Supported Models
75
+ - **Local**: TinyLlama, DialoGPT, Gemma, Qwen
76
+ - **Google AI**: Gemma-2-9b-it, Gemini-1.5-flash, Gemini-1.5-pro
77
+ - **OpenAI**: GPT-3.5-turbo, GPT-4, GPT-4-turbo
78
+ - **Anthropic**: Claude-3-haiku, Claude-3-sonnet, Claude-3-opus
79
+ - **HuggingFace API**: Any model via Inference API
80
+ - **MiniMax**: M1 model with reasoning capabilities
81
+
82
+ ## πŸ”§ Configuration
83
+
84
+ ### Environment Variables
85
+ ```bash
86
+ # Model Backend (local, google, openai, anthropic, hf_api, minimax)
87
+ MODEL_TYPE=google
88
+ MODEL_NAME=gemma-2-9b-it
89
+
90
+ # API Keys (as needed)
91
+ GOOGLE_API_KEY=your_key
92
+ OPENAI_API_KEY=your_key
93
+ ANTHROPIC_API_KEY=your_key
94
+ HF_API_TOKEN=your_token
95
+ MINIMAX_API_KEY=your_key
96
+
97
+ # Generation Settings
98
+ TEMPERATURE=0.7
99
+ MAX_NEW_TOKENS=512
100
+ TOP_P=0.9
101
+
102
+ # Server Settings
103
+ HOST=0.0.0.0
104
+ PORT=7860
105
+ DEBUG=false
106
+ ```
107
+
108
+ ## πŸ“š Documentation
109
+
110
+ - **[Configuration Guide](CONFIGURATION_GUIDE.md)** - Detailed setup for all backends
111
+ - **[HuggingFace Deployment](HUGGINGFACE_DEPLOYMENT.md)** - Step-by-step deployment guide
112
+ - **[API Documentation](http://localhost:7860/)** - Interactive Swagger UI
113
+
114
+ ## πŸ§ͺ Testing
115
+
116
+ ```bash
117
+ # Run comprehensive tests
118
+ python tests/test_api.py
119
+
120
+ # Test different backends
121
+ python examples/test_backends.py
122
+
123
+ # Test specific backend
124
+ python examples/test_backends.py --backend google
125
+ ```
126
+
127
+ ## πŸš€ Deployment
128
+
129
+ ### HuggingFace Spaces (Recommended)
130
+ 1. Run the setup script: `./setup_huggingface.sh`
131
+ 2. Create your Space on HuggingFace
132
+ 3. Push the generated code
133
+ 4. Set environment variables in Space settings
134
+ 5. Your API will be live at: `https://username-spacename.hf.space/`
135
+
136
+ ### Docker
137
+ ```bash
138
+ docker build -t sema-chat-api .
139
+ docker run -e MODEL_TYPE=google \
140
+ -e GOOGLE_API_KEY=your_key \
141
+ -p 7860:7860 \
142
+ sema-chat-api
143
+ ```
144
+
145
+ ## πŸ”— API Endpoints
146
+
147
+ ### Chat
148
+ - **`POST /api/v1/chat`** - Send chat message
149
+ - **`GET /api/v1/chat/stream`** - Streaming chat (SSE)
150
+ - **`WebSocket /api/v1/chat/ws`** - Real-time WebSocket chat
151
+
152
+ ### Sessions
153
+ - **`GET /api/v1/sessions/{id}`** - Get conversation history
154
+ - **`DELETE /api/v1/sessions/{id}`** - Clear conversation
155
+ - **`GET /api/v1/sessions`** - List active sessions
156
+
157
+ ### System
158
+ - **`GET /api/v1/health`** - Comprehensive health check
159
+ - **`GET /api/v1/model/info`** - Current model information
160
+ - **`GET /api/v1/status`** - Basic status
161
+
162
+ ## πŸ’‘ Why This Architecture?
163
+
164
+ 1. **Future-Proof**: Modular design adapts to rapid GenAI advancements
165
+ 2. **Flexible**: Switch between local models and APIs with environment variables
166
+ 3. **Production-Ready**: Rate limiting, monitoring, error handling built-in
167
+ 4. **Cost-Effective**: Start free with local models, scale with APIs
168
+ 5. **Developer-Friendly**: Comprehensive docs, tests, and examples
169
+
170
+ ## πŸ› οΈ Development
171
+
172
+ ### Project Structure
173
+ ```
174
+ app/
175
+ β”œβ”€β”€ main.py # FastAPI application
176
+ β”œβ”€β”€ api/v1/endpoints.py # API routes
177
+ β”œβ”€β”€ core/
178
+ β”‚ β”œβ”€β”€ config.py # Environment-based configuration
179
+ β”‚ └── logging.py # Structured logging
180
+ β”œβ”€β”€ models/schemas.py # Pydantic request/response models
181
+ β”œβ”€β”€ services/
182
+ β”‚ β”œβ”€β”€ chat_manager.py # Chat orchestration
183
+ β”‚ β”œβ”€β”€ model_manager.py # Backend selection
184
+ β”‚ β”œβ”€β”€ session_manager.py # Conversation management
185
+ β”‚ └── model_backends/ # Model implementations
186
+ └── utils/helpers.py # Utility functions
187
+ ```
188
+
189
+ ### Adding New Backends
190
+ 1. Create new backend in `app/services/model_backends/`
191
+ 2. Inherit from `ModelBackend` base class
192
+ 3. Implement required methods
193
+ 4. Add to `ModelManager._create_backend()`
194
+ 5. Update configuration and documentation
195
+
196
+ ## 🀝 Contributing
197
+
198
+ 1. Fork the repository
199
+ 2. Create a feature branch
200
+ 3. Add tests for new functionality
201
+ 4. Ensure all tests pass
202
+ 5. Submit a pull request
203
+
204
+ ## πŸ“„ License
205
+
206
+ MIT License - see LICENSE file for details.
207
+
208
+ ## πŸ™ Acknowledgments
209
+
210
+ - **HuggingFace** for model hosting and Spaces platform
211
+ - **Google** for Gemma models and AI Studio
212
+ - **FastAPI** for the excellent web framework
213
+ - **OpenAI, Anthropic, MiniMax** for their APIs
214
+
215
  ---
216
 
217
+ **Ready to chat? Deploy your Sema Chat API today! πŸš€πŸ’¬**
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Sema Chat API Package
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API package
app/api/v1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API v1 package
app/api/v1/endpoints.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API v1 endpoints for Sema Chat API
3
+ """
4
+
5
+ import asyncio
6
+ import time
7
+ import uuid
8
+ from typing import List, Optional, Dict, Any
9
+ from datetime import datetime
10
+
11
+ from fastapi import APIRouter, HTTPException, Request, Query, WebSocket, WebSocketDisconnect
12
+ from fastapi.responses import StreamingResponse
13
+ from sse_starlette.sse import EventSourceResponse
14
+ from slowapi import Limiter
15
+ from slowapi.util import get_remote_address
16
+ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
17
+ from fastapi.responses import Response
18
+
19
+ from ...models.schemas import (
20
+ ChatRequest, ChatResponse, StreamChunk, ConversationHistory,
21
+ HealthResponse, ErrorResponse, ModelInfo, SessionInfo
22
+ )
23
+ from ...services.chat_manager import get_chat_manager
24
+ from ...services.model_manager import get_model_manager
25
+ from ...services.session_manager import get_session_manager
26
+ from ...services.model_backends.base import ModelBackendError, ModelNotLoadedError, GenerationError
27
+ from ...core.config import settings
28
+ from ...core.logging import get_logger
29
+
30
+ # Initialize router and rate limiter
31
+ router = APIRouter()
32
+ limiter = Limiter(key_func=get_remote_address)
33
+ logger = get_logger()
34
+
35
+ # WebSocket connection manager
36
+ class ConnectionManager:
37
+ def __init__(self):
38
+ self.active_connections: List[WebSocket] = []
39
+
40
+ async def connect(self, websocket: WebSocket):
41
+ await websocket.accept()
42
+ self.active_connections.append(websocket)
43
+
44
+ def disconnect(self, websocket: WebSocket):
45
+ self.active_connections.remove(websocket)
46
+
47
+ async def send_personal_message(self, message: str, websocket: WebSocket):
48
+ await websocket.send_text(message)
49
+
50
+ manager = ConnectionManager()
51
+
52
+
53
+ @router.post("/chat", response_model=ChatResponse)
54
+ @limiter.limit(f"{settings.rate_limit}/minute")
55
+ async def chat(request: ChatRequest, req: Request):
56
+ """
57
+ Send a chat message and get a complete response
58
+ """
59
+ start_time = time.time()
60
+
61
+ try:
62
+ chat_manager = await get_chat_manager()
63
+ response = await chat_manager.process_chat_request(request)
64
+
65
+ # Add timing information
66
+ total_time = time.time() - start_time
67
+ response.generation_time = getattr(response, 'generation_time', total_time)
68
+
69
+ logger.info("chat_request_completed",
70
+ session_id=request.session_id,
71
+ message_length=len(request.message),
72
+ response_length=len(response.message),
73
+ total_time=total_time)
74
+
75
+ return response
76
+
77
+ except ModelNotLoadedError as e:
78
+ logger.error("model_not_loaded", error=str(e), session_id=request.session_id)
79
+ raise HTTPException(status_code=503, detail="Model not available")
80
+
81
+ except GenerationError as e:
82
+ logger.error("generation_error", error=str(e), session_id=request.session_id)
83
+ raise HTTPException(status_code=500, detail="Failed to generate response")
84
+
85
+ except Exception as e:
86
+ logger.error("chat_request_failed", error=str(e), session_id=request.session_id)
87
+ raise HTTPException(status_code=500, detail="Internal server error")
88
+
89
+
90
+ @router.get("/chat/stream")
91
+ @limiter.limit(f"{settings.rate_limit}/minute")
92
+ async def chat_stream(
93
+ message: str = Query(..., description="Chat message"),
94
+ session_id: str = Query(..., description="Session ID"),
95
+ system_prompt: Optional[str] = Query(None, description="Custom system prompt"),
96
+ temperature: Optional[float] = Query(None, ge=0.0, le=1.0, description="Temperature"),
97
+ max_tokens: Optional[int] = Query(None, ge=1, le=2048, description="Max tokens"),
98
+ req: Request = None
99
+ ):
100
+ """
101
+ Send a chat message and get a streaming response via Server-Sent Events
102
+ """
103
+ try:
104
+ # Create chat request
105
+ chat_request = ChatRequest(
106
+ message=message,
107
+ session_id=session_id,
108
+ system_prompt=system_prompt,
109
+ temperature=temperature,
110
+ max_tokens=max_tokens,
111
+ stream=True
112
+ )
113
+
114
+ chat_manager = await get_chat_manager()
115
+
116
+ async def event_generator():
117
+ try:
118
+ async for chunk in chat_manager.process_streaming_chat_request(chat_request):
119
+ # Format as SSE event
120
+ chunk_data = {
121
+ "content": chunk.content,
122
+ "session_id": chunk.session_id,
123
+ "message_id": chunk.message_id,
124
+ "chunk_id": chunk.chunk_id,
125
+ "is_final": chunk.is_final,
126
+ "timestamp": chunk.timestamp.isoformat()
127
+ }
128
+
129
+ yield {
130
+ "event": "chunk",
131
+ "data": chunk_data
132
+ }
133
+
134
+ if chunk.is_final:
135
+ yield {
136
+ "event": "done",
137
+ "data": {"message": "Stream completed"}
138
+ }
139
+ break
140
+
141
+ except Exception as e:
142
+ logger.error("streaming_error", error=str(e), session_id=session_id)
143
+ yield {
144
+ "event": "error",
145
+ "data": {"error": str(e)}
146
+ }
147
+
148
+ return EventSourceResponse(event_generator())
149
+
150
+ except Exception as e:
151
+ logger.error("stream_setup_failed", error=str(e), session_id=session_id)
152
+ raise HTTPException(status_code=500, detail="Failed to setup stream")
153
+
154
+
155
+ @router.websocket("/chat/ws")
156
+ async def websocket_chat(websocket: WebSocket):
157
+ """
158
+ WebSocket endpoint for real-time chat
159
+ """
160
+ await manager.connect(websocket)
161
+ session_id = None
162
+
163
+ try:
164
+ while True:
165
+ # Receive message from client
166
+ data = await websocket.receive_json()
167
+
168
+ # Extract request data
169
+ message = data.get("message")
170
+ session_id = data.get("session_id")
171
+ system_prompt = data.get("system_prompt")
172
+ temperature = data.get("temperature")
173
+ max_tokens = data.get("max_tokens")
174
+
175
+ if not message or not session_id:
176
+ await websocket.send_json({
177
+ "error": "Message and session_id are required"
178
+ })
179
+ continue
180
+
181
+ # Create chat request
182
+ chat_request = ChatRequest(
183
+ message=message,
184
+ session_id=session_id,
185
+ system_prompt=system_prompt,
186
+ temperature=temperature,
187
+ max_tokens=max_tokens,
188
+ stream=True
189
+ )
190
+
191
+ # Process streaming request
192
+ chat_manager = await get_chat_manager()
193
+
194
+ try:
195
+ async for chunk in chat_manager.process_streaming_chat_request(chat_request):
196
+ await websocket.send_json({
197
+ "type": "chunk",
198
+ "content": chunk.content,
199
+ "session_id": chunk.session_id,
200
+ "message_id": chunk.message_id,
201
+ "chunk_id": chunk.chunk_id,
202
+ "is_final": chunk.is_final,
203
+ "timestamp": chunk.timestamp.isoformat()
204
+ })
205
+
206
+ if chunk.is_final:
207
+ break
208
+
209
+ except Exception as e:
210
+ logger.error("websocket_generation_error", error=str(e), session_id=session_id)
211
+ await websocket.send_json({
212
+ "type": "error",
213
+ "error": str(e)
214
+ })
215
+
216
+ except WebSocketDisconnect:
217
+ manager.disconnect(websocket)
218
+ logger.info("websocket_disconnected", session_id=session_id)
219
+ except Exception as e:
220
+ logger.error("websocket_error", error=str(e), session_id=session_id)
221
+ manager.disconnect(websocket)
222
+
223
+
224
+ @router.get("/sessions/{session_id}", response_model=ConversationHistory)
225
+ async def get_session(session_id: str):
226
+ """
227
+ Get conversation history for a session
228
+ """
229
+ try:
230
+ chat_manager = await get_chat_manager()
231
+ history = await chat_manager.get_conversation_history(session_id)
232
+
233
+ if not history:
234
+ raise HTTPException(status_code=404, detail="Session not found")
235
+
236
+ return history
237
+
238
+ except HTTPException:
239
+ raise
240
+ except Exception as e:
241
+ logger.error("get_session_failed", error=str(e), session_id=session_id)
242
+ raise HTTPException(status_code=500, detail="Failed to get session")
243
+
244
+
245
+ @router.delete("/sessions/{session_id}")
246
+ async def clear_session(session_id: str):
247
+ """
248
+ Clear conversation history for a session
249
+ """
250
+ try:
251
+ chat_manager = await get_chat_manager()
252
+ success = await chat_manager.clear_conversation(session_id)
253
+
254
+ if not success:
255
+ raise HTTPException(status_code=404, detail="Session not found")
256
+
257
+ return {"message": "Session cleared successfully"}
258
+
259
+ except HTTPException:
260
+ raise
261
+ except Exception as e:
262
+ logger.error("clear_session_failed", error=str(e), session_id=session_id)
263
+ raise HTTPException(status_code=500, detail="Failed to clear session")
264
+
265
+
266
+ @router.get("/sessions", response_model=List[Dict[str, Any]])
267
+ async def get_active_sessions():
268
+ """
269
+ Get list of active chat sessions
270
+ """
271
+ try:
272
+ chat_manager = await get_chat_manager()
273
+ sessions = await chat_manager.get_active_sessions()
274
+ return sessions
275
+
276
+ except Exception as e:
277
+ logger.error("get_active_sessions_failed", error=str(e))
278
+ raise HTTPException(status_code=500, detail="Failed to get active sessions")
279
+
280
+
281
+ @router.get("/model/info", response_model=ModelInfo)
282
+ async def get_model_info():
283
+ """
284
+ Get information about the current model
285
+ """
286
+ try:
287
+ model_manager = await get_model_manager()
288
+ info = model_manager.get_model_info()
289
+
290
+ return ModelInfo(
291
+ name=info["name"],
292
+ type=info["type"],
293
+ loaded=info["loaded"],
294
+ parameters=info.get("parameters"),
295
+ capabilities=info.get("capabilities", [])
296
+ )
297
+
298
+ except Exception as e:
299
+ logger.error("get_model_info_failed", error=str(e))
300
+ raise HTTPException(status_code=500, detail="Failed to get model info")
301
+
302
+
303
+ @router.get("/health", response_model=HealthResponse)
304
+ async def health_check():
305
+ """
306
+ Comprehensive health check endpoint
307
+ """
308
+ try:
309
+ chat_manager = await get_chat_manager()
310
+ health_data = await chat_manager.health_check()
311
+
312
+ # Extract key information
313
+ overall_status = health_data.get("overall", {})
314
+ model_info = health_data.get("model_manager", {})
315
+ session_info = health_data.get("session_manager", {})
316
+
317
+ return HealthResponse(
318
+ status=overall_status.get("status", "unknown"),
319
+ version=settings.app_version,
320
+ model_type=settings.model_type,
321
+ model_name=settings.model_name,
322
+ model_loaded=model_info.get("status") == "healthy",
323
+ uptime=time.time(), # Simplified uptime
324
+ active_sessions=session_info.get("active_sessions", 0),
325
+ timestamp=datetime.utcnow()
326
+ )
327
+
328
+ except Exception as e:
329
+ logger.error("health_check_failed", error=str(e))
330
+ return HealthResponse(
331
+ status="unhealthy",
332
+ version=settings.app_version,
333
+ model_type=settings.model_type,
334
+ model_name=settings.model_name,
335
+ model_loaded=False,
336
+ uptime=time.time(),
337
+ active_sessions=0,
338
+ timestamp=datetime.utcnow()
339
+ )
340
+
341
+
342
+ @router.get("/status")
343
+ async def status():
344
+ """
345
+ Simple status endpoint
346
+ """
347
+ return {
348
+ "status": "ok",
349
+ "service": "sema-chat-api",
350
+ "version": settings.app_version,
351
+ "timestamp": datetime.utcnow().isoformat()
352
+ }
353
+
354
+
355
+ @router.get("/metrics")
356
+ async def metrics():
357
+ """
358
+ Prometheus metrics endpoint
359
+ """
360
+ if not settings.enable_metrics:
361
+ raise HTTPException(status_code=404, detail="Metrics not enabled")
362
+
363
+ return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Core configuration and utilities
app/core/config.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration management for Sema Chat API
3
+ Environment-driven settings for flexible model backends
4
+ """
5
+
6
+ import os
7
+ from typing import List, Optional
8
+ from pydantic import BaseSettings, Field
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
+
14
+
15
+ class Settings(BaseSettings):
16
+ """Application settings with environment variable support"""
17
+
18
+ # =============================================================================
19
+ # APPLICATION SETTINGS
20
+ # =============================================================================
21
+
22
+ app_name: str = Field(default="Sema Chat API", env="APP_NAME")
23
+ app_version: str = Field(default="1.0.0", env="APP_VERSION")
24
+ environment: str = Field(default="development", env="ENVIRONMENT")
25
+ debug: bool = Field(default=True, env="DEBUG")
26
+
27
+ # =============================================================================
28
+ # SERVER SETTINGS
29
+ # =============================================================================
30
+
31
+ host: str = Field(default="0.0.0.0", env="HOST")
32
+ port: int = Field(default=7860, env="PORT")
33
+ cors_origins: List[str] = Field(default=["*"], env="CORS_ORIGINS")
34
+
35
+ # =============================================================================
36
+ # MODEL CONFIGURATION
37
+ # =============================================================================
38
+
39
+ model_type: str = Field(default="local", env="MODEL_TYPE")
40
+ model_name: str = Field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", env="MODEL_NAME")
41
+
42
+ # Local model settings
43
+ device: str = Field(default="auto", env="DEVICE")
44
+ max_length: int = Field(default=2048, env="MAX_LENGTH")
45
+ temperature: float = Field(default=0.7, env="TEMPERATURE")
46
+ top_p: float = Field(default=0.9, env="TOP_P")
47
+ top_k: int = Field(default=50, env="TOP_K")
48
+ max_new_tokens: int = Field(default=512, env="MAX_NEW_TOKENS")
49
+
50
+ # =============================================================================
51
+ # API KEYS AND TOKENS
52
+ # =============================================================================
53
+
54
+ # HuggingFace
55
+ hf_api_token: Optional[str] = Field(default=None, env="HF_API_TOKEN")
56
+ hf_inference_url: str = Field(
57
+ default="https://api-inference.huggingface.co/models/",
58
+ env="HF_INFERENCE_URL"
59
+ )
60
+
61
+ # OpenAI
62
+ openai_api_key: Optional[str] = Field(default=None, env="OPENAI_API_KEY")
63
+ openai_org_id: Optional[str] = Field(default=None, env="OPENAI_ORG_ID")
64
+
65
+ # Anthropic
66
+ anthropic_api_key: Optional[str] = Field(default=None, env="ANTHROPIC_API_KEY")
67
+
68
+ # MiniMax
69
+ minimax_api_key: Optional[str] = Field(default=None, env="MINIMAX_API_KEY")
70
+ minimax_api_url: Optional[str] = Field(default=None, env="MINIMAX_API_URL")
71
+ minimax_model_version: Optional[str] = Field(default=None, env="MINIMAX_MODEL_VERSION")
72
+
73
+ # Google AI Studio
74
+ google_api_key: Optional[str] = Field(default=None, env="GOOGLE_API_KEY")
75
+
76
+ # =============================================================================
77
+ # RATE LIMITING AND PERFORMANCE
78
+ # =============================================================================
79
+
80
+ rate_limit: int = Field(default=60, env="RATE_LIMIT") # requests per minute
81
+ max_concurrent_streams: int = Field(default=10, env="MAX_CONCURRENT_STREAMS")
82
+ stream_delay: float = Field(default=0.01, env="STREAM_DELAY")
83
+
84
+ # =============================================================================
85
+ # SESSION MANAGEMENT
86
+ # =============================================================================
87
+
88
+ session_timeout: int = Field(default=30, env="SESSION_TIMEOUT") # minutes
89
+ max_sessions_per_user: int = Field(default=5, env="MAX_SESSIONS_PER_USER")
90
+ max_messages_per_session: int = Field(default=100, env="MAX_MESSAGES_PER_SESSION")
91
+
92
+ # =============================================================================
93
+ # STREAMING SETTINGS
94
+ # =============================================================================
95
+
96
+ enable_streaming: bool = Field(default=True, env="ENABLE_STREAMING")
97
+
98
+ # =============================================================================
99
+ # LOGGING AND MONITORING
100
+ # =============================================================================
101
+
102
+ log_level: str = Field(default="INFO", env="LOG_LEVEL")
103
+ structured_logging: bool = Field(default=True, env="STRUCTURED_LOGGING")
104
+ log_file: Optional[str] = Field(default=None, env="LOG_FILE")
105
+
106
+ enable_metrics: bool = Field(default=True, env="ENABLE_METRICS")
107
+ metrics_path: str = Field(default="/metrics", env="METRICS_PATH")
108
+
109
+ # =============================================================================
110
+ # EXTERNAL SERVICES
111
+ # =============================================================================
112
+
113
+ redis_url: Optional[str] = Field(default=None, env="REDIS_URL")
114
+
115
+ # =============================================================================
116
+ # SECURITY
117
+ # =============================================================================
118
+
119
+ api_key: Optional[str] = Field(default=None, env="API_KEY")
120
+ jwt_secret: Optional[str] = Field(default=None, env="JWT_SECRET")
121
+
122
+ # =============================================================================
123
+ # SYSTEM PROMPTS
124
+ # =============================================================================
125
+
126
+ system_prompt: str = Field(
127
+ default="You are a helpful, harmless, and honest AI assistant. Respond in a friendly and professional manner.",
128
+ env="SYSTEM_PROMPT"
129
+ )
130
+
131
+ system_prompt_chat: Optional[str] = Field(default=None, env="SYSTEM_PROMPT_CHAT")
132
+ system_prompt_code: Optional[str] = Field(default=None, env="SYSTEM_PROMPT_CODE")
133
+ system_prompt_creative: Optional[str] = Field(default=None, env="SYSTEM_PROMPT_CREATIVE")
134
+
135
+ class Config:
136
+ env_file = ".env"
137
+ case_sensitive = False
138
+
139
+ def get_system_prompt(self, prompt_type: str = "default") -> str:
140
+ """Get system prompt based on type"""
141
+ if prompt_type == "chat" and self.system_prompt_chat:
142
+ return self.system_prompt_chat
143
+ elif prompt_type == "code" and self.system_prompt_code:
144
+ return self.system_prompt_code
145
+ elif prompt_type == "creative" and self.system_prompt_creative:
146
+ return self.system_prompt_creative
147
+ return self.system_prompt
148
+
149
+ def is_local_model(self) -> bool:
150
+ """Check if using local model backend"""
151
+ return self.model_type.lower() == "local"
152
+
153
+ def is_api_model(self) -> bool:
154
+ """Check if using API-based model backend"""
155
+ return self.model_type.lower() in ["hf_api", "openai", "anthropic"]
156
+
157
+ def validate_model_config(self) -> bool:
158
+ """Validate model configuration based on type"""
159
+ if self.model_type == "hf_api" and not self.hf_api_token:
160
+ return False
161
+ elif self.model_type == "openai" and not self.openai_api_key:
162
+ return False
163
+ elif self.model_type == "anthropic" and not self.anthropic_api_key:
164
+ return False
165
+ elif self.model_type == "minimax" and (not self.minimax_api_key or not self.minimax_api_url):
166
+ return False
167
+ elif self.model_type == "google" and not self.google_api_key:
168
+ return False
169
+ return True
170
+
171
+
172
+ # Global settings instance
173
+ settings = Settings()
174
+
175
+
176
+ def get_settings() -> Settings:
177
+ """Get application settings"""
178
+ return settings
app/core/logging.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Structured logging configuration for Sema Chat API
3
+ """
4
+
5
+ import logging
6
+ import sys
7
+ from typing import Any, Dict
8
+ import structlog
9
+ from .config import settings
10
+
11
+
12
+ def configure_logging():
13
+ """Configure structured logging for the application"""
14
+
15
+ # Configure structlog
16
+ structlog.configure(
17
+ processors=[
18
+ structlog.stdlib.filter_by_level,
19
+ structlog.stdlib.add_logger_name,
20
+ structlog.stdlib.add_log_level,
21
+ structlog.stdlib.PositionalArgumentsFormatter(),
22
+ structlog.processors.TimeStamper(fmt="iso"),
23
+ structlog.processors.StackInfoRenderer(),
24
+ structlog.processors.format_exc_info,
25
+ structlog.processors.UnicodeDecoder(),
26
+ structlog.processors.JSONRenderer() if settings.structured_logging else structlog.dev.ConsoleRenderer(),
27
+ ],
28
+ context_class=dict,
29
+ logger_factory=structlog.stdlib.LoggerFactory(),
30
+ wrapper_class=structlog.stdlib.BoundLogger,
31
+ cache_logger_on_first_use=True,
32
+ )
33
+
34
+ # Configure standard logging
35
+ logging.basicConfig(
36
+ format="%(message)s",
37
+ stream=sys.stdout,
38
+ level=getattr(logging, settings.log_level.upper()),
39
+ )
40
+
41
+ # Configure file logging if specified
42
+ if settings.log_file:
43
+ file_handler = logging.FileHandler(settings.log_file)
44
+ file_handler.setLevel(getattr(logging, settings.log_level.upper()))
45
+
46
+ if settings.structured_logging:
47
+ file_handler.setFormatter(logging.Formatter('%(message)s'))
48
+ else:
49
+ file_handler.setFormatter(
50
+ logging.Formatter(
51
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
52
+ )
53
+ )
54
+
55
+ logging.getLogger().addHandler(file_handler)
56
+
57
+
58
+ def get_logger(name: str = None) -> structlog.BoundLogger:
59
+ """Get a structured logger instance"""
60
+ return structlog.get_logger(name)
61
+
62
+
63
+ class LoggerMixin:
64
+ """Mixin class to add logging capabilities to any class"""
65
+
66
+ @property
67
+ def logger(self) -> structlog.BoundLogger:
68
+ """Get logger for this class"""
69
+ return get_logger(self.__class__.__name__)
70
+
71
+ def log_info(self, message: str, **kwargs: Any):
72
+ """Log info message with context"""
73
+ self.logger.info(message, **kwargs)
74
+
75
+ def log_error(self, message: str, **kwargs: Any):
76
+ """Log error message with context"""
77
+ self.logger.error(message, **kwargs)
78
+
79
+ def log_warning(self, message: str, **kwargs: Any):
80
+ """Log warning message with context"""
81
+ self.logger.warning(message, **kwargs)
82
+
83
+ def log_debug(self, message: str, **kwargs: Any):
84
+ """Log debug message with context"""
85
+ self.logger.debug(message, **kwargs)
app/main.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sema Chat API - Main Application
3
+ Modern chatbot API with streaming capabilities and flexible model backends
4
+ """
5
+
6
+ from fastapi import FastAPI, Request
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.middleware.trustedhost import TrustedHostMiddleware
9
+ from fastapi.responses import RedirectResponse
10
+ from slowapi import _rate_limit_exceeded_handler
11
+ from slowapi.errors import RateLimitExceeded
12
+ import time
13
+
14
+ from .core.config import settings
15
+ from .core.logging import configure_logging, get_logger
16
+ from .services.chat_manager import initialize_chat_manager, shutdown_chat_manager
17
+ from .api.v1.endpoints import router as v1_router, limiter
18
+
19
+ # Configure logging
20
+ configure_logging()
21
+ logger = get_logger()
22
+
23
+ # Application startup time
24
+ startup_time = time.time()
25
+
26
+
27
+ def create_application() -> FastAPI:
28
+ """Create and configure the FastAPI application"""
29
+
30
+ # Create FastAPI app
31
+ app = FastAPI(
32
+ title=settings.app_name,
33
+ description="Modern chatbot API with streaming capabilities and flexible model backends",
34
+ version=settings.app_version,
35
+ docs_url="/docs" if settings.debug else "/", # Swagger UI at root for HF Spaces
36
+ redoc_url="/redoc",
37
+ openapi_url="/openapi.json"
38
+ )
39
+
40
+ # Add rate limiting
41
+ app.state.limiter = limiter
42
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
43
+
44
+ # Add CORS middleware
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=settings.cors_origins,
48
+ allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+ # Add trusted host middleware for production
54
+ if settings.environment == "production":
55
+ app.add_middleware(
56
+ TrustedHostMiddleware,
57
+ allowed_hosts=["*"] # Configure appropriately for production
58
+ )
59
+
60
+ # Add request timing middleware
61
+ @app.middleware("http")
62
+ async def add_process_time_header(request: Request, call_next):
63
+ start_time = time.time()
64
+ response = await call_next(request)
65
+ process_time = time.time() - start_time
66
+ response.headers["X-Process-Time"] = str(process_time)
67
+ response.headers["X-Request-ID"] = str(id(request))
68
+ return response
69
+
70
+ # Include API routes
71
+ app.include_router(v1_router, prefix="/api/v1", tags=["Chat API v1"])
72
+
73
+ # Root redirect for HuggingFace Spaces
74
+ @app.get("/", include_in_schema=False)
75
+ async def root():
76
+ """Redirect root to docs for HuggingFace Spaces"""
77
+ return RedirectResponse(url="/docs")
78
+
79
+ return app
80
+
81
+
82
+ # Create the application instance
83
+ app = create_application()
84
+
85
+
86
+ @app.on_event("startup")
87
+ async def startup_event():
88
+ """Initialize the application on startup"""
89
+ logger.info("application_startup",
90
+ version=settings.app_version,
91
+ environment=settings.environment,
92
+ model_type=settings.model_type,
93
+ model_name=settings.model_name)
94
+
95
+ print(f"\nπŸš€ Starting {settings.app_name} v{settings.app_version}")
96
+ print(f"πŸ“Š Environment: {settings.environment}")
97
+ print(f"πŸ€– Model Backend: {settings.model_type}")
98
+ print(f"🎯 Model: {settings.model_name}")
99
+ print("πŸ”„ Initializing chat services...")
100
+
101
+ try:
102
+ # Initialize chat manager (which initializes model and session managers)
103
+ success = await initialize_chat_manager()
104
+
105
+ if success:
106
+ logger.info("chat_services_initialized")
107
+ print("βœ… Chat services initialized successfully")
108
+ print(f"🌐 API Documentation: http://localhost:7860/docs")
109
+ print(f"πŸ“‘ WebSocket Chat: ws://localhost:7860/api/v1/chat/ws")
110
+ print(f"πŸ”„ Streaming Chat: http://localhost:7860/api/v1/chat/stream")
111
+ print(f"πŸ’¬ Regular Chat: http://localhost:7860/api/v1/chat")
112
+ print(f"❀️ Health Check: http://localhost:7860/api/v1/health")
113
+
114
+ if settings.enable_metrics:
115
+ print(f"πŸ“ˆ Metrics: http://localhost:7860/api/v1/metrics")
116
+
117
+ print("\nπŸŽ‰ Sema Chat API is ready for conversations!")
118
+ print("=" * 60)
119
+ else:
120
+ logger.error("chat_services_initialization_failed")
121
+ print("❌ Failed to initialize chat services")
122
+ print("πŸ”§ Please check your configuration and try again")
123
+ raise RuntimeError("Chat services initialization failed")
124
+
125
+ except Exception as e:
126
+ logger.error("startup_failed", error=str(e))
127
+ print(f"πŸ’₯ Startup failed: {e}")
128
+ raise
129
+
130
+
131
+ @app.on_event("shutdown")
132
+ async def shutdown_event():
133
+ """Cleanup on application shutdown"""
134
+ logger.info("application_shutdown")
135
+ print("\nπŸ›‘ Shutting down Sema Chat API...")
136
+
137
+ try:
138
+ await shutdown_chat_manager()
139
+ print("βœ… Chat services shutdown complete")
140
+ except Exception as e:
141
+ logger.error("shutdown_failed", error=str(e))
142
+ print(f"⚠️ Shutdown warning: {e}")
143
+
144
+ print("πŸ‘‹ Goodbye!\n")
145
+
146
+
147
+ # Health check endpoint at app level
148
+ @app.get("/health", tags=["Health"])
149
+ async def app_health():
150
+ """Simple app-level health check"""
151
+ uptime = time.time() - startup_time
152
+ return {
153
+ "status": "healthy",
154
+ "service": "sema-chat-api",
155
+ "version": settings.app_version,
156
+ "uptime_seconds": uptime,
157
+ "model_type": settings.model_type,
158
+ "model_name": settings.model_name
159
+ }
160
+
161
+
162
+ # Status endpoint
163
+ @app.get("/status", tags=["Health"])
164
+ async def app_status():
165
+ """Simple status endpoint"""
166
+ return {
167
+ "status": "ok",
168
+ "service": "sema-chat-api",
169
+ "version": settings.app_version,
170
+ "environment": settings.environment
171
+ }
172
+
173
+
174
+ if __name__ == "__main__":
175
+ import uvicorn
176
+
177
+ print(f"πŸš€ Starting Sema Chat API on {settings.host}:7860")
178
+ print(f"πŸ”§ Debug mode: {settings.debug}")
179
+ print(f"πŸ€– Model: {settings.model_type}/{settings.model_name}")
180
+
181
+ uvicorn.run(
182
+ "app.main:app",
183
+ host=settings.host,
184
+ port=7860,
185
+ reload=settings.debug,
186
+ log_level=settings.log_level.lower()
187
+ )
app/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Pydantic models and schemas
app/models/schemas.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for request/response validation
3
+ """
4
+
5
+ from typing import List, Optional, Dict, Any, Union
6
+ from pydantic import BaseModel, Field, validator
7
+ from datetime import datetime
8
+ import uuid
9
+
10
+
11
+ class ChatMessage(BaseModel):
12
+ """Individual chat message model"""
13
+
14
+ role: str = Field(..., description="Message role: 'user' or 'assistant'")
15
+ content: str = Field(..., description="Message content")
16
+ timestamp: datetime = Field(default_factory=datetime.utcnow, description="Message timestamp")
17
+ metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional message metadata")
18
+
19
+ @validator('role')
20
+ def validate_role(cls, v):
21
+ if v not in ['user', 'assistant', 'system']:
22
+ raise ValueError('Role must be user, assistant, or system')
23
+ return v
24
+
25
+
26
+ class ChatRequest(BaseModel):
27
+ """Chat request model"""
28
+
29
+ message: str = Field(
30
+ ...,
31
+ description="User message",
32
+ min_length=1,
33
+ max_length=4000,
34
+ example="Hello, how are you today?"
35
+ )
36
+ session_id: str = Field(
37
+ ...,
38
+ description="Session identifier for conversation context",
39
+ example="user-123-session"
40
+ )
41
+ system_prompt: Optional[str] = Field(
42
+ default=None,
43
+ description="Custom system prompt for this conversation",
44
+ max_length=1000
45
+ )
46
+ temperature: Optional[float] = Field(
47
+ default=None,
48
+ description="Sampling temperature (0.0 to 1.0)",
49
+ ge=0.0,
50
+ le=1.0
51
+ )
52
+ max_tokens: Optional[int] = Field(
53
+ default=None,
54
+ description="Maximum tokens to generate",
55
+ ge=1,
56
+ le=2048
57
+ )
58
+ stream: Optional[bool] = Field(
59
+ default=False,
60
+ description="Whether to stream the response"
61
+ )
62
+
63
+
64
+ class ChatResponse(BaseModel):
65
+ """Chat response model"""
66
+
67
+ message: str = Field(..., description="Assistant response message")
68
+ session_id: str = Field(..., description="Session identifier")
69
+ message_id: str = Field(..., description="Unique message identifier")
70
+ model_name: str = Field(..., description="Model used for generation")
71
+ timestamp: datetime = Field(default_factory=datetime.utcnow, description="Response timestamp")
72
+ generation_time: float = Field(..., description="Time taken to generate response (seconds)")
73
+ token_count: Optional[int] = Field(default=None, description="Number of tokens in response")
74
+ finish_reason: Optional[str] = Field(default=None, description="Reason generation finished")
75
+
76
+ class Config:
77
+ json_schema_extra = {
78
+ "example": {
79
+ "message": "Hello! I'm doing well, thank you for asking. How can I help you today?",
80
+ "session_id": "user-123-session",
81
+ "message_id": "msg-456-789",
82
+ "model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
83
+ "timestamp": "2024-01-15T10:30:00Z",
84
+ "generation_time": 1.234,
85
+ "token_count": 25,
86
+ "finish_reason": "stop"
87
+ }
88
+ }
89
+
90
+
91
+ class StreamChunk(BaseModel):
92
+ """Streaming response chunk model"""
93
+
94
+ content: str = Field(..., description="Chunk content")
95
+ session_id: str = Field(..., description="Session identifier")
96
+ message_id: str = Field(..., description="Message identifier")
97
+ chunk_id: int = Field(..., description="Chunk sequence number")
98
+ is_final: bool = Field(default=False, description="Whether this is the final chunk")
99
+ timestamp: datetime = Field(default_factory=datetime.utcnow, description="Chunk timestamp")
100
+
101
+
102
+ class ConversationHistory(BaseModel):
103
+ """Conversation history model"""
104
+
105
+ session_id: str = Field(..., description="Session identifier")
106
+ messages: List[ChatMessage] = Field(..., description="List of messages in conversation")
107
+ created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time")
108
+ updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
109
+ message_count: int = Field(..., description="Total number of messages")
110
+
111
+ class Config:
112
+ json_schema_extra = {
113
+ "example": {
114
+ "session_id": "user-123-session",
115
+ "messages": [
116
+ {
117
+ "role": "user",
118
+ "content": "Hello!",
119
+ "timestamp": "2024-01-15T10:30:00Z"
120
+ },
121
+ {
122
+ "role": "assistant",
123
+ "content": "Hello! How can I help you today?",
124
+ "timestamp": "2024-01-15T10:30:01Z"
125
+ }
126
+ ],
127
+ "created_at": "2024-01-15T10:30:00Z",
128
+ "updated_at": "2024-01-15T10:30:01Z",
129
+ "message_count": 2
130
+ }
131
+ }
132
+
133
+
134
+ class SessionInfo(BaseModel):
135
+ """Session information model"""
136
+
137
+ session_id: str = Field(..., description="Session identifier")
138
+ created_at: datetime = Field(..., description="Session creation time")
139
+ updated_at: datetime = Field(..., description="Last activity time")
140
+ message_count: int = Field(..., description="Number of messages in session")
141
+ model_name: str = Field(..., description="Model used in this session")
142
+ is_active: bool = Field(..., description="Whether session is active")
143
+
144
+
145
+ class HealthResponse(BaseModel):
146
+ """Health check response model"""
147
+
148
+ status: str = Field(..., description="API health status")
149
+ version: str = Field(..., description="API version")
150
+ model_type: str = Field(..., description="Current model backend type")
151
+ model_name: str = Field(..., description="Current model name")
152
+ model_loaded: bool = Field(..., description="Whether model is loaded and ready")
153
+ uptime: float = Field(..., description="API uptime in seconds")
154
+ active_sessions: int = Field(..., description="Number of active chat sessions")
155
+ timestamp: datetime = Field(default_factory=datetime.utcnow, description="Health check timestamp")
156
+
157
+ class Config:
158
+ json_schema_extra = {
159
+ "example": {
160
+ "status": "healthy",
161
+ "version": "1.0.0",
162
+ "model_type": "local",
163
+ "model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
164
+ "model_loaded": True,
165
+ "uptime": 3600.5,
166
+ "active_sessions": 5,
167
+ "timestamp": "2024-01-15T10:30:00Z"
168
+ }
169
+ }
170
+
171
+
172
+ class ErrorResponse(BaseModel):
173
+ """Error response model"""
174
+
175
+ error: str = Field(..., description="Error type")
176
+ message: str = Field(..., description="Error message")
177
+ details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details")
178
+ timestamp: datetime = Field(default_factory=datetime.utcnow, description="Error timestamp")
179
+ request_id: Optional[str] = Field(default=None, description="Request identifier for debugging")
180
+
181
+ class Config:
182
+ json_schema_extra = {
183
+ "example": {
184
+ "error": "validation_error",
185
+ "message": "Message content is required",
186
+ "details": {"field": "message", "constraint": "min_length"},
187
+ "timestamp": "2024-01-15T10:30:00Z",
188
+ "request_id": "req-123-456"
189
+ }
190
+ }
191
+
192
+
193
+ class ModelInfo(BaseModel):
194
+ """Model information model"""
195
+
196
+ name: str = Field(..., description="Model name")
197
+ type: str = Field(..., description="Model backend type")
198
+ loaded: bool = Field(..., description="Whether model is loaded")
199
+ parameters: Optional[Dict[str, Any]] = Field(default=None, description="Model parameters")
200
+ capabilities: List[str] = Field(default=[], description="Model capabilities")
201
+
202
+ class Config:
203
+ json_schema_extra = {
204
+ "example": {
205
+ "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
206
+ "type": "local",
207
+ "loaded": True,
208
+ "parameters": {
209
+ "temperature": 0.7,
210
+ "max_tokens": 512,
211
+ "top_p": 0.9
212
+ },
213
+ "capabilities": ["chat", "instruction_following", "streaming"]
214
+ }
215
+ }
app/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Services package
app/services/chat_manager.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat Manager - Main orchestrator for chat functionality
3
+ Coordinates between model backends and session management
4
+ """
5
+
6
+ import time
7
+ import uuid
8
+ from typing import AsyncGenerator, List, Optional, Dict, Any
9
+ from datetime import datetime
10
+
11
+ from ..core.config import settings
12
+ from ..core.logging import LoggerMixin
13
+ from ..models.schemas import (
14
+ ChatMessage, ChatRequest, ChatResponse, StreamChunk,
15
+ ConversationHistory, ErrorResponse
16
+ )
17
+ from .model_manager import get_model_manager
18
+ from .session_manager import get_session_manager
19
+ from .model_backends.base import ModelBackendError, ModelNotLoadedError, GenerationError
20
+
21
+
22
+ class ChatManager(LoggerMixin):
23
+ """
24
+ Main chat service that orchestrates conversation handling
25
+ """
26
+
27
+ def __init__(self):
28
+ self.is_initialized = False
29
+ self.active_streams = 0
30
+ self.max_concurrent_streams = settings.max_concurrent_streams
31
+
32
+ async def initialize(self) -> bool:
33
+ """Initialize the chat manager"""
34
+ try:
35
+ self.log_info("Initializing chat manager")
36
+
37
+ # Initialize model manager
38
+ model_manager = await get_model_manager()
39
+ if not await model_manager.initialize():
40
+ self.log_error("Failed to initialize model manager")
41
+ return False
42
+
43
+ # Initialize session manager
44
+ session_manager = await get_session_manager()
45
+ if not await session_manager.initialize():
46
+ self.log_error("Failed to initialize session manager")
47
+ return False
48
+
49
+ self.is_initialized = True
50
+ self.log_info("Chat manager initialized successfully")
51
+ return True
52
+
53
+ except Exception as e:
54
+ self.log_error("Chat manager initialization failed", error=str(e))
55
+ return False
56
+
57
+ async def shutdown(self):
58
+ """Shutdown the chat manager"""
59
+ try:
60
+ self.log_info("Shutting down chat manager")
61
+
62
+ # Shutdown managers
63
+ model_manager = await get_model_manager()
64
+ await model_manager.shutdown()
65
+
66
+ session_manager = await get_session_manager()
67
+ await session_manager.shutdown()
68
+
69
+ self.is_initialized = False
70
+ self.log_info("Chat manager shutdown complete")
71
+
72
+ except Exception as e:
73
+ self.log_error("Chat manager shutdown failed", error=str(e))
74
+
75
+ async def process_chat_request(self, request: ChatRequest) -> ChatResponse:
76
+ """
77
+ Process a non-streaming chat request
78
+
79
+ Args:
80
+ request: Chat request containing message and parameters
81
+
82
+ Returns:
83
+ ChatResponse: Complete response
84
+ """
85
+ if not self.is_initialized:
86
+ raise RuntimeError("Chat manager not initialized")
87
+
88
+ start_time = time.time()
89
+
90
+ try:
91
+ # Get managers
92
+ model_manager = await get_model_manager()
93
+ session_manager = await get_session_manager()
94
+
95
+ if not model_manager.is_ready():
96
+ raise ModelNotLoadedError("Model not ready for inference")
97
+
98
+ # Ensure session exists
99
+ await session_manager.create_session(request.session_id)
100
+
101
+ # Add user message to session
102
+ user_message = ChatMessage(
103
+ role="user",
104
+ content=request.message,
105
+ timestamp=datetime.utcnow(),
106
+ metadata={"session_id": request.session_id}
107
+ )
108
+ await session_manager.add_message(request.session_id, user_message)
109
+
110
+ # Get conversation history
111
+ messages = await self._prepare_messages_for_model(
112
+ request.session_id,
113
+ request.system_prompt
114
+ )
115
+
116
+ # Generate response
117
+ backend = model_manager.get_backend()
118
+ response = await backend.generate_response(
119
+ messages=messages,
120
+ temperature=request.temperature or settings.temperature,
121
+ max_tokens=request.max_tokens or settings.max_new_tokens
122
+ )
123
+
124
+ # Add assistant message to session
125
+ assistant_message = ChatMessage(
126
+ role="assistant",
127
+ content=response.message,
128
+ timestamp=datetime.utcnow(),
129
+ metadata={"session_id": request.session_id, "message_id": response.message_id}
130
+ )
131
+ await session_manager.add_message(request.session_id, assistant_message)
132
+
133
+ # Update response with correct session info
134
+ response.session_id = request.session_id
135
+
136
+ self.log_info("Chat request processed",
137
+ session_id=request.session_id,
138
+ generation_time=response.generation_time,
139
+ total_time=time.time() - start_time)
140
+
141
+ return response
142
+
143
+ except ModelBackendError as e:
144
+ self.log_error("Model backend error", error=str(e), session_id=request.session_id)
145
+ raise
146
+ except Exception as e:
147
+ self.log_error("Chat request processing failed", error=str(e), session_id=request.session_id)
148
+ raise
149
+
150
+ async def process_streaming_chat_request(
151
+ self,
152
+ request: ChatRequest
153
+ ) -> AsyncGenerator[StreamChunk, None]:
154
+ """
155
+ Process a streaming chat request
156
+
157
+ Args:
158
+ request: Chat request containing message and parameters
159
+
160
+ Yields:
161
+ StreamChunk: Response chunks
162
+ """
163
+ if not self.is_initialized:
164
+ raise RuntimeError("Chat manager not initialized")
165
+
166
+ # Check concurrent stream limit
167
+ if self.active_streams >= self.max_concurrent_streams:
168
+ raise RuntimeError(f"Maximum concurrent streams ({self.max_concurrent_streams}) exceeded")
169
+
170
+ self.active_streams += 1
171
+
172
+ try:
173
+ # Get managers
174
+ model_manager = await get_model_manager()
175
+ session_manager = await get_session_manager()
176
+
177
+ if not model_manager.is_ready():
178
+ raise ModelNotLoadedError("Model not ready for inference")
179
+
180
+ # Ensure session exists
181
+ await session_manager.create_session(request.session_id)
182
+
183
+ # Add user message to session
184
+ user_message = ChatMessage(
185
+ role="user",
186
+ content=request.message,
187
+ timestamp=datetime.utcnow(),
188
+ metadata={"session_id": request.session_id}
189
+ )
190
+ await session_manager.add_message(request.session_id, user_message)
191
+
192
+ # Get conversation history
193
+ messages = await self._prepare_messages_for_model(
194
+ request.session_id,
195
+ request.system_prompt
196
+ )
197
+
198
+ # Generate streaming response
199
+ backend = model_manager.get_backend()
200
+ full_response = ""
201
+ message_id = None
202
+
203
+ async for chunk in backend.generate_stream(
204
+ messages=messages,
205
+ temperature=request.temperature or settings.temperature,
206
+ max_tokens=request.max_tokens or settings.max_new_tokens
207
+ ):
208
+ if message_id is None:
209
+ message_id = chunk.message_id
210
+
211
+ full_response += chunk.content
212
+ yield chunk
213
+
214
+ # Add complete assistant message to session
215
+ if full_response.strip():
216
+ assistant_message = ChatMessage(
217
+ role="assistant",
218
+ content=full_response.strip(),
219
+ timestamp=datetime.utcnow(),
220
+ metadata={"session_id": request.session_id, "message_id": message_id}
221
+ )
222
+ await session_manager.add_message(request.session_id, assistant_message)
223
+
224
+ self.log_info("Streaming chat request processed",
225
+ session_id=request.session_id,
226
+ response_length=len(full_response))
227
+
228
+ except ModelBackendError as e:
229
+ self.log_error("Model backend error in streaming", error=str(e), session_id=request.session_id)
230
+ raise
231
+ except Exception as e:
232
+ self.log_error("Streaming chat request failed", error=str(e), session_id=request.session_id)
233
+ raise
234
+ finally:
235
+ self.active_streams -= 1
236
+
237
+ async def _prepare_messages_for_model(
238
+ self,
239
+ session_id: str,
240
+ custom_system_prompt: Optional[str] = None
241
+ ) -> List[ChatMessage]:
242
+ """
243
+ Prepare messages for model input, including system prompt and history
244
+
245
+ Args:
246
+ session_id: Session identifier
247
+ custom_system_prompt: Optional custom system prompt
248
+
249
+ Returns:
250
+ List of ChatMessage objects ready for model input
251
+ """
252
+ session_manager = await get_session_manager()
253
+
254
+ # Get conversation history
255
+ history_messages = await session_manager.get_session_messages(session_id)
256
+
257
+ # Prepare messages list
258
+ messages = []
259
+
260
+ # Add system prompt if provided
261
+ system_prompt = custom_system_prompt or settings.get_system_prompt()
262
+ if system_prompt:
263
+ messages.append(ChatMessage(
264
+ role="system",
265
+ content=system_prompt,
266
+ timestamp=datetime.utcnow()
267
+ ))
268
+
269
+ # Add conversation history
270
+ messages.extend(history_messages)
271
+
272
+ return messages
273
+
274
+ async def get_conversation_history(self, session_id: str) -> Optional[ConversationHistory]:
275
+ """
276
+ Get conversation history for a session
277
+
278
+ Args:
279
+ session_id: Session identifier
280
+
281
+ Returns:
282
+ ConversationHistory or None if session not found
283
+ """
284
+ try:
285
+ session_manager = await get_session_manager()
286
+ return await session_manager.get_session(session_id)
287
+
288
+ except Exception as e:
289
+ self.log_error("Failed to get conversation history", error=str(e), session_id=session_id)
290
+ return None
291
+
292
+ async def clear_conversation(self, session_id: str) -> bool:
293
+ """
294
+ Clear conversation history for a session
295
+
296
+ Args:
297
+ session_id: Session identifier
298
+
299
+ Returns:
300
+ bool: True if cleared successfully
301
+ """
302
+ try:
303
+ session_manager = await get_session_manager()
304
+ return await session_manager.delete_session(session_id)
305
+
306
+ except Exception as e:
307
+ self.log_error("Failed to clear conversation", error=str(e), session_id=session_id)
308
+ return False
309
+
310
+ async def get_active_sessions(self) -> List[Dict[str, Any]]:
311
+ """Get information about active chat sessions"""
312
+ try:
313
+ session_manager = await get_session_manager()
314
+ sessions = await session_manager.get_active_sessions()
315
+
316
+ return [
317
+ {
318
+ "session_id": session.session_id,
319
+ "created_at": session.created_at.isoformat(),
320
+ "updated_at": session.updated_at.isoformat(),
321
+ "message_count": session.message_count,
322
+ "model_name": session.model_name,
323
+ "is_active": session.is_active
324
+ }
325
+ for session in sessions
326
+ ]
327
+
328
+ except Exception as e:
329
+ self.log_error("Failed to get active sessions", error=str(e))
330
+ return []
331
+
332
+ async def health_check(self) -> Dict[str, Any]:
333
+ """Perform a comprehensive health check"""
334
+ try:
335
+ health_status = {
336
+ "chat_manager": {
337
+ "status": "healthy" if self.is_initialized else "unhealthy",
338
+ "initialized": self.is_initialized,
339
+ "active_streams": self.active_streams,
340
+ "max_concurrent_streams": self.max_concurrent_streams
341
+ }
342
+ }
343
+
344
+ # Check model manager
345
+ model_manager = await get_model_manager()
346
+ model_health = await model_manager.health_check()
347
+ health_status["model_manager"] = model_health
348
+
349
+ # Check session manager
350
+ session_manager = await get_session_manager()
351
+ active_sessions = await session_manager.get_active_sessions()
352
+ health_status["session_manager"] = {
353
+ "status": "healthy",
354
+ "active_sessions": len(active_sessions),
355
+ "storage_type": "redis" if session_manager.use_redis else "memory"
356
+ }
357
+
358
+ # Overall status
359
+ overall_healthy = (
360
+ self.is_initialized and
361
+ model_health.get("status") == "healthy"
362
+ )
363
+
364
+ health_status["overall"] = {
365
+ "status": "healthy" if overall_healthy else "unhealthy",
366
+ "timestamp": datetime.utcnow().isoformat()
367
+ }
368
+
369
+ return health_status
370
+
371
+ except Exception as e:
372
+ self.log_error("Health check failed", error=str(e))
373
+ return {
374
+ "overall": {
375
+ "status": "unhealthy",
376
+ "error": str(e),
377
+ "timestamp": datetime.utcnow().isoformat()
378
+ }
379
+ }
380
+
381
+
382
+ # Global chat manager instance
383
+ chat_manager = ChatManager()
384
+
385
+
386
+ async def get_chat_manager() -> ChatManager:
387
+ """Get the global chat manager instance"""
388
+ return chat_manager
389
+
390
+
391
+ async def initialize_chat_manager() -> bool:
392
+ """Initialize the global chat manager"""
393
+ return await chat_manager.initialize()
394
+
395
+
396
+ async def shutdown_chat_manager():
397
+ """Shutdown the global chat manager"""
398
+ await chat_manager.shutdown()
app/services/model_backends/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Model backends package
app/services/model_backends/anthropic_api.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anthropic API backend
3
+ Uses Anthropic's API for Claude models
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import uuid
9
+ from typing import AsyncGenerator, List, Dict, Any, Optional
10
+ from datetime import datetime
11
+ import anthropic
12
+
13
+ from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
14
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
15
+ from ...core.config import settings
16
+
17
+
18
+ class AnthropicAPIBackend(ModelBackend):
19
+ """Anthropic API backend for Claude models"""
20
+
21
+ def __init__(self, model_name: str, **kwargs):
22
+ super().__init__(model_name, **kwargs)
23
+ self.client = None
24
+ self.api_key = kwargs.get('api_key', settings.anthropic_api_key)
25
+ self.capabilities = ["chat", "streaming", "api_based", "long_context"]
26
+
27
+ # Generation parameters
28
+ self.parameters = {
29
+ 'temperature': kwargs.get('temperature', settings.temperature),
30
+ 'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
31
+ 'top_p': kwargs.get('top_p', settings.top_p),
32
+ }
33
+
34
+ async def load_model(self) -> bool:
35
+ """Initialize the Anthropic API client"""
36
+ try:
37
+ if not self.api_key:
38
+ raise ModelLoadError("Anthropic API key is required")
39
+
40
+ self.log_info("Initializing Anthropic API client", model=self.model_name)
41
+
42
+ # Initialize the Anthropic client
43
+ self.client = anthropic.AsyncAnthropic(
44
+ api_key=self.api_key
45
+ )
46
+
47
+ # Test the connection
48
+ await self._test_connection()
49
+
50
+ self.is_loaded = True
51
+ self.log_info("Anthropic API client initialized successfully", model=self.model_name)
52
+
53
+ return True
54
+
55
+ except Exception as e:
56
+ self.log_error("Failed to initialize Anthropic API client", error=str(e), model=self.model_name)
57
+ raise ModelLoadError(f"Failed to initialize Anthropic API for {self.model_name}: {str(e)}")
58
+
59
+ async def unload_model(self) -> bool:
60
+ """Clean up the API client"""
61
+ try:
62
+ if self.client:
63
+ await self.client.close()
64
+ self.client = None
65
+ self.is_loaded = False
66
+ self.log_info("Anthropic API client cleaned up", model=self.model_name)
67
+ return True
68
+
69
+ except Exception as e:
70
+ self.log_error("Failed to cleanup Anthropic API client", error=str(e), model=self.model_name)
71
+ return False
72
+
73
+ async def _test_connection(self):
74
+ """Test the Anthropic API connection"""
75
+ try:
76
+ # Simple test request
77
+ response = await self.client.messages.create(
78
+ model=self.model_name,
79
+ max_tokens=5,
80
+ temperature=0.1,
81
+ messages=[{"role": "user", "content": "Hello"}]
82
+ )
83
+
84
+ self.log_info("Anthropic API connection test successful", model=self.model_name)
85
+
86
+ except Exception as e:
87
+ self.log_error("Anthropic API connection test failed", error=str(e), model=self.model_name)
88
+ raise
89
+
90
+ def _format_messages_for_api(self, messages: List[ChatMessage]) -> tuple:
91
+ """Format messages for Anthropic API (separate system and messages)"""
92
+ system_message = None
93
+ formatted_messages = []
94
+
95
+ for msg in messages:
96
+ if msg.role == "system":
97
+ system_message = msg.content
98
+ else:
99
+ formatted_messages.append({
100
+ "role": msg.role,
101
+ "content": msg.content
102
+ })
103
+
104
+ return system_message, formatted_messages
105
+
106
+ async def generate_response(
107
+ self,
108
+ messages: List[ChatMessage],
109
+ temperature: float = 0.7,
110
+ max_tokens: int = 512,
111
+ **kwargs
112
+ ) -> ChatResponse:
113
+ """Generate a complete response using Anthropic API"""
114
+ if not self.is_loaded:
115
+ raise ModelNotLoadedError("Anthropic API client not initialized")
116
+
117
+ start_time = time.time()
118
+ message_id = str(uuid.uuid4())
119
+
120
+ try:
121
+ # Validate parameters
122
+ params = self.validate_parameters(
123
+ temperature=temperature,
124
+ max_tokens=max_tokens,
125
+ **kwargs
126
+ )
127
+
128
+ # Format messages
129
+ system_message, api_messages = self._format_messages_for_api(messages)
130
+
131
+ # Prepare request parameters
132
+ request_params = {
133
+ "model": self.model_name,
134
+ "messages": api_messages,
135
+ "max_tokens": params['max_tokens'],
136
+ "temperature": params['temperature'],
137
+ "top_p": params.get('top_p', 0.9),
138
+ "stream": False
139
+ }
140
+
141
+ # Add system message if present
142
+ if system_message:
143
+ request_params["system"] = system_message
144
+
145
+ # Make API call
146
+ response = await self.client.messages.create(**request_params)
147
+
148
+ # Extract response
149
+ response_text = response.content[0].text if response.content else ""
150
+ finish_reason = getattr(response, 'stop_reason', 'stop')
151
+ token_count = getattr(response.usage, 'output_tokens', None) if hasattr(response, 'usage') else None
152
+
153
+ generation_time = time.time() - start_time
154
+
155
+ return ChatResponse(
156
+ message=response_text.strip(),
157
+ session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
158
+ message_id=message_id,
159
+ model_name=self.model_name,
160
+ generation_time=generation_time,
161
+ token_count=token_count,
162
+ finish_reason=finish_reason
163
+ )
164
+
165
+ except Exception as e:
166
+ self.log_error("Anthropic API generation failed", error=str(e), model=self.model_name)
167
+ raise GenerationError(f"Failed to generate response via Anthropic API: {str(e)}")
168
+
169
+ async def generate_stream(
170
+ self,
171
+ messages: List[ChatMessage],
172
+ temperature: float = 0.7,
173
+ max_tokens: int = 512,
174
+ **kwargs
175
+ ) -> AsyncGenerator[StreamChunk, None]:
176
+ """Generate a streaming response using Anthropic API"""
177
+ if not self.is_loaded:
178
+ raise ModelNotLoadedError("Anthropic API client not initialized")
179
+
180
+ message_id = str(uuid.uuid4())
181
+ session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
182
+ chunk_id = 0
183
+
184
+ try:
185
+ # Validate parameters
186
+ params = self.validate_parameters(
187
+ temperature=temperature,
188
+ max_tokens=max_tokens,
189
+ **kwargs
190
+ )
191
+
192
+ # Format messages
193
+ system_message, api_messages = self._format_messages_for_api(messages)
194
+
195
+ # Prepare request parameters
196
+ request_params = {
197
+ "model": self.model_name,
198
+ "messages": api_messages,
199
+ "max_tokens": params['max_tokens'],
200
+ "temperature": params['temperature'],
201
+ "top_p": params.get('top_p', 0.9),
202
+ "stream": True
203
+ }
204
+
205
+ # Add system message if present
206
+ if system_message:
207
+ request_params["system"] = system_message
208
+
209
+ # Create streaming request
210
+ stream = await self.client.messages.create(**request_params)
211
+
212
+ # Process streaming chunks
213
+ async for chunk in stream:
214
+ if chunk.type == "content_block_delta":
215
+ if hasattr(chunk.delta, 'text') and chunk.delta.text:
216
+ yield StreamChunk(
217
+ content=chunk.delta.text,
218
+ session_id=session_id,
219
+ message_id=message_id,
220
+ chunk_id=chunk_id,
221
+ is_final=False
222
+ )
223
+ chunk_id += 1
224
+
225
+ # Add small delay to prevent overwhelming the client
226
+ await asyncio.sleep(settings.stream_delay)
227
+
228
+ elif chunk.type == "message_stop":
229
+ break
230
+
231
+ # Send final chunk
232
+ yield StreamChunk(
233
+ content="",
234
+ session_id=session_id,
235
+ message_id=message_id,
236
+ chunk_id=chunk_id,
237
+ is_final=True
238
+ )
239
+
240
+ except Exception as e:
241
+ self.log_error("Anthropic API streaming failed", error=str(e), model=self.model_name)
242
+ raise GenerationError(f"Failed to generate streaming response via Anthropic API: {str(e)}")
243
+
244
+ def get_model_info(self) -> Dict[str, Any]:
245
+ """Get information about the current model"""
246
+ return {
247
+ "name": self.model_name,
248
+ "type": "anthropic_api",
249
+ "loaded": self.is_loaded,
250
+ "provider": "Anthropic",
251
+ "capabilities": self.capabilities,
252
+ "parameters": self.parameters,
253
+ "requires_api_key": True,
254
+ "api_key_configured": bool(self.api_key),
255
+ "context_window": self._get_context_window()
256
+ }
257
+
258
+ def _get_context_window(self) -> int:
259
+ """Get the context window size for the model"""
260
+ context_windows = {
261
+ "claude-3-haiku-20240307": 200000,
262
+ "claude-3-sonnet-20240229": 200000,
263
+ "claude-3-opus-20240229": 200000,
264
+ "claude-3-5-sonnet-20241022": 200000,
265
+ "claude-3-5-haiku-20241022": 200000,
266
+ }
267
+ return context_windows.get(self.model_name, 100000)
268
+
269
+ async def health_check(self) -> Dict[str, Any]:
270
+ """Perform a health check on the Anthropic API"""
271
+ try:
272
+ if not self.is_loaded:
273
+ return {
274
+ "status": "unhealthy",
275
+ "reason": "client_not_initialized",
276
+ "model_name": self.model_name
277
+ }
278
+
279
+ # Test API connectivity
280
+ start_time = time.time()
281
+ test_messages = [ChatMessage(role="user", content="Hi")]
282
+
283
+ try:
284
+ response = await asyncio.wait_for(
285
+ self.generate_response(
286
+ test_messages,
287
+ temperature=0.1,
288
+ max_tokens=5
289
+ ),
290
+ timeout=10.0
291
+ )
292
+
293
+ response_time = time.time() - start_time
294
+
295
+ return {
296
+ "status": "healthy",
297
+ "model_name": self.model_name,
298
+ "response_time": response_time,
299
+ "provider": "Anthropic",
300
+ "context_window": self._get_context_window()
301
+ }
302
+
303
+ except asyncio.TimeoutError:
304
+ return {
305
+ "status": "unhealthy",
306
+ "reason": "api_timeout",
307
+ "model_name": self.model_name,
308
+ "provider": "Anthropic"
309
+ }
310
+
311
+ except Exception as e:
312
+ self.log_error("Anthropic API health check failed", error=str(e), model=self.model_name)
313
+ return {
314
+ "status": "unhealthy",
315
+ "reason": "api_error",
316
+ "error": str(e),
317
+ "model_name": self.model_name,
318
+ "provider": "Anthropic"
319
+ }
app/services/model_backends/base.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for model backends
3
+ Defines the interface that all model backends must implement
4
+ """
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import AsyncGenerator, List, Dict, Any, Optional
8
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
9
+ from ...core.logging import LoggerMixin
10
+
11
+
12
+ class ModelBackend(ABC, LoggerMixin):
13
+ """Abstract base class for all model backends"""
14
+
15
+ def __init__(self, model_name: str, **kwargs):
16
+ self.model_name = model_name
17
+ self.is_loaded = False
18
+ self.capabilities = []
19
+ self.parameters = {}
20
+
21
+ @abstractmethod
22
+ async def load_model(self) -> bool:
23
+ """
24
+ Load the model and prepare it for inference
25
+
26
+ Returns:
27
+ bool: True if model loaded successfully, False otherwise
28
+ """
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def unload_model(self) -> bool:
33
+ """
34
+ Unload the model and free resources
35
+
36
+ Returns:
37
+ bool: True if model unloaded successfully, False otherwise
38
+ """
39
+ pass
40
+
41
+ @abstractmethod
42
+ async def generate_response(
43
+ self,
44
+ messages: List[ChatMessage],
45
+ temperature: float = 0.7,
46
+ max_tokens: int = 512,
47
+ **kwargs
48
+ ) -> ChatResponse:
49
+ """
50
+ Generate a complete response (non-streaming)
51
+
52
+ Args:
53
+ messages: List of conversation messages
54
+ temperature: Sampling temperature
55
+ max_tokens: Maximum tokens to generate
56
+ **kwargs: Additional generation parameters
57
+
58
+ Returns:
59
+ ChatResponse: Complete response
60
+ """
61
+ pass
62
+
63
+ @abstractmethod
64
+ async def generate_stream(
65
+ self,
66
+ messages: List[ChatMessage],
67
+ temperature: float = 0.7,
68
+ max_tokens: int = 512,
69
+ **kwargs
70
+ ) -> AsyncGenerator[StreamChunk, None]:
71
+ """
72
+ Generate a streaming response
73
+
74
+ Args:
75
+ messages: List of conversation messages
76
+ temperature: Sampling temperature
77
+ max_tokens: Maximum tokens to generate
78
+ **kwargs: Additional generation parameters
79
+
80
+ Yields:
81
+ StreamChunk: Response chunks
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ def get_model_info(self) -> Dict[str, Any]:
87
+ """
88
+ Get information about the current model
89
+
90
+ Returns:
91
+ Dict containing model information
92
+ """
93
+ pass
94
+
95
+ def supports_streaming(self) -> bool:
96
+ """Check if this backend supports streaming"""
97
+ return "streaming" in self.capabilities
98
+
99
+ def supports_chat(self) -> bool:
100
+ """Check if this backend supports chat/conversation"""
101
+ return "chat" in self.capabilities
102
+
103
+ def is_model_loaded(self) -> bool:
104
+ """Check if model is loaded and ready"""
105
+ return self.is_loaded
106
+
107
+ def format_messages_for_model(self, messages: List[ChatMessage]) -> Any:
108
+ """
109
+ Format messages for the specific model format
110
+ Override in subclasses if needed
111
+
112
+ Args:
113
+ messages: List of ChatMessage objects
114
+
115
+ Returns:
116
+ Formatted messages for the model
117
+ """
118
+ return [{"role": msg.role, "content": msg.content} for msg in messages]
119
+
120
+ def validate_parameters(self, **kwargs) -> Dict[str, Any]:
121
+ """
122
+ Validate and normalize generation parameters
123
+
124
+ Args:
125
+ **kwargs: Generation parameters
126
+
127
+ Returns:
128
+ Dict of validated parameters
129
+ """
130
+ validated = {}
131
+
132
+ # Temperature validation
133
+ temperature = kwargs.get('temperature', 0.7)
134
+ validated['temperature'] = max(0.0, min(1.0, float(temperature)))
135
+
136
+ # Max tokens validation
137
+ max_tokens = kwargs.get('max_tokens', 512)
138
+ validated['max_tokens'] = max(1, min(2048, int(max_tokens)))
139
+
140
+ # Top-p validation
141
+ top_p = kwargs.get('top_p', 0.9)
142
+ validated['top_p'] = max(0.0, min(1.0, float(top_p)))
143
+
144
+ # Top-k validation
145
+ top_k = kwargs.get('top_k', 50)
146
+ validated['top_k'] = max(1, int(top_k))
147
+
148
+ return validated
149
+
150
+ async def health_check(self) -> Dict[str, Any]:
151
+ """
152
+ Perform a health check on the model backend
153
+
154
+ Returns:
155
+ Dict containing health status
156
+ """
157
+ try:
158
+ if not self.is_loaded:
159
+ return {
160
+ "status": "unhealthy",
161
+ "reason": "model_not_loaded",
162
+ "model_name": self.model_name
163
+ }
164
+
165
+ # Try a simple generation to test the model
166
+ test_messages = [
167
+ ChatMessage(role="user", content="Hello")
168
+ ]
169
+
170
+ # Use a timeout for the health check
171
+ import asyncio
172
+ try:
173
+ response = await asyncio.wait_for(
174
+ self.generate_response(
175
+ test_messages,
176
+ temperature=0.1,
177
+ max_tokens=10
178
+ ),
179
+ timeout=10.0
180
+ )
181
+
182
+ return {
183
+ "status": "healthy",
184
+ "model_name": self.model_name,
185
+ "response_time": getattr(response, 'generation_time', 0.0)
186
+ }
187
+
188
+ except asyncio.TimeoutError:
189
+ return {
190
+ "status": "unhealthy",
191
+ "reason": "timeout",
192
+ "model_name": self.model_name
193
+ }
194
+
195
+ except Exception as e:
196
+ self.log_error("Health check failed", error=str(e), model=self.model_name)
197
+ return {
198
+ "status": "unhealthy",
199
+ "reason": "generation_error",
200
+ "error": str(e),
201
+ "model_name": self.model_name
202
+ }
203
+
204
+
205
+ class ModelBackendError(Exception):
206
+ """Base exception for model backend errors"""
207
+ pass
208
+
209
+
210
+ class ModelLoadError(ModelBackendError):
211
+ """Exception raised when model loading fails"""
212
+ pass
213
+
214
+
215
+ class GenerationError(ModelBackendError):
216
+ """Exception raised when text generation fails"""
217
+ pass
218
+
219
+
220
+ class ModelNotLoadedError(ModelBackendError):
221
+ """Exception raised when trying to use an unloaded model"""
222
+ pass
app/services/model_backends/google_api.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Google AI Studio API backend
3
+ Uses Google's AI Studio API for Gemma and other Google models
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import uuid
9
+ import json
10
+ from typing import AsyncGenerator, List, Dict, Any, Optional
11
+ from datetime import datetime
12
+ import httpx
13
+
14
+ from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
15
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
16
+ from ...core.config import settings
17
+
18
+
19
+ class GoogleAIBackend(ModelBackend):
20
+ """Google AI Studio API backend for Gemma and other Google models"""
21
+
22
+ def __init__(self, model_name: str, **kwargs):
23
+ super().__init__(model_name, **kwargs)
24
+ self.api_key = kwargs.get('api_key', settings.google_api_key)
25
+ self.base_url = "https://generativelanguage.googleapis.com/v1beta"
26
+ self.capabilities = ["chat", "streaming", "api_based"]
27
+
28
+ # Generation parameters
29
+ self.parameters = {
30
+ 'temperature': kwargs.get('temperature', settings.temperature),
31
+ 'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
32
+ 'top_p': kwargs.get('top_p', settings.top_p),
33
+ 'top_k': kwargs.get('top_k', settings.top_k),
34
+ }
35
+
36
+ async def load_model(self) -> bool:
37
+ """Initialize the Google AI API client"""
38
+ try:
39
+ if not self.api_key:
40
+ raise ModelLoadError("Google AI API key is required")
41
+
42
+ self.log_info("Initializing Google AI API client", model=self.model_name)
43
+
44
+ # Test the connection
45
+ await self._test_connection()
46
+
47
+ self.is_loaded = True
48
+ self.log_info("Google AI API client initialized successfully", model=self.model_name)
49
+
50
+ return True
51
+
52
+ except Exception as e:
53
+ self.log_error("Failed to initialize Google AI API client", error=str(e), model=self.model_name)
54
+ raise ModelLoadError(f"Failed to initialize Google AI API for {self.model_name}: {str(e)}")
55
+
56
+ async def unload_model(self) -> bool:
57
+ """Clean up the API client"""
58
+ try:
59
+ self.is_loaded = False
60
+ self.log_info("Google AI API client cleaned up", model=self.model_name)
61
+ return True
62
+
63
+ except Exception as e:
64
+ self.log_error("Failed to cleanup Google AI API client", error=str(e), model=self.model_name)
65
+ return False
66
+
67
+ async def _test_connection(self):
68
+ """Test the Google AI API connection"""
69
+ try:
70
+ url = f"{self.base_url}/models/{self.model_name}:generateContent"
71
+
72
+ test_data = {
73
+ "contents": [
74
+ {
75
+ "parts": [{"text": "Hello"}]
76
+ }
77
+ ],
78
+ "generationConfig": {
79
+ "maxOutputTokens": 5,
80
+ "temperature": 0.1
81
+ }
82
+ }
83
+
84
+ async with httpx.AsyncClient() as client:
85
+ response = await client.post(
86
+ f"{url}?key={self.api_key}",
87
+ headers={'Content-Type': 'application/json'},
88
+ json=test_data,
89
+ timeout=10.0
90
+ )
91
+
92
+ if response.status_code != 200:
93
+ raise Exception(f"API test failed with status {response.status_code}: {response.text}")
94
+
95
+ self.log_info("Google AI API connection test successful", model=self.model_name)
96
+
97
+ except Exception as e:
98
+ self.log_error("Google AI API connection test failed", error=str(e), model=self.model_name)
99
+ raise
100
+
101
+ def _format_messages_for_api(self, messages: List[ChatMessage]) -> Dict[str, Any]:
102
+ """Format messages for Google AI API"""
103
+ contents = []
104
+ system_instruction = None
105
+
106
+ for msg in messages:
107
+ if msg.role == "system":
108
+ system_instruction = msg.content
109
+ elif msg.role == "user":
110
+ contents.append({
111
+ "role": "user",
112
+ "parts": [{"text": msg.content}]
113
+ })
114
+ elif msg.role == "assistant":
115
+ contents.append({
116
+ "role": "model",
117
+ "parts": [{"text": msg.content}]
118
+ })
119
+
120
+ result = {"contents": contents}
121
+ if system_instruction:
122
+ result["systemInstruction"] = {"parts": [{"text": system_instruction}]}
123
+
124
+ return result
125
+
126
+ async def generate_response(
127
+ self,
128
+ messages: List[ChatMessage],
129
+ temperature: float = 0.7,
130
+ max_tokens: int = 512,
131
+ **kwargs
132
+ ) -> ChatResponse:
133
+ """Generate a complete response using Google AI API"""
134
+ if not self.is_loaded:
135
+ raise ModelNotLoadedError("Google AI API client not initialized")
136
+
137
+ start_time = time.time()
138
+ message_id = str(uuid.uuid4())
139
+
140
+ try:
141
+ # Validate parameters
142
+ params = self.validate_parameters(
143
+ temperature=temperature,
144
+ max_tokens=max_tokens,
145
+ **kwargs
146
+ )
147
+
148
+ # Format messages
149
+ api_data = self._format_messages_for_api(messages)
150
+
151
+ # Add generation config
152
+ api_data["generationConfig"] = {
153
+ "maxOutputTokens": params['max_tokens'],
154
+ "temperature": params['temperature'],
155
+ "topP": params.get('top_p', 0.9),
156
+ "topK": params.get('top_k', 40)
157
+ }
158
+
159
+ # Make API call
160
+ url = f"{self.base_url}/models/{self.model_name}:generateContent"
161
+
162
+ async with httpx.AsyncClient() as client:
163
+ response = await client.post(
164
+ f"{url}?key={self.api_key}",
165
+ headers={'Content-Type': 'application/json'},
166
+ json=api_data,
167
+ timeout=30.0
168
+ )
169
+
170
+ if response.status_code != 200:
171
+ raise GenerationError(f"API request failed with status {response.status_code}: {response.text}")
172
+
173
+ response_data = response.json()
174
+
175
+ # Extract response text
176
+ if 'candidates' in response_data and response_data['candidates']:
177
+ candidate = response_data['candidates'][0]
178
+ if 'content' in candidate and 'parts' in candidate['content']:
179
+ parts = candidate['content']['parts']
180
+ response_text = ''.join(part.get('text', '') for part in parts)
181
+ else:
182
+ response_text = str(response_data)
183
+ else:
184
+ response_text = str(response_data)
185
+
186
+ generation_time = time.time() - start_time
187
+
188
+ return ChatResponse(
189
+ message=response_text.strip(),
190
+ session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
191
+ message_id=message_id,
192
+ model_name=self.model_name,
193
+ generation_time=generation_time,
194
+ token_count=len(response_text.split()), # Approximate token count
195
+ finish_reason="stop"
196
+ )
197
+
198
+ except Exception as e:
199
+ self.log_error("Google AI API generation failed", error=str(e), model=self.model_name)
200
+ raise GenerationError(f"Failed to generate response via Google AI API: {str(e)}")
201
+
202
+ async def generate_stream(
203
+ self,
204
+ messages: List[ChatMessage],
205
+ temperature: float = 0.7,
206
+ max_tokens: int = 512,
207
+ **kwargs
208
+ ) -> AsyncGenerator[StreamChunk, None]:
209
+ """Generate a streaming response using Google AI API"""
210
+ if not self.is_loaded:
211
+ raise ModelNotLoadedError("Google AI API client not initialized")
212
+
213
+ message_id = str(uuid.uuid4())
214
+ session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
215
+ chunk_id = 0
216
+
217
+ try:
218
+ # Validate parameters
219
+ params = self.validate_parameters(
220
+ temperature=temperature,
221
+ max_tokens=max_tokens,
222
+ **kwargs
223
+ )
224
+
225
+ # Format messages
226
+ api_data = self._format_messages_for_api(messages)
227
+
228
+ # Add generation config
229
+ api_data["generationConfig"] = {
230
+ "maxOutputTokens": params['max_tokens'],
231
+ "temperature": params['temperature'],
232
+ "topP": params.get('top_p', 0.9),
233
+ "topK": params.get('top_k', 40)
234
+ }
235
+
236
+ # Make streaming API call
237
+ url = f"{self.base_url}/models/{self.model_name}:streamGenerateContent"
238
+
239
+ async with httpx.AsyncClient() as client:
240
+ async with client.stream(
241
+ 'POST',
242
+ f"{url}?key={self.api_key}",
243
+ headers={'Content-Type': 'application/json'},
244
+ json=api_data,
245
+ timeout=60.0
246
+ ) as response:
247
+
248
+ if response.status_code != 200:
249
+ raise GenerationError(f"Streaming request failed with status {response.status_code}")
250
+
251
+ async for line in response.aiter_lines():
252
+ if line.strip():
253
+ try:
254
+ # Google AI API returns JSON objects separated by newlines
255
+ data = json.loads(line)
256
+
257
+ if 'candidates' in data and data['candidates']:
258
+ candidate = data['candidates'][0]
259
+ if 'content' in candidate and 'parts' in candidate['content']:
260
+ parts = candidate['content']['parts']
261
+ content = ''.join(part.get('text', '') for part in parts)
262
+
263
+ if content:
264
+ yield StreamChunk(
265
+ content=content,
266
+ session_id=session_id,
267
+ message_id=message_id,
268
+ chunk_id=chunk_id,
269
+ is_final=False
270
+ )
271
+ chunk_id += 1
272
+
273
+ # Add small delay
274
+ await asyncio.sleep(settings.stream_delay)
275
+
276
+ except json.JSONDecodeError:
277
+ continue
278
+
279
+ # Send final chunk
280
+ yield StreamChunk(
281
+ content="",
282
+ session_id=session_id,
283
+ message_id=message_id,
284
+ chunk_id=chunk_id,
285
+ is_final=True
286
+ )
287
+
288
+ except Exception as e:
289
+ self.log_error("Google AI API streaming failed", error=str(e), model=self.model_name)
290
+ raise GenerationError(f"Failed to generate streaming response via Google AI API: {str(e)}")
291
+
292
+ def get_model_info(self) -> Dict[str, Any]:
293
+ """Get information about the current model"""
294
+ return {
295
+ "name": self.model_name,
296
+ "type": "google_ai",
297
+ "loaded": self.is_loaded,
298
+ "provider": "Google AI Studio",
299
+ "capabilities": self.capabilities,
300
+ "parameters": self.parameters,
301
+ "requires_api_key": True,
302
+ "api_key_configured": bool(self.api_key),
303
+ "base_url": self.base_url
304
+ }
app/services/model_backends/hf_api.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Inference API backend
3
+ Uses HuggingFace's hosted inference API for model access
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import uuid
9
+ import json
10
+ from typing import AsyncGenerator, List, Dict, Any, Optional
11
+ from datetime import datetime
12
+ import httpx
13
+ from huggingface_hub import InferenceClient
14
+
15
+ from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
16
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
17
+ from ...core.config import settings
18
+
19
+
20
+ class HuggingFaceAPIBackend(ModelBackend):
21
+ """HuggingFace Inference API backend"""
22
+
23
+ def __init__(self, model_name: str, **kwargs):
24
+ super().__init__(model_name, **kwargs)
25
+ self.client = None
26
+ self.api_token = kwargs.get('api_token', settings.hf_api_token)
27
+ self.inference_url = kwargs.get('inference_url', settings.hf_inference_url)
28
+ self.capabilities = ["chat", "streaming", "api_based"]
29
+
30
+ # Generation parameters
31
+ self.parameters = {
32
+ 'temperature': kwargs.get('temperature', settings.temperature),
33
+ 'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
34
+ 'top_p': kwargs.get('top_p', settings.top_p),
35
+ }
36
+
37
+ async def load_model(self) -> bool:
38
+ """Initialize the HuggingFace API client"""
39
+ try:
40
+ if not self.api_token:
41
+ raise ModelLoadError("HuggingFace API token is required")
42
+
43
+ self.log_info("Initializing HuggingFace API client", model=self.model_name)
44
+
45
+ # Initialize the inference client
46
+ self.client = InferenceClient(
47
+ model=self.model_name,
48
+ token=self.api_token
49
+ )
50
+
51
+ # Test the connection with a simple request
52
+ await self._test_connection()
53
+
54
+ self.is_loaded = True
55
+ self.log_info("HuggingFace API client initialized successfully", model=self.model_name)
56
+
57
+ return True
58
+
59
+ except Exception as e:
60
+ self.log_error("Failed to initialize HuggingFace API client", error=str(e), model=self.model_name)
61
+ raise ModelLoadError(f"Failed to initialize HuggingFace API for {self.model_name}: {str(e)}")
62
+
63
+ async def unload_model(self) -> bool:
64
+ """Clean up the API client"""
65
+ try:
66
+ self.client = None
67
+ self.is_loaded = False
68
+ self.log_info("HuggingFace API client cleaned up", model=self.model_name)
69
+ return True
70
+
71
+ except Exception as e:
72
+ self.log_error("Failed to cleanup API client", error=str(e), model=self.model_name)
73
+ return False
74
+
75
+ async def _test_connection(self):
76
+ """Test the API connection"""
77
+ try:
78
+ # Simple test message
79
+ test_messages = [{"role": "user", "content": "Hello"}]
80
+
81
+ # Use asyncio to run the sync client method
82
+ loop = asyncio.get_event_loop()
83
+ response = await loop.run_in_executor(
84
+ None,
85
+ lambda: self.client.chat_completion(
86
+ messages=test_messages,
87
+ max_tokens=10,
88
+ temperature=0.1
89
+ )
90
+ )
91
+
92
+ self.log_info("API connection test successful", model=self.model_name)
93
+
94
+ except Exception as e:
95
+ self.log_error("API connection test failed", error=str(e), model=self.model_name)
96
+ raise
97
+
98
+ def _format_messages_for_api(self, messages: List[ChatMessage]) -> List[Dict[str, str]]:
99
+ """Format messages for HuggingFace API"""
100
+ formatted = []
101
+ for msg in messages:
102
+ formatted.append({
103
+ "role": msg.role,
104
+ "content": msg.content
105
+ })
106
+ return formatted
107
+
108
+ async def generate_response(
109
+ self,
110
+ messages: List[ChatMessage],
111
+ temperature: float = 0.7,
112
+ max_tokens: int = 512,
113
+ **kwargs
114
+ ) -> ChatResponse:
115
+ """Generate a complete response using HuggingFace API"""
116
+ if not self.is_loaded:
117
+ raise ModelNotLoadedError("HuggingFace API client not initialized")
118
+
119
+ start_time = time.time()
120
+ message_id = str(uuid.uuid4())
121
+
122
+ try:
123
+ # Validate parameters
124
+ params = self.validate_parameters(
125
+ temperature=temperature,
126
+ max_tokens=max_tokens,
127
+ **kwargs
128
+ )
129
+
130
+ # Format messages
131
+ api_messages = self._format_messages_for_api(messages)
132
+
133
+ # Make API call
134
+ loop = asyncio.get_event_loop()
135
+ response = await loop.run_in_executor(
136
+ None,
137
+ lambda: self.client.chat_completion(
138
+ messages=api_messages,
139
+ max_tokens=params['max_tokens'],
140
+ temperature=params['temperature'],
141
+ top_p=params.get('top_p', 0.9),
142
+ stream=False
143
+ )
144
+ )
145
+
146
+ # Extract response text
147
+ if hasattr(response, 'choices') and response.choices:
148
+ response_text = response.choices[0].message.content
149
+ finish_reason = getattr(response.choices[0], 'finish_reason', 'stop')
150
+ else:
151
+ response_text = str(response)
152
+ finish_reason = 'unknown'
153
+
154
+ generation_time = time.time() - start_time
155
+
156
+ return ChatResponse(
157
+ message=response_text.strip(),
158
+ session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
159
+ message_id=message_id,
160
+ model_name=self.model_name,
161
+ generation_time=generation_time,
162
+ token_count=len(response_text.split()), # Approximate token count
163
+ finish_reason=finish_reason
164
+ )
165
+
166
+ except Exception as e:
167
+ self.log_error("HuggingFace API generation failed", error=str(e), model=self.model_name)
168
+ raise GenerationError(f"Failed to generate response via HuggingFace API: {str(e)}")
169
+
170
+ async def generate_stream(
171
+ self,
172
+ messages: List[ChatMessage],
173
+ temperature: float = 0.7,
174
+ max_tokens: int = 512,
175
+ **kwargs
176
+ ) -> AsyncGenerator[StreamChunk, None]:
177
+ """Generate a streaming response using HuggingFace API"""
178
+ if not self.is_loaded:
179
+ raise ModelNotLoadedError("HuggingFace API client not initialized")
180
+
181
+ message_id = str(uuid.uuid4())
182
+ session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
183
+ chunk_id = 0
184
+
185
+ try:
186
+ # Validate parameters
187
+ params = self.validate_parameters(
188
+ temperature=temperature,
189
+ max_tokens=max_tokens,
190
+ **kwargs
191
+ )
192
+
193
+ # Format messages
194
+ api_messages = self._format_messages_for_api(messages)
195
+
196
+ # Create streaming generator
197
+ loop = asyncio.get_event_loop()
198
+
199
+ def stream_generator():
200
+ return self.client.chat_completion(
201
+ messages=api_messages,
202
+ max_tokens=params['max_tokens'],
203
+ temperature=params['temperature'],
204
+ top_p=params.get('top_p', 0.9),
205
+ stream=True
206
+ )
207
+
208
+ # Get the streaming response
209
+ stream = await loop.run_in_executor(None, stream_generator)
210
+
211
+ # Process streaming chunks
212
+ for chunk in stream:
213
+ if hasattr(chunk, 'choices') and chunk.choices:
214
+ delta = chunk.choices[0].delta
215
+ if hasattr(delta, 'content') and delta.content:
216
+ yield StreamChunk(
217
+ content=delta.content,
218
+ session_id=session_id,
219
+ message_id=message_id,
220
+ chunk_id=chunk_id,
221
+ is_final=False
222
+ )
223
+ chunk_id += 1
224
+
225
+ # Add small delay to prevent overwhelming the client
226
+ await asyncio.sleep(settings.stream_delay)
227
+
228
+ # Send final chunk
229
+ yield StreamChunk(
230
+ content="",
231
+ session_id=session_id,
232
+ message_id=message_id,
233
+ chunk_id=chunk_id,
234
+ is_final=True
235
+ )
236
+
237
+ except Exception as e:
238
+ self.log_error("HuggingFace API streaming failed", error=str(e), model=self.model_name)
239
+ raise GenerationError(f"Failed to generate streaming response via HuggingFace API: {str(e)}")
240
+
241
+ def get_model_info(self) -> Dict[str, Any]:
242
+ """Get information about the current model"""
243
+ return {
244
+ "name": self.model_name,
245
+ "type": "huggingface_api",
246
+ "loaded": self.is_loaded,
247
+ "api_endpoint": self.inference_url,
248
+ "capabilities": self.capabilities,
249
+ "parameters": self.parameters,
250
+ "requires_token": True,
251
+ "token_configured": bool(self.api_token)
252
+ }
253
+
254
+ async def health_check(self) -> Dict[str, Any]:
255
+ """Perform a health check on the HuggingFace API"""
256
+ try:
257
+ if not self.is_loaded:
258
+ return {
259
+ "status": "unhealthy",
260
+ "reason": "client_not_initialized",
261
+ "model_name": self.model_name
262
+ }
263
+
264
+ # Test API connectivity
265
+ start_time = time.time()
266
+ test_messages = [ChatMessage(role="user", content="Hi")]
267
+
268
+ try:
269
+ response = await asyncio.wait_for(
270
+ self.generate_response(
271
+ test_messages,
272
+ temperature=0.1,
273
+ max_tokens=5
274
+ ),
275
+ timeout=15.0
276
+ )
277
+
278
+ response_time = time.time() - start_time
279
+
280
+ return {
281
+ "status": "healthy",
282
+ "model_name": self.model_name,
283
+ "response_time": response_time,
284
+ "api_endpoint": self.inference_url
285
+ }
286
+
287
+ except asyncio.TimeoutError:
288
+ return {
289
+ "status": "unhealthy",
290
+ "reason": "api_timeout",
291
+ "model_name": self.model_name,
292
+ "api_endpoint": self.inference_url
293
+ }
294
+
295
+ except Exception as e:
296
+ self.log_error("HuggingFace API health check failed", error=str(e), model=self.model_name)
297
+ return {
298
+ "status": "unhealthy",
299
+ "reason": "api_error",
300
+ "error": str(e),
301
+ "model_name": self.model_name,
302
+ "api_endpoint": self.inference_url
303
+ }
app/services/model_backends/local_hf.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Local HuggingFace model backend
3
+ Loads and runs models locally using transformers library
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import uuid
9
+ from typing import AsyncGenerator, List, Dict, Any, Optional
10
+ from datetime import datetime
11
+ import torch
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ TextIteratorStreamer,
16
+ GenerationConfig
17
+ )
18
+ from threading import Thread
19
+ from queue import Queue
20
+
21
+ from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
22
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
23
+ from ...core.config import settings
24
+
25
+
26
+ class LocalHuggingFaceBackend(ModelBackend):
27
+ """Local HuggingFace model backend using transformers"""
28
+
29
+ def __init__(self, model_name: str, **kwargs):
30
+ super().__init__(model_name, **kwargs)
31
+ self.tokenizer = None
32
+ self.model = None
33
+ self.device = kwargs.get('device', settings.device)
34
+ self.capabilities = ["chat", "streaming", "instruction_following"]
35
+
36
+ # Generation parameters
37
+ self.parameters = {
38
+ 'temperature': kwargs.get('temperature', settings.temperature),
39
+ 'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
40
+ 'top_p': kwargs.get('top_p', settings.top_p),
41
+ 'top_k': kwargs.get('top_k', settings.top_k),
42
+ }
43
+
44
+ async def load_model(self) -> bool:
45
+ """Load the HuggingFace model and tokenizer"""
46
+ try:
47
+ self.log_info("Loading local HuggingFace model", model=self.model_name)
48
+
49
+ # Determine device
50
+ if self.device == "auto":
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+ self.log_info("Using device", device=self.device)
54
+
55
+ # Load tokenizer
56
+ self.log_info("Loading tokenizer")
57
+ self.tokenizer = AutoTokenizer.from_pretrained(
58
+ self.model_name,
59
+ trust_remote_code=True,
60
+ padding_side="left"
61
+ )
62
+
63
+ # Add pad token if not present
64
+ if self.tokenizer.pad_token is None:
65
+ self.tokenizer.pad_token = self.tokenizer.eos_token
66
+
67
+ # Load model
68
+ self.log_info("Loading model")
69
+ self.model = AutoModelForCausalLM.from_pretrained(
70
+ self.model_name,
71
+ trust_remote_code=True,
72
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
73
+ device_map="auto" if self.device == "cuda" else None,
74
+ low_cpu_mem_usage=True
75
+ )
76
+
77
+ if self.device == "cpu":
78
+ self.model = self.model.to(self.device)
79
+
80
+ # Set model to evaluation mode
81
+ self.model.eval()
82
+
83
+ self.is_loaded = True
84
+ self.log_info("Model loaded successfully",
85
+ model=self.model_name,
86
+ device=self.device,
87
+ parameters=self.model.num_parameters() if hasattr(self.model, 'num_parameters') else 'unknown')
88
+
89
+ return True
90
+
91
+ except Exception as e:
92
+ self.log_error("Failed to load model", error=str(e), model=self.model_name)
93
+ raise ModelLoadError(f"Failed to load model {self.model_name}: {str(e)}")
94
+
95
+ async def unload_model(self) -> bool:
96
+ """Unload the model and free memory"""
97
+ try:
98
+ if self.model is not None:
99
+ del self.model
100
+ self.model = None
101
+
102
+ if self.tokenizer is not None:
103
+ del self.tokenizer
104
+ self.tokenizer = None
105
+
106
+ # Clear CUDA cache if using GPU
107
+ if torch.cuda.is_available():
108
+ torch.cuda.empty_cache()
109
+
110
+ self.is_loaded = False
111
+ self.log_info("Model unloaded successfully", model=self.model_name)
112
+ return True
113
+
114
+ except Exception as e:
115
+ self.log_error("Failed to unload model", error=str(e), model=self.model_name)
116
+ return False
117
+
118
+ def _prepare_chat_input(self, messages: List[ChatMessage]) -> str:
119
+ """Prepare chat messages for the model"""
120
+ if not self.tokenizer:
121
+ raise ModelNotLoadedError("Tokenizer not loaded")
122
+
123
+ # Check if tokenizer has chat template
124
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template:
125
+ # Use the model's chat template
126
+ formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
127
+ return self.tokenizer.apply_chat_template(
128
+ formatted_messages,
129
+ tokenize=False,
130
+ add_generation_prompt=True
131
+ )
132
+ else:
133
+ # Fallback to simple concatenation
134
+ chat_text = ""
135
+ for msg in messages:
136
+ if msg.role == "system":
137
+ chat_text += f"System: {msg.content}\n"
138
+ elif msg.role == "user":
139
+ chat_text += f"User: {msg.content}\n"
140
+ elif msg.role == "assistant":
141
+ chat_text += f"Assistant: {msg.content}\n"
142
+
143
+ chat_text += "Assistant: "
144
+ return chat_text
145
+
146
+ async def generate_response(
147
+ self,
148
+ messages: List[ChatMessage],
149
+ temperature: float = 0.7,
150
+ max_tokens: int = 512,
151
+ **kwargs
152
+ ) -> ChatResponse:
153
+ """Generate a complete response"""
154
+ if not self.is_loaded:
155
+ raise ModelNotLoadedError("Model not loaded")
156
+
157
+ start_time = time.time()
158
+ message_id = str(uuid.uuid4())
159
+
160
+ try:
161
+ # Validate parameters
162
+ params = self.validate_parameters(
163
+ temperature=temperature,
164
+ max_tokens=max_tokens,
165
+ **kwargs
166
+ )
167
+
168
+ # Prepare input
169
+ chat_input = self._prepare_chat_input(messages)
170
+
171
+ # Tokenize input
172
+ inputs = self.tokenizer(
173
+ chat_input,
174
+ return_tensors="pt",
175
+ padding=True,
176
+ truncation=True,
177
+ max_length=settings.max_length - params['max_tokens']
178
+ ).to(self.device)
179
+
180
+ # Generate response
181
+ with torch.no_grad():
182
+ outputs = self.model.generate(
183
+ **inputs,
184
+ max_new_tokens=params['max_tokens'],
185
+ temperature=params['temperature'],
186
+ top_p=params['top_p'],
187
+ top_k=params['top_k'],
188
+ do_sample=True,
189
+ pad_token_id=self.tokenizer.pad_token_id,
190
+ eos_token_id=self.tokenizer.eos_token_id,
191
+ repetition_penalty=1.1,
192
+ no_repeat_ngram_size=3
193
+ )
194
+
195
+ # Decode response
196
+ input_length = inputs['input_ids'].shape[1]
197
+ generated_tokens = outputs[0][input_length:]
198
+ response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
199
+
200
+ generation_time = time.time() - start_time
201
+
202
+ return ChatResponse(
203
+ message=response_text.strip(),
204
+ session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
205
+ message_id=message_id,
206
+ model_name=self.model_name,
207
+ generation_time=generation_time,
208
+ token_count=len(generated_tokens),
209
+ finish_reason="stop"
210
+ )
211
+
212
+ except Exception as e:
213
+ self.log_error("Generation failed", error=str(e), model=self.model_name)
214
+ raise GenerationError(f"Failed to generate response: {str(e)}")
215
+
216
+ async def generate_stream(
217
+ self,
218
+ messages: List[ChatMessage],
219
+ temperature: float = 0.7,
220
+ max_tokens: int = 512,
221
+ **kwargs
222
+ ) -> AsyncGenerator[StreamChunk, None]:
223
+ """Generate a streaming response"""
224
+ if not self.is_loaded:
225
+ raise ModelNotLoadedError("Model not loaded")
226
+
227
+ message_id = str(uuid.uuid4())
228
+ session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
229
+ chunk_id = 0
230
+
231
+ try:
232
+ # Validate parameters
233
+ params = self.validate_parameters(
234
+ temperature=temperature,
235
+ max_tokens=max_tokens,
236
+ **kwargs
237
+ )
238
+
239
+ # Prepare input
240
+ chat_input = self._prepare_chat_input(messages)
241
+
242
+ # Tokenize input
243
+ inputs = self.tokenizer(
244
+ chat_input,
245
+ return_tensors="pt",
246
+ padding=True,
247
+ truncation=True,
248
+ max_length=settings.max_length - params['max_tokens']
249
+ ).to(self.device)
250
+
251
+ # Create streamer
252
+ streamer = TextIteratorStreamer(
253
+ self.tokenizer,
254
+ skip_prompt=True,
255
+ skip_special_tokens=True
256
+ )
257
+
258
+ # Generation parameters
259
+ generation_kwargs = {
260
+ **inputs,
261
+ 'max_new_tokens': params['max_tokens'],
262
+ 'temperature': params['temperature'],
263
+ 'top_p': params['top_p'],
264
+ 'top_k': params['top_k'],
265
+ 'do_sample': True,
266
+ 'pad_token_id': self.tokenizer.pad_token_id,
267
+ 'eos_token_id': self.tokenizer.eos_token_id,
268
+ 'repetition_penalty': 1.1,
269
+ 'no_repeat_ngram_size': 3,
270
+ 'streamer': streamer
271
+ }
272
+
273
+ # Start generation in a separate thread
274
+ generation_thread = Thread(
275
+ target=self.model.generate,
276
+ kwargs=generation_kwargs
277
+ )
278
+ generation_thread.start()
279
+
280
+ # Stream the response
281
+ for chunk_text in streamer:
282
+ if chunk_text: # Skip empty chunks
283
+ yield StreamChunk(
284
+ content=chunk_text,
285
+ session_id=session_id,
286
+ message_id=message_id,
287
+ chunk_id=chunk_id,
288
+ is_final=False
289
+ )
290
+ chunk_id += 1
291
+
292
+ # Add small delay to prevent overwhelming the client
293
+ await asyncio.sleep(settings.stream_delay)
294
+
295
+ # Send final chunk
296
+ yield StreamChunk(
297
+ content="",
298
+ session_id=session_id,
299
+ message_id=message_id,
300
+ chunk_id=chunk_id,
301
+ is_final=True
302
+ )
303
+
304
+ # Wait for generation thread to complete
305
+ generation_thread.join()
306
+
307
+ except Exception as e:
308
+ self.log_error("Streaming generation failed", error=str(e), model=self.model_name)
309
+ raise GenerationError(f"Failed to generate streaming response: {str(e)}")
310
+
311
+ def get_model_info(self) -> Dict[str, Any]:
312
+ """Get information about the current model"""
313
+ info = {
314
+ "name": self.model_name,
315
+ "type": "local_huggingface",
316
+ "loaded": self.is_loaded,
317
+ "device": self.device,
318
+ "capabilities": self.capabilities,
319
+ "parameters": self.parameters
320
+ }
321
+
322
+ if self.model and hasattr(self.model, 'config'):
323
+ info["model_config"] = {
324
+ "vocab_size": getattr(self.model.config, 'vocab_size', None),
325
+ "hidden_size": getattr(self.model.config, 'hidden_size', None),
326
+ "num_layers": getattr(self.model.config, 'num_hidden_layers', None),
327
+ "num_attention_heads": getattr(self.model.config, 'num_attention_heads', None),
328
+ }
329
+
330
+ return info
app/services/model_backends/minimax_api.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniMax API backend
3
+ Uses MiniMax's API for their M1 model with reasoning capabilities
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import uuid
9
+ import json
10
+ from typing import AsyncGenerator, List, Dict, Any, Optional
11
+ from datetime import datetime
12
+ import httpx
13
+
14
+ from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
15
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
16
+ from ...core.config import settings
17
+
18
+
19
+ class MiniMaxAPIBackend(ModelBackend):
20
+ """MiniMax API backend for M1 model"""
21
+
22
+ def __init__(self, model_name: str, **kwargs):
23
+ super().__init__(model_name, **kwargs)
24
+ self.api_url = kwargs.get('api_url', settings.minimax_api_url)
25
+ self.api_key = kwargs.get('api_key', settings.minimax_api_key)
26
+ self.model_version = kwargs.get('model_version', settings.minimax_model_version)
27
+ self.capabilities = ["chat", "streaming", "reasoning", "api_based"]
28
+
29
+ # Generation parameters
30
+ self.parameters = {
31
+ 'temperature': kwargs.get('temperature', settings.temperature),
32
+ 'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
33
+ 'top_p': kwargs.get('top_p', settings.top_p),
34
+ }
35
+
36
+ async def load_model(self) -> bool:
37
+ """Initialize the MiniMax API client"""
38
+ try:
39
+ if not self.api_key or not self.api_url:
40
+ raise ModelLoadError("MiniMax API key and URL are required")
41
+
42
+ self.log_info("Initializing MiniMax API client", model=self.model_name)
43
+
44
+ # Test the connection
45
+ await self._test_connection()
46
+
47
+ self.is_loaded = True
48
+ self.log_info("MiniMax API client initialized successfully", model=self.model_name)
49
+
50
+ return True
51
+
52
+ except Exception as e:
53
+ self.log_error("Failed to initialize MiniMax API client", error=str(e), model=self.model_name)
54
+ raise ModelLoadError(f"Failed to initialize MiniMax API for {self.model_name}: {str(e)}")
55
+
56
+ async def unload_model(self) -> bool:
57
+ """Clean up the API client"""
58
+ try:
59
+ self.is_loaded = False
60
+ self.log_info("MiniMax API client cleaned up", model=self.model_name)
61
+ return True
62
+
63
+ except Exception as e:
64
+ self.log_error("Failed to cleanup MiniMax API client", error=str(e), model=self.model_name)
65
+ return False
66
+
67
+ async def _test_connection(self):
68
+ """Test the MiniMax API connection"""
69
+ try:
70
+ test_data = {
71
+ 'model': self.model_version,
72
+ 'messages': [{"role": "user", "content": "Hello"}],
73
+ 'stream': False,
74
+ 'max_tokens': 5,
75
+ 'temperature': 0.1
76
+ }
77
+
78
+ async with httpx.AsyncClient() as client:
79
+ response = await client.post(
80
+ self.api_url,
81
+ headers={
82
+ 'Content-Type': 'application/json',
83
+ 'Authorization': f'Bearer {self.api_key}'
84
+ },
85
+ json=test_data,
86
+ timeout=10.0
87
+ )
88
+
89
+ if response.status_code != 200:
90
+ raise Exception(f"API test failed with status {response.status_code}")
91
+
92
+ self.log_info("MiniMax API connection test successful", model=self.model_name)
93
+
94
+ except Exception as e:
95
+ self.log_error("MiniMax API connection test failed", error=str(e), model=self.model_name)
96
+ raise
97
+
98
+ def _format_messages_for_api(self, messages: List[ChatMessage]) -> List[Dict[str, str]]:
99
+ """Format messages for MiniMax API"""
100
+ formatted = []
101
+ for msg in messages:
102
+ formatted.append({
103
+ "role": msg.role,
104
+ "content": msg.content
105
+ })
106
+ return formatted
107
+
108
+ async def generate_response(
109
+ self,
110
+ messages: List[ChatMessage],
111
+ temperature: float = 0.7,
112
+ max_tokens: int = 512,
113
+ **kwargs
114
+ ) -> ChatResponse:
115
+ """Generate a complete response using MiniMax API"""
116
+ if not self.is_loaded:
117
+ raise ModelNotLoadedError("MiniMax API client not initialized")
118
+
119
+ start_time = time.time()
120
+ message_id = str(uuid.uuid4())
121
+
122
+ try:
123
+ # Validate parameters
124
+ params = self.validate_parameters(
125
+ temperature=temperature,
126
+ max_tokens=max_tokens,
127
+ **kwargs
128
+ )
129
+
130
+ # Format messages
131
+ api_messages = self._format_messages_for_api(messages)
132
+
133
+ # Prepare request data
134
+ request_data = {
135
+ 'model': self.model_version,
136
+ 'messages': api_messages,
137
+ 'stream': False,
138
+ 'max_tokens': params['max_tokens'],
139
+ 'temperature': params['temperature'],
140
+ 'top_p': params.get('top_p', 0.9)
141
+ }
142
+
143
+ # Make API call
144
+ async with httpx.AsyncClient() as client:
145
+ response = await client.post(
146
+ self.api_url,
147
+ headers={
148
+ 'Content-Type': 'application/json',
149
+ 'Authorization': f'Bearer {self.api_key}'
150
+ },
151
+ json=request_data,
152
+ timeout=30.0
153
+ )
154
+
155
+ if response.status_code != 200:
156
+ raise GenerationError(f"API request failed with status {response.status_code}")
157
+
158
+ response_data = response.json()
159
+
160
+ # Extract response text
161
+ if 'choices' in response_data and response_data['choices']:
162
+ choice = response_data['choices'][0]
163
+ if 'message' in choice:
164
+ response_text = choice['message'].get('content', '')
165
+ reasoning_content = choice['message'].get('reasoning_content', '')
166
+
167
+ # Combine reasoning and main content if both exist
168
+ if reasoning_content and response_text:
169
+ full_response = f"[Reasoning: {reasoning_content}]\n\n{response_text}"
170
+ else:
171
+ full_response = response_text or reasoning_content
172
+ else:
173
+ full_response = str(response_data)
174
+ else:
175
+ full_response = str(response_data)
176
+
177
+ generation_time = time.time() - start_time
178
+
179
+ return ChatResponse(
180
+ message=full_response.strip(),
181
+ session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
182
+ message_id=message_id,
183
+ model_name=self.model_name,
184
+ generation_time=generation_time,
185
+ token_count=len(full_response.split()), # Approximate token count
186
+ finish_reason="stop"
187
+ )
188
+
189
+ except Exception as e:
190
+ self.log_error("MiniMax API generation failed", error=str(e), model=self.model_name)
191
+ raise GenerationError(f"Failed to generate response via MiniMax API: {str(e)}")
192
+
193
+ async def generate_stream(
194
+ self,
195
+ messages: List[ChatMessage],
196
+ temperature: float = 0.7,
197
+ max_tokens: int = 512,
198
+ **kwargs
199
+ ) -> AsyncGenerator[StreamChunk, None]:
200
+ """Generate a streaming response using MiniMax API"""
201
+ if not self.is_loaded:
202
+ raise ModelNotLoadedError("MiniMax API client not initialized")
203
+
204
+ message_id = str(uuid.uuid4())
205
+ session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
206
+ chunk_id = 0
207
+
208
+ try:
209
+ # Validate parameters
210
+ params = self.validate_parameters(
211
+ temperature=temperature,
212
+ max_tokens=max_tokens,
213
+ **kwargs
214
+ )
215
+
216
+ # Format messages
217
+ api_messages = self._format_messages_for_api(messages)
218
+
219
+ # Prepare request data
220
+ request_data = {
221
+ 'model': self.model_version,
222
+ 'messages': api_messages,
223
+ 'stream': True,
224
+ 'max_tokens': params['max_tokens'],
225
+ 'temperature': params['temperature'],
226
+ 'top_p': params.get('top_p', 0.9)
227
+ }
228
+
229
+ # Make streaming API call
230
+ async with httpx.AsyncClient() as client:
231
+ async with client.stream(
232
+ 'POST',
233
+ self.api_url,
234
+ headers={
235
+ 'Content-Type': 'application/json',
236
+ 'Authorization': f'Bearer {self.api_key}'
237
+ },
238
+ json=request_data,
239
+ timeout=60.0
240
+ ) as response:
241
+
242
+ if response.status_code != 200:
243
+ raise GenerationError(f"Streaming request failed with status {response.status_code}")
244
+
245
+ async for line in response.aiter_lines():
246
+ if line.startswith('data:'):
247
+ try:
248
+ data = json.loads(line[5:]) # Remove 'data:' prefix
249
+
250
+ if 'choices' not in data:
251
+ continue
252
+
253
+ choice = data['choices'][0]
254
+
255
+ # Handle delta content
256
+ if 'delta' in choice:
257
+ delta = choice['delta']
258
+ reasoning_content = delta.get('reasoning_content', '')
259
+ content = delta.get('content', '')
260
+
261
+ # Send reasoning content if available
262
+ if reasoning_content:
263
+ yield StreamChunk(
264
+ content=f"[Thinking: {reasoning_content}]",
265
+ session_id=session_id,
266
+ message_id=message_id,
267
+ chunk_id=chunk_id,
268
+ is_final=False
269
+ )
270
+ chunk_id += 1
271
+
272
+ # Send main content
273
+ if content:
274
+ yield StreamChunk(
275
+ content=content,
276
+ session_id=session_id,
277
+ message_id=message_id,
278
+ chunk_id=chunk_id,
279
+ is_final=False
280
+ )
281
+ chunk_id += 1
282
+
283
+ # Handle complete message
284
+ elif 'message' in choice:
285
+ message_data = choice['message']
286
+ reasoning_content = message_data.get('reasoning_content', '')
287
+ main_content = message_data.get('content', '')
288
+
289
+ if reasoning_content:
290
+ yield StreamChunk(
291
+ content=f"\n[Final reasoning: {reasoning_content}]\n",
292
+ session_id=session_id,
293
+ message_id=message_id,
294
+ chunk_id=chunk_id,
295
+ is_final=False
296
+ )
297
+ chunk_id += 1
298
+
299
+ if main_content:
300
+ yield StreamChunk(
301
+ content=main_content,
302
+ session_id=session_id,
303
+ message_id=message_id,
304
+ chunk_id=chunk_id,
305
+ is_final=False
306
+ )
307
+ chunk_id += 1
308
+
309
+ except json.JSONDecodeError:
310
+ continue
311
+
312
+ # Add small delay
313
+ await asyncio.sleep(settings.stream_delay)
314
+
315
+ # Send final chunk
316
+ yield StreamChunk(
317
+ content="",
318
+ session_id=session_id,
319
+ message_id=message_id,
320
+ chunk_id=chunk_id,
321
+ is_final=True
322
+ )
323
+
324
+ except Exception as e:
325
+ self.log_error("MiniMax API streaming failed", error=str(e), model=self.model_name)
326
+ raise GenerationError(f"Failed to generate streaming response via MiniMax API: {str(e)}")
327
+
328
+ def get_model_info(self) -> Dict[str, Any]:
329
+ """Get information about the current model"""
330
+ return {
331
+ "name": self.model_name,
332
+ "type": "minimax_api",
333
+ "loaded": self.is_loaded,
334
+ "provider": "MiniMax",
335
+ "model_version": self.model_version,
336
+ "capabilities": self.capabilities,
337
+ "parameters": self.parameters,
338
+ "requires_api_key": True,
339
+ "api_key_configured": bool(self.api_key),
340
+ "api_url": self.api_url
341
+ }
app/services/model_backends/openai_api.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI API backend
3
+ Uses OpenAI's API for model access (GPT-3.5, GPT-4, etc.)
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import uuid
9
+ from typing import AsyncGenerator, List, Dict, Any, Optional
10
+ from datetime import datetime
11
+ import openai
12
+
13
+ from .base import ModelBackend, ModelLoadError, GenerationError, ModelNotLoadedError
14
+ from ...models.schemas import ChatMessage, ChatResponse, StreamChunk
15
+ from ...core.config import settings
16
+
17
+
18
+ class OpenAIAPIBackend(ModelBackend):
19
+ """OpenAI API backend"""
20
+
21
+ def __init__(self, model_name: str, **kwargs):
22
+ super().__init__(model_name, **kwargs)
23
+ self.client = None
24
+ self.api_key = kwargs.get('api_key', settings.openai_api_key)
25
+ self.org_id = kwargs.get('org_id', settings.openai_org_id)
26
+ self.capabilities = ["chat", "streaming", "api_based", "function_calling"]
27
+
28
+ # Generation parameters
29
+ self.parameters = {
30
+ 'temperature': kwargs.get('temperature', settings.temperature),
31
+ 'max_tokens': kwargs.get('max_tokens', settings.max_new_tokens),
32
+ 'top_p': kwargs.get('top_p', settings.top_p),
33
+ }
34
+
35
+ async def load_model(self) -> bool:
36
+ """Initialize the OpenAI API client"""
37
+ try:
38
+ if not self.api_key:
39
+ raise ModelLoadError("OpenAI API key is required")
40
+
41
+ self.log_info("Initializing OpenAI API client", model=self.model_name)
42
+
43
+ # Initialize the OpenAI client
44
+ self.client = openai.AsyncOpenAI(
45
+ api_key=self.api_key,
46
+ organization=self.org_id
47
+ )
48
+
49
+ # Test the connection
50
+ await self._test_connection()
51
+
52
+ self.is_loaded = True
53
+ self.log_info("OpenAI API client initialized successfully", model=self.model_name)
54
+
55
+ return True
56
+
57
+ except Exception as e:
58
+ self.log_error("Failed to initialize OpenAI API client", error=str(e), model=self.model_name)
59
+ raise ModelLoadError(f"Failed to initialize OpenAI API for {self.model_name}: {str(e)}")
60
+
61
+ async def unload_model(self) -> bool:
62
+ """Clean up the API client"""
63
+ try:
64
+ if self.client:
65
+ await self.client.close()
66
+ self.client = None
67
+ self.is_loaded = False
68
+ self.log_info("OpenAI API client cleaned up", model=self.model_name)
69
+ return True
70
+
71
+ except Exception as e:
72
+ self.log_error("Failed to cleanup OpenAI API client", error=str(e), model=self.model_name)
73
+ return False
74
+
75
+ async def _test_connection(self):
76
+ """Test the OpenAI API connection"""
77
+ try:
78
+ # Simple test request
79
+ response = await self.client.chat.completions.create(
80
+ model=self.model_name,
81
+ messages=[{"role": "user", "content": "Hello"}],
82
+ max_tokens=5,
83
+ temperature=0.1
84
+ )
85
+
86
+ self.log_info("OpenAI API connection test successful", model=self.model_name)
87
+
88
+ except Exception as e:
89
+ self.log_error("OpenAI API connection test failed", error=str(e), model=self.model_name)
90
+ raise
91
+
92
+ def _format_messages_for_api(self, messages: List[ChatMessage]) -> List[Dict[str, str]]:
93
+ """Format messages for OpenAI API"""
94
+ formatted = []
95
+ for msg in messages:
96
+ formatted.append({
97
+ "role": msg.role,
98
+ "content": msg.content
99
+ })
100
+ return formatted
101
+
102
+ async def generate_response(
103
+ self,
104
+ messages: List[ChatMessage],
105
+ temperature: float = 0.7,
106
+ max_tokens: int = 512,
107
+ **kwargs
108
+ ) -> ChatResponse:
109
+ """Generate a complete response using OpenAI API"""
110
+ if not self.is_loaded:
111
+ raise ModelNotLoadedError("OpenAI API client not initialized")
112
+
113
+ start_time = time.time()
114
+ message_id = str(uuid.uuid4())
115
+
116
+ try:
117
+ # Validate parameters
118
+ params = self.validate_parameters(
119
+ temperature=temperature,
120
+ max_tokens=max_tokens,
121
+ **kwargs
122
+ )
123
+
124
+ # Format messages
125
+ api_messages = self._format_messages_for_api(messages)
126
+
127
+ # Make API call
128
+ response = await self.client.chat.completions.create(
129
+ model=self.model_name,
130
+ messages=api_messages,
131
+ max_tokens=params['max_tokens'],
132
+ temperature=params['temperature'],
133
+ top_p=params.get('top_p', 0.9),
134
+ stream=False
135
+ )
136
+
137
+ # Extract response
138
+ response_text = response.choices[0].message.content
139
+ finish_reason = response.choices[0].finish_reason
140
+ token_count = response.usage.completion_tokens if response.usage else None
141
+
142
+ generation_time = time.time() - start_time
143
+
144
+ return ChatResponse(
145
+ message=response_text.strip(),
146
+ session_id=messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown',
147
+ message_id=message_id,
148
+ model_name=self.model_name,
149
+ generation_time=generation_time,
150
+ token_count=token_count,
151
+ finish_reason=finish_reason
152
+ )
153
+
154
+ except Exception as e:
155
+ self.log_error("OpenAI API generation failed", error=str(e), model=self.model_name)
156
+ raise GenerationError(f"Failed to generate response via OpenAI API: {str(e)}")
157
+
158
+ async def generate_stream(
159
+ self,
160
+ messages: List[ChatMessage],
161
+ temperature: float = 0.7,
162
+ max_tokens: int = 512,
163
+ **kwargs
164
+ ) -> AsyncGenerator[StreamChunk, None]:
165
+ """Generate a streaming response using OpenAI API"""
166
+ if not self.is_loaded:
167
+ raise ModelNotLoadedError("OpenAI API client not initialized")
168
+
169
+ message_id = str(uuid.uuid4())
170
+ session_id = messages[-1].metadata.get('session_id', 'unknown') if messages[-1].metadata else 'unknown'
171
+ chunk_id = 0
172
+
173
+ try:
174
+ # Validate parameters
175
+ params = self.validate_parameters(
176
+ temperature=temperature,
177
+ max_tokens=max_tokens,
178
+ **kwargs
179
+ )
180
+
181
+ # Format messages
182
+ api_messages = self._format_messages_for_api(messages)
183
+
184
+ # Create streaming request
185
+ stream = await self.client.chat.completions.create(
186
+ model=self.model_name,
187
+ messages=api_messages,
188
+ max_tokens=params['max_tokens'],
189
+ temperature=params['temperature'],
190
+ top_p=params.get('top_p', 0.9),
191
+ stream=True
192
+ )
193
+
194
+ # Process streaming chunks
195
+ async for chunk in stream:
196
+ if chunk.choices and chunk.choices[0].delta.content:
197
+ content = chunk.choices[0].delta.content
198
+
199
+ yield StreamChunk(
200
+ content=content,
201
+ session_id=session_id,
202
+ message_id=message_id,
203
+ chunk_id=chunk_id,
204
+ is_final=False
205
+ )
206
+ chunk_id += 1
207
+
208
+ # Add small delay to prevent overwhelming the client
209
+ await asyncio.sleep(settings.stream_delay)
210
+
211
+ # Check if this is the final chunk
212
+ if chunk.choices and chunk.choices[0].finish_reason:
213
+ break
214
+
215
+ # Send final chunk
216
+ yield StreamChunk(
217
+ content="",
218
+ session_id=session_id,
219
+ message_id=message_id,
220
+ chunk_id=chunk_id,
221
+ is_final=True
222
+ )
223
+
224
+ except Exception as e:
225
+ self.log_error("OpenAI API streaming failed", error=str(e), model=self.model_name)
226
+ raise GenerationError(f"Failed to generate streaming response via OpenAI API: {str(e)}")
227
+
228
+ def get_model_info(self) -> Dict[str, Any]:
229
+ """Get information about the current model"""
230
+ return {
231
+ "name": self.model_name,
232
+ "type": "openai_api",
233
+ "loaded": self.is_loaded,
234
+ "provider": "OpenAI",
235
+ "capabilities": self.capabilities,
236
+ "parameters": self.parameters,
237
+ "requires_api_key": True,
238
+ "api_key_configured": bool(self.api_key),
239
+ "organization": self.org_id
240
+ }
241
+
242
+ async def health_check(self) -> Dict[str, Any]:
243
+ """Perform a health check on the OpenAI API"""
244
+ try:
245
+ if not self.is_loaded:
246
+ return {
247
+ "status": "unhealthy",
248
+ "reason": "client_not_initialized",
249
+ "model_name": self.model_name
250
+ }
251
+
252
+ # Test API connectivity
253
+ start_time = time.time()
254
+ test_messages = [ChatMessage(role="user", content="Hi")]
255
+
256
+ try:
257
+ response = await asyncio.wait_for(
258
+ self.generate_response(
259
+ test_messages,
260
+ temperature=0.1,
261
+ max_tokens=5
262
+ ),
263
+ timeout=10.0
264
+ )
265
+
266
+ response_time = time.time() - start_time
267
+
268
+ return {
269
+ "status": "healthy",
270
+ "model_name": self.model_name,
271
+ "response_time": response_time,
272
+ "provider": "OpenAI"
273
+ }
274
+
275
+ except asyncio.TimeoutError:
276
+ return {
277
+ "status": "unhealthy",
278
+ "reason": "api_timeout",
279
+ "model_name": self.model_name,
280
+ "provider": "OpenAI"
281
+ }
282
+
283
+ except Exception as e:
284
+ self.log_error("OpenAI API health check failed", error=str(e), model=self.model_name)
285
+ return {
286
+ "status": "unhealthy",
287
+ "reason": "api_error",
288
+ "error": str(e),
289
+ "model_name": self.model_name,
290
+ "provider": "OpenAI"
291
+ }
app/services/model_manager.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Manager - Central hub for managing different model backends
3
+ Handles backend selection based on environment configuration
4
+ """
5
+
6
+ from typing import Optional, Dict, Any
7
+ from ..core.config import settings
8
+ from ..core.logging import LoggerMixin
9
+ from .model_backends.base import ModelBackend, ModelBackendError
10
+ from .model_backends.local_hf import LocalHuggingFaceBackend
11
+ from .model_backends.hf_api import HuggingFaceAPIBackend
12
+ from .model_backends.openai_api import OpenAIAPIBackend
13
+ from .model_backends.anthropic_api import AnthropicAPIBackend
14
+ from .model_backends.minimax_api import MiniMaxAPIBackend
15
+ from .model_backends.google_api import GoogleAIBackend
16
+
17
+
18
+ class ModelManager(LoggerMixin):
19
+ """
20
+ Central manager for model backends
21
+ Handles initialization, switching, and management of different model types
22
+ """
23
+
24
+ def __init__(self):
25
+ self.current_backend: Optional[ModelBackend] = None
26
+ self.backend_type = settings.model_type.lower()
27
+ self.model_name = settings.model_name
28
+ self.is_initialized = False
29
+
30
+ async def initialize(self) -> bool:
31
+ """
32
+ Initialize the model backend based on configuration
33
+
34
+ Returns:
35
+ bool: True if initialization successful, False otherwise
36
+ """
37
+ try:
38
+ self.log_info("Initializing model manager",
39
+ backend_type=self.backend_type,
40
+ model_name=self.model_name)
41
+
42
+ # Validate configuration
43
+ if not settings.validate_model_config():
44
+ self.log_error("Invalid model configuration", backend_type=self.backend_type)
45
+ return False
46
+
47
+ # Create the appropriate backend
48
+ backend = self._create_backend()
49
+ if not backend:
50
+ self.log_error("Failed to create backend", backend_type=self.backend_type)
51
+ return False
52
+
53
+ # Load the model
54
+ success = await backend.load_model()
55
+ if success:
56
+ self.current_backend = backend
57
+ self.is_initialized = True
58
+ self.log_info("Model manager initialized successfully",
59
+ backend_type=self.backend_type,
60
+ model_name=self.model_name)
61
+ return True
62
+ else:
63
+ self.log_error("Failed to load model", backend_type=self.backend_type)
64
+ return False
65
+
66
+ except Exception as e:
67
+ self.log_error("Model manager initialization failed",
68
+ error=str(e),
69
+ backend_type=self.backend_type)
70
+ return False
71
+
72
+ def _create_backend(self) -> Optional[ModelBackend]:
73
+ """Create the appropriate model backend based on configuration"""
74
+ try:
75
+ if self.backend_type == "local":
76
+ return LocalHuggingFaceBackend(
77
+ model_name=self.model_name,
78
+ device=settings.device,
79
+ temperature=settings.temperature,
80
+ max_tokens=settings.max_new_tokens,
81
+ top_p=settings.top_p,
82
+ top_k=settings.top_k
83
+ )
84
+
85
+ elif self.backend_type == "hf_api":
86
+ return HuggingFaceAPIBackend(
87
+ model_name=self.model_name,
88
+ api_token=settings.hf_api_token,
89
+ inference_url=settings.hf_inference_url,
90
+ temperature=settings.temperature,
91
+ max_tokens=settings.max_new_tokens,
92
+ top_p=settings.top_p
93
+ )
94
+
95
+ elif self.backend_type == "openai":
96
+ return OpenAIAPIBackend(
97
+ model_name=self.model_name,
98
+ api_key=settings.openai_api_key,
99
+ org_id=settings.openai_org_id,
100
+ temperature=settings.temperature,
101
+ max_tokens=settings.max_new_tokens,
102
+ top_p=settings.top_p
103
+ )
104
+
105
+ elif self.backend_type == "anthropic":
106
+ return AnthropicAPIBackend(
107
+ model_name=self.model_name,
108
+ api_key=settings.anthropic_api_key,
109
+ temperature=settings.temperature,
110
+ max_tokens=settings.max_new_tokens,
111
+ top_p=settings.top_p
112
+ )
113
+
114
+ elif self.backend_type == "minimax":
115
+ return MiniMaxAPIBackend(
116
+ model_name=self.model_name,
117
+ api_key=settings.minimax_api_key,
118
+ api_url=settings.minimax_api_url,
119
+ model_version=settings.minimax_model_version,
120
+ temperature=settings.temperature,
121
+ max_tokens=settings.max_new_tokens,
122
+ top_p=settings.top_p
123
+ )
124
+
125
+ elif self.backend_type == "google":
126
+ return GoogleAIBackend(
127
+ model_name=self.model_name,
128
+ api_key=settings.google_api_key,
129
+ temperature=settings.temperature,
130
+ max_tokens=settings.max_new_tokens,
131
+ top_p=settings.top_p,
132
+ top_k=settings.top_k
133
+ )
134
+
135
+ else:
136
+ self.log_error("Unsupported backend type", backend_type=self.backend_type)
137
+ return None
138
+
139
+ except Exception as e:
140
+ self.log_error("Failed to create backend", error=str(e), backend_type=self.backend_type)
141
+ return None
142
+
143
+ async def shutdown(self) -> bool:
144
+ """
145
+ Shutdown the current backend and cleanup resources
146
+
147
+ Returns:
148
+ bool: True if shutdown successful, False otherwise
149
+ """
150
+ try:
151
+ if self.current_backend:
152
+ success = await self.current_backend.unload_model()
153
+ self.current_backend = None
154
+ self.is_initialized = False
155
+ self.log_info("Model manager shutdown successfully")
156
+ return success
157
+ return True
158
+
159
+ except Exception as e:
160
+ self.log_error("Model manager shutdown failed", error=str(e))
161
+ return False
162
+
163
+ def get_backend(self) -> Optional[ModelBackend]:
164
+ """
165
+ Get the current model backend
166
+
167
+ Returns:
168
+ ModelBackend: Current backend instance or None if not initialized
169
+ """
170
+ return self.current_backend
171
+
172
+ def is_ready(self) -> bool:
173
+ """
174
+ Check if the model manager is ready for inference
175
+
176
+ Returns:
177
+ bool: True if ready, False otherwise
178
+ """
179
+ return (self.is_initialized and
180
+ self.current_backend is not None and
181
+ self.current_backend.is_model_loaded())
182
+
183
+ def get_model_info(self) -> Dict[str, Any]:
184
+ """
185
+ Get information about the current model
186
+
187
+ Returns:
188
+ Dict containing model information
189
+ """
190
+ if not self.current_backend:
191
+ return {
192
+ "status": "not_initialized",
193
+ "backend_type": self.backend_type,
194
+ "model_name": self.model_name,
195
+ "is_ready": False
196
+ }
197
+
198
+ info = self.current_backend.get_model_info()
199
+ info.update({
200
+ "is_ready": self.is_ready(),
201
+ "manager_initialized": self.is_initialized
202
+ })
203
+
204
+ return info
205
+
206
+ async def health_check(self) -> Dict[str, Any]:
207
+ """
208
+ Perform a comprehensive health check
209
+
210
+ Returns:
211
+ Dict containing health status
212
+ """
213
+ if not self.is_ready():
214
+ return {
215
+ "status": "unhealthy",
216
+ "reason": "manager_not_ready",
217
+ "backend_type": self.backend_type,
218
+ "model_name": self.model_name,
219
+ "is_initialized": self.is_initialized,
220
+ "backend_loaded": self.current_backend is not None
221
+ }
222
+
223
+ # Delegate to backend health check
224
+ backend_health = await self.current_backend.health_check()
225
+
226
+ # Add manager-level information
227
+ backend_health.update({
228
+ "manager_status": "healthy",
229
+ "backend_type": self.backend_type,
230
+ "is_ready": self.is_ready()
231
+ })
232
+
233
+ return backend_health
234
+
235
+ async def switch_model(self, new_model_name: str, new_backend_type: Optional[str] = None) -> bool:
236
+ """
237
+ Switch to a different model (and optionally backend type)
238
+
239
+ Args:
240
+ new_model_name: Name of the new model
241
+ new_backend_type: Optional new backend type
242
+
243
+ Returns:
244
+ bool: True if switch successful, False otherwise
245
+ """
246
+ try:
247
+ self.log_info("Switching model",
248
+ current_model=self.model_name,
249
+ new_model=new_model_name,
250
+ current_backend=self.backend_type,
251
+ new_backend=new_backend_type)
252
+
253
+ # Shutdown current backend
254
+ if self.current_backend:
255
+ await self.current_backend.unload_model()
256
+ self.current_backend = None
257
+
258
+ # Update configuration
259
+ old_model_name = self.model_name
260
+ old_backend_type = self.backend_type
261
+
262
+ self.model_name = new_model_name
263
+ if new_backend_type:
264
+ self.backend_type = new_backend_type.lower()
265
+
266
+ # Try to initialize new backend
267
+ success = await self.initialize()
268
+
269
+ if not success:
270
+ # Rollback on failure
271
+ self.log_warning("Model switch failed, rolling back",
272
+ failed_model=new_model_name,
273
+ rollback_model=old_model_name)
274
+
275
+ self.model_name = old_model_name
276
+ self.backend_type = old_backend_type
277
+ await self.initialize() # Try to restore previous state
278
+
279
+ return False
280
+
281
+ self.log_info("Model switch successful",
282
+ new_model=new_model_name,
283
+ new_backend=self.backend_type)
284
+ return True
285
+
286
+ except Exception as e:
287
+ self.log_error("Model switch failed", error=str(e))
288
+ return False
289
+
290
+ def get_supported_backends(self) -> Dict[str, Dict[str, Any]]:
291
+ """
292
+ Get information about supported backends
293
+
294
+ Returns:
295
+ Dict containing backend information
296
+ """
297
+ return {
298
+ "local": {
299
+ "name": "Local HuggingFace",
300
+ "description": "Run models locally using transformers",
301
+ "requires": ["model_name", "device"],
302
+ "capabilities": ["chat", "streaming", "offline"],
303
+ "example_models": [
304
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
305
+ "microsoft/DialoGPT-medium",
306
+ "Qwen/Qwen2.5-0.5B-Instruct"
307
+ ]
308
+ },
309
+ "hf_api": {
310
+ "name": "HuggingFace Inference API",
311
+ "description": "Use HuggingFace's hosted inference API",
312
+ "requires": ["model_name", "hf_api_token"],
313
+ "capabilities": ["chat", "streaming", "serverless"],
314
+ "example_models": [
315
+ "microsoft/DialoGPT-large",
316
+ "microsoft/phi-2",
317
+ "google/gemma-2b-it"
318
+ ]
319
+ },
320
+ "openai": {
321
+ "name": "OpenAI API",
322
+ "description": "Use OpenAI's GPT models",
323
+ "requires": ["model_name", "openai_api_key"],
324
+ "capabilities": ["chat", "streaming", "function_calling"],
325
+ "example_models": [
326
+ "gpt-3.5-turbo",
327
+ "gpt-4",
328
+ "gpt-4-turbo"
329
+ ]
330
+ },
331
+ "anthropic": {
332
+ "name": "Anthropic API",
333
+ "description": "Use Anthropic's Claude models",
334
+ "requires": ["model_name", "anthropic_api_key"],
335
+ "capabilities": ["chat", "streaming", "long_context"],
336
+ "example_models": [
337
+ "claude-3-haiku-20240307",
338
+ "claude-3-sonnet-20240229",
339
+ "claude-3-opus-20240229"
340
+ ]
341
+ },
342
+ "minimax": {
343
+ "name": "MiniMax API",
344
+ "description": "Use MiniMax's M1 model with reasoning capabilities",
345
+ "requires": ["model_name", "minimax_api_key", "minimax_api_url"],
346
+ "capabilities": ["chat", "streaming", "reasoning"],
347
+ "example_models": [
348
+ "MiniMax-M1"
349
+ ]
350
+ },
351
+ "google": {
352
+ "name": "Google AI Studio",
353
+ "description": "Use Google's Gemma and other models via AI Studio",
354
+ "requires": ["model_name", "google_api_key"],
355
+ "capabilities": ["chat", "streaming", "multimodal"],
356
+ "example_models": [
357
+ "gemini-1.5-flash",
358
+ "gemini-1.5-pro",
359
+ "gemma-2-9b-it",
360
+ "gemma-2-27b-it"
361
+ ]
362
+ }
363
+ }
364
+
365
+
366
+ # Global model manager instance
367
+ model_manager = ModelManager()
368
+
369
+
370
+ async def get_model_manager() -> ModelManager:
371
+ """Get the global model manager instance"""
372
+ return model_manager
373
+
374
+
375
+ async def initialize_model_manager() -> bool:
376
+ """Initialize the global model manager"""
377
+ return await model_manager.initialize()
378
+
379
+
380
+ async def shutdown_model_manager() -> bool:
381
+ """Shutdown the global model manager"""
382
+ return await model_manager.shutdown()
app/services/session_manager.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Session Manager - Handles conversation sessions and message history
3
+ Supports both in-memory and Redis-based storage
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ from typing import Dict, List, Optional, Any
9
+ from datetime import datetime, timedelta
10
+ import json
11
+ import uuid
12
+
13
+ from ..core.config import settings
14
+ from ..core.logging import LoggerMixin
15
+ from ..models.schemas import ChatMessage, ConversationHistory, SessionInfo
16
+
17
+
18
+ class SessionManager(LoggerMixin):
19
+ """
20
+ Manages chat sessions and conversation history
21
+ Supports both in-memory and Redis storage backends
22
+ """
23
+
24
+ def __init__(self):
25
+ self.sessions: Dict[str, ConversationHistory] = {}
26
+ self.redis_client = None
27
+ self.use_redis = bool(settings.redis_url)
28
+ self.session_timeout = settings.session_timeout * 60 # Convert to seconds
29
+ self.max_sessions_per_user = settings.max_sessions_per_user
30
+ self.max_messages_per_session = settings.max_messages_per_session
31
+
32
+ # Cleanup task
33
+ self._cleanup_task = None
34
+
35
+ async def initialize(self) -> bool:
36
+ """Initialize the session manager"""
37
+ try:
38
+ if self.use_redis:
39
+ await self._initialize_redis()
40
+
41
+ # Start cleanup task
42
+ self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
43
+
44
+ self.log_info("Session manager initialized",
45
+ storage_type="redis" if self.use_redis else "memory",
46
+ session_timeout=self.session_timeout)
47
+ return True
48
+
49
+ except Exception as e:
50
+ self.log_error("Failed to initialize session manager", error=str(e))
51
+ return False
52
+
53
+ async def _initialize_redis(self):
54
+ """Initialize Redis connection"""
55
+ try:
56
+ import redis.asyncio as redis
57
+ self.redis_client = redis.from_url(settings.redis_url)
58
+
59
+ # Test connection
60
+ await self.redis_client.ping()
61
+ self.log_info("Redis connection established", url=settings.redis_url)
62
+
63
+ except ImportError:
64
+ self.log_warning("Redis not available, falling back to memory storage")
65
+ self.use_redis = False
66
+ except Exception as e:
67
+ self.log_error("Redis connection failed", error=str(e))
68
+ self.use_redis = False
69
+
70
+ async def shutdown(self):
71
+ """Shutdown the session manager"""
72
+ try:
73
+ if self._cleanup_task:
74
+ self._cleanup_task.cancel()
75
+ try:
76
+ await self._cleanup_task
77
+ except asyncio.CancelledError:
78
+ pass
79
+
80
+ if self.redis_client:
81
+ await self.redis_client.close()
82
+
83
+ self.log_info("Session manager shutdown complete")
84
+
85
+ except Exception as e:
86
+ self.log_error("Session manager shutdown failed", error=str(e))
87
+
88
+ async def create_session(self, session_id: str, user_id: Optional[str] = None) -> bool:
89
+ """
90
+ Create a new chat session
91
+
92
+ Args:
93
+ session_id: Unique session identifier
94
+ user_id: Optional user identifier for session limits
95
+
96
+ Returns:
97
+ bool: True if session created successfully
98
+ """
99
+ try:
100
+ # Check if session already exists
101
+ if await self.session_exists(session_id):
102
+ self.log_info("Session already exists", session_id=session_id)
103
+ return True
104
+
105
+ # Check user session limits if user_id provided
106
+ if user_id and self.max_sessions_per_user > 0:
107
+ user_sessions = await self.get_user_sessions(user_id)
108
+ if len(user_sessions) >= self.max_sessions_per_user:
109
+ self.log_warning("User session limit exceeded",
110
+ user_id=user_id,
111
+ limit=self.max_sessions_per_user)
112
+ return False
113
+
114
+ # Create new session
115
+ session = ConversationHistory(
116
+ session_id=session_id,
117
+ messages=[],
118
+ created_at=datetime.utcnow(),
119
+ updated_at=datetime.utcnow(),
120
+ message_count=0
121
+ )
122
+
123
+ await self._store_session(session)
124
+
125
+ self.log_info("Session created", session_id=session_id, user_id=user_id)
126
+ return True
127
+
128
+ except Exception as e:
129
+ self.log_error("Failed to create session", error=str(e), session_id=session_id)
130
+ return False
131
+
132
+ async def add_message(self, session_id: str, message: ChatMessage) -> bool:
133
+ """
134
+ Add a message to a session
135
+
136
+ Args:
137
+ session_id: Session identifier
138
+ message: Message to add
139
+
140
+ Returns:
141
+ bool: True if message added successfully
142
+ """
143
+ try:
144
+ # Get or create session
145
+ session = await self.get_session(session_id)
146
+ if not session:
147
+ await self.create_session(session_id)
148
+ session = await self.get_session(session_id)
149
+
150
+ if not session:
151
+ self.log_error("Failed to create session", session_id=session_id)
152
+ return False
153
+
154
+ # Check message limit
155
+ if (self.max_messages_per_session > 0 and
156
+ len(session.messages) >= self.max_messages_per_session):
157
+ # Remove oldest messages to make room
158
+ messages_to_remove = len(session.messages) - self.max_messages_per_session + 1
159
+ session.messages = session.messages[messages_to_remove:]
160
+ self.log_info("Trimmed old messages",
161
+ session_id=session_id,
162
+ removed_count=messages_to_remove)
163
+
164
+ # Add message
165
+ session.messages.append(message)
166
+ session.message_count = len(session.messages)
167
+ session.updated_at = datetime.utcnow()
168
+
169
+ # Store updated session
170
+ await self._store_session(session)
171
+
172
+ self.log_debug("Message added to session",
173
+ session_id=session_id,
174
+ message_role=message.role,
175
+ total_messages=session.message_count)
176
+ return True
177
+
178
+ except Exception as e:
179
+ self.log_error("Failed to add message", error=str(e), session_id=session_id)
180
+ return False
181
+
182
+ async def get_session(self, session_id: str) -> Optional[ConversationHistory]:
183
+ """
184
+ Get a session by ID
185
+
186
+ Args:
187
+ session_id: Session identifier
188
+
189
+ Returns:
190
+ ConversationHistory or None if not found
191
+ """
192
+ try:
193
+ if self.use_redis:
194
+ return await self._get_session_from_redis(session_id)
195
+ else:
196
+ return self.sessions.get(session_id)
197
+
198
+ except Exception as e:
199
+ self.log_error("Failed to get session", error=str(e), session_id=session_id)
200
+ return None
201
+
202
+ async def session_exists(self, session_id: str) -> bool:
203
+ """Check if a session exists"""
204
+ session = await self.get_session(session_id)
205
+ return session is not None
206
+
207
+ async def delete_session(self, session_id: str) -> bool:
208
+ """
209
+ Delete a session
210
+
211
+ Args:
212
+ session_id: Session identifier
213
+
214
+ Returns:
215
+ bool: True if session deleted successfully
216
+ """
217
+ try:
218
+ if self.use_redis:
219
+ await self.redis_client.delete(f"session:{session_id}")
220
+ else:
221
+ self.sessions.pop(session_id, None)
222
+
223
+ self.log_info("Session deleted", session_id=session_id)
224
+ return True
225
+
226
+ except Exception as e:
227
+ self.log_error("Failed to delete session", error=str(e), session_id=session_id)
228
+ return False
229
+
230
+ async def get_session_messages(self, session_id: str, limit: Optional[int] = None) -> List[ChatMessage]:
231
+ """
232
+ Get messages from a session
233
+
234
+ Args:
235
+ session_id: Session identifier
236
+ limit: Optional limit on number of messages to return
237
+
238
+ Returns:
239
+ List of ChatMessage objects
240
+ """
241
+ session = await self.get_session(session_id)
242
+ if not session:
243
+ return []
244
+
245
+ messages = session.messages
246
+ if limit and limit > 0:
247
+ messages = messages[-limit:] # Get last N messages
248
+
249
+ return messages
250
+
251
+ async def get_active_sessions(self) -> List[SessionInfo]:
252
+ """Get information about all active sessions"""
253
+ try:
254
+ sessions = []
255
+
256
+ if self.use_redis:
257
+ # Get all session keys from Redis
258
+ keys = await self.redis_client.keys("session:*")
259
+ for key in keys:
260
+ session_id = key.decode().replace("session:", "")
261
+ session = await self.get_session(session_id)
262
+ if session:
263
+ sessions.append(self._session_to_info(session))
264
+ else:
265
+ # Get from memory
266
+ for session in self.sessions.values():
267
+ sessions.append(self._session_to_info(session))
268
+
269
+ return sessions
270
+
271
+ except Exception as e:
272
+ self.log_error("Failed to get active sessions", error=str(e))
273
+ return []
274
+
275
+ async def get_user_sessions(self, user_id: str) -> List[SessionInfo]:
276
+ """Get sessions for a specific user (requires user_id in session metadata)"""
277
+ # This is a simplified implementation
278
+ # In a real system, you'd store user_id -> session_id mappings
279
+ all_sessions = await self.get_active_sessions()
280
+ return [s for s in all_sessions if s.session_id.startswith(f"{user_id}-")]
281
+
282
+ def _session_to_info(self, session: ConversationHistory) -> SessionInfo:
283
+ """Convert ConversationHistory to SessionInfo"""
284
+ return SessionInfo(
285
+ session_id=session.session_id,
286
+ created_at=session.created_at,
287
+ updated_at=session.updated_at,
288
+ message_count=session.message_count,
289
+ model_name=settings.model_name, # Current model
290
+ is_active=True
291
+ )
292
+
293
+ async def _store_session(self, session: ConversationHistory):
294
+ """Store session in the appropriate backend"""
295
+ if self.use_redis:
296
+ await self._store_session_in_redis(session)
297
+ else:
298
+ self.sessions[session.session_id] = session
299
+
300
+ async def _store_session_in_redis(self, session: ConversationHistory):
301
+ """Store session in Redis"""
302
+ key = f"session:{session.session_id}"
303
+ data = {
304
+ "session_id": session.session_id,
305
+ "messages": [
306
+ {
307
+ "role": msg.role,
308
+ "content": msg.content,
309
+ "timestamp": msg.timestamp.isoformat(),
310
+ "metadata": msg.metadata or {}
311
+ }
312
+ for msg in session.messages
313
+ ],
314
+ "created_at": session.created_at.isoformat(),
315
+ "updated_at": session.updated_at.isoformat(),
316
+ "message_count": session.message_count
317
+ }
318
+
319
+ await self.redis_client.setex(
320
+ key,
321
+ self.session_timeout,
322
+ json.dumps(data, default=str)
323
+ )
324
+
325
+ async def _get_session_from_redis(self, session_id: str) -> Optional[ConversationHistory]:
326
+ """Get session from Redis"""
327
+ key = f"session:{session_id}"
328
+ data = await self.redis_client.get(key)
329
+
330
+ if not data:
331
+ return None
332
+
333
+ try:
334
+ session_data = json.loads(data)
335
+ messages = [
336
+ ChatMessage(
337
+ role=msg["role"],
338
+ content=msg["content"],
339
+ timestamp=datetime.fromisoformat(msg["timestamp"]),
340
+ metadata=msg.get("metadata")
341
+ )
342
+ for msg in session_data["messages"]
343
+ ]
344
+
345
+ return ConversationHistory(
346
+ session_id=session_data["session_id"],
347
+ messages=messages,
348
+ created_at=datetime.fromisoformat(session_data["created_at"]),
349
+ updated_at=datetime.fromisoformat(session_data["updated_at"]),
350
+ message_count=session_data["message_count"]
351
+ )
352
+
353
+ except Exception as e:
354
+ self.log_error("Failed to parse session from Redis", error=str(e), session_id=session_id)
355
+ return None
356
+
357
+ async def _cleanup_expired_sessions(self):
358
+ """Background task to cleanup expired sessions"""
359
+ while True:
360
+ try:
361
+ await asyncio.sleep(300) # Run every 5 minutes
362
+
363
+ if not self.use_redis: # Redis handles expiration automatically
364
+ current_time = datetime.utcnow()
365
+ expired_sessions = []
366
+
367
+ for session_id, session in self.sessions.items():
368
+ if (current_time - session.updated_at).total_seconds() > self.session_timeout:
369
+ expired_sessions.append(session_id)
370
+
371
+ for session_id in expired_sessions:
372
+ del self.sessions[session_id]
373
+ self.log_debug("Expired session cleaned up", session_id=session_id)
374
+
375
+ if expired_sessions:
376
+ self.log_info("Cleaned up expired sessions", count=len(expired_sessions))
377
+
378
+ except asyncio.CancelledError:
379
+ break
380
+ except Exception as e:
381
+ self.log_error("Session cleanup failed", error=str(e))
382
+
383
+
384
+ # Global session manager instance
385
+ session_manager = SessionManager()
386
+
387
+
388
+ async def get_session_manager() -> SessionManager:
389
+ """Get the global session manager instance"""
390
+ return session_manager
391
+
392
+
393
+ async def initialize_session_manager() -> bool:
394
+ """Initialize the global session manager"""
395
+ return await session_manager.initialize()
396
+
397
+
398
+ async def shutdown_session_manager():
399
+ """Shutdown the global session manager"""
400
+ await session_manager.shutdown()
app/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utilities package
app/utils/helpers.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions and helpers
3
+ """
4
+
5
+ import re
6
+ import uuid
7
+ import hashlib
8
+ from typing import Optional, Dict, Any, List
9
+ from datetime import datetime, timezone
10
+
11
+
12
+ def generate_session_id(user_id: Optional[str] = None) -> str:
13
+ """
14
+ Generate a unique session ID
15
+
16
+ Args:
17
+ user_id: Optional user identifier to include in session ID
18
+
19
+ Returns:
20
+ Unique session identifier
21
+ """
22
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
23
+ random_part = str(uuid.uuid4())[:8]
24
+
25
+ if user_id:
26
+ # Create a hash of user_id for privacy
27
+ user_hash = hashlib.md5(user_id.encode()).hexdigest()[:8]
28
+ return f"{user_hash}-{timestamp}-{random_part}"
29
+ else:
30
+ return f"anon-{timestamp}-{random_part}"
31
+
32
+
33
+ def generate_message_id() -> str:
34
+ """Generate a unique message ID"""
35
+ return f"msg-{uuid.uuid4()}"
36
+
37
+
38
+ def sanitize_text(text: str, max_length: int = 4000) -> str:
39
+ """
40
+ Sanitize and clean text input
41
+
42
+ Args:
43
+ text: Input text to sanitize
44
+ max_length: Maximum allowed length
45
+
46
+ Returns:
47
+ Sanitized text
48
+ """
49
+ if not text:
50
+ return ""
51
+
52
+ # Remove excessive whitespace
53
+ text = re.sub(r'\s+', ' ', text.strip())
54
+
55
+ # Truncate if too long
56
+ if len(text) > max_length:
57
+ text = text[:max_length].rsplit(' ', 1)[0] + "..."
58
+
59
+ return text
60
+
61
+
62
+ def format_timestamp(dt: datetime) -> str:
63
+ """
64
+ Format datetime for consistent display
65
+
66
+ Args:
67
+ dt: Datetime object
68
+
69
+ Returns:
70
+ Formatted timestamp string
71
+ """
72
+ return dt.strftime("%Y-%m-%d %H:%M:%S UTC")
73
+
74
+
75
+ def estimate_tokens(text: str) -> int:
76
+ """
77
+ Rough estimation of token count for text
78
+
79
+ Args:
80
+ text: Input text
81
+
82
+ Returns:
83
+ Estimated token count
84
+ """
85
+ # Very rough estimation: ~4 characters per token on average
86
+ return max(1, len(text) // 4)
87
+
88
+
89
+ def truncate_conversation_history(
90
+ messages: List[Dict[str, Any]],
91
+ max_tokens: int = 2000
92
+ ) -> List[Dict[str, Any]]:
93
+ """
94
+ Truncate conversation history to fit within token limit
95
+
96
+ Args:
97
+ messages: List of message dictionaries
98
+ max_tokens: Maximum token limit
99
+
100
+ Returns:
101
+ Truncated list of messages
102
+ """
103
+ if not messages:
104
+ return messages
105
+
106
+ # Always keep system message if present
107
+ system_messages = [msg for msg in messages if msg.get("role") == "system"]
108
+ other_messages = [msg for msg in messages if msg.get("role") != "system"]
109
+
110
+ # Estimate tokens for system messages
111
+ system_tokens = sum(estimate_tokens(msg.get("content", "")) for msg in system_messages)
112
+ available_tokens = max_tokens - system_tokens
113
+
114
+ if available_tokens <= 0:
115
+ return system_messages
116
+
117
+ # Add messages from the end (most recent first) until we hit the limit
118
+ selected_messages = []
119
+ current_tokens = 0
120
+
121
+ for msg in reversed(other_messages):
122
+ msg_tokens = estimate_tokens(msg.get("content", ""))
123
+ if current_tokens + msg_tokens <= available_tokens:
124
+ selected_messages.insert(0, msg)
125
+ current_tokens += msg_tokens
126
+ else:
127
+ break
128
+
129
+ return system_messages + selected_messages
130
+
131
+
132
+ def validate_session_id(session_id: str) -> bool:
133
+ """
134
+ Validate session ID format
135
+
136
+ Args:
137
+ session_id: Session identifier to validate
138
+
139
+ Returns:
140
+ True if valid, False otherwise
141
+ """
142
+ if not session_id or len(session_id) < 5 or len(session_id) > 100:
143
+ return False
144
+
145
+ # Allow alphanumeric, hyphens, and underscores
146
+ return bool(re.match(r'^[a-zA-Z0-9_-]+$', session_id))
147
+
148
+
149
+ def extract_model_name_from_path(model_path: str) -> str:
150
+ """
151
+ Extract clean model name from HuggingFace model path
152
+
153
+ Args:
154
+ model_path: Full model path (e.g., "microsoft/DialoGPT-medium")
155
+
156
+ Returns:
157
+ Clean model name
158
+ """
159
+ if "/" in model_path:
160
+ return model_path.split("/")[-1]
161
+ return model_path
162
+
163
+
164
+ def format_model_info(model_info: Dict[str, Any]) -> Dict[str, Any]:
165
+ """
166
+ Format model information for API responses
167
+
168
+ Args:
169
+ model_info: Raw model information
170
+
171
+ Returns:
172
+ Formatted model information
173
+ """
174
+ formatted = {
175
+ "name": model_info.get("name", "unknown"),
176
+ "type": model_info.get("type", "unknown"),
177
+ "loaded": model_info.get("loaded", False),
178
+ "capabilities": model_info.get("capabilities", []),
179
+ }
180
+
181
+ # Add backend-specific information
182
+ if "device" in model_info:
183
+ formatted["device"] = model_info["device"]
184
+
185
+ if "provider" in model_info:
186
+ formatted["provider"] = model_info["provider"]
187
+
188
+ if "parameters" in model_info:
189
+ formatted["parameters"] = model_info["parameters"]
190
+
191
+ return formatted
192
+
193
+
194
+ def create_error_response(
195
+ error_type: str,
196
+ message: str,
197
+ details: Optional[Dict[str, Any]] = None,
198
+ request_id: Optional[str] = None
199
+ ) -> Dict[str, Any]:
200
+ """
201
+ Create standardized error response
202
+
203
+ Args:
204
+ error_type: Type of error
205
+ message: Error message
206
+ details: Optional additional details
207
+ request_id: Optional request identifier
208
+
209
+ Returns:
210
+ Formatted error response
211
+ """
212
+ return {
213
+ "error": error_type,
214
+ "message": message,
215
+ "details": details or {},
216
+ "timestamp": datetime.utcnow().isoformat(),
217
+ "request_id": request_id or generate_message_id()
218
+ }
219
+
220
+
221
+ def parse_model_backend_from_name(model_name: str) -> str:
222
+ """
223
+ Guess the appropriate backend type from model name
224
+
225
+ Args:
226
+ model_name: Model name or path
227
+
228
+ Returns:
229
+ Suggested backend type
230
+ """
231
+ model_lower = model_name.lower()
232
+
233
+ if "gpt" in model_lower and ("3.5" in model_lower or "4" in model_lower):
234
+ return "openai"
235
+ elif "claude" in model_lower:
236
+ return "anthropic"
237
+ elif any(provider in model_lower for provider in ["microsoft", "google", "meta", "huggingface"]):
238
+ return "hf_api" # Likely available via HF API
239
+ else:
240
+ return "local" # Default to local
241
+
242
+
243
+ def get_supported_model_examples() -> Dict[str, List[str]]:
244
+ """
245
+ Get examples of supported models for each backend type
246
+
247
+ Returns:
248
+ Dictionary mapping backend types to example models
249
+ """
250
+ return {
251
+ "local": [
252
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
253
+ "microsoft/DialoGPT-medium",
254
+ "Qwen/Qwen2.5-0.5B-Instruct",
255
+ "microsoft/phi-2"
256
+ ],
257
+ "hf_api": [
258
+ "microsoft/DialoGPT-large",
259
+ "google/gemma-2b-it",
260
+ "microsoft/phi-2",
261
+ "meta-llama/Llama-2-7b-chat-hf"
262
+ ],
263
+ "openai": [
264
+ "gpt-3.5-turbo",
265
+ "gpt-4",
266
+ "gpt-4-turbo",
267
+ "gpt-4o"
268
+ ],
269
+ "anthropic": [
270
+ "claude-3-haiku-20240307",
271
+ "claude-3-sonnet-20240229",
272
+ "claude-3-opus-20240229",
273
+ "claude-3-5-sonnet-20241022"
274
+ ]
275
+ }
276
+
277
+
278
+ def calculate_response_metrics(
279
+ start_time: float,
280
+ response_text: str,
281
+ token_count: Optional[int] = None
282
+ ) -> Dict[str, Any]:
283
+ """
284
+ Calculate response metrics for monitoring
285
+
286
+ Args:
287
+ start_time: Request start time
288
+ response_text: Generated response text
289
+ token_count: Actual token count if available
290
+
291
+ Returns:
292
+ Dictionary of metrics
293
+ """
294
+ import time
295
+
296
+ end_time = time.time()
297
+ total_time = end_time - start_time
298
+
299
+ estimated_tokens = token_count or estimate_tokens(response_text)
300
+ tokens_per_second = estimated_tokens / total_time if total_time > 0 else 0
301
+
302
+ return {
303
+ "total_time": total_time,
304
+ "character_count": len(response_text),
305
+ "estimated_tokens": estimated_tokens,
306
+ "actual_tokens": token_count,
307
+ "tokens_per_second": tokens_per_second,
308
+ "words_count": len(response_text.split())
309
+ }
examples/test_backends.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script to test different model backends
3
+ Demonstrates how to configure and use various model types
4
+ """
5
+
6
+ import os
7
+ import asyncio
8
+ import sys
9
+ import time
10
+ from pathlib import Path
11
+
12
+ # Add the app directory to the Python path
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ from app.core.config import Settings
16
+ from app.services.model_backends.local_hf import LocalHuggingFaceBackend
17
+ from app.services.model_backends.hf_api import HuggingFaceAPIBackend
18
+ from app.services.model_backends.openai_api import OpenAIAPIBackend
19
+ from app.services.model_backends.anthropic_api import AnthropicAPIBackend
20
+ from app.models.schemas import ChatMessage
21
+
22
+
23
+ async def test_local_hf_backend():
24
+ """Test local HuggingFace backend"""
25
+ print("πŸ€– Testing Local HuggingFace Backend")
26
+ print("-" * 40)
27
+
28
+ # Use a small model for testing
29
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
30
+
31
+ backend = LocalHuggingFaceBackend(
32
+ model_name=model_name,
33
+ device="cpu", # Use CPU for compatibility
34
+ temperature=0.7,
35
+ max_tokens=50
36
+ )
37
+
38
+ try:
39
+ print(f"Loading model: {model_name}")
40
+ success = await backend.load_model()
41
+
42
+ if not success:
43
+ print("❌ Failed to load model")
44
+ return False
45
+
46
+ print("βœ… Model loaded successfully")
47
+
48
+ # Test generation
49
+ messages = [
50
+ ChatMessage(role="user", content="Hello! What's your name?")
51
+ ]
52
+
53
+ print("Generating response...")
54
+ start_time = time.time()
55
+ response = await backend.generate_response(messages, max_tokens=30)
56
+ end_time = time.time()
57
+
58
+ print(f"βœ… Response generated in {end_time - start_time:.2f}s")
59
+ print(f"Response: {response.message}")
60
+
61
+ # Test streaming
62
+ print("\nTesting streaming...")
63
+ full_response = ""
64
+ chunk_count = 0
65
+
66
+ async for chunk in backend.generate_stream(messages, max_tokens=30):
67
+ full_response += chunk.content
68
+ chunk_count += 1
69
+ if chunk.is_final:
70
+ break
71
+
72
+ print(f"βœ… Streaming completed with {chunk_count} chunks")
73
+ print(f"Streamed response: {full_response}")
74
+
75
+ # Cleanup
76
+ await backend.unload_model()
77
+ print("βœ… Model unloaded")
78
+
79
+ return True
80
+
81
+ except Exception as e:
82
+ print(f"❌ Local HF backend test failed: {e}")
83
+ return False
84
+
85
+
86
+ async def test_hf_api_backend():
87
+ """Test HuggingFace API backend"""
88
+ print("\n🌐 Testing HuggingFace API Backend")
89
+ print("-" * 40)
90
+
91
+ # Check if API token is available
92
+ api_token = os.getenv("HF_API_TOKEN")
93
+ if not api_token:
94
+ print("⚠️ HF_API_TOKEN not set, skipping HF API test")
95
+ return True
96
+
97
+ model_name = "microsoft/DialoGPT-medium"
98
+
99
+ backend = HuggingFaceAPIBackend(
100
+ model_name=model_name,
101
+ api_token=api_token,
102
+ temperature=0.7,
103
+ max_tokens=50
104
+ )
105
+
106
+ try:
107
+ print(f"Initializing API client for: {model_name}")
108
+ success = await backend.load_model()
109
+
110
+ if not success:
111
+ print("❌ Failed to initialize API client")
112
+ return False
113
+
114
+ print("βœ… API client initialized")
115
+
116
+ # Test generation
117
+ messages = [
118
+ ChatMessage(role="user", content="Hello! How are you?")
119
+ ]
120
+
121
+ print("Generating response via API...")
122
+ start_time = time.time()
123
+ response = await backend.generate_response(messages, max_tokens=30)
124
+ end_time = time.time()
125
+
126
+ print(f"βœ… Response generated in {end_time - start_time:.2f}s")
127
+ print(f"Response: {response.message}")
128
+
129
+ return True
130
+
131
+ except Exception as e:
132
+ print(f"❌ HF API backend test failed: {e}")
133
+ return False
134
+
135
+
136
+ async def test_openai_backend():
137
+ """Test OpenAI API backend"""
138
+ print("\nπŸ”₯ Testing OpenAI API Backend")
139
+ print("-" * 40)
140
+
141
+ # Check if API key is available
142
+ api_key = os.getenv("OPENAI_API_KEY")
143
+ if not api_key:
144
+ print("⚠️ OPENAI_API_KEY not set, skipping OpenAI test")
145
+ return True
146
+
147
+ model_name = "gpt-3.5-turbo"
148
+
149
+ backend = OpenAIAPIBackend(
150
+ model_name=model_name,
151
+ api_key=api_key,
152
+ temperature=0.7,
153
+ max_tokens=50
154
+ )
155
+
156
+ try:
157
+ print(f"Initializing OpenAI client for: {model_name}")
158
+ success = await backend.load_model()
159
+
160
+ if not success:
161
+ print("❌ Failed to initialize OpenAI client")
162
+ return False
163
+
164
+ print("βœ… OpenAI client initialized")
165
+
166
+ # Test generation
167
+ messages = [
168
+ ChatMessage(role="user", content="Hello! What's the weather like?")
169
+ ]
170
+
171
+ print("Generating response via OpenAI...")
172
+ start_time = time.time()
173
+ response = await backend.generate_response(messages, max_tokens=30)
174
+ end_time = time.time()
175
+
176
+ print(f"βœ… Response generated in {end_time - start_time:.2f}s")
177
+ print(f"Response: {response.message}")
178
+
179
+ # Test streaming
180
+ print("\nTesting streaming...")
181
+ full_response = ""
182
+ chunk_count = 0
183
+
184
+ async for chunk in backend.generate_stream(messages, max_tokens=30):
185
+ full_response += chunk.content
186
+ chunk_count += 1
187
+ if chunk.is_final:
188
+ break
189
+
190
+ print(f"βœ… Streaming completed with {chunk_count} chunks")
191
+ print(f"Streamed response: {full_response}")
192
+
193
+ return True
194
+
195
+ except Exception as e:
196
+ print(f"❌ OpenAI backend test failed: {e}")
197
+ return False
198
+
199
+
200
+ async def test_anthropic_backend():
201
+ """Test Anthropic API backend"""
202
+ print("\n🧠 Testing Anthropic API Backend")
203
+ print("-" * 40)
204
+
205
+ # Check if API key is available
206
+ api_key = os.getenv("ANTHROPIC_API_KEY")
207
+ if not api_key:
208
+ print("⚠️ ANTHROPIC_API_KEY not set, skipping Anthropic test")
209
+ return True
210
+
211
+ model_name = "claude-3-haiku-20240307"
212
+
213
+ backend = AnthropicAPIBackend(
214
+ model_name=model_name,
215
+ api_key=api_key,
216
+ temperature=0.7,
217
+ max_tokens=50
218
+ )
219
+
220
+ try:
221
+ print(f"Initializing Anthropic client for: {model_name}")
222
+ success = await backend.load_model()
223
+
224
+ if not success:
225
+ print("❌ Failed to initialize Anthropic client")
226
+ return False
227
+
228
+ print("βœ… Anthropic client initialized")
229
+
230
+ # Test generation
231
+ messages = [
232
+ ChatMessage(role="user", content="Hello! Tell me about yourself.")
233
+ ]
234
+
235
+ print("Generating response via Anthropic...")
236
+ start_time = time.time()
237
+ response = await backend.generate_response(messages, max_tokens=30)
238
+ end_time = time.time()
239
+
240
+ print(f"βœ… Response generated in {end_time - start_time:.2f}s")
241
+ print(f"Response: {response.message}")
242
+
243
+ return True
244
+
245
+ except Exception as e:
246
+ print(f"❌ Anthropic backend test failed: {e}")
247
+ return False
248
+
249
+
250
+ async def main():
251
+ """Main test function"""
252
+ print("πŸš€ Sema Chat Backend Testing")
253
+ print("=" * 50)
254
+
255
+ results = {}
256
+
257
+ # Test each backend
258
+ results["local_hf"] = await test_local_hf_backend()
259
+ results["hf_api"] = await test_hf_api_backend()
260
+ results["openai"] = await test_openai_backend()
261
+ results["anthropic"] = await test_anthropic_backend()
262
+
263
+ # Summary
264
+ print("\n" + "=" * 50)
265
+ print("πŸ“Š Test Results Summary")
266
+ print("-" * 25)
267
+
268
+ for backend, success in results.items():
269
+ status = "βœ… PASS" if success else "❌ FAIL"
270
+ print(f"{backend:15} {status}")
271
+
272
+ total_tests = len(results)
273
+ passed_tests = sum(results.values())
274
+
275
+ print(f"\nTotal: {passed_tests}/{total_tests} backends working")
276
+
277
+ if passed_tests == total_tests:
278
+ print("πŸŽ‰ All available backends are working!")
279
+ elif passed_tests > 0:
280
+ print("⚠️ Some backends are working, check configuration for others")
281
+ else:
282
+ print("❌ No backends are working, check your setup")
283
+
284
+ print("\nπŸ’‘ Tips:")
285
+ print("- For HF API: Set HF_API_TOKEN environment variable")
286
+ print("- For OpenAI: Set OPENAI_API_KEY environment variable")
287
+ print("- For Anthropic: Set ANTHROPIC_API_KEY environment variable")
288
+ print("- For local models: Ensure you have enough RAM/VRAM")
289
+
290
+
291
+ if __name__ == "__main__":
292
+ asyncio.run(main())
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ python-multipart
5
+ websockets
6
+ sse-starlette
7
+ slowapi
8
+ prometheus-client
9
+ structlog
10
+ python-dotenv
11
+ httpx
12
+ aiofiles
13
+
14
+ # HuggingFace & ML
15
+ transformers
16
+ torch
17
+ huggingface-hub
18
+ accelerate
19
+ sentencepiece
20
+
21
+ # API Clients
22
+ openai
23
+ anthropic
24
+
25
+ # Utilities
26
+ uuid
27
+ asyncio-mqtt
28
+ redis
setup_huggingface.sh ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # πŸš€ Sema Chat API - HuggingFace Spaces Setup Script
4
+ # This script helps you deploy Sema Chat API to HuggingFace Spaces with Gemma
5
+
6
+ set -e
7
+
8
+ echo "πŸš€ Sema Chat API - HuggingFace Spaces Setup"
9
+ echo "=========================================="
10
+
11
+ # Check if we're in the right directory
12
+ if [ ! -f "app/main.py" ]; then
13
+ echo "❌ Error: Please run this script from the backend/sema-chat directory"
14
+ echo " Current directory: $(pwd)"
15
+ echo " Expected files: app/main.py, requirements.txt, Dockerfile"
16
+ exit 1
17
+ fi
18
+
19
+ echo "βœ… Found Sema Chat API files"
20
+
21
+ # Get user input
22
+ read -p "πŸ“ Enter your HuggingFace username: " HF_USERNAME
23
+ read -p "πŸ“ Enter your Space name (e.g., sema-chat-gemma): " SPACE_NAME
24
+ read -p "πŸ”‘ Enter your Google AI API key (or press Enter to skip): " GOOGLE_API_KEY
25
+
26
+ # Validate inputs
27
+ if [ -z "$HF_USERNAME" ]; then
28
+ echo "❌ Error: HuggingFace username is required"
29
+ exit 1
30
+ fi
31
+
32
+ if [ -z "$SPACE_NAME" ]; then
33
+ echo "❌ Error: Space name is required"
34
+ exit 1
35
+ fi
36
+
37
+ SPACE_URL="https://huggingface.co/spaces/$HF_USERNAME/$SPACE_NAME"
38
+ SPACE_REPO="https://huggingface.co/spaces/$HF_USERNAME/$SPACE_NAME"
39
+
40
+ echo ""
41
+ echo "πŸ“‹ Configuration Summary:"
42
+ echo " HuggingFace Username: $HF_USERNAME"
43
+ echo " Space Name: $SPACE_NAME"
44
+ echo " Space URL: $SPACE_URL"
45
+ echo " Google AI Key: ${GOOGLE_API_KEY:+[PROVIDED]}${GOOGLE_API_KEY:-[NOT PROVIDED]}"
46
+ echo ""
47
+
48
+ read -p "πŸ€” Continue with deployment? (y/N): " CONFIRM
49
+ if [[ ! $CONFIRM =~ ^[Yy]$ ]]; then
50
+ echo "❌ Deployment cancelled"
51
+ exit 0
52
+ fi
53
+
54
+ # Create deployment directory
55
+ DEPLOY_DIR="../sema-chat-deploy"
56
+ echo "πŸ“ Creating deployment directory: $DEPLOY_DIR"
57
+ rm -rf "$DEPLOY_DIR"
58
+ mkdir -p "$DEPLOY_DIR"
59
+
60
+ # Copy all files
61
+ echo "πŸ“‹ Copying files..."
62
+ cp -r . "$DEPLOY_DIR/"
63
+ cd "$DEPLOY_DIR"
64
+
65
+ # Create README.md for the Space
66
+ echo "πŸ“ Creating Space README..."
67
+ cat > README.md << EOF
68
+ ---
69
+ title: Sema Chat API
70
+ emoji: πŸ’¬
71
+ colorFrom: purple
72
+ colorTo: pink
73
+ sdk: docker
74
+ pinned: false
75
+ license: mit
76
+ short_description: Modern chatbot API with Gemma integration and streaming capabilities
77
+ ---
78
+
79
+ # Sema Chat API πŸ’¬
80
+
81
+ Modern chatbot API with streaming capabilities, powered by Google's Gemma model.
82
+
83
+ ## πŸš€ Features
84
+
85
+ - **Real-time Streaming**: Server-Sent Events and WebSocket support
86
+ - **Gemma Integration**: Powered by Google's Gemma 2 9B model
87
+ - **Session Management**: Persistent conversation contexts
88
+ - **RESTful API**: Clean, documented endpoints
89
+ - **Interactive UI**: Built-in Swagger documentation
90
+
91
+ ## πŸ”— API Endpoints
92
+
93
+ - **Chat**: \`POST /api/v1/chat\`
94
+ - **Streaming**: \`GET /api/v1/chat/stream\`
95
+ - **WebSocket**: \`ws://space-url/api/v1/chat/ws\`
96
+ - **Health**: \`GET /api/v1/health\`
97
+ - **Docs**: \`GET /\` (Swagger UI)
98
+
99
+ ## πŸ’¬ Quick Test
100
+
101
+ \`\`\`bash
102
+ curl -X POST "https://$HF_USERNAME-$SPACE_NAME.hf.space/api/v1/chat" \\
103
+ -H "Content-Type: application/json" \\
104
+ -d '{
105
+ "message": "Hello! Can you introduce yourself?",
106
+ "session_id": "test-session"
107
+ }'
108
+ \`\`\`
109
+
110
+ ## πŸ”„ Streaming Test
111
+
112
+ \`\`\`bash
113
+ curl -N -H "Accept: text/event-stream" \\
114
+ "https://$HF_USERNAME-$SPACE_NAME.hf.space/api/v1/chat/stream?message=Tell%20me%20about%20AI&session_id=test"
115
+ \`\`\`
116
+
117
+ ## βš™οΈ Configuration
118
+
119
+ This Space is configured to use Google's Gemma model via AI Studio.
120
+ Set your \`GOOGLE_API_KEY\` in the Space settings to enable the API.
121
+
122
+ ## πŸ› οΈ Built With
123
+
124
+ - **FastAPI**: Modern Python web framework
125
+ - **Google Gemma**: Advanced language model
126
+ - **Docker**: Containerized deployment
127
+ - **HuggingFace Spaces**: Hosting platform
128
+
129
+ ---
130
+
131
+ Created by $HF_USERNAME | Powered by Sema AI
132
+ EOF
133
+
134
+ # Create .gitignore
135
+ echo "🚫 Creating .gitignore..."
136
+ cat > .gitignore << EOF
137
+ __pycache__/
138
+ *.py[cod]
139
+ *$py.class
140
+ *.so
141
+ .Python
142
+ build/
143
+ develop-eggs/
144
+ dist/
145
+ downloads/
146
+ eggs/
147
+ .eggs/
148
+ lib/
149
+ lib64/
150
+ parts/
151
+ sdist/
152
+ var/
153
+ wheels/
154
+ *.egg-info/
155
+ .installed.cfg
156
+ *.egg
157
+ MANIFEST
158
+
159
+ .env
160
+ .venv
161
+ env/
162
+ venv/
163
+ ENV/
164
+ env.bak/
165
+ venv.bak/
166
+
167
+ .pytest_cache/
168
+ .coverage
169
+ htmlcov/
170
+
171
+ .DS_Store
172
+ .vscode/
173
+ .idea/
174
+
175
+ logs/
176
+ *.log
177
+ EOF
178
+
179
+ # Initialize git repository
180
+ echo "πŸ”§ Initializing git repository..."
181
+ git init
182
+ git remote add origin "$SPACE_REPO"
183
+
184
+ # Create initial commit
185
+ echo "πŸ“¦ Creating initial commit..."
186
+ git add .
187
+ git commit -m "Initial deployment of Sema Chat API with Gemma support
188
+
189
+ Features:
190
+ - Google Gemma 2 9B integration
191
+ - Real-time streaming responses
192
+ - Session management
193
+ - RESTful API with Swagger docs
194
+ - WebSocket support
195
+ - Health monitoring
196
+
197
+ Configuration:
198
+ - MODEL_TYPE=google
199
+ - MODEL_NAME=gemma-2-9b-it
200
+ - Port: 7860 (HuggingFace standard)
201
+ "
202
+
203
+ echo ""
204
+ echo "πŸŽ‰ Setup Complete!"
205
+ echo "=================="
206
+ echo ""
207
+ echo "πŸ“‹ Next Steps:"
208
+ echo "1. Create your HuggingFace Space:"
209
+ echo " β†’ Go to: https://huggingface.co/spaces"
210
+ echo " β†’ Click 'Create new Space'"
211
+ echo " β†’ Name: $SPACE_NAME"
212
+ echo " β†’ SDK: Docker"
213
+ echo " β†’ License: MIT"
214
+ echo ""
215
+ echo "2. Push your code:"
216
+ echo " β†’ cd $DEPLOY_DIR"
217
+ echo " β†’ git push origin main"
218
+ echo ""
219
+ echo "3. Configure environment variables in Space settings:"
220
+ if [ -n "$GOOGLE_API_KEY" ]; then
221
+ echo " β†’ MODEL_TYPE=google"
222
+ echo " β†’ MODEL_NAME=gemma-2-9b-it"
223
+ echo " β†’ GOOGLE_API_KEY=$GOOGLE_API_KEY"
224
+ else
225
+ echo " β†’ MODEL_TYPE=google"
226
+ echo " β†’ MODEL_NAME=gemma-2-9b-it"
227
+ echo " β†’ GOOGLE_API_KEY=your_google_api_key_here"
228
+ echo ""
229
+ echo " πŸ”‘ Get your Google AI API key from: https://aistudio.google.com/"
230
+ fi
231
+ echo " β†’ DEBUG=false"
232
+ echo " β†’ ENVIRONMENT=production"
233
+ echo ""
234
+ echo "4. Wait for build and test:"
235
+ echo " β†’ Space URL: $SPACE_URL"
236
+ echo " β†’ API Docs: $SPACE_URL/"
237
+ echo " β†’ Health Check: $SPACE_URL/api/v1/health"
238
+ echo ""
239
+ echo "πŸš€ Your Sema Chat API will be live at:"
240
+ echo " $SPACE_URL"
241
+ echo ""
242
+ echo "Happy deploying! πŸ’¬βœ¨"
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Tests package
tests/test_api.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for Sema Chat API
3
+ Tests all endpoints and functionality
4
+ """
5
+
6
+ import requests
7
+ import json
8
+ import time
9
+ import asyncio
10
+ import websockets
11
+ from typing import Dict, Any
12
+ import sys
13
+
14
+
15
+ class SemaChatAPITester:
16
+ """Test client for Sema Chat API"""
17
+
18
+ def __init__(self, base_url: str = "http://localhost:7860"):
19
+ self.base_url = base_url.rstrip("/")
20
+ self.session_id = f"test-session-{int(time.time())}"
21
+
22
+ def test_health_endpoints(self):
23
+ """Test health and status endpoints"""
24
+ print("πŸ₯ Testing health endpoints...")
25
+
26
+ # Test basic status
27
+ response = requests.get(f"{self.base_url}/status")
28
+ assert response.status_code == 200
29
+ print("βœ… Status endpoint working")
30
+
31
+ # Test app-level health
32
+ response = requests.get(f"{self.base_url}/health")
33
+ assert response.status_code == 200
34
+ print("βœ… App health endpoint working")
35
+
36
+ # Test detailed health
37
+ response = requests.get(f"{self.base_url}/api/v1/health")
38
+ assert response.status_code == 200
39
+ health_data = response.json()
40
+ print(f"βœ… Detailed health check: {health_data['status']}")
41
+ print(f" Model: {health_data['model_name']} ({health_data['model_type']})")
42
+ print(f" Model loaded: {health_data['model_loaded']}")
43
+
44
+ return health_data
45
+
46
+ def test_model_info(self):
47
+ """Test model information endpoint"""
48
+ print("\nπŸ€– Testing model info...")
49
+
50
+ response = requests.get(f"{self.base_url}/api/v1/model/info")
51
+ assert response.status_code == 200
52
+
53
+ model_info = response.json()
54
+ print(f"βœ… Model info retrieved")
55
+ print(f" Name: {model_info['name']}")
56
+ print(f" Type: {model_info['type']}")
57
+ print(f" Loaded: {model_info['loaded']}")
58
+ print(f" Capabilities: {model_info['capabilities']}")
59
+
60
+ return model_info
61
+
62
+ def test_regular_chat(self):
63
+ """Test regular (non-streaming) chat"""
64
+ print("\nπŸ’¬ Testing regular chat...")
65
+
66
+ chat_request = {
67
+ "message": "Hello! Can you introduce yourself?",
68
+ "session_id": self.session_id,
69
+ "temperature": 0.7,
70
+ "max_tokens": 100
71
+ }
72
+
73
+ start_time = time.time()
74
+ response = requests.post(
75
+ f"{self.base_url}/api/v1/chat",
76
+ json=chat_request,
77
+ headers={"Content-Type": "application/json"}
78
+ )
79
+ end_time = time.time()
80
+
81
+ assert response.status_code == 200
82
+ chat_response = response.json()
83
+
84
+ print(f"βœ… Regular chat working")
85
+ print(f" Response time: {end_time - start_time:.2f}s")
86
+ print(f" Generation time: {chat_response['generation_time']:.2f}s")
87
+ print(f" Response: {chat_response['message'][:100]}...")
88
+ print(f" Session ID: {chat_response['session_id']}")
89
+ print(f" Message ID: {chat_response['message_id']}")
90
+
91
+ return chat_response
92
+
93
+ def test_streaming_chat(self):
94
+ """Test streaming chat via SSE"""
95
+ print("\nπŸ”„ Testing streaming chat...")
96
+
97
+ params = {
98
+ "message": "Tell me a short story about AI",
99
+ "session_id": self.session_id,
100
+ "temperature": 0.8,
101
+ "max_tokens": 150
102
+ }
103
+
104
+ start_time = time.time()
105
+ response = requests.get(
106
+ f"{self.base_url}/api/v1/chat/stream",
107
+ params=params,
108
+ headers={"Accept": "text/event-stream"},
109
+ stream=True
110
+ )
111
+
112
+ assert response.status_code == 200
113
+
114
+ chunks_received = 0
115
+ full_response = ""
116
+
117
+ for line in response.iter_lines():
118
+ if line:
119
+ line_str = line.decode('utf-8')
120
+ if line_str.startswith('data: '):
121
+ try:
122
+ data = json.loads(line_str[6:]) # Remove 'data: ' prefix
123
+ if 'content' in data:
124
+ full_response += data['content']
125
+ chunks_received += 1
126
+
127
+ if data.get('is_final'):
128
+ break
129
+ except json.JSONDecodeError:
130
+ continue
131
+
132
+ end_time = time.time()
133
+
134
+ print(f"βœ… Streaming chat working")
135
+ print(f" Total time: {end_time - start_time:.2f}s")
136
+ print(f" Chunks received: {chunks_received}")
137
+ print(f" Response: {full_response[:100]}...")
138
+
139
+ return full_response
140
+
141
+ def test_session_management(self):
142
+ """Test session management endpoints"""
143
+ print("\nπŸ“ Testing session management...")
144
+
145
+ # Get session history
146
+ response = requests.get(f"{self.base_url}/api/v1/sessions/{self.session_id}")
147
+ assert response.status_code == 200
148
+
149
+ session_data = response.json()
150
+ print(f"βœ… Session retrieval working")
151
+ print(f" Messages in session: {session_data['message_count']}")
152
+ print(f" Session created: {session_data['created_at']}")
153
+
154
+ # Get active sessions
155
+ response = requests.get(f"{self.base_url}/api/v1/sessions")
156
+ assert response.status_code == 200
157
+
158
+ sessions = response.json()
159
+ print(f"βœ… Active sessions list working")
160
+ print(f" Total active sessions: {len(sessions)}")
161
+
162
+ return session_data
163
+
164
+ async def test_websocket_chat(self):
165
+ """Test WebSocket chat functionality"""
166
+ print("\nπŸ”Œ Testing WebSocket chat...")
167
+
168
+ ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
169
+ ws_url += "/api/v1/chat/ws"
170
+
171
+ try:
172
+ async with websockets.connect(ws_url) as websocket:
173
+ # Send a message
174
+ message = {
175
+ "message": "Hello via WebSocket!",
176
+ "session_id": f"{self.session_id}-ws",
177
+ "temperature": 0.7,
178
+ "max_tokens": 50
179
+ }
180
+
181
+ await websocket.send(json.dumps(message))
182
+
183
+ # Receive response chunks
184
+ chunks_received = 0
185
+ full_response = ""
186
+
187
+ while True:
188
+ try:
189
+ response = await asyncio.wait_for(websocket.recv(), timeout=30.0)
190
+ data = json.loads(response)
191
+
192
+ if data.get("type") == "chunk":
193
+ full_response += data.get("content", "")
194
+ chunks_received += 1
195
+
196
+ if data.get("is_final"):
197
+ break
198
+ elif data.get("type") == "error":
199
+ print(f"❌ WebSocket error: {data.get('error')}")
200
+ break
201
+
202
+ except asyncio.TimeoutError:
203
+ print("⚠️ WebSocket timeout")
204
+ break
205
+
206
+ print(f"βœ… WebSocket chat working")
207
+ print(f" Chunks received: {chunks_received}")
208
+ print(f" Response: {full_response[:100]}...")
209
+
210
+ return full_response
211
+
212
+ except Exception as e:
213
+ print(f"❌ WebSocket test failed: {e}")
214
+ return None
215
+
216
+ def test_error_handling(self):
217
+ """Test error handling"""
218
+ print("\n🚨 Testing error handling...")
219
+
220
+ # Test empty message
221
+ response = requests.post(
222
+ f"{self.base_url}/api/v1/chat",
223
+ json={"message": "", "session_id": self.session_id}
224
+ )
225
+ assert response.status_code == 422 # Validation error
226
+ print("βœ… Empty message validation working")
227
+
228
+ # Test invalid session ID
229
+ response = requests.get(f"{self.base_url}/api/v1/sessions/invalid-session-id-that-does-not-exist")
230
+ assert response.status_code == 404
231
+ print("βœ… Invalid session handling working")
232
+
233
+ # Test rate limiting (if enabled)
234
+ print("βœ… Error handling tests passed")
235
+
236
+ def test_session_cleanup(self):
237
+ """Test session cleanup"""
238
+ print("\n🧹 Testing session cleanup...")
239
+
240
+ # Clear the test session
241
+ response = requests.delete(f"{self.base_url}/api/v1/sessions/{self.session_id}")
242
+ assert response.status_code == 200
243
+ print("βœ… Session cleanup working")
244
+
245
+ # Verify session is gone
246
+ response = requests.get(f"{self.base_url}/api/v1/sessions/{self.session_id}")
247
+ assert response.status_code == 404
248
+ print("βœ… Session deletion verified")
249
+
250
+ def run_all_tests(self):
251
+ """Run all tests"""
252
+ print("πŸš€ Starting Sema Chat API Tests")
253
+ print("=" * 50)
254
+
255
+ try:
256
+ # Test basic endpoints
257
+ health_data = self.test_health_endpoints()
258
+
259
+ if not health_data.get('model_loaded'):
260
+ print("⚠️ Model not loaded, skipping chat tests")
261
+ return False
262
+
263
+ model_info = self.test_model_info()
264
+
265
+ # Test chat functionality
266
+ self.test_regular_chat()
267
+ self.test_streaming_chat()
268
+
269
+ # Test session management
270
+ self.test_session_management()
271
+
272
+ # Test WebSocket (async)
273
+ asyncio.run(self.test_websocket_chat())
274
+
275
+ # Test error handling
276
+ self.test_error_handling()
277
+
278
+ # Cleanup
279
+ self.test_session_cleanup()
280
+
281
+ print("\n" + "=" * 50)
282
+ print("πŸŽ‰ All tests passed successfully!")
283
+ print(f"βœ… API is working correctly with {model_info['name']}")
284
+ return True
285
+
286
+ except Exception as e:
287
+ print(f"\n❌ Test failed: {e}")
288
+ import traceback
289
+ traceback.print_exc()
290
+ return False
291
+
292
+
293
+ def main():
294
+ """Main test function"""
295
+ import argparse
296
+
297
+ parser = argparse.ArgumentParser(description="Test Sema Chat API")
298
+ parser.add_argument(
299
+ "--url",
300
+ default="http://localhost:7860",
301
+ help="Base URL of the API (default: http://localhost:7860)"
302
+ )
303
+
304
+ args = parser.parse_args()
305
+
306
+ tester = SemaChatAPITester(args.url)
307
+ success = tester.run_all_tests()
308
+
309
+ sys.exit(0 if success else 1)
310
+
311
+
312
+ if __name__ == "__main__":
313
+ main()