Spaces:
Sleeping
Sleeping
suhail
commited on
Commit
·
676582c
1
Parent(s):
87238f5
chatbot
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +18 -0
- .env +16 -0
- .env.example +29 -4
- README.md +211 -8
- alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py +59 -0
- alembic/versions/20260114_1115_37ca2e18468d_description.py +24 -0
- alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py +27 -0
- alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py +26 -0
- alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py +28 -0
- alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc +0 -0
- alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc +0 -0
- alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc +0 -0
- alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc +0 -0
- alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc +0 -0
- alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc +0 -0
- alembic/versions/tmpclaude-82e1-cwd +1 -0
- package-lock.json +6 -0
- requirements.txt +7 -1
- src/__pycache__/main.cpython-313.pyc +0 -0
- src/agent/__init__.py +0 -0
- src/agent/__pycache__/__init__.cpython-313.pyc +0 -0
- src/agent/__pycache__/agent_config.cpython-313.pyc +0 -0
- src/agent/__pycache__/agent_runner.cpython-313.pyc +0 -0
- src/agent/agent_config.py +124 -0
- src/agent/agent_runner.py +281 -0
- src/agent/providers/__init__.py +0 -0
- src/agent/providers/__pycache__/__init__.cpython-313.pyc +0 -0
- src/agent/providers/__pycache__/base.cpython-313.pyc +0 -0
- src/agent/providers/__pycache__/cohere.cpython-313.pyc +0 -0
- src/agent/providers/__pycache__/gemini.cpython-313.pyc +0 -0
- src/agent/providers/__pycache__/openrouter.cpython-313.pyc +0 -0
- src/agent/providers/base.py +105 -0
- src/agent/providers/cohere.py +232 -0
- src/agent/providers/gemini.py +285 -0
- src/agent/providers/openrouter.py +264 -0
- src/api/routes/__pycache__/chat.cpython-313.pyc +0 -0
- src/api/routes/__pycache__/conversations.cpython-313.pyc +0 -0
- src/api/routes/chat.py +281 -0
- src/api/routes/conversations.py +291 -0
- src/core/__pycache__/config.cpython-313.pyc +0 -0
- src/core/__pycache__/exceptions.cpython-313.pyc +0 -0
- src/core/__pycache__/security.cpython-313.pyc +0 -0
- src/core/config.py +15 -0
- src/core/exceptions.py +125 -0
- src/core/security.py +52 -1
- src/main.py +67 -2
- src/mcp/__init__.py +130 -0
- src/mcp/__pycache__/__init__.cpython-313.pyc +0 -0
- src/mcp/__pycache__/tool_registry.cpython-313.pyc +0 -0
- src/mcp/tool_registry.py +140 -0
.dockerignore
CHANGED
|
@@ -43,3 +43,21 @@ alembic/versions/*.pyc
|
|
| 43 |
|
| 44 |
# Documentation
|
| 45 |
docs/_build/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Documentation
|
| 45 |
docs/_build/
|
| 46 |
+
# Python
|
| 47 |
+
__pycache__/
|
| 48 |
+
*.pyc
|
| 49 |
+
*.pyo
|
| 50 |
+
*.pyd
|
| 51 |
+
|
| 52 |
+
# Environment
|
| 53 |
+
.env
|
| 54 |
+
|
| 55 |
+
# Alembic cache
|
| 56 |
+
alembic/versions/__pycache__/
|
| 57 |
+
|
| 58 |
+
# Node
|
| 59 |
+
node_modules/
|
| 60 |
+
package-lock.json
|
| 61 |
+
|
| 62 |
+
# Temp / AI tools
|
| 63 |
+
tmpclaude-*
|
.env
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
# For local PostgreSQL: postgresql://user:password@localhost:5432/todo_db
|
| 3 |
# For Neon: Use your Neon connection string from the dashboard
|
| 4 |
DATABASE_URL=postgresql://neondb_owner:npg_MmFvJBHT8Y0k@ep-silent-thunder-ab0rbvrp-pooler.eu-west-2.aws.neon.tech/neondb?sslmode=require&channel_binding=require
|
|
|
|
| 5 |
# Application Settings
|
| 6 |
APP_NAME=Task CRUD API
|
| 7 |
DEBUG=True
|
|
@@ -11,3 +12,18 @@ CORS_ORIGINS=http://localhost:3000
|
|
| 11 |
BETTER_AUTH_SECRET=zMdW1P03wJvWJnLKzQ8YYO26vHeinqmR
|
| 12 |
JWT_ALGORITHM=HS256
|
| 13 |
JWT_EXPIRATION_DAYS=7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
# For local PostgreSQL: postgresql://user:password@localhost:5432/todo_db
|
| 3 |
# For Neon: Use your Neon connection string from the dashboard
|
| 4 |
DATABASE_URL=postgresql://neondb_owner:npg_MmFvJBHT8Y0k@ep-silent-thunder-ab0rbvrp-pooler.eu-west-2.aws.neon.tech/neondb?sslmode=require&channel_binding=require
|
| 5 |
+
|
| 6 |
# Application Settings
|
| 7 |
APP_NAME=Task CRUD API
|
| 8 |
DEBUG=True
|
|
|
|
| 12 |
BETTER_AUTH_SECRET=zMdW1P03wJvWJnLKzQ8YYO26vHeinqmR
|
| 13 |
JWT_ALGORITHM=HS256
|
| 14 |
JWT_EXPIRATION_DAYS=7
|
| 15 |
+
|
| 16 |
+
# LLM Provider Configuration
|
| 17 |
+
LLM_PROVIDER=gemini
|
| 18 |
+
# FALLBACK_PROVIDER=openrouter
|
| 19 |
+
GEMINI_API_KEY=AIzaSyCAlcHZxp5ELh1GqJwKqBLQziUNi0vnobU
|
| 20 |
+
# OPENROUTER_API_KEY=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
|
| 21 |
+
COHERE_API_KEY=
|
| 22 |
+
|
| 23 |
+
# Agent Configuration
|
| 24 |
+
AGENT_TEMPERATURE=0.7
|
| 25 |
+
AGENT_MAX_TOKENS=500
|
| 26 |
+
|
| 27 |
+
# Conversation Settings
|
| 28 |
+
CONVERSATION_MAX_MESSAGES=20
|
| 29 |
+
CONVERSATION_MAX_TOKENS=2000
|
.env.example
CHANGED
|
@@ -6,7 +6,32 @@ APP_NAME=Task CRUD API
|
|
| 6 |
DEBUG=True
|
| 7 |
CORS_ORIGINS=http://localhost:3000
|
| 8 |
|
| 9 |
-
# Authentication
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
DEBUG=True
|
| 7 |
CORS_ORIGINS=http://localhost:3000
|
| 8 |
|
| 9 |
+
# Authentication
|
| 10 |
+
BETTER_AUTH_SECRET=your-secret-key-here-min-32-characters
|
| 11 |
+
JWT_ALGORITHM=HS256
|
| 12 |
+
JWT_EXPIRATION_DAYS=7
|
| 13 |
+
|
| 14 |
+
# LLM Provider Configuration
|
| 15 |
+
# Primary provider: gemini, openrouter, cohere
|
| 16 |
+
LLM_PROVIDER=gemini
|
| 17 |
+
|
| 18 |
+
# Optional fallback provider (recommended for production)
|
| 19 |
+
FALLBACK_PROVIDER=openrouter
|
| 20 |
+
|
| 21 |
+
# API Keys (provide at least one for your primary provider)
|
| 22 |
+
GEMINI_API_KEY=your-gemini-api-key-here
|
| 23 |
+
OPENROUTER_API_KEY=your-openrouter-api-key-here
|
| 24 |
+
COHERE_API_KEY=your-cohere-api-key-here
|
| 25 |
+
|
| 26 |
+
# Agent Configuration
|
| 27 |
+
AGENT_TEMPERATURE=0.7
|
| 28 |
+
AGENT_MAX_TOKENS=8192
|
| 29 |
+
|
| 30 |
+
# Conversation Settings (for free-tier constraints)
|
| 31 |
+
CONVERSATION_MAX_MESSAGES=20
|
| 32 |
+
CONVERSATION_MAX_TOKENS=8000
|
| 33 |
+
|
| 34 |
+
# How to get API keys:
|
| 35 |
+
# - Gemini: https://makersuite.google.com/app/apikey (free, no credit card required)
|
| 36 |
+
# - OpenRouter: https://openrouter.ai/ (free models available)
|
| 37 |
+
# - Cohere: https://cohere.com/ (trial only, not recommended for production)
|
README.md
CHANGED
|
@@ -10,24 +10,227 @@ license: mit
|
|
| 10 |
|
| 11 |
# TaskFlow Backend API
|
| 12 |
|
| 13 |
-
FastAPI backend for TaskFlow task management application.
|
| 14 |
|
| 15 |
## Features
|
| 16 |
|
| 17 |
-
- User authentication with JWT
|
| 18 |
- Task CRUD operations
|
|
|
|
| 19 |
- PostgreSQL database with SQLModel ORM
|
| 20 |
- RESTful API design
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
## Environment Variables
|
| 23 |
|
| 24 |
-
Configure these in your
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
## API Documentation
|
| 32 |
|
| 33 |
-
Once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# TaskFlow Backend API
|
| 12 |
|
| 13 |
+
FastAPI backend for TaskFlow task management application with AI chatbot integration.
|
| 14 |
|
| 15 |
## Features
|
| 16 |
|
| 17 |
+
- User authentication with JWT and Better Auth
|
| 18 |
- Task CRUD operations
|
| 19 |
+
- **AI Chatbot Assistant** - Conversational AI for task management
|
| 20 |
- PostgreSQL database with SQLModel ORM
|
| 21 |
- RESTful API design
|
| 22 |
+
- Multi-turn conversation support with context management
|
| 23 |
+
- Intent recognition for todo-related requests
|
| 24 |
+
|
| 25 |
+
## Tech Stack
|
| 26 |
+
|
| 27 |
+
- **Framework**: FastAPI 0.104.1
|
| 28 |
+
- **ORM**: SQLModel 0.0.14
|
| 29 |
+
- **Database**: PostgreSQL (Neon Serverless)
|
| 30 |
+
- **Authentication**: Better Auth + JWT
|
| 31 |
+
- **AI Provider**: Google Gemini (gemini-pro)
|
| 32 |
+
- **Migrations**: Alembic 1.13.0
|
| 33 |
|
| 34 |
## Environment Variables
|
| 35 |
|
| 36 |
+
Configure these in your `.env` file:
|
| 37 |
+
|
| 38 |
+
### Database
|
| 39 |
+
- `DATABASE_URL`: PostgreSQL connection string (Neon or local)
|
| 40 |
+
|
| 41 |
+
### Application
|
| 42 |
+
- `APP_NAME`: Application name (default: "Task CRUD API")
|
| 43 |
+
- `DEBUG`: Debug mode (default: True)
|
| 44 |
+
- `CORS_ORIGINS`: Allowed CORS origins (default: "http://localhost:3000")
|
| 45 |
+
|
| 46 |
+
### Authentication
|
| 47 |
+
- `BETTER_AUTH_SECRET`: Secret key for Better Auth (required)
|
| 48 |
+
- `JWT_ALGORITHM`: JWT algorithm (default: "HS256")
|
| 49 |
+
- `JWT_EXPIRATION_DAYS`: Token expiration in days (default: 7)
|
| 50 |
+
|
| 51 |
+
### AI Provider Configuration
|
| 52 |
+
- `AI_PROVIDER`: AI provider to use (default: "gemini")
|
| 53 |
+
- `GEMINI_API_KEY`: Google Gemini API key (required if using Gemini)
|
| 54 |
+
- `OPENROUTER_API_KEY`: OpenRouter API key (optional)
|
| 55 |
+
- `COHERE_API_KEY`: Cohere API key (optional)
|
| 56 |
+
|
| 57 |
+
### Conversation Settings
|
| 58 |
+
- `MAX_CONVERSATION_MESSAGES`: Maximum messages to keep in history (default: 20)
|
| 59 |
+
- `MAX_CONVERSATION_TOKENS`: Maximum tokens to keep in history (default: 8000)
|
| 60 |
+
|
| 61 |
+
## Setup Instructions
|
| 62 |
+
|
| 63 |
+
### 1. Install Dependencies
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
pip install -r requirements.txt
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### 2. Configure Environment
|
| 70 |
+
|
| 71 |
+
Create a `.env` file in the `backend/` directory:
|
| 72 |
|
| 73 |
+
```env
|
| 74 |
+
# Database
|
| 75 |
+
DATABASE_URL=postgresql://user:password@localhost:5432/todo_db
|
| 76 |
+
|
| 77 |
+
# Application
|
| 78 |
+
APP_NAME=Task CRUD API
|
| 79 |
+
DEBUG=True
|
| 80 |
+
CORS_ORIGINS=http://localhost:3000
|
| 81 |
+
|
| 82 |
+
# Authentication
|
| 83 |
+
BETTER_AUTH_SECRET=your_secret_key_here
|
| 84 |
+
JWT_ALGORITHM=HS256
|
| 85 |
+
JWT_EXPIRATION_DAYS=7
|
| 86 |
+
|
| 87 |
+
# AI Provider
|
| 88 |
+
AI_PROVIDER=gemini
|
| 89 |
+
GEMINI_API_KEY=your_gemini_api_key_here
|
| 90 |
+
|
| 91 |
+
# Conversation Settings
|
| 92 |
+
MAX_CONVERSATION_MESSAGES=20
|
| 93 |
+
MAX_CONVERSATION_TOKENS=8000
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### 3. Run Database Migrations
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
alembic upgrade head
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### 4. Start the Server
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
The API will be available at `http://localhost:8000`
|
| 109 |
|
| 110 |
## API Documentation
|
| 111 |
|
| 112 |
+
Once running, visit:
|
| 113 |
+
- **Interactive Docs**: `http://localhost:8000/docs`
|
| 114 |
+
- **ReDoc**: `http://localhost:8000/redoc`
|
| 115 |
+
|
| 116 |
+
## API Endpoints
|
| 117 |
+
|
| 118 |
+
### Authentication
|
| 119 |
+
- `POST /api/auth/signup` - Register new user
|
| 120 |
+
- `POST /api/auth/login` - Login user
|
| 121 |
+
|
| 122 |
+
### Tasks
|
| 123 |
+
- `GET /api/{user_id}/tasks` - Get all tasks for user
|
| 124 |
+
- `POST /api/{user_id}/tasks` - Create new task
|
| 125 |
+
- `GET /api/{user_id}/tasks/{task_id}` - Get specific task
|
| 126 |
+
- `PUT /api/{user_id}/tasks/{task_id}` - Update task
|
| 127 |
+
- `DELETE /api/{user_id}/tasks/{task_id}` - Delete task
|
| 128 |
+
|
| 129 |
+
### AI Chat (New in Phase 1)
|
| 130 |
+
- `POST /api/{user_id}/chat` - Send message to AI assistant
|
| 131 |
+
|
| 132 |
+
#### Chat Request Body
|
| 133 |
+
```json
|
| 134 |
+
{
|
| 135 |
+
"message": "Can you help me organize my tasks?",
|
| 136 |
+
"conversation_id": 123, // Optional: null for new conversation
|
| 137 |
+
"temperature": 0.7 // Optional: 0.0 to 1.0
|
| 138 |
+
}
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
#### Chat Response
|
| 142 |
+
```json
|
| 143 |
+
{
|
| 144 |
+
"conversation_id": 123,
|
| 145 |
+
"message": "I'd be happy to help you organize your tasks!",
|
| 146 |
+
"role": "assistant",
|
| 147 |
+
"timestamp": "2026-01-14T10:30:00Z",
|
| 148 |
+
"token_count": 25,
|
| 149 |
+
"model": "gemini-pro"
|
| 150 |
+
}
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## AI Chatbot Features
|
| 154 |
+
|
| 155 |
+
### Phase 1 (Current)
|
| 156 |
+
- ✅ Natural conversation with AI assistant
|
| 157 |
+
- ✅ Multi-turn conversations with context retention
|
| 158 |
+
- ✅ Intent recognition for todo-related requests
|
| 159 |
+
- ✅ Conversation history persistence
|
| 160 |
+
- ✅ Automatic history trimming (20 messages / 8000 tokens)
|
| 161 |
+
- ✅ Free-tier AI provider support (Gemini)
|
| 162 |
+
|
| 163 |
+
### Phase 2 (Coming Soon)
|
| 164 |
+
- 🔄 MCP tools for task CRUD operations
|
| 165 |
+
- 🔄 AI can directly create, update, and delete tasks
|
| 166 |
+
- 🔄 Natural language task management
|
| 167 |
+
|
| 168 |
+
## Error Handling
|
| 169 |
+
|
| 170 |
+
The API returns standard HTTP status codes:
|
| 171 |
+
|
| 172 |
+
- `200 OK` - Request successful
|
| 173 |
+
- `400 Bad Request` - Invalid request data
|
| 174 |
+
- `401 Unauthorized` - Authentication required or failed
|
| 175 |
+
- `404 Not Found` - Resource not found
|
| 176 |
+
- `429 Too Many Requests` - Rate limit exceeded
|
| 177 |
+
- `500 Internal Server Error` - Server error
|
| 178 |
+
|
| 179 |
+
## Database Schema
|
| 180 |
+
|
| 181 |
+
### Users Table
|
| 182 |
+
- `id`: Primary key
|
| 183 |
+
- `email`: Unique email address
|
| 184 |
+
- `name`: User's name
|
| 185 |
+
- `password`: Hashed password
|
| 186 |
+
- `created_at`, `updated_at`: Timestamps
|
| 187 |
+
|
| 188 |
+
### Tasks Table
|
| 189 |
+
- `id`: Primary key
|
| 190 |
+
- `user_id`: Foreign key to users
|
| 191 |
+
- `title`: Task title
|
| 192 |
+
- `description`: Task description
|
| 193 |
+
- `completed`: Boolean status
|
| 194 |
+
- `created_at`, `updated_at`: Timestamps
|
| 195 |
+
|
| 196 |
+
### Conversation Table (New)
|
| 197 |
+
- `id`: Primary key
|
| 198 |
+
- `user_id`: Foreign key to users
|
| 199 |
+
- `title`: Conversation title
|
| 200 |
+
- `created_at`, `updated_at`: Timestamps
|
| 201 |
+
|
| 202 |
+
### Message Table (New)
|
| 203 |
+
- `id`: Primary key
|
| 204 |
+
- `conversation_id`: Foreign key to conversation
|
| 205 |
+
- `role`: "user" or "assistant"
|
| 206 |
+
- `content`: Message text
|
| 207 |
+
- `timestamp`: Message timestamp
|
| 208 |
+
- `token_count`: Token count for the message
|
| 209 |
+
|
| 210 |
+
## Development
|
| 211 |
+
|
| 212 |
+
### Running Tests
|
| 213 |
+
```bash
|
| 214 |
+
pytest
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
### Database Migrations
|
| 218 |
+
|
| 219 |
+
Create a new migration:
|
| 220 |
+
```bash
|
| 221 |
+
alembic revision -m "description"
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
Apply migrations:
|
| 225 |
+
```bash
|
| 226 |
+
alembic upgrade head
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
Rollback migration:
|
| 230 |
+
```bash
|
| 231 |
+
alembic downgrade -1
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
## License
|
| 235 |
+
|
| 236 |
+
MIT License
|
alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""add_conversation_and_message_tables
|
| 2 |
+
|
| 3 |
+
Revision ID: 48b10b49730f
|
| 4 |
+
Revises: 002_add_user_password
|
| 5 |
+
Create Date: 2026-01-14 10:44:27.010796
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from alembic import op
|
| 9 |
+
import sqlalchemy as sa
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# revision identifiers, used by Alembic.
|
| 13 |
+
revision = '48b10b49730f'
|
| 14 |
+
down_revision = '002_add_user_password'
|
| 15 |
+
branch_labels = None
|
| 16 |
+
depends_on = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def upgrade() -> None:
|
| 20 |
+
# Create conversation table
|
| 21 |
+
op.create_table(
|
| 22 |
+
'conversation',
|
| 23 |
+
sa.Column('id', sa.Integer(), nullable=False),
|
| 24 |
+
sa.Column('user_id', sa.Integer(), nullable=False),
|
| 25 |
+
sa.Column('title', sa.String(length=255), nullable=True),
|
| 26 |
+
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
| 27 |
+
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
| 28 |
+
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
| 29 |
+
sa.PrimaryKeyConstraint('id')
|
| 30 |
+
)
|
| 31 |
+
op.create_index('ix_conversation_user_id', 'conversation', ['user_id'])
|
| 32 |
+
op.create_index('ix_conversation_created_at', 'conversation', ['created_at'])
|
| 33 |
+
|
| 34 |
+
# Create message table
|
| 35 |
+
op.create_table(
|
| 36 |
+
'message',
|
| 37 |
+
sa.Column('id', sa.Integer(), nullable=False),
|
| 38 |
+
sa.Column('conversation_id', sa.Integer(), nullable=False),
|
| 39 |
+
sa.Column('role', sa.String(length=50), nullable=False),
|
| 40 |
+
sa.Column('content', sa.Text(), nullable=False),
|
| 41 |
+
sa.Column('timestamp', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
| 42 |
+
sa.Column('token_count', sa.Integer(), nullable=True),
|
| 43 |
+
sa.ForeignKeyConstraint(['conversation_id'], ['conversation.id'], ondelete='CASCADE'),
|
| 44 |
+
sa.PrimaryKeyConstraint('id')
|
| 45 |
+
)
|
| 46 |
+
op.create_index('ix_message_conversation_id', 'message', ['conversation_id'])
|
| 47 |
+
op.create_index('ix_message_timestamp', 'message', ['timestamp'])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def downgrade() -> None:
|
| 51 |
+
# Drop message table first (due to foreign key dependency)
|
| 52 |
+
op.drop_index('ix_message_timestamp', table_name='message')
|
| 53 |
+
op.drop_index('ix_message_conversation_id', table_name='message')
|
| 54 |
+
op.drop_table('message')
|
| 55 |
+
|
| 56 |
+
# Drop conversation table
|
| 57 |
+
op.drop_index('ix_conversation_created_at', table_name='conversation')
|
| 58 |
+
op.drop_index('ix_conversation_user_id', table_name='conversation')
|
| 59 |
+
op.drop_table('conversation')
|
alembic/versions/20260114_1115_37ca2e18468d_description.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""description
|
| 2 |
+
|
| 3 |
+
Revision ID: 37ca2e18468d
|
| 4 |
+
Revises: 48b10b49730f
|
| 5 |
+
Create Date: 2026-01-14 11:15:03.691055
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from alembic import op
|
| 9 |
+
import sqlalchemy as sa
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# revision identifiers, used by Alembic.
|
| 13 |
+
revision = '37ca2e18468d'
|
| 14 |
+
down_revision = '48b10b49730f'
|
| 15 |
+
branch_labels = None
|
| 16 |
+
depends_on = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def upgrade() -> None:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def downgrade() -> None:
|
| 24 |
+
pass
|
alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""add metadata column to message table
|
| 2 |
+
|
| 3 |
+
Revision ID: a3c44bf7ddcb
|
| 4 |
+
Revises: 37ca2e18468d
|
| 5 |
+
Create Date: 2026-01-14 17:02:51.060200
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from alembic import op
|
| 9 |
+
import sqlalchemy as sa
|
| 10 |
+
from sqlalchemy.dialects import postgresql
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# revision identifiers, used by Alembic.
|
| 14 |
+
revision = 'a3c44bf7ddcb'
|
| 15 |
+
down_revision = '37ca2e18468d'
|
| 16 |
+
branch_labels = None
|
| 17 |
+
depends_on = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def upgrade() -> None:
|
| 21 |
+
# Add metadata column to message table
|
| 22 |
+
op.add_column('message', sa.Column('metadata', postgresql.JSON(astext_type=sa.Text()), nullable=True))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def downgrade() -> None:
|
| 26 |
+
# Remove metadata column from message table
|
| 27 |
+
op.drop_column('message', 'metadata')
|
alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""rename metadata to tool_metadata in message table
|
| 2 |
+
|
| 3 |
+
Revision ID: e8275e6c143c
|
| 4 |
+
Revises: a3c44bf7ddcb
|
| 5 |
+
Create Date: 2026-01-14 17:12:53.740315
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from alembic import op
|
| 9 |
+
import sqlalchemy as sa
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# revision identifiers, used by Alembic.
|
| 13 |
+
revision = 'e8275e6c143c'
|
| 14 |
+
down_revision = 'a3c44bf7ddcb'
|
| 15 |
+
branch_labels = None
|
| 16 |
+
depends_on = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def upgrade() -> None:
|
| 20 |
+
# Rename metadata column to tool_metadata
|
| 21 |
+
op.alter_column('message', 'metadata', new_column_name='tool_metadata')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def downgrade() -> None:
|
| 25 |
+
# Rename tool_metadata column back to metadata
|
| 26 |
+
op.alter_column('message', 'tool_metadata', new_column_name='metadata')
|
alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""add due_date and priority to task table
|
| 2 |
+
|
| 3 |
+
Revision ID: d34db62bd406
|
| 4 |
+
Revises: e8275e6c143c
|
| 5 |
+
Create Date: 2026-01-14 19:00:45.426280
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from alembic import op
|
| 9 |
+
import sqlalchemy as sa
|
| 10 |
+
from sqlalchemy.dialects import postgresql
|
| 11 |
+
|
| 12 |
+
# revision identifiers, used by Alembic.
|
| 13 |
+
revision = 'd34db62bd406'
|
| 14 |
+
down_revision = 'e8275e6c143c'
|
| 15 |
+
branch_labels = None
|
| 16 |
+
depends_on = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def upgrade() -> None:
|
| 20 |
+
# Add due_date and priority columns to tasks table
|
| 21 |
+
op.add_column('tasks', sa.Column('due_date', sa.Date(), nullable=True))
|
| 22 |
+
op.add_column('tasks', sa.Column('priority', sa.String(length=20), nullable=False, server_default='medium'))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def downgrade() -> None:
|
| 26 |
+
# Remove due_date and priority columns from tasks table
|
| 27 |
+
op.drop_column('tasks', 'priority')
|
| 28 |
+
op.drop_column('tasks', 'due_date')
|
alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc
ADDED
|
Binary file (771 Bytes). View file
|
|
|
alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc
ADDED
|
Binary file (4.64 kB). View file
|
|
|
alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc
ADDED
|
Binary file (1.25 kB). View file
|
|
|
alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
alembic/versions/tmpclaude-82e1-cwd
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/d/Agentic_ai_learning/hacathoon_2/evolution-of-todo/phase-2-full-stack-web-app/backend/alembic/versions
|
package-lock.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "backend",
|
| 3 |
+
"lockfileVersion": 3,
|
| 4 |
+
"requires": true,
|
| 5 |
+
"packages": {}
|
| 6 |
+
}
|
requirements.txt
CHANGED
|
@@ -13,4 +13,10 @@ passlib[bcrypt]==1.7.4
|
|
| 13 |
python-multipart==0.0.6
|
| 14 |
|
| 15 |
# Add this line (or replace if bcrypt is already listed)
|
| 16 |
-
bcrypt==4.3.0 # Last stable version before the 5.0 break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
python-multipart==0.0.6
|
| 14 |
|
| 15 |
# Add this line (or replace if bcrypt is already listed)
|
| 16 |
+
bcrypt==4.3.0 # Last stable version before the 5.0 break
|
| 17 |
+
|
| 18 |
+
# AI Chatbot dependencies
|
| 19 |
+
google-generativeai==0.3.2 # Gemini API client
|
| 20 |
+
tiktoken==0.5.2 # Token counting (optional)
|
| 21 |
+
mcp==1.20.0 # Official MCP SDK for tool server
|
| 22 |
+
cohere==4.37 # Cohere API client (optional provider)
|
src/__pycache__/main.cpython-313.pyc
CHANGED
|
Binary files a/src/__pycache__/main.cpython-313.pyc and b/src/__pycache__/main.cpython-313.pyc differ
|
|
|
src/agent/__init__.py
ADDED
|
File without changes
|
src/agent/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
src/agent/__pycache__/agent_config.cpython-313.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
src/agent/__pycache__/agent_runner.cpython-313.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
src/agent/agent_config.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Configuration
|
| 3 |
+
|
| 4 |
+
Configuration dataclass for agent behavior and LLM provider settings.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class AgentConfiguration:
|
| 13 |
+
"""
|
| 14 |
+
Configuration for the AI agent.
|
| 15 |
+
|
| 16 |
+
This dataclass defines all configurable parameters for agent behavior,
|
| 17 |
+
including LLM provider settings, conversation limits, and system prompts.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Provider settings
|
| 21 |
+
provider: str = "gemini" # Options: gemini, openrouter, cohere
|
| 22 |
+
fallback_provider: Optional[str] = None # Optional fallback provider
|
| 23 |
+
model: Optional[str] = None # Model name (provider-specific)
|
| 24 |
+
|
| 25 |
+
# API keys (loaded from environment)
|
| 26 |
+
gemini_api_key: Optional[str] = None
|
| 27 |
+
openrouter_api_key: Optional[str] = None
|
| 28 |
+
cohere_api_key: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
# Generation parameters
|
| 31 |
+
temperature: float = 0.7 # Sampling temperature (0.0 to 1.0)
|
| 32 |
+
max_tokens: int = 8192 # Maximum tokens in response
|
| 33 |
+
|
| 34 |
+
# Conversation history limits (for free-tier constraints)
|
| 35 |
+
max_messages: int = 20 # Maximum messages to keep in history
|
| 36 |
+
max_conversation_tokens: int = 8000 # Maximum tokens in conversation history
|
| 37 |
+
|
| 38 |
+
# System prompt
|
| 39 |
+
system_prompt: str = """You are a helpful AI assistant for managing tasks.
|
| 40 |
+
You can help users create, view, complete, update, and delete tasks using natural language.
|
| 41 |
+
|
| 42 |
+
Available tools:
|
| 43 |
+
- add_task: Create a new task
|
| 44 |
+
- list_tasks: View all tasks (with optional filtering)
|
| 45 |
+
- complete_task: Mark a task as completed
|
| 46 |
+
- delete_task: Remove a task
|
| 47 |
+
- update_task: Modify task properties
|
| 48 |
+
|
| 49 |
+
Always respond in a friendly, conversational manner and confirm actions taken."""
|
| 50 |
+
|
| 51 |
+
# Retry settings
|
| 52 |
+
max_retries: int = 3 # Maximum retries on rate limit errors
|
| 53 |
+
retry_delay: float = 1.0 # Delay between retries (seconds)
|
| 54 |
+
|
| 55 |
+
def get_provider_api_key(self, provider_name: str) -> Optional[str]:
|
| 56 |
+
"""
|
| 57 |
+
Get API key for a specific provider.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
provider_name: Provider name (gemini, openrouter, cohere)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
API key or None if not configured
|
| 64 |
+
"""
|
| 65 |
+
if provider_name == "gemini":
|
| 66 |
+
return self.gemini_api_key
|
| 67 |
+
elif provider_name == "openrouter":
|
| 68 |
+
return self.openrouter_api_key
|
| 69 |
+
elif provider_name == "cohere":
|
| 70 |
+
return self.cohere_api_key
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
def get_provider_model(self, provider_name: str) -> str:
|
| 74 |
+
"""
|
| 75 |
+
Get default model for a specific provider.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
provider_name: Provider name
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Model identifier
|
| 82 |
+
"""
|
| 83 |
+
if self.model:
|
| 84 |
+
return self.model
|
| 85 |
+
|
| 86 |
+
# Default models per provider
|
| 87 |
+
defaults = {
|
| 88 |
+
"gemini": "gemini-flash-latest",
|
| 89 |
+
"openrouter": "google/gemini-flash-1.5",
|
| 90 |
+
"cohere": "command-r-plus"
|
| 91 |
+
}
|
| 92 |
+
return defaults.get(provider_name, "gemini-flash-latest")
|
| 93 |
+
|
| 94 |
+
def validate(self) -> bool:
|
| 95 |
+
"""
|
| 96 |
+
Validate configuration.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
True if configuration is valid
|
| 100 |
+
|
| 101 |
+
Raises:
|
| 102 |
+
ValueError: If configuration is invalid
|
| 103 |
+
"""
|
| 104 |
+
# Check primary provider has API key
|
| 105 |
+
primary_key = self.get_provider_api_key(self.provider)
|
| 106 |
+
if not primary_key:
|
| 107 |
+
raise ValueError(f"API key not configured for primary provider: {self.provider}")
|
| 108 |
+
|
| 109 |
+
# Validate temperature range
|
| 110 |
+
if not 0.0 <= self.temperature <= 1.0:
|
| 111 |
+
raise ValueError(f"Temperature must be between 0.0 and 1.0, got: {self.temperature}")
|
| 112 |
+
|
| 113 |
+
# Validate max_tokens
|
| 114 |
+
if self.max_tokens <= 0:
|
| 115 |
+
raise ValueError(f"max_tokens must be positive, got: {self.max_tokens}")
|
| 116 |
+
|
| 117 |
+
# Validate conversation limits
|
| 118 |
+
if self.max_messages <= 0:
|
| 119 |
+
raise ValueError(f"max_messages must be positive, got: {self.max_messages}")
|
| 120 |
+
|
| 121 |
+
if self.max_conversation_tokens <= 0:
|
| 122 |
+
raise ValueError(f"max_conversation_tokens must be positive, got: {self.max_conversation_tokens}")
|
| 123 |
+
|
| 124 |
+
return True
|
src/agent/agent_runner.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Runner
|
| 3 |
+
|
| 4 |
+
Core orchestrator for AI agent execution with tool calling support.
|
| 5 |
+
Manages the full request cycle: LLM generation → tool execution → final response.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import asyncio
|
| 11 |
+
|
| 12 |
+
from .agent_config import AgentConfiguration
|
| 13 |
+
from .providers.base import LLMProvider
|
| 14 |
+
from .providers.gemini import GeminiProvider
|
| 15 |
+
from .providers.openrouter import OpenRouterProvider
|
| 16 |
+
from .providers.cohere import CohereProvider
|
| 17 |
+
from ..mcp.tool_registry import MCPToolRegistry, ToolExecutionResult
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AgentRunner:
|
| 23 |
+
"""
|
| 24 |
+
Agent execution orchestrator with tool calling support.
|
| 25 |
+
|
| 26 |
+
This class manages the full agent request cycle:
|
| 27 |
+
1. Generate LLM response with tool definitions
|
| 28 |
+
2. If tool calls requested, execute tools with user context injection
|
| 29 |
+
3. Generate final response with tool results
|
| 30 |
+
4. Handle rate limiting with fallback providers
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: AgentConfiguration, tool_registry: MCPToolRegistry):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the agent runner.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
config: Agent configuration
|
| 39 |
+
tool_registry: MCP tool registry
|
| 40 |
+
"""
|
| 41 |
+
self.config = config
|
| 42 |
+
self.tool_registry = tool_registry
|
| 43 |
+
self.primary_provider = self._create_provider(config.provider)
|
| 44 |
+
self.fallback_provider = None
|
| 45 |
+
|
| 46 |
+
if config.fallback_provider:
|
| 47 |
+
self.fallback_provider = self._create_provider(config.fallback_provider)
|
| 48 |
+
|
| 49 |
+
logger.info(f"Initialized AgentRunner with provider: {config.provider}")
|
| 50 |
+
|
| 51 |
+
def _create_provider(self, provider_name: str) -> LLMProvider:
|
| 52 |
+
"""
|
| 53 |
+
Create an LLM provider instance.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
provider_name: Provider name (gemini, openrouter, cohere)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
LLMProvider instance
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
ValueError: If provider is not supported or API key is missing
|
| 63 |
+
"""
|
| 64 |
+
api_key = self.config.get_provider_api_key(provider_name)
|
| 65 |
+
if not api_key:
|
| 66 |
+
raise ValueError(f"API key not configured for provider: {provider_name}")
|
| 67 |
+
|
| 68 |
+
model = self.config.get_provider_model(provider_name)
|
| 69 |
+
|
| 70 |
+
if provider_name == "gemini":
|
| 71 |
+
return GeminiProvider(
|
| 72 |
+
api_key=api_key,
|
| 73 |
+
model=model,
|
| 74 |
+
temperature=self.config.temperature,
|
| 75 |
+
max_tokens=self.config.max_tokens
|
| 76 |
+
)
|
| 77 |
+
elif provider_name == "openrouter":
|
| 78 |
+
return OpenRouterProvider(
|
| 79 |
+
api_key=api_key,
|
| 80 |
+
model=model,
|
| 81 |
+
temperature=self.config.temperature,
|
| 82 |
+
max_tokens=self.config.max_tokens
|
| 83 |
+
)
|
| 84 |
+
elif provider_name == "cohere":
|
| 85 |
+
return CohereProvider(
|
| 86 |
+
api_key=api_key,
|
| 87 |
+
model=model,
|
| 88 |
+
temperature=self.config.temperature,
|
| 89 |
+
max_tokens=self.config.max_tokens
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(f"Unsupported provider: {provider_name}")
|
| 93 |
+
|
| 94 |
+
async def execute(
|
| 95 |
+
self,
|
| 96 |
+
messages: List[Dict[str, str]],
|
| 97 |
+
user_id: int,
|
| 98 |
+
system_prompt: Optional[str] = None
|
| 99 |
+
) -> Dict[str, Any]:
|
| 100 |
+
"""
|
| 101 |
+
Execute agent request with tool calling support.
|
| 102 |
+
|
| 103 |
+
SECURITY: user_id is injected by backend, never from LLM output.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
messages: Conversation history [{"role": "user", "content": "..."}]
|
| 107 |
+
user_id: User ID (injected by backend for security)
|
| 108 |
+
system_prompt: Optional system prompt (uses config default if not provided)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Dict with response content and metadata
|
| 112 |
+
"""
|
| 113 |
+
prompt = system_prompt or self.config.system_prompt
|
| 114 |
+
provider = self.primary_provider
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Get tool definitions
|
| 118 |
+
tool_definitions = self.tool_registry.get_tool_definitions()
|
| 119 |
+
|
| 120 |
+
logger.info(f"Executing agent for user {user_id} with {len(tool_definitions)} tools")
|
| 121 |
+
|
| 122 |
+
# Generate initial response with tool definitions
|
| 123 |
+
response = await provider.generate_response_with_tools(
|
| 124 |
+
messages=messages,
|
| 125 |
+
system_prompt=prompt,
|
| 126 |
+
tools=tool_definitions
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Check if tool calls were requested
|
| 130 |
+
if response.tool_calls:
|
| 131 |
+
logger.info(f"Agent requested {len(response.tool_calls)} tool calls")
|
| 132 |
+
|
| 133 |
+
# Execute all tool calls
|
| 134 |
+
tool_results = []
|
| 135 |
+
for tool_call in response.tool_calls:
|
| 136 |
+
result = await self.tool_registry.execute_tool(
|
| 137 |
+
tool_name=tool_call["name"],
|
| 138 |
+
arguments=tool_call["arguments"],
|
| 139 |
+
user_id=user_id # Inject user context for security
|
| 140 |
+
)
|
| 141 |
+
tool_results.append(result)
|
| 142 |
+
|
| 143 |
+
# Generate final response with tool results
|
| 144 |
+
final_response = await provider.generate_response_with_tool_results(
|
| 145 |
+
messages=messages,
|
| 146 |
+
tool_calls=response.tool_calls,
|
| 147 |
+
tool_results=tool_results
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"content": final_response.content,
|
| 152 |
+
"tool_calls": response.tool_calls,
|
| 153 |
+
"tool_results": tool_results,
|
| 154 |
+
"provider": provider.get_provider_name()
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# No tool calls, return direct response
|
| 158 |
+
logger.info("Agent generated direct response (no tool calls)")
|
| 159 |
+
return {
|
| 160 |
+
"content": response.content,
|
| 161 |
+
"tool_calls": None,
|
| 162 |
+
"tool_results": None,
|
| 163 |
+
"provider": provider.get_provider_name()
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Agent execution failed with primary provider: {str(e)}")
|
| 168 |
+
|
| 169 |
+
# Try fallback provider if configured
|
| 170 |
+
if self.fallback_provider:
|
| 171 |
+
logger.info("Attempting fallback provider")
|
| 172 |
+
try:
|
| 173 |
+
return await self._execute_with_provider(
|
| 174 |
+
provider=self.fallback_provider,
|
| 175 |
+
messages=messages,
|
| 176 |
+
user_id=user_id,
|
| 177 |
+
system_prompt=prompt
|
| 178 |
+
)
|
| 179 |
+
except Exception as fallback_error:
|
| 180 |
+
logger.error(f"Fallback provider also failed: {str(fallback_error)}")
|
| 181 |
+
raise
|
| 182 |
+
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
async def _execute_with_provider(
|
| 186 |
+
self,
|
| 187 |
+
provider: LLMProvider,
|
| 188 |
+
messages: List[Dict[str, str]],
|
| 189 |
+
user_id: int,
|
| 190 |
+
system_prompt: str
|
| 191 |
+
) -> Dict[str, Any]:
|
| 192 |
+
"""
|
| 193 |
+
Execute agent request with a specific provider.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
provider: LLM provider to use
|
| 197 |
+
messages: Conversation history
|
| 198 |
+
user_id: User ID
|
| 199 |
+
system_prompt: System prompt
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dict with response content and metadata
|
| 203 |
+
"""
|
| 204 |
+
tool_definitions = self.tool_registry.get_tool_definitions()
|
| 205 |
+
|
| 206 |
+
# Generate initial response
|
| 207 |
+
response = await provider.generate_response_with_tools(
|
| 208 |
+
messages=messages,
|
| 209 |
+
system_prompt=system_prompt,
|
| 210 |
+
tools=tool_definitions
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Handle tool calls
|
| 214 |
+
if response.tool_calls:
|
| 215 |
+
tool_results = []
|
| 216 |
+
for tool_call in response.tool_calls:
|
| 217 |
+
result = await self.tool_registry.execute_tool(
|
| 218 |
+
tool_name=tool_call["name"],
|
| 219 |
+
arguments=tool_call["arguments"],
|
| 220 |
+
user_id=user_id
|
| 221 |
+
)
|
| 222 |
+
tool_results.append(result)
|
| 223 |
+
|
| 224 |
+
final_response = await provider.generate_response_with_tool_results(
|
| 225 |
+
messages=messages,
|
| 226 |
+
tool_calls=response.tool_calls,
|
| 227 |
+
tool_results=tool_results
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return {
|
| 231 |
+
"content": final_response.content,
|
| 232 |
+
"tool_calls": response.tool_calls,
|
| 233 |
+
"tool_results": tool_results,
|
| 234 |
+
"provider": provider.get_provider_name()
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
"content": response.content,
|
| 239 |
+
"tool_calls": None,
|
| 240 |
+
"tool_results": None,
|
| 241 |
+
"provider": provider.get_provider_name()
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
async def execute_simple(
|
| 245 |
+
self,
|
| 246 |
+
messages: List[Dict[str, str]],
|
| 247 |
+
system_prompt: Optional[str] = None
|
| 248 |
+
) -> str:
|
| 249 |
+
"""
|
| 250 |
+
Execute a simple agent request without tool calling.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
messages: Conversation history
|
| 254 |
+
system_prompt: Optional system prompt
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
Response content as string
|
| 258 |
+
"""
|
| 259 |
+
prompt = system_prompt or self.config.system_prompt
|
| 260 |
+
provider = self.primary_provider
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
response = await provider.generate_simple_response(
|
| 264 |
+
messages=messages,
|
| 265 |
+
system_prompt=prompt
|
| 266 |
+
)
|
| 267 |
+
return response.content or ""
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"Simple execution failed: {str(e)}")
|
| 271 |
+
|
| 272 |
+
# Try fallback provider
|
| 273 |
+
if self.fallback_provider:
|
| 274 |
+
logger.info("Attempting fallback provider for simple execution")
|
| 275 |
+
response = await self.fallback_provider.generate_simple_response(
|
| 276 |
+
messages=messages,
|
| 277 |
+
system_prompt=prompt
|
| 278 |
+
)
|
| 279 |
+
return response.content or ""
|
| 280 |
+
|
| 281 |
+
raise
|
src/agent/providers/__init__.py
ADDED
|
File without changes
|
src/agent/providers/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|
src/agent/providers/__pycache__/base.cpython-313.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
src/agent/providers/__pycache__/cohere.cpython-313.pyc
ADDED
|
Binary file (8.05 kB). View file
|
|
|
src/agent/providers/__pycache__/gemini.cpython-313.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
src/agent/providers/__pycache__/openrouter.cpython-313.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
src/agent/providers/base.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Provider Base Class
|
| 3 |
+
|
| 4 |
+
Abstract base class for LLM provider implementations.
|
| 5 |
+
Defines the interface for generating responses with function calling support.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class LLMResponse:
|
| 15 |
+
"""Response from an LLM provider."""
|
| 16 |
+
content: Optional[str] = None
|
| 17 |
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
| 18 |
+
finish_reason: Optional[str] = None
|
| 19 |
+
usage: Optional[Dict[str, int]] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LLMProvider(ABC):
|
| 23 |
+
"""
|
| 24 |
+
Abstract base class for LLM providers.
|
| 25 |
+
|
| 26 |
+
All provider implementations (Gemini, OpenRouter, Cohere) must
|
| 27 |
+
implement these methods to support function calling and tool execution.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, api_key: str, model: str, temperature: float = 0.7, max_tokens: int = 8192):
|
| 31 |
+
"""
|
| 32 |
+
Initialize the LLM provider.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
api_key: API key for the provider
|
| 36 |
+
model: Model identifier (e.g., "gemini-1.5-flash")
|
| 37 |
+
temperature: Sampling temperature (0.0 to 1.0)
|
| 38 |
+
max_tokens: Maximum tokens in response
|
| 39 |
+
"""
|
| 40 |
+
self.api_key = api_key
|
| 41 |
+
self.model = model
|
| 42 |
+
self.temperature = temperature
|
| 43 |
+
self.max_tokens = max_tokens
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
async def generate_response_with_tools(
|
| 47 |
+
self,
|
| 48 |
+
messages: List[Dict[str, str]],
|
| 49 |
+
system_prompt: str,
|
| 50 |
+
tools: List[Dict[str, Any]]
|
| 51 |
+
) -> LLMResponse:
|
| 52 |
+
"""
|
| 53 |
+
Generate a response with function calling support.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
messages: Conversation history [{"role": "user", "content": "..."}]
|
| 57 |
+
system_prompt: System instructions for the agent
|
| 58 |
+
tools: Tool definitions for function calling
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
LLMResponse with content and/or tool_calls
|
| 62 |
+
"""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
async def generate_response_with_tool_results(
|
| 67 |
+
self,
|
| 68 |
+
messages: List[Dict[str, str]],
|
| 69 |
+
tool_calls: List[Dict[str, Any]],
|
| 70 |
+
tool_results: List[Dict[str, Any]]
|
| 71 |
+
) -> LLMResponse:
|
| 72 |
+
"""
|
| 73 |
+
Generate a final response after tool execution.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
messages: Original conversation history
|
| 77 |
+
tool_calls: Tool calls that were made
|
| 78 |
+
tool_results: Results from tool execution
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
LLMResponse with final content
|
| 82 |
+
"""
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
async def generate_simple_response(
|
| 87 |
+
self,
|
| 88 |
+
messages: List[Dict[str, str]],
|
| 89 |
+
system_prompt: str
|
| 90 |
+
) -> LLMResponse:
|
| 91 |
+
"""
|
| 92 |
+
Generate a simple response without function calling.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
messages: Conversation history
|
| 96 |
+
system_prompt: System instructions
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
LLMResponse with content
|
| 100 |
+
"""
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
def get_provider_name(self) -> str:
|
| 104 |
+
"""Get the provider name (e.g., 'gemini', 'openrouter', 'cohere')."""
|
| 105 |
+
return self.__class__.__name__.replace("Provider", "").lower()
|
src/agent/providers/cohere.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cohere Provider Implementation
|
| 3 |
+
|
| 4 |
+
Cohere API provider with function calling support.
|
| 5 |
+
Optional provider (trial only, not recommended for production).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
import cohere
|
| 11 |
+
|
| 12 |
+
from .base import LLMProvider, LLMResponse
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CohereProvider(LLMProvider):
|
| 18 |
+
"""
|
| 19 |
+
Cohere API provider implementation.
|
| 20 |
+
|
| 21 |
+
Features:
|
| 22 |
+
- Native function calling support
|
| 23 |
+
- Trial tier only (not recommended for production)
|
| 24 |
+
- Model: command-r-plus (best for function calling)
|
| 25 |
+
|
| 26 |
+
Note: Cohere requires a paid plan after trial expires.
|
| 27 |
+
Use Gemini or OpenRouter for free-tier operation.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
api_key: str,
|
| 33 |
+
model: str = "command-r-plus",
|
| 34 |
+
temperature: float = 0.7,
|
| 35 |
+
max_tokens: int = 8192
|
| 36 |
+
):
|
| 37 |
+
super().__init__(api_key, model, temperature, max_tokens)
|
| 38 |
+
self.client = cohere.Client(api_key)
|
| 39 |
+
logger.info(f"Initialized CohereProvider with model: {model}")
|
| 40 |
+
|
| 41 |
+
def _convert_tools_to_cohere_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 42 |
+
"""
|
| 43 |
+
Convert MCP tool definitions to Cohere tool format.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
tools: MCP tool definitions
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List of Cohere-formatted tool definitions
|
| 50 |
+
"""
|
| 51 |
+
return [
|
| 52 |
+
{
|
| 53 |
+
"name": tool["name"],
|
| 54 |
+
"description": tool["description"],
|
| 55 |
+
"parameter_definitions": tool["parameters"].get("properties", {})
|
| 56 |
+
}
|
| 57 |
+
for tool in tools
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
async def generate_response_with_tools(
|
| 61 |
+
self,
|
| 62 |
+
messages: List[Dict[str, str]],
|
| 63 |
+
system_prompt: str,
|
| 64 |
+
tools: List[Dict[str, Any]]
|
| 65 |
+
) -> LLMResponse:
|
| 66 |
+
"""
|
| 67 |
+
Generate a response with function calling support.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
messages: Conversation history
|
| 71 |
+
system_prompt: System instructions
|
| 72 |
+
tools: Tool definitions
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
LLMResponse with content and/or tool_calls
|
| 76 |
+
"""
|
| 77 |
+
try:
|
| 78 |
+
# Convert tools to Cohere format
|
| 79 |
+
cohere_tools = self._convert_tools_to_cohere_format(tools)
|
| 80 |
+
|
| 81 |
+
# Format chat history for Cohere
|
| 82 |
+
chat_history = []
|
| 83 |
+
for msg in messages[:-1]: # All except last message
|
| 84 |
+
chat_history.append({
|
| 85 |
+
"role": "USER" if msg["role"] == "user" else "CHATBOT",
|
| 86 |
+
"message": msg["content"]
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
# Last message is the current user message
|
| 90 |
+
current_message = messages[-1]["content"] if messages else ""
|
| 91 |
+
|
| 92 |
+
# Generate response with function calling
|
| 93 |
+
response = self.client.chat(
|
| 94 |
+
message=current_message,
|
| 95 |
+
chat_history=chat_history,
|
| 96 |
+
preamble=system_prompt,
|
| 97 |
+
model=self.model,
|
| 98 |
+
temperature=self.temperature,
|
| 99 |
+
max_tokens=self.max_tokens,
|
| 100 |
+
tools=cohere_tools
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Check for tool calls
|
| 104 |
+
if response.tool_calls:
|
| 105 |
+
tool_calls = [
|
| 106 |
+
{
|
| 107 |
+
"name": tc.name,
|
| 108 |
+
"arguments": tc.parameters
|
| 109 |
+
}
|
| 110 |
+
for tc in response.tool_calls
|
| 111 |
+
]
|
| 112 |
+
logger.info(f"Cohere requested function calls: {[tc['name'] for tc in tool_calls]}")
|
| 113 |
+
return LLMResponse(
|
| 114 |
+
content=None,
|
| 115 |
+
tool_calls=tool_calls,
|
| 116 |
+
finish_reason="tool_calls"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Regular text response
|
| 120 |
+
content = response.text
|
| 121 |
+
logger.info("Cohere generated text response")
|
| 122 |
+
return LLMResponse(
|
| 123 |
+
content=content,
|
| 124 |
+
finish_reason="COMPLETE"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Cohere API error: {str(e)}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
async def generate_response_with_tool_results(
|
| 132 |
+
self,
|
| 133 |
+
messages: List[Dict[str, str]],
|
| 134 |
+
tool_calls: List[Dict[str, Any]],
|
| 135 |
+
tool_results: List[Dict[str, Any]]
|
| 136 |
+
) -> LLMResponse:
|
| 137 |
+
"""
|
| 138 |
+
Generate a final response after tool execution.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
messages: Original conversation history
|
| 142 |
+
tool_calls: Tool calls that were made
|
| 143 |
+
tool_results: Results from tool execution
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
LLMResponse with final content
|
| 147 |
+
"""
|
| 148 |
+
try:
|
| 149 |
+
# Format chat history
|
| 150 |
+
chat_history = []
|
| 151 |
+
for msg in messages:
|
| 152 |
+
chat_history.append({
|
| 153 |
+
"role": "USER" if msg["role"] == "user" else "CHATBOT",
|
| 154 |
+
"message": msg["content"]
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
# Format tool results for Cohere
|
| 158 |
+
tool_results_formatted = [
|
| 159 |
+
{
|
| 160 |
+
"call": {"name": call["name"], "parameters": call["arguments"]},
|
| 161 |
+
"outputs": [{"result": str(result)}]
|
| 162 |
+
}
|
| 163 |
+
for call, result in zip(tool_calls, tool_results)
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
# Generate final response
|
| 167 |
+
response = self.client.chat(
|
| 168 |
+
message="Based on the tool results, provide a natural language response.",
|
| 169 |
+
chat_history=chat_history,
|
| 170 |
+
model=self.model,
|
| 171 |
+
temperature=self.temperature,
|
| 172 |
+
max_tokens=self.max_tokens,
|
| 173 |
+
tool_results=tool_results_formatted
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
content = response.text
|
| 177 |
+
logger.info("Cohere generated final response after tool execution")
|
| 178 |
+
return LLMResponse(
|
| 179 |
+
content=content,
|
| 180 |
+
finish_reason="COMPLETE"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"Cohere API error in tool results: {str(e)}")
|
| 185 |
+
raise
|
| 186 |
+
|
| 187 |
+
async def generate_simple_response(
|
| 188 |
+
self,
|
| 189 |
+
messages: List[Dict[str, str]],
|
| 190 |
+
system_prompt: str
|
| 191 |
+
) -> LLMResponse:
|
| 192 |
+
"""
|
| 193 |
+
Generate a simple response without function calling.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
messages: Conversation history
|
| 197 |
+
system_prompt: System instructions
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
LLMResponse with content
|
| 201 |
+
"""
|
| 202 |
+
try:
|
| 203 |
+
# Format chat history
|
| 204 |
+
chat_history = []
|
| 205 |
+
for msg in messages[:-1]:
|
| 206 |
+
chat_history.append({
|
| 207 |
+
"role": "USER" if msg["role"] == "user" else "CHATBOT",
|
| 208 |
+
"message": msg["content"]
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
current_message = messages[-1]["content"] if messages else ""
|
| 212 |
+
|
| 213 |
+
# Generate response
|
| 214 |
+
response = self.client.chat(
|
| 215 |
+
message=current_message,
|
| 216 |
+
chat_history=chat_history,
|
| 217 |
+
preamble=system_prompt,
|
| 218 |
+
model=self.model,
|
| 219 |
+
temperature=self.temperature,
|
| 220 |
+
max_tokens=self.max_tokens
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
content = response.text
|
| 224 |
+
logger.info("Cohere generated simple response")
|
| 225 |
+
return LLMResponse(
|
| 226 |
+
content=content,
|
| 227 |
+
finish_reason="COMPLETE"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Cohere API error: {str(e)}")
|
| 232 |
+
raise
|
src/agent/providers/gemini.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini Provider Implementation
|
| 3 |
+
|
| 4 |
+
Google Gemini API provider with function calling support.
|
| 5 |
+
Primary provider for free-tier operation (15 RPM, 1M token context).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
import google.generativeai as genai
|
| 11 |
+
from google.generativeai.types import FunctionDeclaration, Tool
|
| 12 |
+
|
| 13 |
+
from .base import LLMProvider, LLMResponse
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GeminiProvider(LLMProvider):
|
| 19 |
+
"""
|
| 20 |
+
Google Gemini API provider implementation.
|
| 21 |
+
|
| 22 |
+
Features:
|
| 23 |
+
- Native function calling support
|
| 24 |
+
- 1M token context window
|
| 25 |
+
- Free tier: 15 requests/minute
|
| 26 |
+
- Model: gemini-1.5-flash (recommended for free tier)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, api_key: str, model: str = "gemini-flash-latest", temperature: float = 0.7, max_tokens: int = 8192):
|
| 30 |
+
super().__init__(api_key, model, temperature, max_tokens)
|
| 31 |
+
genai.configure(api_key=api_key)
|
| 32 |
+
self.client = genai.GenerativeModel(model)
|
| 33 |
+
logger.info(f"Initialized GeminiProvider with model: {model}")
|
| 34 |
+
|
| 35 |
+
def _sanitize_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
| 36 |
+
"""
|
| 37 |
+
Sanitize JSON Schema to be Gemini-compatible.
|
| 38 |
+
|
| 39 |
+
Gemini only supports a subset of JSON Schema keywords:
|
| 40 |
+
- Supported: type, description, enum, required, properties, items
|
| 41 |
+
- NOT supported: maxLength, minLength, pattern, format, minimum, maximum, default, etc.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
schema: Original JSON Schema
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Gemini-compatible schema with unsupported fields removed
|
| 48 |
+
"""
|
| 49 |
+
# Fields that Gemini supports
|
| 50 |
+
ALLOWED_FIELDS = {
|
| 51 |
+
"type", "description", "enum", "required",
|
| 52 |
+
"properties", "items"
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# Create a sanitized copy
|
| 56 |
+
sanitized = {}
|
| 57 |
+
|
| 58 |
+
for key, value in schema.items():
|
| 59 |
+
if key in ALLOWED_FIELDS:
|
| 60 |
+
# Recursively sanitize nested objects
|
| 61 |
+
if key == "properties" and isinstance(value, dict):
|
| 62 |
+
sanitized[key] = {
|
| 63 |
+
prop_name: self._sanitize_schema_for_gemini(prop_schema)
|
| 64 |
+
for prop_name, prop_schema in value.items()
|
| 65 |
+
}
|
| 66 |
+
elif key == "items" and isinstance(value, dict):
|
| 67 |
+
sanitized[key] = self._sanitize_schema_for_gemini(value)
|
| 68 |
+
else:
|
| 69 |
+
sanitized[key] = value
|
| 70 |
+
|
| 71 |
+
return sanitized
|
| 72 |
+
|
| 73 |
+
def _convert_tools_to_gemini_format(self, tools: List[Dict[str, Any]]) -> List[Tool]:
|
| 74 |
+
"""
|
| 75 |
+
Convert MCP tool definitions to Gemini function declarations.
|
| 76 |
+
|
| 77 |
+
Sanitizes schemas to remove unsupported JSON Schema keywords.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
tools: MCP tool definitions
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List of Gemini Tool objects
|
| 84 |
+
"""
|
| 85 |
+
function_declarations = []
|
| 86 |
+
for tool in tools:
|
| 87 |
+
# Sanitize parameters to remove unsupported fields
|
| 88 |
+
sanitized_parameters = self._sanitize_schema_for_gemini(tool["parameters"])
|
| 89 |
+
|
| 90 |
+
function_declarations.append(
|
| 91 |
+
FunctionDeclaration(
|
| 92 |
+
name=tool["name"],
|
| 93 |
+
description=tool["description"],
|
| 94 |
+
parameters=sanitized_parameters
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
logger.debug(f"Sanitized tool schema for Gemini: {tool['name']}")
|
| 99 |
+
|
| 100 |
+
return [Tool(function_declarations=function_declarations)]
|
| 101 |
+
|
| 102 |
+
def _convert_messages_to_gemini_format(self, messages: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]:
|
| 103 |
+
"""
|
| 104 |
+
Convert standard message format to Gemini format.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
messages: Standard message format [{"role": "user", "content": "..."}]
|
| 108 |
+
system_prompt: System instructions
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Gemini-formatted messages
|
| 112 |
+
"""
|
| 113 |
+
gemini_messages = []
|
| 114 |
+
|
| 115 |
+
# Add system prompt as first user message if provided
|
| 116 |
+
if system_prompt:
|
| 117 |
+
gemini_messages.append({
|
| 118 |
+
"role": "user",
|
| 119 |
+
"parts": [{"text": system_prompt}]
|
| 120 |
+
})
|
| 121 |
+
gemini_messages.append({
|
| 122 |
+
"role": "model",
|
| 123 |
+
"parts": [{"text": "Understood. I'll follow these instructions."}]
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
# Convert messages
|
| 127 |
+
for msg in messages:
|
| 128 |
+
role = "user" if msg["role"] == "user" else "model"
|
| 129 |
+
gemini_messages.append({
|
| 130 |
+
"role": role,
|
| 131 |
+
"parts": [{"text": msg["content"]}]
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
return gemini_messages
|
| 135 |
+
|
| 136 |
+
async def generate_response_with_tools(
|
| 137 |
+
self,
|
| 138 |
+
messages: List[Dict[str, str]],
|
| 139 |
+
system_prompt: str,
|
| 140 |
+
tools: List[Dict[str, Any]]
|
| 141 |
+
) -> LLMResponse:
|
| 142 |
+
"""
|
| 143 |
+
Generate a response with function calling support.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
messages: Conversation history
|
| 147 |
+
system_prompt: System instructions
|
| 148 |
+
tools: Tool definitions
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
LLMResponse with content and/or tool_calls
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
# Convert tools to Gemini format
|
| 155 |
+
gemini_tools = self._convert_tools_to_gemini_format(tools)
|
| 156 |
+
|
| 157 |
+
# Convert messages to Gemini format
|
| 158 |
+
gemini_messages = self._convert_messages_to_gemini_format(messages, system_prompt)
|
| 159 |
+
|
| 160 |
+
# Generate response with function calling
|
| 161 |
+
response = self.client.generate_content(
|
| 162 |
+
gemini_messages,
|
| 163 |
+
tools=gemini_tools,
|
| 164 |
+
generation_config={
|
| 165 |
+
"temperature": self.temperature,
|
| 166 |
+
"max_output_tokens": self.max_tokens
|
| 167 |
+
}
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Check if function calls were made
|
| 171 |
+
if response.candidates[0].content.parts:
|
| 172 |
+
first_part = response.candidates[0].content.parts[0]
|
| 173 |
+
|
| 174 |
+
# Check for function call
|
| 175 |
+
if hasattr(first_part, 'function_call') and first_part.function_call:
|
| 176 |
+
function_call = first_part.function_call
|
| 177 |
+
tool_calls = [{
|
| 178 |
+
"name": function_call.name,
|
| 179 |
+
"arguments": dict(function_call.args)
|
| 180 |
+
}]
|
| 181 |
+
logger.info(f"Gemini requested function call: {function_call.name}")
|
| 182 |
+
return LLMResponse(
|
| 183 |
+
content=None,
|
| 184 |
+
tool_calls=tool_calls,
|
| 185 |
+
finish_reason="function_call"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Regular text response
|
| 189 |
+
content = response.text if hasattr(response, 'text') else None
|
| 190 |
+
logger.info("Gemini generated text response")
|
| 191 |
+
return LLMResponse(
|
| 192 |
+
content=content,
|
| 193 |
+
finish_reason="stop"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Gemini API error: {str(e)}")
|
| 198 |
+
raise
|
| 199 |
+
|
| 200 |
+
async def generate_response_with_tool_results(
|
| 201 |
+
self,
|
| 202 |
+
messages: List[Dict[str, str]],
|
| 203 |
+
tool_calls: List[Dict[str, Any]],
|
| 204 |
+
tool_results: List[Dict[str, Any]]
|
| 205 |
+
) -> LLMResponse:
|
| 206 |
+
"""
|
| 207 |
+
Generate a final response after tool execution.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
messages: Original conversation history
|
| 211 |
+
tool_calls: Tool calls that were made
|
| 212 |
+
tool_results: Results from tool execution
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
LLMResponse with final content
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
# Format tool results as a message
|
| 219 |
+
tool_results_text = "\n\n".join([
|
| 220 |
+
f"Tool: {call['name']}\nResult: {result}"
|
| 221 |
+
for call, result in zip(tool_calls, tool_results)
|
| 222 |
+
])
|
| 223 |
+
|
| 224 |
+
# Add tool results to messages
|
| 225 |
+
messages_with_results = messages + [
|
| 226 |
+
{"role": "assistant", "content": f"I called the following tools:\n{tool_results_text}"},
|
| 227 |
+
{"role": "user", "content": "Based on these tool results, provide a natural language response to the user."}
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
# Generate final response
|
| 231 |
+
gemini_messages = self._convert_messages_to_gemini_format(messages_with_results, "")
|
| 232 |
+
response = self.client.generate_content(
|
| 233 |
+
gemini_messages,
|
| 234 |
+
generation_config={
|
| 235 |
+
"temperature": self.temperature,
|
| 236 |
+
"max_output_tokens": self.max_tokens
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
content = response.text if hasattr(response, 'text') else None
|
| 241 |
+
logger.info("Gemini generated final response after tool execution")
|
| 242 |
+
return LLMResponse(
|
| 243 |
+
content=content,
|
| 244 |
+
finish_reason="stop"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.error(f"Gemini API error in tool results: {str(e)}")
|
| 249 |
+
raise
|
| 250 |
+
|
| 251 |
+
async def generate_simple_response(
|
| 252 |
+
self,
|
| 253 |
+
messages: List[Dict[str, str]],
|
| 254 |
+
system_prompt: str
|
| 255 |
+
) -> LLMResponse:
|
| 256 |
+
"""
|
| 257 |
+
Generate a simple response without function calling.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
messages: Conversation history
|
| 261 |
+
system_prompt: System instructions
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
LLMResponse with content
|
| 265 |
+
"""
|
| 266 |
+
try:
|
| 267 |
+
gemini_messages = self._convert_messages_to_gemini_format(messages, system_prompt)
|
| 268 |
+
response = self.client.generate_content(
|
| 269 |
+
gemini_messages,
|
| 270 |
+
generation_config={
|
| 271 |
+
"temperature": self.temperature,
|
| 272 |
+
"max_output_tokens": self.max_tokens
|
| 273 |
+
}
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
content = response.text if hasattr(response, 'text') else None
|
| 277 |
+
logger.info("Gemini generated simple response")
|
| 278 |
+
return LLMResponse(
|
| 279 |
+
content=content,
|
| 280 |
+
finish_reason="stop"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
logger.error(f"Gemini API error: {str(e)}")
|
| 285 |
+
raise
|
src/agent/providers/openrouter.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenRouter Provider Implementation
|
| 3 |
+
|
| 4 |
+
OpenRouter API provider with function calling support.
|
| 5 |
+
Fallback provider for when Gemini rate limits are exceeded.
|
| 6 |
+
Uses OpenAI-compatible API format.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import List, Dict, Any
|
| 11 |
+
import httpx
|
| 12 |
+
|
| 13 |
+
from .base import LLMProvider, LLMResponse
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OpenRouterProvider(LLMProvider):
|
| 19 |
+
"""
|
| 20 |
+
OpenRouter API provider implementation.
|
| 21 |
+
|
| 22 |
+
Features:
|
| 23 |
+
- OpenAI-compatible API
|
| 24 |
+
- Access to multiple free models
|
| 25 |
+
- Function calling support
|
| 26 |
+
- Recommended free model: google/gemini-flash-1.5
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
api_key: str,
|
| 32 |
+
model: str = "google/gemini-flash-1.5",
|
| 33 |
+
temperature: float = 0.7,
|
| 34 |
+
max_tokens: int = 8192
|
| 35 |
+
):
|
| 36 |
+
super().__init__(api_key, model, temperature, max_tokens)
|
| 37 |
+
self.base_url = "https://openrouter.ai/api/v1"
|
| 38 |
+
self.headers = {
|
| 39 |
+
"Authorization": f"Bearer {api_key}",
|
| 40 |
+
"Content-Type": "application/json"
|
| 41 |
+
}
|
| 42 |
+
logger.info(f"Initialized OpenRouterProvider with model: {model}")
|
| 43 |
+
|
| 44 |
+
def _convert_tools_to_openai_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 45 |
+
"""
|
| 46 |
+
Convert MCP tool definitions to OpenAI function format.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
tools: MCP tool definitions
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
List of OpenAI-formatted function definitions
|
| 53 |
+
"""
|
| 54 |
+
return [
|
| 55 |
+
{
|
| 56 |
+
"type": "function",
|
| 57 |
+
"function": {
|
| 58 |
+
"name": tool["name"],
|
| 59 |
+
"description": tool["description"],
|
| 60 |
+
"parameters": tool["parameters"]
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
for tool in tools
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
async def generate_response_with_tools(
|
| 67 |
+
self,
|
| 68 |
+
messages: List[Dict[str, str]],
|
| 69 |
+
system_prompt: str,
|
| 70 |
+
tools: List[Dict[str, Any]]
|
| 71 |
+
) -> LLMResponse:
|
| 72 |
+
"""
|
| 73 |
+
Generate a response with function calling support.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
messages: Conversation history
|
| 77 |
+
system_prompt: System instructions
|
| 78 |
+
tools: Tool definitions
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
LLMResponse with content and/or tool_calls
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
# Prepare messages with system prompt
|
| 85 |
+
formatted_messages = [{"role": "system", "content": system_prompt}] + messages
|
| 86 |
+
|
| 87 |
+
# Convert tools to OpenAI format
|
| 88 |
+
openai_tools = self._convert_tools_to_openai_format(tools)
|
| 89 |
+
|
| 90 |
+
# Make API request
|
| 91 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 92 |
+
response = await client.post(
|
| 93 |
+
f"{self.base_url}/chat/completions",
|
| 94 |
+
headers=self.headers,
|
| 95 |
+
json={
|
| 96 |
+
"model": self.model,
|
| 97 |
+
"messages": formatted_messages,
|
| 98 |
+
"tools": openai_tools,
|
| 99 |
+
"temperature": self.temperature,
|
| 100 |
+
"max_tokens": self.max_tokens
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
response.raise_for_status()
|
| 104 |
+
data = response.json()
|
| 105 |
+
|
| 106 |
+
# Parse response
|
| 107 |
+
choice = data["choices"][0]
|
| 108 |
+
message = choice["message"]
|
| 109 |
+
|
| 110 |
+
# Check for function calls
|
| 111 |
+
if "tool_calls" in message and message["tool_calls"]:
|
| 112 |
+
tool_calls = [
|
| 113 |
+
{
|
| 114 |
+
"name": tc["function"]["name"],
|
| 115 |
+
"arguments": tc["function"]["arguments"]
|
| 116 |
+
}
|
| 117 |
+
for tc in message["tool_calls"]
|
| 118 |
+
]
|
| 119 |
+
logger.info(f"OpenRouter requested function calls: {[tc['name'] for tc in tool_calls]}")
|
| 120 |
+
return LLMResponse(
|
| 121 |
+
content=None,
|
| 122 |
+
tool_calls=tool_calls,
|
| 123 |
+
finish_reason=choice.get("finish_reason", "function_call")
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Regular text response
|
| 127 |
+
content = message.get("content")
|
| 128 |
+
logger.info("OpenRouter generated text response")
|
| 129 |
+
return LLMResponse(
|
| 130 |
+
content=content,
|
| 131 |
+
finish_reason=choice.get("finish_reason", "stop")
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
except httpx.HTTPStatusError as e:
|
| 135 |
+
logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}")
|
| 136 |
+
raise
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"OpenRouter API error: {str(e)}")
|
| 139 |
+
raise
|
| 140 |
+
|
| 141 |
+
async def generate_response_with_tool_results(
|
| 142 |
+
self,
|
| 143 |
+
messages: List[Dict[str, str]],
|
| 144 |
+
tool_calls: List[Dict[str, Any]],
|
| 145 |
+
tool_results: List[Dict[str, Any]]
|
| 146 |
+
) -> LLMResponse:
|
| 147 |
+
"""
|
| 148 |
+
Generate a final response after tool execution.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
messages: Original conversation history
|
| 152 |
+
tool_calls: Tool calls that were made
|
| 153 |
+
tool_results: Results from tool execution
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
LLMResponse with final content
|
| 157 |
+
"""
|
| 158 |
+
try:
|
| 159 |
+
# Format tool results as messages
|
| 160 |
+
messages_with_results = messages.copy()
|
| 161 |
+
|
| 162 |
+
# Add assistant message with tool calls
|
| 163 |
+
messages_with_results.append({
|
| 164 |
+
"role": "assistant",
|
| 165 |
+
"content": None,
|
| 166 |
+
"tool_calls": [
|
| 167 |
+
{
|
| 168 |
+
"id": f"call_{i}",
|
| 169 |
+
"type": "function",
|
| 170 |
+
"function": {
|
| 171 |
+
"name": call["name"],
|
| 172 |
+
"arguments": str(call["arguments"])
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
for i, call in enumerate(tool_calls)
|
| 176 |
+
]
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
+
# Add tool result messages
|
| 180 |
+
for i, (call, result) in enumerate(zip(tool_calls, tool_results)):
|
| 181 |
+
messages_with_results.append({
|
| 182 |
+
"role": "tool",
|
| 183 |
+
"tool_call_id": f"call_{i}",
|
| 184 |
+
"content": str(result)
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
# Generate final response
|
| 188 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 189 |
+
response = await client.post(
|
| 190 |
+
f"{self.base_url}/chat/completions",
|
| 191 |
+
headers=self.headers,
|
| 192 |
+
json={
|
| 193 |
+
"model": self.model,
|
| 194 |
+
"messages": messages_with_results,
|
| 195 |
+
"temperature": self.temperature,
|
| 196 |
+
"max_tokens": self.max_tokens
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
response.raise_for_status()
|
| 200 |
+
data = response.json()
|
| 201 |
+
|
| 202 |
+
choice = data["choices"][0]
|
| 203 |
+
content = choice["message"].get("content")
|
| 204 |
+
logger.info("OpenRouter generated final response after tool execution")
|
| 205 |
+
return LLMResponse(
|
| 206 |
+
content=content,
|
| 207 |
+
finish_reason=choice.get("finish_reason", "stop")
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
except httpx.HTTPStatusError as e:
|
| 211 |
+
logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}")
|
| 212 |
+
raise
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error(f"OpenRouter API error in tool results: {str(e)}")
|
| 215 |
+
raise
|
| 216 |
+
|
| 217 |
+
async def generate_simple_response(
|
| 218 |
+
self,
|
| 219 |
+
messages: List[Dict[str, str]],
|
| 220 |
+
system_prompt: str
|
| 221 |
+
) -> LLMResponse:
|
| 222 |
+
"""
|
| 223 |
+
Generate a simple response without function calling.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
messages: Conversation history
|
| 227 |
+
system_prompt: System instructions
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
LLMResponse with content
|
| 231 |
+
"""
|
| 232 |
+
try:
|
| 233 |
+
# Prepare messages with system prompt
|
| 234 |
+
formatted_messages = [{"role": "system", "content": system_prompt}] + messages
|
| 235 |
+
|
| 236 |
+
# Make API request
|
| 237 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 238 |
+
response = await client.post(
|
| 239 |
+
f"{self.base_url}/chat/completions",
|
| 240 |
+
headers=self.headers,
|
| 241 |
+
json={
|
| 242 |
+
"model": self.model,
|
| 243 |
+
"messages": formatted_messages,
|
| 244 |
+
"temperature": self.temperature,
|
| 245 |
+
"max_tokens": self.max_tokens
|
| 246 |
+
}
|
| 247 |
+
)
|
| 248 |
+
response.raise_for_status()
|
| 249 |
+
data = response.json()
|
| 250 |
+
|
| 251 |
+
choice = data["choices"][0]
|
| 252 |
+
content = choice["message"].get("content")
|
| 253 |
+
logger.info("OpenRouter generated simple response")
|
| 254 |
+
return LLMResponse(
|
| 255 |
+
content=content,
|
| 256 |
+
finish_reason=choice.get("finish_reason", "stop")
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
except httpx.HTTPStatusError as e:
|
| 260 |
+
logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}")
|
| 261 |
+
raise
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.error(f"OpenRouter API error: {str(e)}")
|
| 264 |
+
raise
|
src/api/routes/__pycache__/chat.cpython-313.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/api/routes/__pycache__/conversations.cpython-313.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
src/api/routes/chat.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chat API endpoint for AI chatbot."""
|
| 2 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 3 |
+
from sqlmodel import Session
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
import logging
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from src.core.database import get_session
|
| 9 |
+
from src.core.security import get_current_user
|
| 10 |
+
from src.core.config import settings
|
| 11 |
+
from src.schemas.chat_request import ChatRequest
|
| 12 |
+
from src.schemas.chat_response import ChatResponse
|
| 13 |
+
from src.services.conversation_service import ConversationService
|
| 14 |
+
from src.agent.agent_config import AgentConfiguration
|
| 15 |
+
from src.agent.agent_runner import AgentRunner
|
| 16 |
+
from src.mcp import tool_registry
|
| 17 |
+
from src.core.exceptions import (
|
| 18 |
+
classify_ai_error,
|
| 19 |
+
APIKeyMissingException,
|
| 20 |
+
APIKeyInvalidException
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Configure logging
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
router = APIRouter(prefix="/api", tags=["chat"])
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_conversation_title(first_user_message: str) -> str:
|
| 31 |
+
"""Generate a conversation title from the first user message.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
first_user_message: The first message from the user
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
A title string (max 50 characters)
|
| 38 |
+
"""
|
| 39 |
+
# Remove leading/trailing whitespace
|
| 40 |
+
message = first_user_message.strip()
|
| 41 |
+
|
| 42 |
+
# Try to extract the first sentence or first 50 characters
|
| 43 |
+
# Split by common sentence endings
|
| 44 |
+
for delimiter in ['. ', '! ', '? ', '\n']:
|
| 45 |
+
if delimiter in message:
|
| 46 |
+
title = message.split(delimiter)[0]
|
| 47 |
+
break
|
| 48 |
+
else:
|
| 49 |
+
# No sentence delimiter found, use first 50 chars
|
| 50 |
+
title = message[:50]
|
| 51 |
+
|
| 52 |
+
# If title is too short (less than 10 chars), use timestamp-based default
|
| 53 |
+
if len(title) < 10:
|
| 54 |
+
return f"Chat {datetime.now().strftime('%b %d, %I:%M %p')}"
|
| 55 |
+
|
| 56 |
+
# Truncate to 50 characters and add ellipsis if needed
|
| 57 |
+
if len(title) > 50:
|
| 58 |
+
title = title[:47] + "..."
|
| 59 |
+
|
| 60 |
+
return title
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.post("/{user_id}/chat", response_model=ChatResponse)
|
| 64 |
+
async def chat(
|
| 65 |
+
user_id: int,
|
| 66 |
+
request: ChatRequest,
|
| 67 |
+
db: Session = Depends(get_session),
|
| 68 |
+
current_user: Dict[str, Any] = Depends(get_current_user)
|
| 69 |
+
) -> ChatResponse:
|
| 70 |
+
"""Handle chat messages from users.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
user_id: ID of the user sending the message
|
| 74 |
+
request: ChatRequest containing the user's message
|
| 75 |
+
db: Database session
|
| 76 |
+
current_user: Authenticated user from JWT token
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
ChatResponse containing the AI's response
|
| 80 |
+
|
| 81 |
+
Raises:
|
| 82 |
+
HTTPException 401: If user is not authenticated or user_id doesn't match
|
| 83 |
+
HTTPException 404: If conversation_id is provided but not found
|
| 84 |
+
HTTPException 500: If AI provider fails to generate response
|
| 85 |
+
"""
|
| 86 |
+
# Verify user authorization
|
| 87 |
+
if current_user["id"] != user_id:
|
| 88 |
+
raise HTTPException(
|
| 89 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 90 |
+
detail="Not authorized to access this user's chat"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Validate request message length
|
| 95 |
+
if not request.message or len(request.message.strip()) == 0:
|
| 96 |
+
raise HTTPException(
|
| 97 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 98 |
+
detail="Message cannot be empty"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if len(request.message) > 10000:
|
| 102 |
+
raise HTTPException(
|
| 103 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 104 |
+
detail="Message exceeds maximum length of 10,000 characters"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Initialize services
|
| 108 |
+
conversation_service = ConversationService(db)
|
| 109 |
+
|
| 110 |
+
# Initialize agent configuration from settings
|
| 111 |
+
try:
|
| 112 |
+
agent_config = AgentConfiguration(
|
| 113 |
+
provider=settings.LLM_PROVIDER,
|
| 114 |
+
fallback_provider=settings.FALLBACK_PROVIDER,
|
| 115 |
+
gemini_api_key=settings.GEMINI_API_KEY,
|
| 116 |
+
openrouter_api_key=settings.OPENROUTER_API_KEY,
|
| 117 |
+
cohere_api_key=settings.COHERE_API_KEY,
|
| 118 |
+
temperature=settings.AGENT_TEMPERATURE,
|
| 119 |
+
max_tokens=settings.AGENT_MAX_TOKENS,
|
| 120 |
+
max_messages=settings.CONVERSATION_MAX_MESSAGES,
|
| 121 |
+
max_conversation_tokens=settings.CONVERSATION_MAX_TOKENS
|
| 122 |
+
)
|
| 123 |
+
agent_config.validate()
|
| 124 |
+
|
| 125 |
+
# Create agent runner with tool registry
|
| 126 |
+
agent_runner = AgentRunner(agent_config, tool_registry)
|
| 127 |
+
except ValueError as e:
|
| 128 |
+
logger.error(f"Agent initialization failed: {str(e)}")
|
| 129 |
+
# Check if it's an API key issue
|
| 130 |
+
error_msg = str(e).lower()
|
| 131 |
+
if "api key" in error_msg:
|
| 132 |
+
if "not found" in error_msg or "missing" in error_msg:
|
| 133 |
+
raise APIKeyMissingException(provider=settings.LLM_PROVIDER)
|
| 134 |
+
elif "invalid" in error_msg:
|
| 135 |
+
raise APIKeyInvalidException(provider=settings.LLM_PROVIDER)
|
| 136 |
+
raise HTTPException(
|
| 137 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 138 |
+
detail="AI service is not properly configured. Please contact support."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Get or create conversation
|
| 142 |
+
is_new_conversation = False
|
| 143 |
+
if request.conversation_id:
|
| 144 |
+
conversation = conversation_service.get_conversation(
|
| 145 |
+
request.conversation_id,
|
| 146 |
+
user_id
|
| 147 |
+
)
|
| 148 |
+
if not conversation:
|
| 149 |
+
raise HTTPException(
|
| 150 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 151 |
+
detail=f"Conversation {request.conversation_id} not found or you don't have access to it"
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
# Create new conversation with auto-generated title
|
| 155 |
+
try:
|
| 156 |
+
# Generate title from first user message
|
| 157 |
+
title = generate_conversation_title(request.message)
|
| 158 |
+
conversation = conversation_service.create_conversation(
|
| 159 |
+
user_id=user_id,
|
| 160 |
+
title=title
|
| 161 |
+
)
|
| 162 |
+
is_new_conversation = True
|
| 163 |
+
logger.info(f"Created new conversation {conversation.id} with title: {title}")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"Failed to create conversation: {str(e)}")
|
| 166 |
+
raise HTTPException(
|
| 167 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 168 |
+
detail="Failed to create conversation. Please try again."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Add user message to conversation
|
| 172 |
+
try:
|
| 173 |
+
user_message = conversation_service.add_message(
|
| 174 |
+
conversation_id=conversation.id,
|
| 175 |
+
role="user",
|
| 176 |
+
content=request.message,
|
| 177 |
+
token_count=len(request.message) // 4 # Rough token estimate
|
| 178 |
+
)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"Failed to save user message: {str(e)}")
|
| 181 |
+
raise HTTPException(
|
| 182 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 183 |
+
detail="Failed to save your message. Please try again."
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Get conversation history and format for agent
|
| 187 |
+
history_messages = conversation_service.get_conversation_messages(
|
| 188 |
+
conversation_id=conversation.id
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Format messages for agent with trimming
|
| 192 |
+
formatted_messages = conversation_service.format_messages_for_agent(
|
| 193 |
+
messages=history_messages,
|
| 194 |
+
max_messages=agent_config.max_messages,
|
| 195 |
+
max_tokens=agent_config.max_conversation_tokens
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Generate AI response with tool calling support
|
| 199 |
+
system_prompt = request.system_prompt or agent_config.system_prompt
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
agent_result = await agent_runner.execute(
|
| 203 |
+
messages=formatted_messages,
|
| 204 |
+
user_id=user_id, # Inject user context for security
|
| 205 |
+
system_prompt=system_prompt
|
| 206 |
+
)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
# Use classify_ai_error to determine the appropriate exception
|
| 209 |
+
logger.error(f"AI service error for user {user_id}: {str(e)}")
|
| 210 |
+
provider = agent_result.get("provider") if 'agent_result' in locals() else settings.LLM_PROVIDER
|
| 211 |
+
raise classify_ai_error(e, provider=provider)
|
| 212 |
+
|
| 213 |
+
# Add AI response to conversation with tool call metadata
|
| 214 |
+
try:
|
| 215 |
+
# Prepare metadata if tools were used
|
| 216 |
+
tool_metadata = None
|
| 217 |
+
if agent_result.get("tool_calls"):
|
| 218 |
+
# Convert ToolExecutionResult objects to dicts for JSON serialization
|
| 219 |
+
tool_results = agent_result.get("tool_results", [])
|
| 220 |
+
serializable_results = []
|
| 221 |
+
for result in tool_results:
|
| 222 |
+
if hasattr(result, '__dict__'):
|
| 223 |
+
# Convert dataclass/object to dict
|
| 224 |
+
serializable_results.append({
|
| 225 |
+
"success": result.success,
|
| 226 |
+
"data": result.data,
|
| 227 |
+
"message": result.message,
|
| 228 |
+
"error": result.error
|
| 229 |
+
})
|
| 230 |
+
else:
|
| 231 |
+
# Already a dict
|
| 232 |
+
serializable_results.append(result)
|
| 233 |
+
|
| 234 |
+
tool_metadata = {
|
| 235 |
+
"tool_calls": agent_result["tool_calls"],
|
| 236 |
+
"tool_results": serializable_results,
|
| 237 |
+
"provider": agent_result.get("provider")
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
assistant_message = conversation_service.add_message(
|
| 241 |
+
conversation_id=conversation.id,
|
| 242 |
+
role="assistant",
|
| 243 |
+
content=agent_result["content"],
|
| 244 |
+
token_count=len(agent_result["content"]) // 4 # Rough token estimate
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Update tool_metadata if tools were used
|
| 248 |
+
if tool_metadata:
|
| 249 |
+
assistant_message.tool_metadata = tool_metadata
|
| 250 |
+
db.add(assistant_message)
|
| 251 |
+
db.commit()
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(f"Failed to save AI response: {str(e)}")
|
| 254 |
+
# Still return the response even if saving fails
|
| 255 |
+
# User gets the response but it won't be in history
|
| 256 |
+
logger.warning(f"Returning response without saving to database for conversation {conversation.id}")
|
| 257 |
+
|
| 258 |
+
# Log tool usage if any
|
| 259 |
+
if agent_result.get("tool_calls"):
|
| 260 |
+
logger.info(f"Agent used {len(agent_result['tool_calls'])} tools for user {user_id}")
|
| 261 |
+
|
| 262 |
+
# Return response
|
| 263 |
+
return ChatResponse(
|
| 264 |
+
conversation_id=conversation.id,
|
| 265 |
+
message=agent_result["content"],
|
| 266 |
+
role="assistant",
|
| 267 |
+
timestamp=assistant_message.timestamp if 'assistant_message' in locals() else user_message.timestamp,
|
| 268 |
+
token_count=len(agent_result["content"]) // 4,
|
| 269 |
+
model=agent_result.get("provider")
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
except HTTPException:
|
| 273 |
+
# Re-raise HTTP exceptions
|
| 274 |
+
raise
|
| 275 |
+
except Exception as e:
|
| 276 |
+
# Catch-all for unexpected errors
|
| 277 |
+
logger.exception(f"Unexpected error in chat endpoint: {str(e)}")
|
| 278 |
+
raise HTTPException(
|
| 279 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 280 |
+
detail="An unexpected error occurred. Please try again later."
|
| 281 |
+
)
|
src/api/routes/conversations.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conversations API endpoints for managing chat conversations."""
|
| 2 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
| 3 |
+
from sqlmodel import Session
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from src.core.database import get_session
|
| 8 |
+
from src.core.security import get_current_user
|
| 9 |
+
from src.services.conversation_service import ConversationService
|
| 10 |
+
from src.schemas.conversation import (
|
| 11 |
+
ConversationListResponse,
|
| 12 |
+
ConversationSummary,
|
| 13 |
+
MessageListResponse,
|
| 14 |
+
MessageResponse,
|
| 15 |
+
UpdateConversationRequest,
|
| 16 |
+
UpdateConversationResponse
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
router = APIRouter(prefix="/api", tags=["conversations"])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@router.get("/{user_id}/conversations", response_model=ConversationListResponse)
|
| 26 |
+
async def list_conversations(
|
| 27 |
+
user_id: int,
|
| 28 |
+
limit: int = Query(50, ge=1, le=100, description="Maximum number of conversations to return"),
|
| 29 |
+
db: Session = Depends(get_session),
|
| 30 |
+
current_user: Dict[str, Any] = Depends(get_current_user)
|
| 31 |
+
) -> ConversationListResponse:
|
| 32 |
+
"""List all conversations for a user.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
user_id: ID of the user
|
| 36 |
+
limit: Maximum number of conversations to return (default: 50, max: 100)
|
| 37 |
+
db: Database session
|
| 38 |
+
current_user: Authenticated user from JWT token
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
ConversationListResponse with list of conversations
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
HTTPException 401: If user is not authenticated or user_id doesn't match
|
| 45 |
+
"""
|
| 46 |
+
# Verify user authorization
|
| 47 |
+
if current_user["id"] != user_id:
|
| 48 |
+
raise HTTPException(
|
| 49 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 50 |
+
detail="Not authorized to access this user's conversations"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
conversation_service = ConversationService(db)
|
| 55 |
+
conversations = conversation_service.get_user_conversations(user_id, limit=limit)
|
| 56 |
+
|
| 57 |
+
# Build conversation summaries with message count and preview
|
| 58 |
+
summaries: List[ConversationSummary] = []
|
| 59 |
+
for conv in conversations:
|
| 60 |
+
# Get messages for this conversation
|
| 61 |
+
messages = conversation_service.get_conversation_messages(conv.id)
|
| 62 |
+
message_count = len(messages)
|
| 63 |
+
|
| 64 |
+
# Get last message preview
|
| 65 |
+
last_message_preview = None
|
| 66 |
+
if messages:
|
| 67 |
+
last_msg = messages[-1]
|
| 68 |
+
# Take first 100 characters of the last message
|
| 69 |
+
last_message_preview = last_msg.content[:100]
|
| 70 |
+
if len(last_msg.content) > 100:
|
| 71 |
+
last_message_preview += "..."
|
| 72 |
+
|
| 73 |
+
summaries.append(ConversationSummary(
|
| 74 |
+
id=conv.id,
|
| 75 |
+
title=conv.title,
|
| 76 |
+
created_at=conv.created_at,
|
| 77 |
+
updated_at=conv.updated_at,
|
| 78 |
+
message_count=message_count,
|
| 79 |
+
last_message_preview=last_message_preview
|
| 80 |
+
))
|
| 81 |
+
|
| 82 |
+
return ConversationListResponse(
|
| 83 |
+
conversations=summaries,
|
| 84 |
+
total=len(summaries)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.exception(f"Failed to list conversations for user {user_id}: {str(e)}")
|
| 89 |
+
raise HTTPException(
|
| 90 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 91 |
+
detail="Failed to retrieve conversations. Please try again."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@router.get("/{user_id}/conversations/{conversation_id}/messages", response_model=MessageListResponse)
|
| 96 |
+
async def get_conversation_messages(
|
| 97 |
+
user_id: int,
|
| 98 |
+
conversation_id: int,
|
| 99 |
+
offset: int = Query(0, ge=0, description="Number of messages to skip"),
|
| 100 |
+
limit: int = Query(50, ge=1, le=200, description="Maximum number of messages to return"),
|
| 101 |
+
db: Session = Depends(get_session),
|
| 102 |
+
current_user: Dict[str, Any] = Depends(get_current_user)
|
| 103 |
+
) -> MessageListResponse:
|
| 104 |
+
"""Get message history for a conversation.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
user_id: ID of the user
|
| 108 |
+
conversation_id: ID of the conversation
|
| 109 |
+
offset: Number of messages to skip (for pagination)
|
| 110 |
+
limit: Maximum number of messages to return (default: 50, max: 200)
|
| 111 |
+
db: Database session
|
| 112 |
+
current_user: Authenticated user from JWT token
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
MessageListResponse with list of messages
|
| 116 |
+
|
| 117 |
+
Raises:
|
| 118 |
+
HTTPException 401: If user is not authenticated or user_id doesn't match
|
| 119 |
+
HTTPException 404: If conversation not found or user doesn't have access
|
| 120 |
+
"""
|
| 121 |
+
# Verify user authorization
|
| 122 |
+
if current_user["id"] != user_id:
|
| 123 |
+
raise HTTPException(
|
| 124 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 125 |
+
detail="Not authorized to access this user's conversations"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
conversation_service = ConversationService(db)
|
| 130 |
+
|
| 131 |
+
# Verify conversation exists and belongs to user
|
| 132 |
+
conversation = conversation_service.get_conversation(conversation_id, user_id)
|
| 133 |
+
if not conversation:
|
| 134 |
+
raise HTTPException(
|
| 135 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 136 |
+
detail=f"Conversation {conversation_id} not found or you don't have access to it"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Get all messages (we'll handle pagination manually)
|
| 140 |
+
all_messages = conversation_service.get_conversation_messages(conversation_id)
|
| 141 |
+
total = len(all_messages)
|
| 142 |
+
|
| 143 |
+
# Apply pagination
|
| 144 |
+
paginated_messages = all_messages[offset:offset + limit]
|
| 145 |
+
|
| 146 |
+
# Convert to response format
|
| 147 |
+
message_responses = [
|
| 148 |
+
MessageResponse(
|
| 149 |
+
id=msg.id,
|
| 150 |
+
role=msg.role,
|
| 151 |
+
content=msg.content,
|
| 152 |
+
timestamp=msg.timestamp,
|
| 153 |
+
token_count=msg.token_count
|
| 154 |
+
)
|
| 155 |
+
for msg in paginated_messages
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
return MessageListResponse(
|
| 159 |
+
conversation_id=conversation_id,
|
| 160 |
+
messages=message_responses,
|
| 161 |
+
total=total
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
except HTTPException:
|
| 165 |
+
raise
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.exception(f"Failed to get messages for conversation {conversation_id}: {str(e)}")
|
| 168 |
+
raise HTTPException(
|
| 169 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 170 |
+
detail="Failed to retrieve messages. Please try again."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@router.patch("/{user_id}/conversations/{conversation_id}", response_model=UpdateConversationResponse)
|
| 175 |
+
async def update_conversation(
|
| 176 |
+
user_id: int,
|
| 177 |
+
conversation_id: int,
|
| 178 |
+
request: UpdateConversationRequest,
|
| 179 |
+
db: Session = Depends(get_session),
|
| 180 |
+
current_user: Dict[str, Any] = Depends(get_current_user)
|
| 181 |
+
) -> UpdateConversationResponse:
|
| 182 |
+
"""Update a conversation's title.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
user_id: ID of the user
|
| 186 |
+
conversation_id: ID of the conversation
|
| 187 |
+
request: UpdateConversationRequest with new title
|
| 188 |
+
db: Database session
|
| 189 |
+
current_user: Authenticated user from JWT token
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
UpdateConversationResponse with updated conversation
|
| 193 |
+
|
| 194 |
+
Raises:
|
| 195 |
+
HTTPException 401: If user is not authenticated or user_id doesn't match
|
| 196 |
+
HTTPException 404: If conversation not found or user doesn't have access
|
| 197 |
+
"""
|
| 198 |
+
# Verify user authorization
|
| 199 |
+
if current_user["id"] != user_id:
|
| 200 |
+
raise HTTPException(
|
| 201 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 202 |
+
detail="Not authorized to access this user's conversations"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
conversation_service = ConversationService(db)
|
| 207 |
+
|
| 208 |
+
# Verify conversation exists and belongs to user
|
| 209 |
+
conversation = conversation_service.get_conversation(conversation_id, user_id)
|
| 210 |
+
if not conversation:
|
| 211 |
+
raise HTTPException(
|
| 212 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 213 |
+
detail=f"Conversation {conversation_id} not found or you don't have access to it"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Update the title
|
| 217 |
+
from datetime import datetime
|
| 218 |
+
conversation.title = request.title
|
| 219 |
+
conversation.updated_at = datetime.utcnow()
|
| 220 |
+
db.add(conversation)
|
| 221 |
+
db.commit()
|
| 222 |
+
db.refresh(conversation)
|
| 223 |
+
|
| 224 |
+
return UpdateConversationResponse(
|
| 225 |
+
id=conversation.id,
|
| 226 |
+
title=conversation.title,
|
| 227 |
+
updated_at=conversation.updated_at
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
except HTTPException:
|
| 231 |
+
raise
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.exception(f"Failed to update conversation {conversation_id}: {str(e)}")
|
| 234 |
+
raise HTTPException(
|
| 235 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 236 |
+
detail="Failed to update conversation. Please try again."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@router.delete("/{user_id}/conversations/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
| 241 |
+
async def delete_conversation(
|
| 242 |
+
user_id: int,
|
| 243 |
+
conversation_id: int,
|
| 244 |
+
db: Session = Depends(get_session),
|
| 245 |
+
current_user: Dict[str, Any] = Depends(get_current_user)
|
| 246 |
+
) -> None:
|
| 247 |
+
"""Delete a conversation and all its messages.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
user_id: ID of the user
|
| 251 |
+
conversation_id: ID of the conversation
|
| 252 |
+
db: Database session
|
| 253 |
+
current_user: Authenticated user from JWT token
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
None (204 No Content)
|
| 257 |
+
|
| 258 |
+
Raises:
|
| 259 |
+
HTTPException 401: If user is not authenticated or user_id doesn't match
|
| 260 |
+
HTTPException 404: If conversation not found or user doesn't have access
|
| 261 |
+
"""
|
| 262 |
+
# Verify user authorization
|
| 263 |
+
if current_user["id"] != user_id:
|
| 264 |
+
raise HTTPException(
|
| 265 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 266 |
+
detail="Not authorized to access this user's conversations"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
conversation_service = ConversationService(db)
|
| 271 |
+
|
| 272 |
+
# Delete the conversation (service method handles authorization check)
|
| 273 |
+
deleted = conversation_service.delete_conversation(conversation_id, user_id)
|
| 274 |
+
|
| 275 |
+
if not deleted:
|
| 276 |
+
raise HTTPException(
|
| 277 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 278 |
+
detail=f"Conversation {conversation_id} not found or you don't have access to it"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Return 204 No Content (no response body)
|
| 282 |
+
return None
|
| 283 |
+
|
| 284 |
+
except HTTPException:
|
| 285 |
+
raise
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.exception(f"Failed to delete conversation {conversation_id}: {str(e)}")
|
| 288 |
+
raise HTTPException(
|
| 289 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 290 |
+
detail="Failed to delete conversation. Please try again."
|
| 291 |
+
)
|
src/core/__pycache__/config.cpython-313.pyc
CHANGED
|
Binary files a/src/core/__pycache__/config.cpython-313.pyc and b/src/core/__pycache__/config.cpython-313.pyc differ
|
|
|
src/core/__pycache__/exceptions.cpython-313.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
src/core/__pycache__/security.cpython-313.pyc
CHANGED
|
Binary files a/src/core/__pycache__/security.cpython-313.pyc and b/src/core/__pycache__/security.cpython-313.pyc differ
|
|
|
src/core/config.py
CHANGED
|
@@ -19,6 +19,21 @@ class Settings(BaseSettings):
|
|
| 19 |
JWT_ALGORITHM: str = "HS256"
|
| 20 |
JWT_EXPIRATION_DAYS: int = 7
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class Config:
|
| 23 |
env_file = ".env"
|
| 24 |
case_sensitive = True
|
|
|
|
| 19 |
JWT_ALGORITHM: str = "HS256"
|
| 20 |
JWT_EXPIRATION_DAYS: int = 7
|
| 21 |
|
| 22 |
+
# LLM Provider Configuration
|
| 23 |
+
LLM_PROVIDER: str = "gemini" # Primary provider: gemini, openrouter, cohere
|
| 24 |
+
FALLBACK_PROVIDER: str | None = None # Optional fallback provider
|
| 25 |
+
GEMINI_API_KEY: str | None = None # Required if using Gemini
|
| 26 |
+
OPENROUTER_API_KEY: str | None = None # Required if using OpenRouter
|
| 27 |
+
COHERE_API_KEY: str | None = None # Required if using Cohere
|
| 28 |
+
|
| 29 |
+
# Agent Configuration
|
| 30 |
+
AGENT_TEMPERATURE: float = 0.7 # Sampling temperature (0.0 to 1.0)
|
| 31 |
+
AGENT_MAX_TOKENS: int = 8192 # Maximum tokens in response
|
| 32 |
+
|
| 33 |
+
# Conversation Settings (for free-tier constraints)
|
| 34 |
+
CONVERSATION_MAX_MESSAGES: int = 20 # Maximum messages to keep in history
|
| 35 |
+
CONVERSATION_MAX_TOKENS: int = 8000 # Maximum tokens in conversation history
|
| 36 |
+
|
| 37 |
class Config:
|
| 38 |
env_file = ".env"
|
| 39 |
case_sensitive = True
|
src/core/exceptions.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom exception classes for structured error handling.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from fastapi import HTTPException, status
|
| 7 |
+
from src.schemas.error import ErrorCode
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AIProviderException(HTTPException):
|
| 11 |
+
"""Base exception for AI provider errors."""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
error_code: str,
|
| 16 |
+
detail: str,
|
| 17 |
+
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 18 |
+
provider: Optional[str] = None
|
| 19 |
+
):
|
| 20 |
+
super().__init__(status_code=status_code, detail=detail)
|
| 21 |
+
self.error_code = error_code
|
| 22 |
+
self.source = "AI_PROVIDER"
|
| 23 |
+
self.provider = provider
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RateLimitExceededException(AIProviderException):
|
| 27 |
+
"""Exception raised when AI provider rate limit is exceeded."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, provider: Optional[str] = None):
|
| 30 |
+
super().__init__(
|
| 31 |
+
error_code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
| 32 |
+
detail="AI service rate limit exceeded. Please wait a moment and try again.",
|
| 33 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 34 |
+
provider=provider
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class APIKeyMissingException(AIProviderException):
|
| 39 |
+
"""Exception raised when API key is not configured."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, provider: Optional[str] = None):
|
| 42 |
+
super().__init__(
|
| 43 |
+
error_code=ErrorCode.API_KEY_MISSING,
|
| 44 |
+
detail="AI service is not configured. Please add an API key.",
|
| 45 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 46 |
+
provider=provider
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class APIKeyInvalidException(AIProviderException):
|
| 51 |
+
"""Exception raised when API key is invalid or expired."""
|
| 52 |
+
|
| 53 |
+
def __init__(self, provider: Optional[str] = None):
|
| 54 |
+
super().__init__(
|
| 55 |
+
error_code=ErrorCode.API_KEY_INVALID,
|
| 56 |
+
detail="Your API key is invalid or expired. Please check your configuration.",
|
| 57 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 58 |
+
provider=provider
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ProviderUnavailableException(AIProviderException):
|
| 63 |
+
"""Exception raised when AI provider is temporarily unavailable."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, provider: Optional[str] = None):
|
| 66 |
+
super().__init__(
|
| 67 |
+
error_code=ErrorCode.PROVIDER_UNAVAILABLE,
|
| 68 |
+
detail="AI service is temporarily unavailable. Please try again later.",
|
| 69 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 70 |
+
provider=provider
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ProviderErrorException(AIProviderException):
|
| 75 |
+
"""Exception raised for generic AI provider errors."""
|
| 76 |
+
|
| 77 |
+
def __init__(self, detail: str, provider: Optional[str] = None):
|
| 78 |
+
super().__init__(
|
| 79 |
+
error_code=ErrorCode.PROVIDER_ERROR,
|
| 80 |
+
detail=detail,
|
| 81 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 82 |
+
provider=provider
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def classify_ai_error(error: Exception, provider: Optional[str] = None) -> AIProviderException:
|
| 87 |
+
"""
|
| 88 |
+
Classify an AI provider error and return appropriate exception.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
error: The original exception from the AI provider
|
| 92 |
+
provider: Name of the AI provider (gemini, openrouter, cohere)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Appropriate AIProviderException subclass
|
| 96 |
+
"""
|
| 97 |
+
error_message = str(error).lower()
|
| 98 |
+
|
| 99 |
+
# Rate limit errors
|
| 100 |
+
if any(keyword in error_message for keyword in ["rate limit", "429", "quota exceeded", "too many requests"]):
|
| 101 |
+
return RateLimitExceededException(provider=provider)
|
| 102 |
+
|
| 103 |
+
# API key missing errors
|
| 104 |
+
if any(keyword in error_message for keyword in ["api key not found", "api key is required", "missing api key"]):
|
| 105 |
+
return APIKeyMissingException(provider=provider)
|
| 106 |
+
|
| 107 |
+
# API key invalid errors
|
| 108 |
+
if any(keyword in error_message for keyword in [
|
| 109 |
+
"invalid api key", "api key invalid", "unauthorized", "401",
|
| 110 |
+
"authentication failed", "invalid credentials", "api key expired"
|
| 111 |
+
]):
|
| 112 |
+
return APIKeyInvalidException(provider=provider)
|
| 113 |
+
|
| 114 |
+
# Provider unavailable errors
|
| 115 |
+
if any(keyword in error_message for keyword in [
|
| 116 |
+
"503", "service unavailable", "temporarily unavailable",
|
| 117 |
+
"connection refused", "connection timeout", "timeout"
|
| 118 |
+
]):
|
| 119 |
+
return ProviderUnavailableException(provider=provider)
|
| 120 |
+
|
| 121 |
+
# Generic provider error
|
| 122 |
+
return ProviderErrorException(
|
| 123 |
+
detail=f"AI service error: {str(error)}",
|
| 124 |
+
provider=provider
|
| 125 |
+
)
|
src/core/security.py
CHANGED
|
@@ -111,9 +111,13 @@
|
|
| 111 |
import jwt
|
| 112 |
from datetime import datetime, timedelta
|
| 113 |
from passlib.context import CryptContext
|
| 114 |
-
from fastapi import HTTPException, status
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
| 117 |
import hashlib
|
| 118 |
MAX_BCRYPT_BYTES = 72
|
| 119 |
|
|
@@ -165,3 +169,50 @@ def verify_jwt_token(token: str, secret: str) -> dict:
|
|
| 165 |
raise HTTPException(status_code=401, detail="Token expired")
|
| 166 |
except jwt.InvalidTokenError:
|
| 167 |
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
import jwt
|
| 112 |
from datetime import datetime, timedelta
|
| 113 |
from passlib.context import CryptContext
|
| 114 |
+
from fastapi import HTTPException, status, Depends
|
| 115 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 116 |
+
from typing import Dict, Any
|
| 117 |
+
from src.core.config import settings
|
| 118 |
|
| 119 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 120 |
+
security = HTTPBearer()
|
| 121 |
import hashlib
|
| 122 |
MAX_BCRYPT_BYTES = 72
|
| 123 |
|
|
|
|
| 169 |
raise HTTPException(status_code=401, detail="Token expired")
|
| 170 |
except jwt.InvalidTokenError:
|
| 171 |
raise HTTPException(status_code=401, detail="Invalid token")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_current_user(
|
| 175 |
+
credentials: HTTPAuthorizationCredentials = Depends(security)
|
| 176 |
+
) -> Dict[str, Any]:
|
| 177 |
+
"""
|
| 178 |
+
FastAPI dependency to extract and validate JWT token from Authorization header.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
credentials: HTTP Bearer token credentials from request header
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Dictionary containing user information from token payload:
|
| 185 |
+
- id: User ID (parsed from 'sub' claim)
|
| 186 |
+
- email: User email
|
| 187 |
+
- iat: Token issued at timestamp
|
| 188 |
+
- exp: Token expiration timestamp
|
| 189 |
+
|
| 190 |
+
Raises:
|
| 191 |
+
HTTPException 401: If token is missing, invalid, or expired
|
| 192 |
+
"""
|
| 193 |
+
token = credentials.credentials
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
payload = verify_jwt_token(token, settings.BETTER_AUTH_SECRET)
|
| 197 |
+
|
| 198 |
+
# Extract user ID from 'sub' claim and convert to integer
|
| 199 |
+
user_id = int(payload.get("sub"))
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"id": user_id,
|
| 203 |
+
"email": payload.get("email"),
|
| 204 |
+
"iat": payload.get("iat"),
|
| 205 |
+
"exp": payload.get("exp")
|
| 206 |
+
}
|
| 207 |
+
except ValueError:
|
| 208 |
+
raise HTTPException(
|
| 209 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 210 |
+
detail="Invalid user ID in token",
|
| 211 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 212 |
+
)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
raise HTTPException(
|
| 215 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 216 |
+
detail=f"Authentication failed: {str(e)}",
|
| 217 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 218 |
+
)
|
src/main.py
CHANGED
|
@@ -1,13 +1,76 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
| 3 |
from .core.config import settings
|
| 4 |
-
from .api.routes import tasks, auth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app = FastAPI(
|
| 7 |
title=settings.APP_NAME,
|
| 8 |
debug=settings.DEBUG
|
| 9 |
)
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# Configure CORS
|
| 12 |
app.add_middleware(
|
| 13 |
CORSMiddleware,
|
|
@@ -20,6 +83,8 @@ app.add_middleware(
|
|
| 20 |
# Register routes
|
| 21 |
app.include_router(auth.router)
|
| 22 |
app.include_router(tasks.router)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@app.get("/")
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
import logging
|
| 5 |
from .core.config import settings
|
| 6 |
+
from .api.routes import tasks, auth, chat, conversations
|
| 7 |
+
from .mcp import register_all_tools
|
| 8 |
+
from .core.exceptions import AIProviderException
|
| 9 |
+
from .schemas.error import ErrorResponse
|
| 10 |
+
|
| 11 |
+
# Configure logging
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
app = FastAPI(
|
| 15 |
title=settings.APP_NAME,
|
| 16 |
debug=settings.DEBUG
|
| 17 |
)
|
| 18 |
|
| 19 |
+
|
| 20 |
+
# Global exception handler for AIProviderException
|
| 21 |
+
@app.exception_handler(AIProviderException)
|
| 22 |
+
async def ai_provider_exception_handler(request: Request, exc: AIProviderException):
|
| 23 |
+
"""Convert AIProviderException to structured ErrorResponse."""
|
| 24 |
+
error_response = ErrorResponse(
|
| 25 |
+
error_code=exc.error_code,
|
| 26 |
+
detail=exc.detail,
|
| 27 |
+
source=exc.source,
|
| 28 |
+
provider=exc.provider
|
| 29 |
+
)
|
| 30 |
+
logger.error(
|
| 31 |
+
f"AI Provider Error: {exc.error_code} - {exc.detail} "
|
| 32 |
+
f"(Provider: {exc.provider}, Status: {exc.status_code})"
|
| 33 |
+
)
|
| 34 |
+
return JSONResponse(
|
| 35 |
+
status_code=exc.status_code,
|
| 36 |
+
content=error_response.model_dump()
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Global exception handler for generic HTTPException
|
| 41 |
+
@app.exception_handler(Exception)
|
| 42 |
+
async def generic_exception_handler(request: Request, exc: Exception):
|
| 43 |
+
"""Catch-all exception handler for unexpected errors."""
|
| 44 |
+
# Log the full exception for debugging
|
| 45 |
+
logger.exception(f"Unhandled exception: {str(exc)}")
|
| 46 |
+
|
| 47 |
+
# Return structured error response
|
| 48 |
+
error_response = ErrorResponse(
|
| 49 |
+
error_code="INTERNAL_ERROR",
|
| 50 |
+
detail="An unexpected error occurred. Please try again later.",
|
| 51 |
+
source="INTERNAL"
|
| 52 |
+
)
|
| 53 |
+
return JSONResponse(
|
| 54 |
+
status_code=500,
|
| 55 |
+
content=error_response.model_dump()
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@app.on_event("startup")
|
| 60 |
+
async def startup_event():
|
| 61 |
+
"""Initialize application on startup."""
|
| 62 |
+
logger.info("Starting application initialization...")
|
| 63 |
+
|
| 64 |
+
# Register all MCP tools with the tool registry
|
| 65 |
+
try:
|
| 66 |
+
register_all_tools()
|
| 67 |
+
logger.info("MCP tools registered successfully")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Failed to register MCP tools: {str(e)}")
|
| 70 |
+
raise
|
| 71 |
+
|
| 72 |
+
logger.info("Application initialization complete")
|
| 73 |
+
|
| 74 |
# Configure CORS
|
| 75 |
app.add_middleware(
|
| 76 |
CORSMiddleware,
|
|
|
|
| 83 |
# Register routes
|
| 84 |
app.include_router(auth.router)
|
| 85 |
app.include_router(tasks.router)
|
| 86 |
+
app.include_router(chat.router)
|
| 87 |
+
app.include_router(conversations.router)
|
| 88 |
|
| 89 |
|
| 90 |
@app.get("/")
|
src/mcp/__init__.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Tools Registration
|
| 3 |
+
|
| 4 |
+
Registers all MCP tools with the global tool registry.
|
| 5 |
+
Each tool is registered with its contract definition (name, description, parameters).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from .tool_registry import tool_registry
|
| 13 |
+
from .tools.add_task import add_task
|
| 14 |
+
from .tools.list_tasks import list_tasks
|
| 15 |
+
from .tools.complete_task import complete_task
|
| 16 |
+
from .tools.delete_task import delete_task
|
| 17 |
+
from .tools.update_task import update_task
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_tool_contract(tool_name: str) -> dict:
|
| 23 |
+
"""
|
| 24 |
+
Load tool contract definition from JSON file.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
tool_name: Name of the tool (e.g., "add_task")
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tool contract dictionary
|
| 31 |
+
|
| 32 |
+
Raises:
|
| 33 |
+
FileNotFoundError: If contract file not found
|
| 34 |
+
"""
|
| 35 |
+
# Get the project root directory
|
| 36 |
+
current_file = Path(__file__)
|
| 37 |
+
project_root = current_file.parent.parent.parent # backend/src/mcp -> backend
|
| 38 |
+
contract_path = project_root.parent / "specs" / "001-openai-agent-mcp-tools" / "contracts" / f"{tool_name}.json"
|
| 39 |
+
|
| 40 |
+
if not contract_path.exists():
|
| 41 |
+
raise FileNotFoundError(f"Contract file not found: {contract_path}")
|
| 42 |
+
|
| 43 |
+
with open(contract_path, "r") as f:
|
| 44 |
+
return json.load(f)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def register_all_tools():
|
| 48 |
+
"""
|
| 49 |
+
Register all MCP tools with the global tool registry.
|
| 50 |
+
|
| 51 |
+
This function should be called during application startup to ensure
|
| 52 |
+
all tools are available to the agent.
|
| 53 |
+
"""
|
| 54 |
+
logger.info("Registering MCP tools...")
|
| 55 |
+
|
| 56 |
+
# Register add_task tool
|
| 57 |
+
try:
|
| 58 |
+
add_task_contract = load_tool_contract("add_task")
|
| 59 |
+
tool_registry.register_tool(
|
| 60 |
+
name=add_task_contract["name"],
|
| 61 |
+
description=add_task_contract["description"],
|
| 62 |
+
parameters=add_task_contract["parameters"],
|
| 63 |
+
handler=add_task
|
| 64 |
+
)
|
| 65 |
+
logger.info("Registered tool: add_task")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Failed to register add_task tool: {str(e)}")
|
| 68 |
+
raise
|
| 69 |
+
|
| 70 |
+
# Register list_tasks tool
|
| 71 |
+
try:
|
| 72 |
+
list_tasks_contract = load_tool_contract("list_tasks")
|
| 73 |
+
tool_registry.register_tool(
|
| 74 |
+
name=list_tasks_contract["name"],
|
| 75 |
+
description=list_tasks_contract["description"],
|
| 76 |
+
parameters=list_tasks_contract["parameters"],
|
| 77 |
+
handler=list_tasks
|
| 78 |
+
)
|
| 79 |
+
logger.info("Registered tool: list_tasks")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Failed to register list_tasks tool: {str(e)}")
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
# Register complete_task tool
|
| 85 |
+
try:
|
| 86 |
+
complete_task_contract = load_tool_contract("complete_task")
|
| 87 |
+
tool_registry.register_tool(
|
| 88 |
+
name=complete_task_contract["name"],
|
| 89 |
+
description=complete_task_contract["description"],
|
| 90 |
+
parameters=complete_task_contract["parameters"],
|
| 91 |
+
handler=complete_task
|
| 92 |
+
)
|
| 93 |
+
logger.info("Registered tool: complete_task")
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Failed to register complete_task tool: {str(e)}")
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
# Register delete_task tool
|
| 99 |
+
try:
|
| 100 |
+
delete_task_contract = load_tool_contract("delete_task")
|
| 101 |
+
tool_registry.register_tool(
|
| 102 |
+
name=delete_task_contract["name"],
|
| 103 |
+
description=delete_task_contract["description"],
|
| 104 |
+
parameters=delete_task_contract["parameters"],
|
| 105 |
+
handler=delete_task
|
| 106 |
+
)
|
| 107 |
+
logger.info("Registered tool: delete_task")
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Failed to register delete_task tool: {str(e)}")
|
| 110 |
+
raise
|
| 111 |
+
|
| 112 |
+
# Register update_task tool
|
| 113 |
+
try:
|
| 114 |
+
update_task_contract = load_tool_contract("update_task")
|
| 115 |
+
tool_registry.register_tool(
|
| 116 |
+
name=update_task_contract["name"],
|
| 117 |
+
description=update_task_contract["description"],
|
| 118 |
+
parameters=update_task_contract["parameters"],
|
| 119 |
+
handler=update_task
|
| 120 |
+
)
|
| 121 |
+
logger.info("Registered tool: update_task")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Failed to register update_task tool: {str(e)}")
|
| 124 |
+
raise
|
| 125 |
+
|
| 126 |
+
logger.info(f"Successfully registered {len(tool_registry.list_tools())} MCP tools")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Export the global registry instance for use in other modules
|
| 130 |
+
__all__ = ["tool_registry", "register_all_tools"]
|
src/mcp/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
src/mcp/__pycache__/tool_registry.cpython-313.pyc
ADDED
|
Binary file (5.78 kB). View file
|
|
|
src/mcp/tool_registry.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Tool Registry
|
| 3 |
+
|
| 4 |
+
Manages registration and execution of MCP tools with user context injection.
|
| 5 |
+
Security: user_id is injected by the backend, never trusted from LLM output.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Any, Callable, Optional
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ToolDefinition:
|
| 17 |
+
"""Definition of an MCP tool for LLM function calling."""
|
| 18 |
+
name: str
|
| 19 |
+
description: str
|
| 20 |
+
parameters: Dict[str, Any]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ToolExecutionResult:
|
| 25 |
+
"""Result of executing an MCP tool."""
|
| 26 |
+
success: bool
|
| 27 |
+
data: Optional[Dict[str, Any]] = None
|
| 28 |
+
message: Optional[str] = None
|
| 29 |
+
error: Optional[str] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MCPToolRegistry:
|
| 33 |
+
"""
|
| 34 |
+
Registry for MCP tools with user context injection.
|
| 35 |
+
|
| 36 |
+
This class manages tool registration and execution, ensuring that
|
| 37 |
+
user_id is always injected by the backend for security.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
self._tools: Dict[str, Callable] = {}
|
| 42 |
+
self._tool_definitions: Dict[str, ToolDefinition] = {}
|
| 43 |
+
|
| 44 |
+
def register_tool(
|
| 45 |
+
self,
|
| 46 |
+
name: str,
|
| 47 |
+
description: str,
|
| 48 |
+
parameters: Dict[str, Any],
|
| 49 |
+
handler: Callable
|
| 50 |
+
) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Register an MCP tool with its handler function.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
name: Tool name (e.g., "add_task")
|
| 56 |
+
description: Tool description for LLM
|
| 57 |
+
parameters: JSON schema for tool parameters
|
| 58 |
+
handler: Async function that executes the tool
|
| 59 |
+
"""
|
| 60 |
+
self._tools[name] = handler
|
| 61 |
+
self._tool_definitions[name] = ToolDefinition(
|
| 62 |
+
name=name,
|
| 63 |
+
description=description,
|
| 64 |
+
parameters=parameters
|
| 65 |
+
)
|
| 66 |
+
logger.info(f"Registered MCP tool: {name}")
|
| 67 |
+
|
| 68 |
+
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
| 69 |
+
"""
|
| 70 |
+
Get tool definitions in format suitable for LLM function calling.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List of tool definitions with name, description, and parameters
|
| 74 |
+
"""
|
| 75 |
+
return [
|
| 76 |
+
{
|
| 77 |
+
"name": tool_def.name,
|
| 78 |
+
"description": tool_def.description,
|
| 79 |
+
"parameters": tool_def.parameters
|
| 80 |
+
}
|
| 81 |
+
for tool_def in self._tool_definitions.values()
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
async def execute_tool(
|
| 85 |
+
self,
|
| 86 |
+
tool_name: str,
|
| 87 |
+
arguments: Dict[str, Any],
|
| 88 |
+
user_id: int
|
| 89 |
+
) -> ToolExecutionResult:
|
| 90 |
+
"""
|
| 91 |
+
Execute an MCP tool with user context injection.
|
| 92 |
+
|
| 93 |
+
SECURITY: user_id is injected by the backend, never from LLM output.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
tool_name: Name of the tool to execute
|
| 97 |
+
arguments: Tool arguments from LLM
|
| 98 |
+
user_id: User ID (injected by backend, not from LLM)
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
ToolExecutionResult with success status and data/error
|
| 102 |
+
"""
|
| 103 |
+
if tool_name not in self._tools:
|
| 104 |
+
logger.error(f"Tool not found: {tool_name}")
|
| 105 |
+
return ToolExecutionResult(
|
| 106 |
+
success=False,
|
| 107 |
+
error=f"Tool '{tool_name}' not found"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# Inject user_id into arguments for security
|
| 112 |
+
arguments_with_context = {**arguments, "user_id": user_id}
|
| 113 |
+
|
| 114 |
+
logger.info(f"Executing tool: {tool_name} for user: {user_id}")
|
| 115 |
+
|
| 116 |
+
# Execute the tool handler
|
| 117 |
+
handler = self._tools[tool_name]
|
| 118 |
+
result = await handler(**arguments_with_context)
|
| 119 |
+
|
| 120 |
+
logger.info(f"Tool execution successful: {tool_name}")
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Tool execution failed: {tool_name} - {str(e)}")
|
| 125 |
+
return ToolExecutionResult(
|
| 126 |
+
success=False,
|
| 127 |
+
error=f"Tool execution failed: {str(e)}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def list_tools(self) -> List[str]:
|
| 131 |
+
"""Get list of registered tool names."""
|
| 132 |
+
return list(self._tools.keys())
|
| 133 |
+
|
| 134 |
+
def has_tool(self, tool_name: str) -> bool:
|
| 135 |
+
"""Check if a tool is registered."""
|
| 136 |
+
return tool_name in self._tools
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Global registry instance
|
| 140 |
+
tool_registry = MCPToolRegistry()
|